diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 7d8f5c8..b31f306 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -165,12 +165,15 @@ class Config(BaseSettings): self, model: str | None = None ) -> tuple["ProviderConfig | None", str | None]: """Match provider config and its registry name. Returns (config, spec_name).""" - from nanobot.providers.registry import PROVIDERS + from nanobot.providers.registry import PROVIDERS, find_by_name forced = self.agents.defaults.provider if forced != "auto": - p = getattr(self.providers, forced, None) - return (p, forced) if p else (None, None) + spec = find_by_name(forced) + if spec: + p = getattr(self.providers, spec.name, None) + return (p, spec.name) if p else (None, None) + return None, None model_lower = (model or self.agents.defaults.model).lower() model_normalized = model_lower.replace("-", "_") diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 9cc430b..10e0fec 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -15,6 +15,8 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Any +from pydantic.alias_generators import to_snake + @dataclass(frozen=True) class ProviderSpec: @@ -545,7 +547,8 @@ def find_gateway( def find_by_name(name: str) -> ProviderSpec | None: """Find a provider spec by config field name, e.g. "dashscope".""" + normalized = to_snake(name.replace("-", "_")) for spec in PROVIDERS: - if spec.name == name: + if spec.name == normalized: return spec return None diff --git a/tests/test_commands.py b/tests/test_commands.py index 68cc429..4e79fc7 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,7 +11,7 @@ from nanobot.cli.commands import _make_provider, app from nanobot.config.schema import Config from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import _strip_model_prefix -from nanobot.providers.registry import find_by_model +from nanobot.providers.registry import find_by_model, find_by_name runner = CliRunner() @@ -240,6 +240,34 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): assert config.get_api_base() == "http://localhost:11434" +def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan(): + config = Config.model_validate( + { + "agents": { + "defaults": { + "provider": "volcengineCodingPlan", + "model": "doubao-1-5-pro", + } + }, + "providers": { + "volcengineCodingPlan": { + "apiKey": "test-key", + } + }, + } + ) + + assert config.get_provider_name() == "volcengine_coding_plan" + assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3" + + +def test_find_by_name_accepts_camel_case_and_hyphen_aliases(): + assert find_by_name("volcengineCodingPlan") is not None + assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan" + assert find_by_name("github-copilot") is not None + assert find_by_name("github-copilot").name == "github_copilot" + + def test_config_auto_detects_ollama_from_local_api_base(): config = Config.model_validate( {