fix(channels): restore plugin discovery after merge

This commit is contained in:
Hua
2026-03-15 18:21:02 +08:00
parent 5a5587e39b
commit fc4cc5385a
3 changed files with 108 additions and 23 deletions

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import inspect
from typing import Any
from loguru import logger
@@ -31,18 +32,28 @@ class ChannelManager:
self._init_channels()
def _init_channels(self) -> None:
"""Initialize channels discovered via pkgutil scan."""
from nanobot.channels.registry import discover_channel_names, load_channel_class
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
from nanobot.channels.registry import discover_all
groq_key = self.config.providers.groq.api_key
for modname in discover_channel_names():
section = getattr(self.config.channels, modname, None)
if not section or not getattr(section, "enabled", False):
for name, cls in discover_all().items():
section = getattr(self.config.channels, name, None)
if section is None:
continue
enabled = (
section.get("enabled", False)
if isinstance(section, dict)
else getattr(section, "enabled", False)
)
if not enabled:
continue
try:
cls = load_channel_class(modname)
instances = getattr(section, "instances", None)
instances = (
section.get("instances")
if isinstance(section, dict)
else getattr(section, "instances", None)
)
if instances is not None:
if not instances:
logger.warning(
@@ -52,18 +63,22 @@ class ChannelManager:
continue
for inst in instances:
inst_name = getattr(inst, "name", None)
inst_name = (
inst.get("name")
if isinstance(inst, dict)
else getattr(inst, "name", None)
)
if not inst_name:
raise ValueError(
f'{modname}.instances item missing required field "name"'
f'{name}.instances item missing required field "name"'
)
# Session keys use "channel:chat_id", so instance names cannot use ":".
channel_name = f"{modname}/{inst_name}"
channel_name = f"{name}/{inst_name}"
if channel_name in self.channels:
raise ValueError(f"Duplicate channel instance name: {channel_name}")
channel = cls(inst, self.bus)
channel = self._instantiate_channel(cls, inst)
channel.name = channel_name
channel.transcription_api_key = groq_key
self.channels[channel_name] = channel
@@ -72,16 +87,36 @@ class ChannelManager:
cls.display_name,
channel_name,
)
else:
channel = cls(section, self.bus)
channel.transcription_api_key = groq_key
self.channels[modname] = channel
logger.info("{} channel enabled", cls.display_name)
except ImportError as e:
logger.warning("{} channel not available: {}", modname, e)
continue
channel = self._instantiate_channel(cls, section)
channel.name = name
channel.transcription_api_key = groq_key
self.channels[name] = channel
logger.info("{} channel enabled", cls.display_name)
except Exception as e:
logger.warning("{} channel not available: {}", name, e)
self._validate_allow_from()
def _instantiate_channel(self, cls: type[BaseChannel], section: Any) -> BaseChannel:
"""Instantiate a channel, passing optional supported kwargs when available."""
kwargs: dict[str, Any] = {}
try:
params = inspect.signature(cls.__init__).parameters
except (TypeError, ValueError):
params = {}
tools = getattr(self.config, "tools", None)
if "restrict_to_workspace" in params:
kwargs["restrict_to_workspace"] = bool(
getattr(tools, "restrict_to_workspace", False)
)
if "workspace" in params:
kwargs["workspace"] = getattr(self.config, "workspace_path", None)
return cls(section, self.bus, **kwargs)
def _validate_allow_from(self) -> None:
for name, ch in self.channels.items():
if getattr(ch.config, "allow_from", None) == []:

View File

@@ -6,6 +6,8 @@ import importlib
import pkgutil
from typing import TYPE_CHECKING
from loguru import logger
if TYPE_CHECKING:
from nanobot.channels.base import BaseChannel
@@ -13,7 +15,7 @@ _INTERNAL = frozenset({"base", "manager", "registry"})
def discover_channel_names() -> list[str]:
"""Return all channel module names by scanning the package (zero imports)."""
"""Return all built-in channel module names by scanning the package (zero imports)."""
import nanobot.channels as pkg
return [
@@ -33,3 +35,37 @@ def load_channel_class(module_name: str) -> type[BaseChannel]:
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
return obj
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
def discover_plugins() -> dict[str, type[BaseChannel]]:
"""Discover external channel plugins registered via entry_points."""
from importlib.metadata import entry_points
plugins: dict[str, type[BaseChannel]] = {}
for ep in entry_points(group="nanobot.channels"):
try:
cls = ep.load()
plugins[ep.name] = cls
except Exception as e:
logger.warning("Failed to load channel plugin '{}': {}", ep.name, e)
return plugins
def discover_all() -> dict[str, type[BaseChannel]]:
"""Return all channels: built-in (pkgutil) merged with external (entry_points).
Built-in channels take priority — an external plugin cannot shadow a built-in name.
"""
builtin: dict[str, type[BaseChannel]] = {}
for modname in discover_channel_names():
try:
builtin[modname] = load_channel_class(modname)
except ImportError as e:
logger.debug("Skipping built-in channel '{}': {}", modname, e)
external = discover_plugins()
shadowed = set(external) & set(builtin)
if shadowed:
logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
return {**external, **builtin}

View File

@@ -46,8 +46,10 @@ class _DummyChannel(BaseChannel):
def _patch_registry(monkeypatch: pytest.MonkeyPatch, channel_names: list[str]) -> None:
monkeypatch.setattr("nanobot.channels.registry.discover_channel_names", lambda: channel_names)
monkeypatch.setattr("nanobot.channels.registry.load_channel_class", lambda _: _DummyChannel)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {name: _DummyChannel for name in channel_names},
)
@pytest.mark.parametrize(
@@ -178,7 +180,7 @@ def test_config_parses_supported_single_instance_channels(
@pytest.mark.parametrize(
("field_name", "payload", "expected_cls", "attr_name", "attr_value"),
("field_name", "payload", "expected_cls", "expected_names", "attr_name", "attr_value"),
[
(
"whatsapp",
@@ -190,6 +192,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
WhatsAppMultiConfig,
["main", "backup"],
"bridge_url",
"ws://127.0.0.1:3002",
),
@@ -203,6 +206,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
TelegramMultiConfig,
["main", "backup"],
"token",
"tg-backup",
),
@@ -216,6 +220,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
DiscordMultiConfig,
["main", "backup"],
"token",
"dc-backup",
),
@@ -234,6 +239,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
FeishuMultiConfig,
["main", "backup"],
"app_id",
"fs-backup",
),
@@ -257,6 +263,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
DingTalkMultiConfig,
["main", "backup"],
"client_id",
"dt-backup",
),
@@ -282,6 +289,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
MatrixMultiConfig,
["main", "backup"],
"homeserver",
"https://matrix-2.example.com",
),
@@ -305,6 +313,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
EmailMultiConfig,
["work", "home"],
"imap_host",
"imap.home",
),
@@ -328,6 +337,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
MochatMultiConfig,
["main", "backup"],
"claw_token",
"claw-backup",
),
@@ -351,6 +361,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
SlackMultiConfig,
["main", "backup"],
"bot_token",
"xoxb-backup",
),
@@ -369,6 +380,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
QQMultiConfig,
["main", "backup"],
"app_id",
"qq-backup",
),
@@ -387,6 +399,7 @@ def test_config_parses_supported_single_instance_channels(
],
},
WecomMultiConfig,
["main", "backup"],
"bot_id",
"wc-backup",
),
@@ -396,6 +409,7 @@ def test_config_parses_supported_multi_instance_channels(
field_name: str,
payload: dict,
expected_cls: type,
expected_names: list[str],
attr_name: str,
attr_value: str,
) -> None:
@@ -403,7 +417,7 @@ def test_config_parses_supported_multi_instance_channels(
section = getattr(config.channels, field_name)
assert isinstance(section, expected_cls)
assert [inst.name for inst in section.instances] == ["main", "backup"]
assert [inst.name for inst in section.instances] == expected_names
assert getattr(section.instances[1], attr_name) == attr_value