fix(channels): restore plugin discovery after merge
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -31,18 +32,28 @@ class ChannelManager:
|
|||||||
self._init_channels()
|
self._init_channels()
|
||||||
|
|
||||||
def _init_channels(self) -> None:
|
def _init_channels(self) -> None:
|
||||||
"""Initialize channels discovered via pkgutil scan."""
|
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
|
||||||
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
from nanobot.channels.registry import discover_all
|
||||||
|
|
||||||
groq_key = self.config.providers.groq.api_key
|
groq_key = self.config.providers.groq.api_key
|
||||||
|
|
||||||
for modname in discover_channel_names():
|
for name, cls in discover_all().items():
|
||||||
section = getattr(self.config.channels, modname, None)
|
section = getattr(self.config.channels, name, None)
|
||||||
if not section or not getattr(section, "enabled", False):
|
if section is None:
|
||||||
|
continue
|
||||||
|
enabled = (
|
||||||
|
section.get("enabled", False)
|
||||||
|
if isinstance(section, dict)
|
||||||
|
else getattr(section, "enabled", False)
|
||||||
|
)
|
||||||
|
if not enabled:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
cls = load_channel_class(modname)
|
instances = (
|
||||||
instances = getattr(section, "instances", None)
|
section.get("instances")
|
||||||
|
if isinstance(section, dict)
|
||||||
|
else getattr(section, "instances", None)
|
||||||
|
)
|
||||||
if instances is not None:
|
if instances is not None:
|
||||||
if not instances:
|
if not instances:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -52,18 +63,22 @@ class ChannelManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for inst in instances:
|
for inst in instances:
|
||||||
inst_name = getattr(inst, "name", None)
|
inst_name = (
|
||||||
|
inst.get("name")
|
||||||
|
if isinstance(inst, dict)
|
||||||
|
else getattr(inst, "name", None)
|
||||||
|
)
|
||||||
if not inst_name:
|
if not inst_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'{modname}.instances item missing required field "name"'
|
f'{name}.instances item missing required field "name"'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Session keys use "channel:chat_id", so instance names cannot use ":".
|
# Session keys use "channel:chat_id", so instance names cannot use ":".
|
||||||
channel_name = f"{modname}/{inst_name}"
|
channel_name = f"{name}/{inst_name}"
|
||||||
if channel_name in self.channels:
|
if channel_name in self.channels:
|
||||||
raise ValueError(f"Duplicate channel instance name: {channel_name}")
|
raise ValueError(f"Duplicate channel instance name: {channel_name}")
|
||||||
|
|
||||||
channel = cls(inst, self.bus)
|
channel = self._instantiate_channel(cls, inst)
|
||||||
channel.name = channel_name
|
channel.name = channel_name
|
||||||
channel.transcription_api_key = groq_key
|
channel.transcription_api_key = groq_key
|
||||||
self.channels[channel_name] = channel
|
self.channels[channel_name] = channel
|
||||||
@@ -72,16 +87,36 @@ class ChannelManager:
|
|||||||
cls.display_name,
|
cls.display_name,
|
||||||
channel_name,
|
channel_name,
|
||||||
)
|
)
|
||||||
else:
|
continue
|
||||||
channel = cls(section, self.bus)
|
|
||||||
|
channel = self._instantiate_channel(cls, section)
|
||||||
|
channel.name = name
|
||||||
channel.transcription_api_key = groq_key
|
channel.transcription_api_key = groq_key
|
||||||
self.channels[modname] = channel
|
self.channels[name] = channel
|
||||||
logger.info("{} channel enabled", cls.display_name)
|
logger.info("{} channel enabled", cls.display_name)
|
||||||
except ImportError as e:
|
except Exception as e:
|
||||||
logger.warning("{} channel not available: {}", modname, e)
|
logger.warning("{} channel not available: {}", name, e)
|
||||||
|
|
||||||
self._validate_allow_from()
|
self._validate_allow_from()
|
||||||
|
|
||||||
|
def _instantiate_channel(self, cls: type[BaseChannel], section: Any) -> BaseChannel:
|
||||||
|
"""Instantiate a channel, passing optional supported kwargs when available."""
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
try:
|
||||||
|
params = inspect.signature(cls.__init__).parameters
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
tools = getattr(self.config, "tools", None)
|
||||||
|
if "restrict_to_workspace" in params:
|
||||||
|
kwargs["restrict_to_workspace"] = bool(
|
||||||
|
getattr(tools, "restrict_to_workspace", False)
|
||||||
|
)
|
||||||
|
if "workspace" in params:
|
||||||
|
kwargs["workspace"] = getattr(self.config, "workspace_path", None)
|
||||||
|
|
||||||
|
return cls(section, self.bus, **kwargs)
|
||||||
|
|
||||||
def _validate_allow_from(self) -> None:
|
def _validate_allow_from(self) -> None:
|
||||||
for name, ch in self.channels.items():
|
for name, ch in self.channels.items():
|
||||||
if getattr(ch.config, "allow_from", None) == []:
|
if getattr(ch.config, "allow_from", None) == []:
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import importlib
|
|||||||
import pkgutil
|
import pkgutil
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
|
|
||||||
@@ -13,7 +15,7 @@ _INTERNAL = frozenset({"base", "manager", "registry"})
|
|||||||
|
|
||||||
|
|
||||||
def discover_channel_names() -> list[str]:
|
def discover_channel_names() -> list[str]:
|
||||||
"""Return all channel module names by scanning the package (zero imports)."""
|
"""Return all built-in channel module names by scanning the package (zero imports)."""
|
||||||
import nanobot.channels as pkg
|
import nanobot.channels as pkg
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -33,3 +35,37 @@ def load_channel_class(module_name: str) -> type[BaseChannel]:
|
|||||||
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
|
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
|
||||||
return obj
|
return obj
|
||||||
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
|
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def discover_plugins() -> dict[str, type[BaseChannel]]:
|
||||||
|
"""Discover external channel plugins registered via entry_points."""
|
||||||
|
from importlib.metadata import entry_points
|
||||||
|
|
||||||
|
plugins: dict[str, type[BaseChannel]] = {}
|
||||||
|
for ep in entry_points(group="nanobot.channels"):
|
||||||
|
try:
|
||||||
|
cls = ep.load()
|
||||||
|
plugins[ep.name] = cls
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to load channel plugin '{}': {}", ep.name, e)
|
||||||
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
|
def discover_all() -> dict[str, type[BaseChannel]]:
|
||||||
|
"""Return all channels: built-in (pkgutil) merged with external (entry_points).
|
||||||
|
|
||||||
|
Built-in channels take priority — an external plugin cannot shadow a built-in name.
|
||||||
|
"""
|
||||||
|
builtin: dict[str, type[BaseChannel]] = {}
|
||||||
|
for modname in discover_channel_names():
|
||||||
|
try:
|
||||||
|
builtin[modname] = load_channel_class(modname)
|
||||||
|
except ImportError as e:
|
||||||
|
logger.debug("Skipping built-in channel '{}': {}", modname, e)
|
||||||
|
|
||||||
|
external = discover_plugins()
|
||||||
|
shadowed = set(external) & set(builtin)
|
||||||
|
if shadowed:
|
||||||
|
logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
|
||||||
|
|
||||||
|
return {**external, **builtin}
|
||||||
|
|||||||
@@ -46,8 +46,10 @@ class _DummyChannel(BaseChannel):
|
|||||||
|
|
||||||
|
|
||||||
def _patch_registry(monkeypatch: pytest.MonkeyPatch, channel_names: list[str]) -> None:
|
def _patch_registry(monkeypatch: pytest.MonkeyPatch, channel_names: list[str]) -> None:
|
||||||
monkeypatch.setattr("nanobot.channels.registry.discover_channel_names", lambda: channel_names)
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr("nanobot.channels.registry.load_channel_class", lambda _: _DummyChannel)
|
"nanobot.channels.registry.discover_all",
|
||||||
|
lambda: {name: _DummyChannel for name in channel_names},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -178,7 +180,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("field_name", "payload", "expected_cls", "attr_name", "attr_value"),
|
("field_name", "payload", "expected_cls", "expected_names", "attr_name", "attr_value"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"whatsapp",
|
"whatsapp",
|
||||||
@@ -190,6 +192,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
WhatsAppMultiConfig,
|
WhatsAppMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"bridge_url",
|
"bridge_url",
|
||||||
"ws://127.0.0.1:3002",
|
"ws://127.0.0.1:3002",
|
||||||
),
|
),
|
||||||
@@ -203,6 +206,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
TelegramMultiConfig,
|
TelegramMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"token",
|
"token",
|
||||||
"tg-backup",
|
"tg-backup",
|
||||||
),
|
),
|
||||||
@@ -216,6 +220,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
DiscordMultiConfig,
|
DiscordMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"token",
|
"token",
|
||||||
"dc-backup",
|
"dc-backup",
|
||||||
),
|
),
|
||||||
@@ -234,6 +239,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
FeishuMultiConfig,
|
FeishuMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"app_id",
|
"app_id",
|
||||||
"fs-backup",
|
"fs-backup",
|
||||||
),
|
),
|
||||||
@@ -257,6 +263,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
DingTalkMultiConfig,
|
DingTalkMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"client_id",
|
"client_id",
|
||||||
"dt-backup",
|
"dt-backup",
|
||||||
),
|
),
|
||||||
@@ -282,6 +289,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
MatrixMultiConfig,
|
MatrixMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"homeserver",
|
"homeserver",
|
||||||
"https://matrix-2.example.com",
|
"https://matrix-2.example.com",
|
||||||
),
|
),
|
||||||
@@ -305,6 +313,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
EmailMultiConfig,
|
EmailMultiConfig,
|
||||||
|
["work", "home"],
|
||||||
"imap_host",
|
"imap_host",
|
||||||
"imap.home",
|
"imap.home",
|
||||||
),
|
),
|
||||||
@@ -328,6 +337,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
MochatMultiConfig,
|
MochatMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"claw_token",
|
"claw_token",
|
||||||
"claw-backup",
|
"claw-backup",
|
||||||
),
|
),
|
||||||
@@ -351,6 +361,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
SlackMultiConfig,
|
SlackMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"bot_token",
|
"bot_token",
|
||||||
"xoxb-backup",
|
"xoxb-backup",
|
||||||
),
|
),
|
||||||
@@ -369,6 +380,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
QQMultiConfig,
|
QQMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"app_id",
|
"app_id",
|
||||||
"qq-backup",
|
"qq-backup",
|
||||||
),
|
),
|
||||||
@@ -387,6 +399,7 @@ def test_config_parses_supported_single_instance_channels(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
WecomMultiConfig,
|
WecomMultiConfig,
|
||||||
|
["main", "backup"],
|
||||||
"bot_id",
|
"bot_id",
|
||||||
"wc-backup",
|
"wc-backup",
|
||||||
),
|
),
|
||||||
@@ -396,6 +409,7 @@ def test_config_parses_supported_multi_instance_channels(
|
|||||||
field_name: str,
|
field_name: str,
|
||||||
payload: dict,
|
payload: dict,
|
||||||
expected_cls: type,
|
expected_cls: type,
|
||||||
|
expected_names: list[str],
|
||||||
attr_name: str,
|
attr_name: str,
|
||||||
attr_value: str,
|
attr_value: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -403,7 +417,7 @@ def test_config_parses_supported_multi_instance_channels(
|
|||||||
|
|
||||||
section = getattr(config.channels, field_name)
|
section = getattr(config.channels, field_name)
|
||||||
assert isinstance(section, expected_cls)
|
assert isinstance(section, expected_cls)
|
||||||
assert [inst.name for inst in section.instances] == ["main", "backup"]
|
assert [inst.name for inst in section.instances] == expected_names
|
||||||
assert getattr(section.instances[1], attr_name) == attr_value
|
assert getattr(section.instances[1], attr_name) == attr_value
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user