refactor(channels): formalize default config onboarding
This commit is contained in:
@@ -24,6 +24,11 @@ class BaseChannel(ABC):
|
||||
display_name: str = "Base"
|
||||
transcription_api_key: str = ""
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any] | None:
|
||||
"""Return the default config payload for onboarding, if the channel provides one."""
|
||||
return None
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
Initialize the channel.
|
||||
|
||||
@@ -162,6 +162,10 @@ class DingTalkChannel(BaseChannel):
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return DingTalkConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: DingTalkConfig | DingTalkInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DingTalkConfig | DingTalkInstanceConfig = config
|
||||
@@ -262,9 +266,12 @@ class DingTalkChannel(BaseChannel):
|
||||
|
||||
def _guess_upload_type(self, media_ref: str) -> str:
|
||||
ext = Path(urlparse(media_ref).path).suffix.lower()
|
||||
if ext in self._IMAGE_EXTS: return "image"
|
||||
if ext in self._AUDIO_EXTS: return "voice"
|
||||
if ext in self._VIDEO_EXTS: return "video"
|
||||
if ext in self._IMAGE_EXTS:
|
||||
return "image"
|
||||
if ext in self._AUDIO_EXTS:
|
||||
return "voice"
|
||||
if ext in self._VIDEO_EXTS:
|
||||
return "video"
|
||||
return "file"
|
||||
|
||||
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
||||
@@ -385,8 +392,10 @@ class DingTalkChannel(BaseChannel):
|
||||
if resp.status_code != 200:
|
||||
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||
return False
|
||||
try: result = resp.json()
|
||||
except Exception: result = {}
|
||||
try:
|
||||
result = resp.json()
|
||||
except Exception:
|
||||
result = {}
|
||||
errcode = result.get("errcode")
|
||||
if errcode not in (None, 0):
|
||||
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||
|
||||
@@ -27,6 +27,10 @@ class DiscordChannel(BaseChannel):
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return DiscordConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: DiscordConfig | DiscordInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig | DiscordInstanceConfig = config
|
||||
|
||||
@@ -51,6 +51,10 @@ class EmailChannel(BaseChannel):
|
||||
"Dec",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return EmailConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: EmailConfig | EmailInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: EmailConfig | EmailInstanceConfig = config
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@@ -17,8 +18,6 @@ from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import FeishuConfig, FeishuInstanceConfig
|
||||
|
||||
import importlib.util
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
# Message type display mapping
|
||||
@@ -246,6 +245,10 @@ class FeishuChannel(BaseChannel):
|
||||
name = "feishu"
|
||||
display_name = "Feishu"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return FeishuConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: FeishuConfig | FeishuInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: FeishuConfig | FeishuInstanceConfig = config
|
||||
@@ -314,8 +317,8 @@ class FeishuChannel(BaseChannel):
|
||||
# instead of the already-running main asyncio loop, which would cause
|
||||
# "This event loop is already running" errors.
|
||||
def run_ws():
|
||||
import time
|
||||
import lark_oapi.ws.client as _lark_ws_client
|
||||
|
||||
ws_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(ws_loop)
|
||||
# Patch the module-level loop used by lark's ws Client.start()
|
||||
@@ -375,7 +378,12 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
Emoji,
|
||||
)
|
||||
|
||||
try:
|
||||
request = CreateMessageReactionRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
|
||||
@@ -219,6 +219,10 @@ class MochatChannel(BaseChannel):
|
||||
name = "mochat"
|
||||
display_name = "Mochat"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return MochatConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: MochatConfig | MochatInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: MochatConfig | MochatInstanceConfig = config
|
||||
|
||||
@@ -56,6 +56,10 @@ class QQChannel(BaseChannel):
|
||||
name = "qq"
|
||||
display_name = "QQ"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return QQConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: QQConfig | QQInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: QQConfig | QQInstanceConfig = config
|
||||
@@ -75,8 +79,8 @@ class QQChannel(BaseChannel):
|
||||
return
|
||||
|
||||
self._running = True
|
||||
BotClass = _make_bot_class(self)
|
||||
self._client = BotClass()
|
||||
bot_class = _make_bot_class(self)
|
||||
self._client = bot_class()
|
||||
logger.info("QQ bot started (C2C & Group supported)")
|
||||
await self._run_bot()
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@@ -23,6 +22,10 @@ class SlackChannel(BaseChannel):
|
||||
name = "slack"
|
||||
display_name = "Slack"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return SlackConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: SlackConfig | SlackInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: SlackConfig | SlackInstanceConfig = config
|
||||
|
||||
@@ -166,6 +166,10 @@ class TelegramChannel(BaseChannel):
|
||||
|
||||
COMMAND_NAMES = ("start", "new", "lang", "persona", "skill", "stop", "help", "restart")
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return TelegramConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: TelegramConfig | TelegramInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig | TelegramInstanceConfig = config
|
||||
|
||||
@@ -38,6 +38,10 @@ class WecomChannel(BaseChannel):
|
||||
name = "wecom"
|
||||
display_name = "WeCom"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return WecomConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: WecomConfig | WecomInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WecomConfig | WecomInstanceConfig = config
|
||||
|
||||
@@ -24,6 +24,10 @@ class WhatsAppChannel(BaseChannel):
|
||||
name = "whatsapp"
|
||||
display_name = "WhatsApp"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, object]:
|
||||
return WhatsAppConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: WhatsAppConfig | WhatsAppInstanceConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WhatsAppConfig | WhatsAppInstanceConfig = config
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -21,12 +21,11 @@ if sys.platform == "win32":
|
||||
pass
|
||||
|
||||
import typer
|
||||
from prompt_toolkit import print_formatted_text
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit import PromptSession, print_formatted_text
|
||||
from prompt_toolkit.application import run_in_terminal
|
||||
from prompt_toolkit.formatted_text import ANSI, HTML
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
from prompt_toolkit.application import run_in_terminal
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.table import Table
|
||||
@@ -337,6 +336,30 @@ def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
|
||||
return merged
|
||||
|
||||
|
||||
def _resolve_channel_default_config(channel_cls: Any) -> dict[str, Any] | None:
|
||||
"""Return a channel's default config if it exposes a valid onboarding payload."""
|
||||
from loguru import logger
|
||||
|
||||
default_config = getattr(channel_cls, "default_config", None)
|
||||
if not callable(default_config):
|
||||
return None
|
||||
try:
|
||||
payload = default_config()
|
||||
except Exception as exc:
|
||||
logger.warning("Skipping channel default_config for {}: {}", channel_cls, exc)
|
||||
return None
|
||||
if payload is None:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning(
|
||||
"Skipping channel default_config for {}: expected dict, got {}",
|
||||
channel_cls,
|
||||
type(payload).__name__,
|
||||
)
|
||||
return None
|
||||
return payload
|
||||
|
||||
|
||||
def _onboard_plugins(config_path: Path) -> None:
|
||||
"""Inject default config for all discovered channels (built-in + plugins)."""
|
||||
import json
|
||||
@@ -352,13 +375,13 @@ def _onboard_plugins(config_path: Path) -> None:
|
||||
|
||||
channels = data.setdefault("channels", {})
|
||||
for name, cls in all_channels.items():
|
||||
default_config = getattr(cls, "default_config", None)
|
||||
if not callable(default_config):
|
||||
payload = _resolve_channel_default_config(cls)
|
||||
if payload is None:
|
||||
continue
|
||||
if name not in channels:
|
||||
channels[name] = default_config()
|
||||
channels[name] = payload
|
||||
else:
|
||||
channels[name] = _merge_missing_defaults(channels[name], default_config())
|
||||
channels[name] = _merge_missing_defaults(channels[name], payload)
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
@@ -366,9 +389,9 @@ def _onboard_plugins(config_path: Path) -> None:
|
||||
|
||||
def _make_provider(config: Config):
|
||||
"""Create the appropriate LLM provider from config."""
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
|
||||
@@ -23,3 +23,7 @@ def test_is_allowed_requires_exact_match() -> None:
|
||||
|
||||
assert channel.is_allowed("allow@email.com") is True
|
||||
assert channel.is_allowed("attacker|allow@email.com") is False
|
||||
|
||||
|
||||
def test_default_config_returns_none_by_default() -> None:
|
||||
assert _DummyChannel.default_config() is None
|
||||
|
||||
9
tests/test_channel_default_config.py
Normal file
9
tests/test_channel_default_config.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||
|
||||
|
||||
def test_builtin_channels_expose_default_config_dicts() -> None:
|
||||
for module_name in sorted(discover_channel_names()):
|
||||
channel_cls = load_channel_class(module_name)
|
||||
payload = channel_cls.default_config()
|
||||
assert isinstance(payload, dict), module_name
|
||||
assert "enabled" in payload, module_name
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.cli.commands import _resolve_channel_default_config, app
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
|
||||
runner = CliRunner()
|
||||
@@ -130,3 +131,66 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
|
||||
assert result.exit_code == 0
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["channels"]["qq"]["msgFormat"] == "plain"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("channel_cls", "expected"),
|
||||
[
|
||||
(SimpleNamespace(), None),
|
||||
(SimpleNamespace(default_config="invalid"), None),
|
||||
(SimpleNamespace(default_config=lambda: None), None),
|
||||
(SimpleNamespace(default_config=lambda: ["invalid"]), None),
|
||||
(SimpleNamespace(default_config=lambda: {"enabled": False}), {"enabled": False}),
|
||||
],
|
||||
)
|
||||
def test_resolve_channel_default_config_validates_payload(channel_cls, expected) -> None:
|
||||
assert _resolve_channel_default_config(channel_cls) == expected
|
||||
|
||||
|
||||
def test_resolve_channel_default_config_skips_exceptions() -> None:
|
||||
def _raise() -> dict[str, object]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
assert _resolve_channel_default_config(SimpleNamespace(default_config=_raise)) is None
|
||||
|
||||
|
||||
def test_onboard_refresh_skips_invalid_channel_default_configs(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(json.dumps({"channels": {}}), encoding="utf-8")
|
||||
|
||||
def _raise() -> dict[str, object]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {
|
||||
"missing": SimpleNamespace(),
|
||||
"noncallable": SimpleNamespace(default_config="invalid"),
|
||||
"none": SimpleNamespace(default_config=lambda: None),
|
||||
"wrong_type": SimpleNamespace(default_config=lambda: ["invalid"]),
|
||||
"raises": SimpleNamespace(default_config=_raise),
|
||||
"qq": SimpleNamespace(
|
||||
default_config=lambda: {
|
||||
"enabled": False,
|
||||
"appId": "",
|
||||
"secret": "",
|
||||
"allowFrom": [],
|
||||
"msgFormat": "plain",
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert "missing" not in saved["channels"]
|
||||
assert "noncallable" not in saved["channels"]
|
||||
assert "none" not in saved["channels"]
|
||||
assert "wrong_type" not in saved["channels"]
|
||||
assert "raises" not in saved["channels"]
|
||||
assert saved["channels"]["qq"]["msgFormat"] == "plain"
|
||||
|
||||
Reference in New Issue
Block a user