diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 87a27a2..19c956d 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -7,6 +7,7 @@ import select import signal import sys from pathlib import Path +from typing import Any # Force UTF-8 encoding for Windows console if sys.platform == "win32": @@ -281,14 +282,15 @@ def onboard(): save_config(config) console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)") else: - save_config(Config()) + config = Config() + save_config(config) console.print(f"[green]✓[/green] Created config at {config_path}") - console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]") - # Create workspace - workspace = get_workspace_path() + _onboard_plugins(config_path) + # Create workspace, preferring the configured workspace path. + workspace = get_workspace_path(config.workspace_path) if not workspace.exists(): workspace.mkdir(parents=True, exist_ok=True) console.print(f"[green]✓[/green] Created workspace at {workspace}") @@ -303,7 +305,45 @@ def onboard(): console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]") +def _merge_missing_defaults(existing: Any, defaults: Any) -> Any: + """Recursively fill in missing values from defaults without overwriting user config.""" + if not isinstance(existing, dict) or not isinstance(defaults, dict): + return existing + merged = dict(existing) + for key, value in defaults.items(): + if key not in merged: + merged[key] = value + else: + merged[key] = _merge_missing_defaults(merged[key], value) + return merged + + +def _onboard_plugins(config_path: Path) -> None: + """Inject default config for all discovered channels (built-in + plugins).""" + import json + + from nanobot.channels.registry import discover_all + + all_channels = discover_all() + if not all_channels: + return + + with open(config_path, encoding="utf-8") as f: + data = json.load(f) + + channels = data.setdefault("channels", {}) + for name, cls in all_channels.items(): + default_config = getattr(cls, "default_config", None) + if not callable(default_config): + continue + if name not in channels: + channels[name] = default_config() + else: + channels[name] = _merge_missing_defaults(channels[name], default_config()) + + with open(config_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) def _make_provider(config: Config): @@ -326,6 +366,7 @@ def _make_provider(config: Config): api_key=p.api_key if p else "no-key", api_base=config.get_api_base(model) or "http://localhost:8000/v1", default_model=model, + extra_headers=p.extra_headers if p else None, ) # Azure OpenAI: direct Azure OpenAI endpoint with deployment name elif provider_name == "azure_openai": diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py index f16c69b..e177e55 100644 --- a/nanobot/providers/custom_provider.py +++ b/nanobot/providers/custom_provider.py @@ -13,14 +13,25 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest class CustomProvider(LLMProvider): - def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"): + def __init__( + self, + api_key: str = "no-key", + api_base: str = "http://localhost:8000/v1", + default_model: str = "default", + extra_headers: dict[str, str] | None = None, + ): super().__init__(api_key, api_base) self.default_model = default_model - # Keep affinity stable for this provider instance to improve backend cache locality. + # Keep affinity stable for this provider instance to improve backend cache locality, + # while still letting users attach provider-specific headers for custom gateways. + default_headers = { + "x-session-affinity": uuid.uuid4().hex, + **(extra_headers or {}), + } self._client = AsyncOpenAI( api_key=api_key, base_url=api_base, - default_headers={"x-session-affinity": uuid.uuid4().hex}, + default_headers=default_headers, ) async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, diff --git a/tests/test_commands.py b/tests/test_commands.py index 8decf3e..ec4463d 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from typer.testing import CliRunner -from nanobot.cli.commands import app +from nanobot.cli.commands import _make_provider, app from nanobot.config.schema import Config from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import _strip_model_prefix @@ -38,7 +38,7 @@ def mock_paths(): mock_ws.return_value = workspace_dir mock_sc.side_effect = lambda config: config_file.write_text("{}") - yield config_file, workspace_dir + yield config_file, workspace_dir, mock_ws if base_dir.exists(): shutil.rmtree(base_dir) @@ -46,7 +46,7 @@ def mock_paths(): def test_onboard_fresh_install(mock_paths): """No existing config — should create from scratch.""" - config_file, workspace_dir = mock_paths + config_file, workspace_dir, mock_ws = mock_paths result = runner.invoke(app, ["onboard"]) @@ -57,11 +57,13 @@ def test_onboard_fresh_install(mock_paths): assert config_file.exists() assert (workspace_dir / "AGENTS.md").exists() assert (workspace_dir / "memory" / "MEMORY.md").exists() + expected_workspace = Config().workspace_path + assert mock_ws.call_args.args == (expected_workspace,) def test_onboard_existing_config_refresh(mock_paths): """Config exists, user declines overwrite — should refresh (load-merge-save).""" - config_file, workspace_dir = mock_paths + config_file, workspace_dir, _ = mock_paths config_file.write_text('{"existing": true}') result = runner.invoke(app, ["onboard"], input="n\n") @@ -75,7 +77,7 @@ def test_onboard_existing_config_refresh(mock_paths): def test_onboard_existing_config_overwrite(mock_paths): """Config exists, user confirms overwrite — should reset to defaults.""" - config_file, workspace_dir = mock_paths + config_file, workspace_dir, _ = mock_paths config_file.write_text('{"existing": true}') result = runner.invoke(app, ["onboard"], input="y\n") @@ -88,7 +90,7 @@ def test_onboard_existing_config_overwrite(mock_paths): def test_onboard_existing_workspace_safe_create(mock_paths): """Workspace exists — should not recreate, but still add missing templates.""" - config_file, workspace_dir = mock_paths + config_file, workspace_dir, _ = mock_paths workspace_dir.mkdir(parents=True) config_file.write_text("{}") @@ -192,6 +194,33 @@ def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex" +def test_make_provider_passes_extra_headers_to_custom_provider(): + config = Config.model_validate( + { + "agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}}, + "providers": { + "custom": { + "apiKey": "test-key", + "apiBase": "https://example.com/v1", + "extraHeaders": { + "APP-Code": "demo-app", + "x-session-affinity": "sticky-session", + }, + } + }, + } + ) + + with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai: + _make_provider(config) + + kwargs = mock_async_openai.call_args.kwargs + assert kwargs["api_key"] == "test-key" + assert kwargs["base_url"] == "https://example.com/v1" + assert kwargs["default_headers"]["APP-Code"] == "demo-app" + assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session" + + @pytest.fixture def mock_agent_runtime(tmp_path): """Mock agent command dependencies for focused CLI tests.""" diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py index 62e601e..2a446b7 100644 --- a/tests/test_config_migration.py +++ b/tests/test_config_migration.py @@ -1,4 +1,5 @@ import json +from types import SimpleNamespace from typer.testing import CliRunner @@ -75,7 +76,7 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) ) monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) - monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace) + monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) result = runner.invoke(app, ["onboard"], input="n\n") @@ -86,3 +87,46 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) assert defaults["maxTokens"] == 3333 assert defaults["contextWindowTokens"] == 65_536 assert "memoryWindow" not in defaults + + +def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None: + config_path = tmp_path / "config.json" + workspace = tmp_path / "workspace" + config_path.write_text( + json.dumps( + { + "channels": { + "qq": { + "enabled": False, + "appId": "", + "secret": "", + "allowFrom": [], + } + } + } + ), + encoding="utf-8", + ) + + monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) + monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) + monkeypatch.setattr( + "nanobot.channels.registry.discover_all", + lambda: { + "qq": SimpleNamespace( + default_config=lambda: { + "enabled": False, + "appId": "", + "secret": "", + "allowFrom": [], + "msgFormat": "plain", + } + ) + }, + ) + + result = runner.invoke(app, ["onboard"], input="n\n") + + assert result.exit_code == 0 + saved = json.loads(config_path.read_text(encoding="utf-8")) + assert saved["channels"]["qq"]["msgFormat"] == "plain"