refactor(channels): formalize default config onboarding
Some checks failed
Test Suite / test (3.11) (push) Failing after 1m32s
Test Suite / test (3.12) (push) Failing after 1m19s
Test Suite / test (3.13) (push) Failing after 1m27s

This commit is contained in:
Hua
2026-03-17 15:12:15 +08:00
parent bae0332af3
commit d31d6cdbe6
15 changed files with 176 additions and 23 deletions

View File

@@ -24,6 +24,11 @@ class BaseChannel(ABC):
display_name: str = "Base" display_name: str = "Base"
transcription_api_key: str = "" transcription_api_key: str = ""
@classmethod
def default_config(cls) -> dict[str, Any] | None:
"""Return the default config payload for onboarding, if the channel provides one."""
return None
def __init__(self, config: Any, bus: MessageBus): def __init__(self, config: Any, bus: MessageBus):
""" """
Initialize the channel. Initialize the channel.

View File

@@ -162,6 +162,10 @@ class DingTalkChannel(BaseChannel):
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"} _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
@classmethod
def default_config(cls) -> dict[str, object]:
return DingTalkConfig().model_dump(by_alias=True)
def __init__(self, config: DingTalkConfig | DingTalkInstanceConfig, bus: MessageBus): def __init__(self, config: DingTalkConfig | DingTalkInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: DingTalkConfig | DingTalkInstanceConfig = config self.config: DingTalkConfig | DingTalkInstanceConfig = config
@@ -262,9 +266,12 @@ class DingTalkChannel(BaseChannel):
def _guess_upload_type(self, media_ref: str) -> str: def _guess_upload_type(self, media_ref: str) -> str:
ext = Path(urlparse(media_ref).path).suffix.lower() ext = Path(urlparse(media_ref).path).suffix.lower()
if ext in self._IMAGE_EXTS: return "image" if ext in self._IMAGE_EXTS:
if ext in self._AUDIO_EXTS: return "voice" return "image"
if ext in self._VIDEO_EXTS: return "video" if ext in self._AUDIO_EXTS:
return "voice"
if ext in self._VIDEO_EXTS:
return "video"
return "file" return "file"
def _guess_filename(self, media_ref: str, upload_type: str) -> str: def _guess_filename(self, media_ref: str, upload_type: str) -> str:
@@ -385,8 +392,10 @@ class DingTalkChannel(BaseChannel):
if resp.status_code != 200: if resp.status_code != 200:
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500]) logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
return False return False
try: result = resp.json() try:
except Exception: result = {} result = resp.json()
except Exception:
result = {}
errcode = result.get("errcode") errcode = result.get("errcode")
if errcode not in (None, 0): if errcode not in (None, 0):
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500]) logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])

View File

@@ -27,6 +27,10 @@ class DiscordChannel(BaseChannel):
name = "discord" name = "discord"
display_name = "Discord" display_name = "Discord"
@classmethod
def default_config(cls) -> dict[str, object]:
return DiscordConfig().model_dump(by_alias=True)
def __init__(self, config: DiscordConfig | DiscordInstanceConfig, bus: MessageBus): def __init__(self, config: DiscordConfig | DiscordInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: DiscordConfig | DiscordInstanceConfig = config self.config: DiscordConfig | DiscordInstanceConfig = config

View File

@@ -51,6 +51,10 @@ class EmailChannel(BaseChannel):
"Dec", "Dec",
) )
@classmethod
def default_config(cls) -> dict[str, object]:
return EmailConfig().model_dump(by_alias=True)
def __init__(self, config: EmailConfig | EmailInstanceConfig, bus: MessageBus): def __init__(self, config: EmailConfig | EmailInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: EmailConfig | EmailInstanceConfig = config self.config: EmailConfig | EmailInstanceConfig = config

View File

@@ -1,12 +1,13 @@
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection.""" """Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
import asyncio import asyncio
import importlib.util
import json import json
import os import os
import re import re
import threading import threading
import time
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@@ -17,8 +18,6 @@ from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir from nanobot.config.paths import get_media_dir
from nanobot.config.schema import FeishuConfig, FeishuInstanceConfig from nanobot.config.schema import FeishuConfig, FeishuInstanceConfig
import importlib.util
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
# Message type display mapping # Message type display mapping
@@ -246,6 +245,10 @@ class FeishuChannel(BaseChannel):
name = "feishu" name = "feishu"
display_name = "Feishu" display_name = "Feishu"
@classmethod
def default_config(cls) -> dict[str, object]:
return FeishuConfig().model_dump(by_alias=True)
def __init__(self, config: FeishuConfig | FeishuInstanceConfig, bus: MessageBus): def __init__(self, config: FeishuConfig | FeishuInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: FeishuConfig | FeishuInstanceConfig = config self.config: FeishuConfig | FeishuInstanceConfig = config
@@ -314,8 +317,8 @@ class FeishuChannel(BaseChannel):
# instead of the already-running main asyncio loop, which would cause # instead of the already-running main asyncio loop, which would cause
# "This event loop is already running" errors. # "This event loop is already running" errors.
def run_ws(): def run_ws():
import time
import lark_oapi.ws.client as _lark_ws_client import lark_oapi.ws.client as _lark_ws_client
ws_loop = asyncio.new_event_loop() ws_loop = asyncio.new_event_loop()
asyncio.set_event_loop(ws_loop) asyncio.set_event_loop(ws_loop)
# Patch the module-level loop used by lark's ws Client.start() # Patch the module-level loop used by lark's ws Client.start()
@@ -375,7 +378,12 @@ class FeishuChannel(BaseChannel):
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool).""" """Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji from lark_oapi.api.im.v1 import (
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
)
try: try:
request = CreateMessageReactionRequest.builder() \ request = CreateMessageReactionRequest.builder() \
.message_id(message_id) \ .message_id(message_id) \

View File

@@ -219,6 +219,10 @@ class MochatChannel(BaseChannel):
name = "mochat" name = "mochat"
display_name = "Mochat" display_name = "Mochat"
@classmethod
def default_config(cls) -> dict[str, object]:
return MochatConfig().model_dump(by_alias=True)
def __init__(self, config: MochatConfig | MochatInstanceConfig, bus: MessageBus): def __init__(self, config: MochatConfig | MochatInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: MochatConfig | MochatInstanceConfig = config self.config: MochatConfig | MochatInstanceConfig = config

View File

@@ -56,6 +56,10 @@ class QQChannel(BaseChannel):
name = "qq" name = "qq"
display_name = "QQ" display_name = "QQ"
@classmethod
def default_config(cls) -> dict[str, object]:
return QQConfig().model_dump(by_alias=True)
def __init__(self, config: QQConfig | QQInstanceConfig, bus: MessageBus): def __init__(self, config: QQConfig | QQInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: QQConfig | QQInstanceConfig = config self.config: QQConfig | QQInstanceConfig = config
@@ -75,8 +79,8 @@ class QQChannel(BaseChannel):
return return
self._running = True self._running = True
BotClass = _make_bot_class(self) bot_class = _make_bot_class(self)
self._client = BotClass() self._client = bot_class()
logger.info("QQ bot started (C2C & Group supported)") logger.info("QQ bot started (C2C & Group supported)")
await self._run_bot() await self._run_bot()

View File

@@ -2,7 +2,6 @@
import asyncio import asyncio
import re import re
from typing import Any
from loguru import logger from loguru import logger
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
@@ -23,6 +22,10 @@ class SlackChannel(BaseChannel):
name = "slack" name = "slack"
display_name = "Slack" display_name = "Slack"
@classmethod
def default_config(cls) -> dict[str, object]:
return SlackConfig().model_dump(by_alias=True)
def __init__(self, config: SlackConfig | SlackInstanceConfig, bus: MessageBus): def __init__(self, config: SlackConfig | SlackInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: SlackConfig | SlackInstanceConfig = config self.config: SlackConfig | SlackInstanceConfig = config

View File

@@ -166,6 +166,10 @@ class TelegramChannel(BaseChannel):
COMMAND_NAMES = ("start", "new", "lang", "persona", "skill", "stop", "help", "restart") COMMAND_NAMES = ("start", "new", "lang", "persona", "skill", "stop", "help", "restart")
@classmethod
def default_config(cls) -> dict[str, object]:
return TelegramConfig().model_dump(by_alias=True)
def __init__(self, config: TelegramConfig | TelegramInstanceConfig, bus: MessageBus): def __init__(self, config: TelegramConfig | TelegramInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: TelegramConfig | TelegramInstanceConfig = config self.config: TelegramConfig | TelegramInstanceConfig = config

View File

@@ -38,6 +38,10 @@ class WecomChannel(BaseChannel):
name = "wecom" name = "wecom"
display_name = "WeCom" display_name = "WeCom"
@classmethod
def default_config(cls) -> dict[str, object]:
return WecomConfig().model_dump(by_alias=True)
def __init__(self, config: WecomConfig | WecomInstanceConfig, bus: MessageBus): def __init__(self, config: WecomConfig | WecomInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: WecomConfig | WecomInstanceConfig = config self.config: WecomConfig | WecomInstanceConfig = config

View File

@@ -24,6 +24,10 @@ class WhatsAppChannel(BaseChannel):
name = "whatsapp" name = "whatsapp"
display_name = "WhatsApp" display_name = "WhatsApp"
@classmethod
def default_config(cls) -> dict[str, object]:
return WhatsAppConfig().model_dump(by_alias=True)
def __init__(self, config: WhatsAppConfig | WhatsAppInstanceConfig, bus: MessageBus): def __init__(self, config: WhatsAppConfig | WhatsAppInstanceConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: WhatsAppConfig | WhatsAppInstanceConfig = config self.config: WhatsAppConfig | WhatsAppInstanceConfig = config

View File

@@ -1,11 +1,11 @@
"""CLI commands for nanobot.""" """CLI commands for nanobot."""
import asyncio import asyncio
from contextlib import contextmanager, nullcontext
import os import os
import select import select
import signal import signal
import sys import sys
from contextlib import contextmanager, nullcontext
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -21,12 +21,11 @@ if sys.platform == "win32":
pass pass
import typer import typer
from prompt_toolkit import print_formatted_text from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit import PromptSession from prompt_toolkit.application import run_in_terminal
from prompt_toolkit.formatted_text import ANSI, HTML from prompt_toolkit.formatted_text import ANSI, HTML
from prompt_toolkit.history import FileHistory from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.application import run_in_terminal
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
@@ -337,6 +336,30 @@ def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
return merged return merged
def _resolve_channel_default_config(channel_cls: Any) -> dict[str, Any] | None:
"""Return a channel's default config if it exposes a valid onboarding payload."""
from loguru import logger
default_config = getattr(channel_cls, "default_config", None)
if not callable(default_config):
return None
try:
payload = default_config()
except Exception as exc:
logger.warning("Skipping channel default_config for {}: {}", channel_cls, exc)
return None
if payload is None:
return None
if not isinstance(payload, dict):
logger.warning(
"Skipping channel default_config for {}: expected dict, got {}",
channel_cls,
type(payload).__name__,
)
return None
return payload
def _onboard_plugins(config_path: Path) -> None: def _onboard_plugins(config_path: Path) -> None:
"""Inject default config for all discovered channels (built-in + plugins).""" """Inject default config for all discovered channels (built-in + plugins)."""
import json import json
@@ -352,13 +375,13 @@ def _onboard_plugins(config_path: Path) -> None:
channels = data.setdefault("channels", {}) channels = data.setdefault("channels", {})
for name, cls in all_channels.items(): for name, cls in all_channels.items():
default_config = getattr(cls, "default_config", None) payload = _resolve_channel_default_config(cls)
if not callable(default_config): if payload is None:
continue continue
if name not in channels: if name not in channels:
channels[name] = default_config() channels[name] = payload
else: else:
channels[name] = _merge_missing_defaults(channels[name], default_config()) channels[name] = _merge_missing_defaults(channels[name], payload)
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False) json.dump(data, f, indent=2, ensure_ascii=False)
@@ -366,9 +389,9 @@ def _onboard_plugins(config_path: Path) -> None:
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import GenerationSettings from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model model = config.agents.defaults.model
provider_name = config.get_provider_name(model) provider_name = config.get_provider_name(model)

View File

@@ -23,3 +23,7 @@ def test_is_allowed_requires_exact_match() -> None:
assert channel.is_allowed("allow@email.com") is True assert channel.is_allowed("allow@email.com") is True
assert channel.is_allowed("attacker|allow@email.com") is False assert channel.is_allowed("attacker|allow@email.com") is False
def test_default_config_returns_none_by_default() -> None:
assert _DummyChannel.default_config() is None

View File

@@ -0,0 +1,9 @@
from nanobot.channels.registry import discover_channel_names, load_channel_class
def test_builtin_channels_expose_default_config_dicts() -> None:
for module_name in sorted(discover_channel_names()):
channel_cls = load_channel_class(module_name)
payload = channel_cls.default_config()
assert isinstance(payload, dict), module_name
assert "enabled" in payload, module_name

View File

@@ -1,9 +1,10 @@
import json import json
from types import SimpleNamespace from types import SimpleNamespace
import pytest
from typer.testing import CliRunner from typer.testing import CliRunner
from nanobot.cli.commands import app from nanobot.cli.commands import _resolve_channel_default_config, app
from nanobot.config.loader import load_config, save_config from nanobot.config.loader import load_config, save_config
runner = CliRunner() runner = CliRunner()
@@ -130,3 +131,66 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
assert result.exit_code == 0 assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8")) saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["qq"]["msgFormat"] == "plain" assert saved["channels"]["qq"]["msgFormat"] == "plain"
@pytest.mark.parametrize(
("channel_cls", "expected"),
[
(SimpleNamespace(), None),
(SimpleNamespace(default_config="invalid"), None),
(SimpleNamespace(default_config=lambda: None), None),
(SimpleNamespace(default_config=lambda: ["invalid"]), None),
(SimpleNamespace(default_config=lambda: {"enabled": False}), {"enabled": False}),
],
)
def test_resolve_channel_default_config_validates_payload(channel_cls, expected) -> None:
assert _resolve_channel_default_config(channel_cls) == expected
def test_resolve_channel_default_config_skips_exceptions() -> None:
def _raise() -> dict[str, object]:
raise RuntimeError("boom")
assert _resolve_channel_default_config(SimpleNamespace(default_config=_raise)) is None
def test_onboard_refresh_skips_invalid_channel_default_configs(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(json.dumps({"channels": {}}), encoding="utf-8")
def _raise() -> dict[str, object]:
raise RuntimeError("boom")
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {
"missing": SimpleNamespace(),
"noncallable": SimpleNamespace(default_config="invalid"),
"none": SimpleNamespace(default_config=lambda: None),
"wrong_type": SimpleNamespace(default_config=lambda: ["invalid"]),
"raises": SimpleNamespace(default_config=_raise),
"qq": SimpleNamespace(
default_config=lambda: {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
"msgFormat": "plain",
}
),
},
)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert "missing" not in saved["channels"]
assert "noncallable" not in saved["channels"]
assert "none" not in saved["channels"]
assert "wrong_type" not in saved["channels"]
assert "raises" not in saved["channels"]
assert saved["channels"]["qq"]["msgFormat"] == "plain"