From fc4cc5385ab1100dd532127f4697e3b72b50c7df Mon Sep 17 00:00:00 2001 From: Hua Date: Sun, 15 Mar 2026 18:21:02 +0800 Subject: [PATCH] fix(channels): restore plugin discovery after merge --- nanobot/channels/manager.py | 71 ++++++++++++++++++++++-------- nanobot/channels/registry.py | 38 +++++++++++++++- tests/test_channel_multi_config.py | 22 +++++++-- 3 files changed, 108 insertions(+), 23 deletions(-) diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 80ec3cc..3d6c50b 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import inspect from typing import Any from loguru import logger @@ -31,18 +32,28 @@ class ChannelManager: self._init_channels() def _init_channels(self) -> None: - """Initialize channels discovered via pkgutil scan.""" - from nanobot.channels.registry import discover_channel_names, load_channel_class + """Initialize channels discovered via pkgutil scan + entry_points plugins.""" + from nanobot.channels.registry import discover_all groq_key = self.config.providers.groq.api_key - for modname in discover_channel_names(): - section = getattr(self.config.channels, modname, None) - if not section or not getattr(section, "enabled", False): + for name, cls in discover_all().items(): + section = getattr(self.config.channels, name, None) + if section is None: + continue + enabled = ( + section.get("enabled", False) + if isinstance(section, dict) + else getattr(section, "enabled", False) + ) + if not enabled: continue try: - cls = load_channel_class(modname) - instances = getattr(section, "instances", None) + instances = ( + section.get("instances") + if isinstance(section, dict) + else getattr(section, "instances", None) + ) if instances is not None: if not instances: logger.warning( @@ -52,18 +63,22 @@ class ChannelManager: continue for inst in instances: - inst_name = getattr(inst, "name", None) + inst_name = ( + inst.get("name") + if isinstance(inst, dict) + else getattr(inst, "name", None) + ) if not inst_name: raise ValueError( - f'{modname}.instances item missing required field "name"' + f'{name}.instances item missing required field "name"' ) # Session keys use "channel:chat_id", so instance names cannot use ":". - channel_name = f"{modname}/{inst_name}" + channel_name = f"{name}/{inst_name}" if channel_name in self.channels: raise ValueError(f"Duplicate channel instance name: {channel_name}") - channel = cls(inst, self.bus) + channel = self._instantiate_channel(cls, inst) channel.name = channel_name channel.transcription_api_key = groq_key self.channels[channel_name] = channel @@ -72,16 +87,36 @@ class ChannelManager: cls.display_name, channel_name, ) - else: - channel = cls(section, self.bus) - channel.transcription_api_key = groq_key - self.channels[modname] = channel - logger.info("{} channel enabled", cls.display_name) - except ImportError as e: - logger.warning("{} channel not available: {}", modname, e) + continue + + channel = self._instantiate_channel(cls, section) + channel.name = name + channel.transcription_api_key = groq_key + self.channels[name] = channel + logger.info("{} channel enabled", cls.display_name) + except Exception as e: + logger.warning("{} channel not available: {}", name, e) self._validate_allow_from() + def _instantiate_channel(self, cls: type[BaseChannel], section: Any) -> BaseChannel: + """Instantiate a channel, passing optional supported kwargs when available.""" + kwargs: dict[str, Any] = {} + try: + params = inspect.signature(cls.__init__).parameters + except (TypeError, ValueError): + params = {} + + tools = getattr(self.config, "tools", None) + if "restrict_to_workspace" in params: + kwargs["restrict_to_workspace"] = bool( + getattr(tools, "restrict_to_workspace", False) + ) + if "workspace" in params: + kwargs["workspace"] = getattr(self.config, "workspace_path", None) + + return cls(section, self.bus, **kwargs) + def _validate_allow_from(self) -> None: for name, ch in self.channels.items(): if getattr(ch.config, "allow_from", None) == []: diff --git a/nanobot/channels/registry.py b/nanobot/channels/registry.py index eb30ff7..bfe6c3e 100644 --- a/nanobot/channels/registry.py +++ b/nanobot/channels/registry.py @@ -6,6 +6,8 @@ import importlib import pkgutil from typing import TYPE_CHECKING +from loguru import logger + if TYPE_CHECKING: from nanobot.channels.base import BaseChannel @@ -13,7 +15,7 @@ _INTERNAL = frozenset({"base", "manager", "registry"}) def discover_channel_names() -> list[str]: - """Return all channel module names by scanning the package (zero imports).""" + """Return all built-in channel module names by scanning the package (zero imports).""" import nanobot.channels as pkg return [ @@ -33,3 +35,37 @@ def load_channel_class(module_name: str) -> type[BaseChannel]: if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base: return obj raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}") + + +def discover_plugins() -> dict[str, type[BaseChannel]]: + """Discover external channel plugins registered via entry_points.""" + from importlib.metadata import entry_points + + plugins: dict[str, type[BaseChannel]] = {} + for ep in entry_points(group="nanobot.channels"): + try: + cls = ep.load() + plugins[ep.name] = cls + except Exception as e: + logger.warning("Failed to load channel plugin '{}': {}", ep.name, e) + return plugins + + +def discover_all() -> dict[str, type[BaseChannel]]: + """Return all channels: built-in (pkgutil) merged with external (entry_points). + + Built-in channels take priority — an external plugin cannot shadow a built-in name. + """ + builtin: dict[str, type[BaseChannel]] = {} + for modname in discover_channel_names(): + try: + builtin[modname] = load_channel_class(modname) + except ImportError as e: + logger.debug("Skipping built-in channel '{}': {}", modname, e) + + external = discover_plugins() + shadowed = set(external) & set(builtin) + if shadowed: + logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed) + + return {**external, **builtin} diff --git a/tests/test_channel_multi_config.py b/tests/test_channel_multi_config.py index 46f2b6b..52f194c 100644 --- a/tests/test_channel_multi_config.py +++ b/tests/test_channel_multi_config.py @@ -46,8 +46,10 @@ class _DummyChannel(BaseChannel): def _patch_registry(monkeypatch: pytest.MonkeyPatch, channel_names: list[str]) -> None: - monkeypatch.setattr("nanobot.channels.registry.discover_channel_names", lambda: channel_names) - monkeypatch.setattr("nanobot.channels.registry.load_channel_class", lambda _: _DummyChannel) + monkeypatch.setattr( + "nanobot.channels.registry.discover_all", + lambda: {name: _DummyChannel for name in channel_names}, + ) @pytest.mark.parametrize( @@ -178,7 +180,7 @@ def test_config_parses_supported_single_instance_channels( @pytest.mark.parametrize( - ("field_name", "payload", "expected_cls", "attr_name", "attr_value"), + ("field_name", "payload", "expected_cls", "expected_names", "attr_name", "attr_value"), [ ( "whatsapp", @@ -190,6 +192,7 @@ def test_config_parses_supported_single_instance_channels( ], }, WhatsAppMultiConfig, + ["main", "backup"], "bridge_url", "ws://127.0.0.1:3002", ), @@ -203,6 +206,7 @@ def test_config_parses_supported_single_instance_channels( ], }, TelegramMultiConfig, + ["main", "backup"], "token", "tg-backup", ), @@ -216,6 +220,7 @@ def test_config_parses_supported_single_instance_channels( ], }, DiscordMultiConfig, + ["main", "backup"], "token", "dc-backup", ), @@ -234,6 +239,7 @@ def test_config_parses_supported_single_instance_channels( ], }, FeishuMultiConfig, + ["main", "backup"], "app_id", "fs-backup", ), @@ -257,6 +263,7 @@ def test_config_parses_supported_single_instance_channels( ], }, DingTalkMultiConfig, + ["main", "backup"], "client_id", "dt-backup", ), @@ -282,6 +289,7 @@ def test_config_parses_supported_single_instance_channels( ], }, MatrixMultiConfig, + ["main", "backup"], "homeserver", "https://matrix-2.example.com", ), @@ -305,6 +313,7 @@ def test_config_parses_supported_single_instance_channels( ], }, EmailMultiConfig, + ["work", "home"], "imap_host", "imap.home", ), @@ -328,6 +337,7 @@ def test_config_parses_supported_single_instance_channels( ], }, MochatMultiConfig, + ["main", "backup"], "claw_token", "claw-backup", ), @@ -351,6 +361,7 @@ def test_config_parses_supported_single_instance_channels( ], }, SlackMultiConfig, + ["main", "backup"], "bot_token", "xoxb-backup", ), @@ -369,6 +380,7 @@ def test_config_parses_supported_single_instance_channels( ], }, QQMultiConfig, + ["main", "backup"], "app_id", "qq-backup", ), @@ -387,6 +399,7 @@ def test_config_parses_supported_single_instance_channels( ], }, WecomMultiConfig, + ["main", "backup"], "bot_id", "wc-backup", ), @@ -396,6 +409,7 @@ def test_config_parses_supported_multi_instance_channels( field_name: str, payload: dict, expected_cls: type, + expected_names: list[str], attr_name: str, attr_value: str, ) -> None: @@ -403,7 +417,7 @@ def test_config_parses_supported_multi_instance_channels( section = getattr(config.channels, field_name) assert isinstance(section, expected_cls) - assert [inst.name for inst in section.instances] == ["main", "backup"] + assert [inst.name for inst in section.instances] == expected_names assert getattr(section.instances[1], attr_name) == attr_value