diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 74c540a..e1b673f 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -24,6 +24,11 @@ class BaseChannel(ABC): display_name: str = "Base" transcription_api_key: str = "" + @classmethod + def default_config(cls) -> dict[str, Any] | None: + """Return the default config payload for onboarding, if the channel provides one.""" + return None + def __init__(self, config: Any, bus: MessageBus): """ Initialize the channel. diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 10b6b98..404f09d 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -162,6 +162,10 @@ class DingTalkChannel(BaseChannel): _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"} + @classmethod + def default_config(cls) -> dict[str, object]: + return DingTalkConfig().model_dump(by_alias=True) + def __init__(self, config: DingTalkConfig | DingTalkInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: DingTalkConfig | DingTalkInstanceConfig = config @@ -262,9 +266,12 @@ class DingTalkChannel(BaseChannel): def _guess_upload_type(self, media_ref: str) -> str: ext = Path(urlparse(media_ref).path).suffix.lower() - if ext in self._IMAGE_EXTS: return "image" - if ext in self._AUDIO_EXTS: return "voice" - if ext in self._VIDEO_EXTS: return "video" + if ext in self._IMAGE_EXTS: + return "image" + if ext in self._AUDIO_EXTS: + return "voice" + if ext in self._VIDEO_EXTS: + return "video" return "file" def _guess_filename(self, media_ref: str, upload_type: str) -> str: @@ -385,8 +392,10 @@ class DingTalkChannel(BaseChannel): if resp.status_code != 200: logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500]) return False - try: result = resp.json() - except Exception: result = {} + try: + result = resp.json() + except Exception: + result = {} errcode = result.get("errcode") if errcode not in (None, 0): logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500]) diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index a11101f..c72f7b0 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -27,6 +27,10 @@ class DiscordChannel(BaseChannel): name = "discord" display_name = "Discord" + @classmethod + def default_config(cls) -> dict[str, object]: + return DiscordConfig().model_dump(by_alias=True) + def __init__(self, config: DiscordConfig | DiscordInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: DiscordConfig | DiscordInstanceConfig = config diff --git a/nanobot/channels/email.py b/nanobot/channels/email.py index cc1e774..8070418 100644 --- a/nanobot/channels/email.py +++ b/nanobot/channels/email.py @@ -51,6 +51,10 @@ class EmailChannel(BaseChannel): "Dec", ) + @classmethod + def default_config(cls) -> dict[str, object]: + return EmailConfig().model_dump(by_alias=True) + def __init__(self, config: EmailConfig | EmailInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: EmailConfig | EmailInstanceConfig = config diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 52f5eda..2876b48 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1,12 +1,13 @@ """Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection.""" import asyncio +import importlib.util import json import os import re import threading +import time from collections import OrderedDict -from pathlib import Path from typing import Any from loguru import logger @@ -17,8 +18,6 @@ from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir from nanobot.config.schema import FeishuConfig, FeishuInstanceConfig -import importlib.util - FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None # Message type display mapping @@ -246,6 +245,10 @@ class FeishuChannel(BaseChannel): name = "feishu" display_name = "Feishu" + @classmethod + def default_config(cls) -> dict[str, object]: + return FeishuConfig().model_dump(by_alias=True) + def __init__(self, config: FeishuConfig | FeishuInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: FeishuConfig | FeishuInstanceConfig = config @@ -314,8 +317,8 @@ class FeishuChannel(BaseChannel): # instead of the already-running main asyncio loop, which would cause # "This event loop is already running" errors. def run_ws(): - import time import lark_oapi.ws.client as _lark_ws_client + ws_loop = asyncio.new_event_loop() asyncio.set_event_loop(ws_loop) # Patch the module-level loop used by lark's ws Client.start() @@ -375,7 +378,12 @@ class FeishuChannel(BaseChannel): def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: """Sync helper for adding reaction (runs in thread pool).""" - from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji + from lark_oapi.api.im.v1 import ( + CreateMessageReactionRequest, + CreateMessageReactionRequestBody, + Emoji, + ) + try: request = CreateMessageReactionRequest.builder() \ .message_id(message_id) \ diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py index 79220b2..c1d0055 100644 --- a/nanobot/channels/mochat.py +++ b/nanobot/channels/mochat.py @@ -219,6 +219,10 @@ class MochatChannel(BaseChannel): name = "mochat" display_name = "Mochat" + @classmethod + def default_config(cls) -> dict[str, object]: + return MochatConfig().model_dump(by_alias=True) + def __init__(self, config: MochatConfig | MochatInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: MochatConfig | MochatInstanceConfig = config diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 4b20d80..3d7c8e4 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -56,6 +56,10 @@ class QQChannel(BaseChannel): name = "qq" display_name = "QQ" + @classmethod + def default_config(cls) -> dict[str, object]: + return QQConfig().model_dump(by_alias=True) + def __init__(self, config: QQConfig | QQInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: QQConfig | QQInstanceConfig = config @@ -75,8 +79,8 @@ class QQChannel(BaseChannel): return self._running = True - BotClass = _make_bot_class(self) - self._client = BotClass() + bot_class = _make_bot_class(self) + self._client = bot_class() logger.info("QQ bot started (C2C & Group supported)") await self._run_bot() diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index a13620e..c1e8d6d 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -2,7 +2,6 @@ import asyncio import re -from typing import Any from loguru import logger from slack_sdk.socket_mode.request import SocketModeRequest @@ -23,6 +22,10 @@ class SlackChannel(BaseChannel): name = "slack" display_name = "Slack" + @classmethod + def default_config(cls) -> dict[str, object]: + return SlackConfig().model_dump(by_alias=True) + def __init__(self, config: SlackConfig | SlackInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: SlackConfig | SlackInstanceConfig = config diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index b6f5433..594f23a 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -166,6 +166,10 @@ class TelegramChannel(BaseChannel): COMMAND_NAMES = ("start", "new", "lang", "persona", "skill", "stop", "help", "restart") + @classmethod + def default_config(cls) -> dict[str, object]: + return TelegramConfig().model_dump(by_alias=True) + def __init__(self, config: TelegramConfig | TelegramInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: TelegramConfig | TelegramInstanceConfig = config diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index 6c3e90b..73dcb9a 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -38,6 +38,10 @@ class WecomChannel(BaseChannel): name = "wecom" display_name = "WeCom" + @classmethod + def default_config(cls) -> dict[str, object]: + return WecomConfig().model_dump(by_alias=True) + def __init__(self, config: WecomConfig | WecomInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: WecomConfig | WecomInstanceConfig = config diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index 4360a9c..dc43b85 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -24,6 +24,10 @@ class WhatsAppChannel(BaseChannel): name = "whatsapp" display_name = "WhatsApp" + @classmethod + def default_config(cls) -> dict[str, object]: + return WhatsAppConfig().model_dump(by_alias=True) + def __init__(self, config: WhatsAppConfig | WhatsAppInstanceConfig, bus: MessageBus): super().__init__(config, bus) self.config: WhatsAppConfig | WhatsAppInstanceConfig = config diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 07b97a7..7ac8d0a 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -1,11 +1,11 @@ """CLI commands for nanobot.""" import asyncio -from contextlib import contextmanager, nullcontext import os import select import signal import sys +from contextlib import contextmanager, nullcontext from pathlib import Path from typing import Any @@ -21,12 +21,11 @@ if sys.platform == "win32": pass import typer -from prompt_toolkit import print_formatted_text -from prompt_toolkit import PromptSession +from prompt_toolkit import PromptSession, print_formatted_text +from prompt_toolkit.application import run_in_terminal from prompt_toolkit.formatted_text import ANSI, HTML from prompt_toolkit.history import FileHistory from prompt_toolkit.patch_stdout import patch_stdout -from prompt_toolkit.application import run_in_terminal from rich.console import Console from rich.markdown import Markdown from rich.table import Table @@ -337,6 +336,30 @@ def _merge_missing_defaults(existing: Any, defaults: Any) -> Any: return merged +def _resolve_channel_default_config(channel_cls: Any) -> dict[str, Any] | None: + """Return a channel's default config if it exposes a valid onboarding payload.""" + from loguru import logger + + default_config = getattr(channel_cls, "default_config", None) + if not callable(default_config): + return None + try: + payload = default_config() + except Exception as exc: + logger.warning("Skipping channel default_config for {}: {}", channel_cls, exc) + return None + if payload is None: + return None + if not isinstance(payload, dict): + logger.warning( + "Skipping channel default_config for {}: expected dict, got {}", + channel_cls, + type(payload).__name__, + ) + return None + return payload + + def _onboard_plugins(config_path: Path) -> None: """Inject default config for all discovered channels (built-in + plugins).""" import json @@ -352,13 +375,13 @@ def _onboard_plugins(config_path: Path) -> None: channels = data.setdefault("channels", {}) for name, cls in all_channels.items(): - default_config = getattr(cls, "default_config", None) - if not callable(default_config): + payload = _resolve_channel_default_config(cls) + if payload is None: continue if name not in channels: - channels[name] = default_config() + channels[name] = payload else: - channels[name] = _merge_missing_defaults(channels[name], default_config()) + channels[name] = _merge_missing_defaults(channels[name], payload) with open(config_path, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) @@ -366,9 +389,9 @@ def _onboard_plugins(config_path: Path) -> None: def _make_provider(config: Config): """Create the appropriate LLM provider from config.""" + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.base import GenerationSettings from nanobot.providers.openai_codex_provider import OpenAICodexProvider - from nanobot.providers.azure_openai_provider import AzureOpenAIProvider model = config.agents.defaults.model provider_name = config.get_provider_name(model) diff --git a/tests/test_base_channel.py b/tests/test_base_channel.py index 5d10d4e..aeb7ea8 100644 --- a/tests/test_base_channel.py +++ b/tests/test_base_channel.py @@ -23,3 +23,7 @@ def test_is_allowed_requires_exact_match() -> None: assert channel.is_allowed("allow@email.com") is True assert channel.is_allowed("attacker|allow@email.com") is False + + +def test_default_config_returns_none_by_default() -> None: + assert _DummyChannel.default_config() is None diff --git a/tests/test_channel_default_config.py b/tests/test_channel_default_config.py new file mode 100644 index 0000000..31171fb --- /dev/null +++ b/tests/test_channel_default_config.py @@ -0,0 +1,9 @@ +from nanobot.channels.registry import discover_channel_names, load_channel_class + + +def test_builtin_channels_expose_default_config_dicts() -> None: + for module_name in sorted(discover_channel_names()): + channel_cls = load_channel_class(module_name) + payload = channel_cls.default_config() + assert isinstance(payload, dict), module_name + assert "enabled" in payload, module_name diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py index 2a446b7..86e08aa 100644 --- a/tests/test_config_migration.py +++ b/tests/test_config_migration.py @@ -1,9 +1,10 @@ import json from types import SimpleNamespace +import pytest from typer.testing import CliRunner -from nanobot.cli.commands import app +from nanobot.cli.commands import _resolve_channel_default_config, app from nanobot.config.loader import load_config, save_config runner = CliRunner() @@ -130,3 +131,66 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) assert result.exit_code == 0 saved = json.loads(config_path.read_text(encoding="utf-8")) assert saved["channels"]["qq"]["msgFormat"] == "plain" + + +@pytest.mark.parametrize( + ("channel_cls", "expected"), + [ + (SimpleNamespace(), None), + (SimpleNamespace(default_config="invalid"), None), + (SimpleNamespace(default_config=lambda: None), None), + (SimpleNamespace(default_config=lambda: ["invalid"]), None), + (SimpleNamespace(default_config=lambda: {"enabled": False}), {"enabled": False}), + ], +) +def test_resolve_channel_default_config_validates_payload(channel_cls, expected) -> None: + assert _resolve_channel_default_config(channel_cls) == expected + + +def test_resolve_channel_default_config_skips_exceptions() -> None: + def _raise() -> dict[str, object]: + raise RuntimeError("boom") + + assert _resolve_channel_default_config(SimpleNamespace(default_config=_raise)) is None + + +def test_onboard_refresh_skips_invalid_channel_default_configs(tmp_path, monkeypatch) -> None: + config_path = tmp_path / "config.json" + workspace = tmp_path / "workspace" + config_path.write_text(json.dumps({"channels": {}}), encoding="utf-8") + + def _raise() -> dict[str, object]: + raise RuntimeError("boom") + + monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) + monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) + monkeypatch.setattr( + "nanobot.channels.registry.discover_all", + lambda: { + "missing": SimpleNamespace(), + "noncallable": SimpleNamespace(default_config="invalid"), + "none": SimpleNamespace(default_config=lambda: None), + "wrong_type": SimpleNamespace(default_config=lambda: ["invalid"]), + "raises": SimpleNamespace(default_config=_raise), + "qq": SimpleNamespace( + default_config=lambda: { + "enabled": False, + "appId": "", + "secret": "", + "allowFrom": [], + "msgFormat": "plain", + } + ), + }, + ) + + result = runner.invoke(app, ["onboard"], input="n\n") + + assert result.exit_code == 0 + saved = json.loads(config_path.read_text(encoding="utf-8")) + assert "missing" not in saved["channels"] + assert "noncallable" not in saved["channels"] + assert "none" not in saved["channels"] + assert "wrong_type" not in saved["channels"] + assert "raises" not in saved["channels"] + assert saved["channels"]["qq"]["msgFormat"] == "plain"