fix(channels): restore plugin discovery after merge
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@@ -31,18 +32,28 @@ class ChannelManager:
|
||||
self._init_channels()
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels discovered via pkgutil scan."""
|
||||
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
|
||||
groq_key = self.config.providers.groq.api_key
|
||||
|
||||
for modname in discover_channel_names():
|
||||
section = getattr(self.config.channels, modname, None)
|
||||
if not section or not getattr(section, "enabled", False):
|
||||
for name, cls in discover_all().items():
|
||||
section = getattr(self.config.channels, name, None)
|
||||
if section is None:
|
||||
continue
|
||||
enabled = (
|
||||
section.get("enabled", False)
|
||||
if isinstance(section, dict)
|
||||
else getattr(section, "enabled", False)
|
||||
)
|
||||
if not enabled:
|
||||
continue
|
||||
try:
|
||||
cls = load_channel_class(modname)
|
||||
instances = getattr(section, "instances", None)
|
||||
instances = (
|
||||
section.get("instances")
|
||||
if isinstance(section, dict)
|
||||
else getattr(section, "instances", None)
|
||||
)
|
||||
if instances is not None:
|
||||
if not instances:
|
||||
logger.warning(
|
||||
@@ -52,18 +63,22 @@ class ChannelManager:
|
||||
continue
|
||||
|
||||
for inst in instances:
|
||||
inst_name = getattr(inst, "name", None)
|
||||
inst_name = (
|
||||
inst.get("name")
|
||||
if isinstance(inst, dict)
|
||||
else getattr(inst, "name", None)
|
||||
)
|
||||
if not inst_name:
|
||||
raise ValueError(
|
||||
f'{modname}.instances item missing required field "name"'
|
||||
f'{name}.instances item missing required field "name"'
|
||||
)
|
||||
|
||||
# Session keys use "channel:chat_id", so instance names cannot use ":".
|
||||
channel_name = f"{modname}/{inst_name}"
|
||||
channel_name = f"{name}/{inst_name}"
|
||||
if channel_name in self.channels:
|
||||
raise ValueError(f"Duplicate channel instance name: {channel_name}")
|
||||
|
||||
channel = cls(inst, self.bus)
|
||||
channel = self._instantiate_channel(cls, inst)
|
||||
channel.name = channel_name
|
||||
channel.transcription_api_key = groq_key
|
||||
self.channels[channel_name] = channel
|
||||
@@ -72,16 +87,36 @@ class ChannelManager:
|
||||
cls.display_name,
|
||||
channel_name,
|
||||
)
|
||||
else:
|
||||
channel = cls(section, self.bus)
|
||||
channel.transcription_api_key = groq_key
|
||||
self.channels[modname] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except ImportError as e:
|
||||
logger.warning("{} channel not available: {}", modname, e)
|
||||
continue
|
||||
|
||||
channel = self._instantiate_channel(cls, section)
|
||||
channel.name = name
|
||||
channel.transcription_api_key = groq_key
|
||||
self.channels[name] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except Exception as e:
|
||||
logger.warning("{} channel not available: {}", name, e)
|
||||
|
||||
self._validate_allow_from()
|
||||
|
||||
def _instantiate_channel(self, cls: type[BaseChannel], section: Any) -> BaseChannel:
|
||||
"""Instantiate a channel, passing optional supported kwargs when available."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = inspect.signature(cls.__init__).parameters
|
||||
except (TypeError, ValueError):
|
||||
params = {}
|
||||
|
||||
tools = getattr(self.config, "tools", None)
|
||||
if "restrict_to_workspace" in params:
|
||||
kwargs["restrict_to_workspace"] = bool(
|
||||
getattr(tools, "restrict_to_workspace", False)
|
||||
)
|
||||
if "workspace" in params:
|
||||
kwargs["workspace"] = getattr(self.config, "workspace_path", None)
|
||||
|
||||
return cls(section, self.bus, **kwargs)
|
||||
|
||||
def _validate_allow_from(self) -> None:
|
||||
for name, ch in self.channels.items():
|
||||
if getattr(ch.config, "allow_from", None) == []:
|
||||
|
||||
@@ -6,6 +6,8 @@ import importlib
|
||||
import pkgutil
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
@@ -13,7 +15,7 @@ _INTERNAL = frozenset({"base", "manager", "registry"})
|
||||
|
||||
|
||||
def discover_channel_names() -> list[str]:
|
||||
"""Return all channel module names by scanning the package (zero imports)."""
|
||||
"""Return all built-in channel module names by scanning the package (zero imports)."""
|
||||
import nanobot.channels as pkg
|
||||
|
||||
return [
|
||||
@@ -33,3 +35,37 @@ def load_channel_class(module_name: str) -> type[BaseChannel]:
|
||||
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
|
||||
return obj
|
||||
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
|
||||
|
||||
|
||||
def discover_plugins() -> dict[str, type[BaseChannel]]:
|
||||
"""Discover external channel plugins registered via entry_points."""
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
plugins: dict[str, type[BaseChannel]] = {}
|
||||
for ep in entry_points(group="nanobot.channels"):
|
||||
try:
|
||||
cls = ep.load()
|
||||
plugins[ep.name] = cls
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load channel plugin '{}': {}", ep.name, e)
|
||||
return plugins
|
||||
|
||||
|
||||
def discover_all() -> dict[str, type[BaseChannel]]:
|
||||
"""Return all channels: built-in (pkgutil) merged with external (entry_points).
|
||||
|
||||
Built-in channels take priority — an external plugin cannot shadow a built-in name.
|
||||
"""
|
||||
builtin: dict[str, type[BaseChannel]] = {}
|
||||
for modname in discover_channel_names():
|
||||
try:
|
||||
builtin[modname] = load_channel_class(modname)
|
||||
except ImportError as e:
|
||||
logger.debug("Skipping built-in channel '{}': {}", modname, e)
|
||||
|
||||
external = discover_plugins()
|
||||
shadowed = set(external) & set(builtin)
|
||||
if shadowed:
|
||||
logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
|
||||
|
||||
return {**external, **builtin}
|
||||
|
||||
@@ -46,8 +46,10 @@ class _DummyChannel(BaseChannel):
|
||||
|
||||
|
||||
def _patch_registry(monkeypatch: pytest.MonkeyPatch, channel_names: list[str]) -> None:
|
||||
monkeypatch.setattr("nanobot.channels.registry.discover_channel_names", lambda: channel_names)
|
||||
monkeypatch.setattr("nanobot.channels.registry.load_channel_class", lambda _: _DummyChannel)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {name: _DummyChannel for name in channel_names},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -178,7 +180,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field_name", "payload", "expected_cls", "attr_name", "attr_value"),
|
||||
("field_name", "payload", "expected_cls", "expected_names", "attr_name", "attr_value"),
|
||||
[
|
||||
(
|
||||
"whatsapp",
|
||||
@@ -190,6 +192,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
WhatsAppMultiConfig,
|
||||
["main", "backup"],
|
||||
"bridge_url",
|
||||
"ws://127.0.0.1:3002",
|
||||
),
|
||||
@@ -203,6 +206,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
TelegramMultiConfig,
|
||||
["main", "backup"],
|
||||
"token",
|
||||
"tg-backup",
|
||||
),
|
||||
@@ -216,6 +220,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
DiscordMultiConfig,
|
||||
["main", "backup"],
|
||||
"token",
|
||||
"dc-backup",
|
||||
),
|
||||
@@ -234,6 +239,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
FeishuMultiConfig,
|
||||
["main", "backup"],
|
||||
"app_id",
|
||||
"fs-backup",
|
||||
),
|
||||
@@ -257,6 +263,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
DingTalkMultiConfig,
|
||||
["main", "backup"],
|
||||
"client_id",
|
||||
"dt-backup",
|
||||
),
|
||||
@@ -282,6 +289,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
MatrixMultiConfig,
|
||||
["main", "backup"],
|
||||
"homeserver",
|
||||
"https://matrix-2.example.com",
|
||||
),
|
||||
@@ -305,6 +313,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
EmailMultiConfig,
|
||||
["work", "home"],
|
||||
"imap_host",
|
||||
"imap.home",
|
||||
),
|
||||
@@ -328,6 +337,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
MochatMultiConfig,
|
||||
["main", "backup"],
|
||||
"claw_token",
|
||||
"claw-backup",
|
||||
),
|
||||
@@ -351,6 +361,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
SlackMultiConfig,
|
||||
["main", "backup"],
|
||||
"bot_token",
|
||||
"xoxb-backup",
|
||||
),
|
||||
@@ -369,6 +380,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
QQMultiConfig,
|
||||
["main", "backup"],
|
||||
"app_id",
|
||||
"qq-backup",
|
||||
),
|
||||
@@ -387,6 +399,7 @@ def test_config_parses_supported_single_instance_channels(
|
||||
],
|
||||
},
|
||||
WecomMultiConfig,
|
||||
["main", "backup"],
|
||||
"bot_id",
|
||||
"wc-backup",
|
||||
),
|
||||
@@ -396,6 +409,7 @@ def test_config_parses_supported_multi_instance_channels(
|
||||
field_name: str,
|
||||
payload: dict,
|
||||
expected_cls: type,
|
||||
expected_names: list[str],
|
||||
attr_name: str,
|
||||
attr_value: str,
|
||||
) -> None:
|
||||
@@ -403,7 +417,7 @@ def test_config_parses_supported_multi_instance_channels(
|
||||
|
||||
section = getattr(config.channels, field_name)
|
||||
assert isinstance(section, expected_cls)
|
||||
assert [inst.name for inst in section.instances] == ["main", "backup"]
|
||||
assert [inst.name for inst in section.instances] == expected_names
|
||||
assert getattr(section.instances[1], attr_name) == attr_value
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user