refactor(channels): formalize default config onboarding
Some checks failed
Test Suite / test (3.11) (push) Failing after 1m32s
Test Suite / test (3.12) (push) Failing after 1m19s
Test Suite / test (3.13) (push) Failing after 1m27s

This commit is contained in:
Hua
2026-03-17 15:12:15 +08:00
parent bae0332af3
commit d31d6cdbe6
15 changed files with 176 additions and 23 deletions

View File

@@ -24,6 +24,11 @@ class BaseChannel(ABC):
display_name: str = "Base"
transcription_api_key: str = ""
@classmethod
def default_config(cls) -> dict[str, Any] | None:
"""Return the default config payload for onboarding, if the channel provides one."""
return None
def __init__(self, config: Any, bus: MessageBus):
"""
Initialize the channel.

View File

@@ -162,6 +162,10 @@ class DingTalkChannel(BaseChannel):
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
@classmethod
def default_config(cls) -> dict[str, object]:
return DingTalkConfig().model_dump(by_alias=True)
def __init__(self, config: DingTalkConfig | DingTalkInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: DingTalkConfig | DingTalkInstanceConfig = config
@@ -262,9 +266,12 @@ class DingTalkChannel(BaseChannel):
def _guess_upload_type(self, media_ref: str) -> str:
ext = Path(urlparse(media_ref).path).suffix.lower()
if ext in self._IMAGE_EXTS: return "image"
if ext in self._AUDIO_EXTS: return "voice"
if ext in self._VIDEO_EXTS: return "video"
if ext in self._IMAGE_EXTS:
return "image"
if ext in self._AUDIO_EXTS:
return "voice"
if ext in self._VIDEO_EXTS:
return "video"
return "file"
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
@@ -385,8 +392,10 @@ class DingTalkChannel(BaseChannel):
if resp.status_code != 200:
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
return False
try: result = resp.json()
except Exception: result = {}
try:
result = resp.json()
except Exception:
result = {}
errcode = result.get("errcode")
if errcode not in (None, 0):
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])

View File

@@ -27,6 +27,10 @@ class DiscordChannel(BaseChannel):
name = "discord"
display_name = "Discord"
@classmethod
def default_config(cls) -> dict[str, object]:
return DiscordConfig().model_dump(by_alias=True)
def __init__(self, config: DiscordConfig | DiscordInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: DiscordConfig | DiscordInstanceConfig = config

View File

@@ -51,6 +51,10 @@ class EmailChannel(BaseChannel):
"Dec",
)
@classmethod
def default_config(cls) -> dict[str, object]:
return EmailConfig().model_dump(by_alias=True)
def __init__(self, config: EmailConfig | EmailInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: EmailConfig | EmailInstanceConfig = config

View File

@@ -1,12 +1,13 @@
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
import asyncio
import importlib.util
import json
import os
import re
import threading
import time
from collections import OrderedDict
from pathlib import Path
from typing import Any
from loguru import logger
@@ -17,8 +18,6 @@ from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import FeishuConfig, FeishuInstanceConfig
import importlib.util
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
# Message type display mapping
@@ -246,6 +245,10 @@ class FeishuChannel(BaseChannel):
name = "feishu"
display_name = "Feishu"
@classmethod
def default_config(cls) -> dict[str, object]:
return FeishuConfig().model_dump(by_alias=True)
def __init__(self, config: FeishuConfig | FeishuInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: FeishuConfig | FeishuInstanceConfig = config
@@ -314,8 +317,8 @@ class FeishuChannel(BaseChannel):
# instead of the already-running main asyncio loop, which would cause
# "This event loop is already running" errors.
def run_ws():
import time
import lark_oapi.ws.client as _lark_ws_client
ws_loop = asyncio.new_event_loop()
asyncio.set_event_loop(ws_loop)
# Patch the module-level loop used by lark's ws Client.start()
@@ -375,7 +378,12 @@ class FeishuChannel(BaseChannel):
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
from lark_oapi.api.im.v1 import (
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
)
try:
request = CreateMessageReactionRequest.builder() \
.message_id(message_id) \

View File

@@ -219,6 +219,10 @@ class MochatChannel(BaseChannel):
name = "mochat"
display_name = "Mochat"
@classmethod
def default_config(cls) -> dict[str, object]:
return MochatConfig().model_dump(by_alias=True)
def __init__(self, config: MochatConfig | MochatInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: MochatConfig | MochatInstanceConfig = config

View File

@@ -56,6 +56,10 @@ class QQChannel(BaseChannel):
name = "qq"
display_name = "QQ"
@classmethod
def default_config(cls) -> dict[str, object]:
return QQConfig().model_dump(by_alias=True)
def __init__(self, config: QQConfig | QQInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: QQConfig | QQInstanceConfig = config
@@ -75,8 +79,8 @@ class QQChannel(BaseChannel):
return
self._running = True
BotClass = _make_bot_class(self)
self._client = BotClass()
bot_class = _make_bot_class(self)
self._client = bot_class()
logger.info("QQ bot started (C2C & Group supported)")
await self._run_bot()

View File

@@ -2,7 +2,6 @@
import asyncio
import re
from typing import Any
from loguru import logger
from slack_sdk.socket_mode.request import SocketModeRequest
@@ -23,6 +22,10 @@ class SlackChannel(BaseChannel):
name = "slack"
display_name = "Slack"
@classmethod
def default_config(cls) -> dict[str, object]:
return SlackConfig().model_dump(by_alias=True)
def __init__(self, config: SlackConfig | SlackInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: SlackConfig | SlackInstanceConfig = config

View File

@@ -166,6 +166,10 @@ class TelegramChannel(BaseChannel):
COMMAND_NAMES = ("start", "new", "lang", "persona", "skill", "stop", "help", "restart")
@classmethod
def default_config(cls) -> dict[str, object]:
return TelegramConfig().model_dump(by_alias=True)
def __init__(self, config: TelegramConfig | TelegramInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: TelegramConfig | TelegramInstanceConfig = config

View File

@@ -38,6 +38,10 @@ class WecomChannel(BaseChannel):
name = "wecom"
display_name = "WeCom"
@classmethod
def default_config(cls) -> dict[str, object]:
return WecomConfig().model_dump(by_alias=True)
def __init__(self, config: WecomConfig | WecomInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: WecomConfig | WecomInstanceConfig = config

View File

@@ -24,6 +24,10 @@ class WhatsAppChannel(BaseChannel):
name = "whatsapp"
display_name = "WhatsApp"
@classmethod
def default_config(cls) -> dict[str, object]:
return WhatsAppConfig().model_dump(by_alias=True)
def __init__(self, config: WhatsAppConfig | WhatsAppInstanceConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: WhatsAppConfig | WhatsAppInstanceConfig = config

View File

@@ -1,11 +1,11 @@
"""CLI commands for nanobot."""
import asyncio
from contextlib import contextmanager, nullcontext
import os
import select
import signal
import sys
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Any
@@ -21,12 +21,11 @@ if sys.platform == "win32":
pass
import typer
from prompt_toolkit import print_formatted_text
from prompt_toolkit import PromptSession
from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.application import run_in_terminal
from prompt_toolkit.formatted_text import ANSI, HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.application import run_in_terminal
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
@@ -337,6 +336,30 @@ def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
return merged
def _resolve_channel_default_config(channel_cls: Any) -> dict[str, Any] | None:
"""Return a channel's default config if it exposes a valid onboarding payload."""
from loguru import logger
default_config = getattr(channel_cls, "default_config", None)
if not callable(default_config):
return None
try:
payload = default_config()
except Exception as exc:
logger.warning("Skipping channel default_config for {}: {}", channel_cls, exc)
return None
if payload is None:
return None
if not isinstance(payload, dict):
logger.warning(
"Skipping channel default_config for {}: expected dict, got {}",
channel_cls,
type(payload).__name__,
)
return None
return payload
def _onboard_plugins(config_path: Path) -> None:
"""Inject default config for all discovered channels (built-in + plugins)."""
import json
@@ -352,13 +375,13 @@ def _onboard_plugins(config_path: Path) -> None:
channels = data.setdefault("channels", {})
for name, cls in all_channels.items():
default_config = getattr(cls, "default_config", None)
if not callable(default_config):
payload = _resolve_channel_default_config(cls)
if payload is None:
continue
if name not in channels:
channels[name] = default_config()
channels[name] = payload
else:
channels[name] = _merge_missing_defaults(channels[name], default_config())
channels[name] = _merge_missing_defaults(channels[name], payload)
with open(config_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
@@ -366,9 +389,9 @@ def _onboard_plugins(config_path: Path) -> None:
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)

View File

@@ -23,3 +23,7 @@ def test_is_allowed_requires_exact_match() -> None:
assert channel.is_allowed("allow@email.com") is True
assert channel.is_allowed("attacker|allow@email.com") is False
def test_default_config_returns_none_by_default() -> None:
assert _DummyChannel.default_config() is None

View File

@@ -0,0 +1,9 @@
from nanobot.channels.registry import discover_channel_names, load_channel_class
def test_builtin_channels_expose_default_config_dicts() -> None:
for module_name in sorted(discover_channel_names()):
channel_cls = load_channel_class(module_name)
payload = channel_cls.default_config()
assert isinstance(payload, dict), module_name
assert "enabled" in payload, module_name

View File

@@ -1,9 +1,10 @@
import json
from types import SimpleNamespace
import pytest
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.cli.commands import _resolve_channel_default_config, app
from nanobot.config.loader import load_config, save_config
runner = CliRunner()
@@ -130,3 +131,66 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["qq"]["msgFormat"] == "plain"
@pytest.mark.parametrize(
("channel_cls", "expected"),
[
(SimpleNamespace(), None),
(SimpleNamespace(default_config="invalid"), None),
(SimpleNamespace(default_config=lambda: None), None),
(SimpleNamespace(default_config=lambda: ["invalid"]), None),
(SimpleNamespace(default_config=lambda: {"enabled": False}), {"enabled": False}),
],
)
def test_resolve_channel_default_config_validates_payload(channel_cls, expected) -> None:
assert _resolve_channel_default_config(channel_cls) == expected
def test_resolve_channel_default_config_skips_exceptions() -> None:
def _raise() -> dict[str, object]:
raise RuntimeError("boom")
assert _resolve_channel_default_config(SimpleNamespace(default_config=_raise)) is None
def test_onboard_refresh_skips_invalid_channel_default_configs(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(json.dumps({"channels": {}}), encoding="utf-8")
def _raise() -> dict[str, object]:
raise RuntimeError("boom")
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {
"missing": SimpleNamespace(),
"noncallable": SimpleNamespace(default_config="invalid"),
"none": SimpleNamespace(default_config=lambda: None),
"wrong_type": SimpleNamespace(default_config=lambda: ["invalid"]),
"raises": SimpleNamespace(default_config=_raise),
"qq": SimpleNamespace(
default_config=lambda: {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
"msgFormat": "plain",
}
),
},
)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert "missing" not in saved["channels"]
assert "noncallable" not in saved["channels"]
assert "none" not in saved["channels"]
assert "wrong_type" not in saved["channels"]
assert "raises" not in saved["channels"]
assert saved["channels"]["qq"]["msgFormat"] == "plain"