fix(provider): accept canonical and alias provider names consistently

This commit is contained in:
Xubin Ren
2026-03-24 03:03:59 +00:00
parent d454386f32
commit 14763a6ad1
3 changed files with 39 additions and 5 deletions

View File

@@ -165,12 +165,15 @@ class Config(BaseSettings):
self, model: str | None = None self, model: str | None = None
) -> tuple["ProviderConfig | None", str | None]: ) -> tuple["ProviderConfig | None", str | None]:
"""Match provider config and its registry name. Returns (config, spec_name).""" """Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS from nanobot.providers.registry import PROVIDERS, find_by_name
forced = self.agents.defaults.provider forced = self.agents.defaults.provider
if forced != "auto": if forced != "auto":
p = getattr(self.providers, forced, None) spec = find_by_name(forced)
return (p, forced) if p else (None, None) if spec:
p = getattr(self.providers, spec.name, None)
return (p, spec.name) if p else (None, None)
return None, None
model_lower = (model or self.agents.defaults.model).lower() model_lower = (model or self.agents.defaults.model).lower()
model_normalized = model_lower.replace("-", "_") model_normalized = model_lower.replace("-", "_")

View File

@@ -15,6 +15,8 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from pydantic.alias_generators import to_snake
@dataclass(frozen=True) @dataclass(frozen=True)
class ProviderSpec: class ProviderSpec:
@@ -545,7 +547,8 @@ def find_gateway(
def find_by_name(name: str) -> ProviderSpec | None: def find_by_name(name: str) -> ProviderSpec | None:
"""Find a provider spec by config field name, e.g. "dashscope".""" """Find a provider spec by config field name, e.g. "dashscope"."""
normalized = to_snake(name.replace("-", "_"))
for spec in PROVIDERS: for spec in PROVIDERS:
if spec.name == name: if spec.name == normalized:
return spec return spec
return None return None

View File

@@ -11,7 +11,7 @@ from nanobot.cli.commands import _make_provider, app
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model from nanobot.providers.registry import find_by_model, find_by_name
runner = CliRunner() runner = CliRunner()
@@ -240,6 +240,34 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
assert config.get_api_base() == "http://localhost:11434" assert config.get_api_base() == "http://localhost:11434"
def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "volcengineCodingPlan",
"model": "doubao-1-5-pro",
}
},
"providers": {
"volcengineCodingPlan": {
"apiKey": "test-key",
}
},
}
)
assert config.get_provider_name() == "volcengine_coding_plan"
assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3"
def test_find_by_name_accepts_camel_case_and_hyphen_aliases():
assert find_by_name("volcengineCodingPlan") is not None
assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan"
assert find_by_name("github-copilot") is not None
assert find_by_name("github-copilot").name == "github_copilot"
def test_config_auto_detects_ollama_from_local_api_base(): def test_config_auto_detects_ollama_from_local_api_base():
config = Config.model_validate( config = Config.model_validate(
{ {