fix(provider): accept canonical and alias provider names consistently
This commit is contained in:
@@ -165,12 +165,15 @@ class Config(BaseSettings):
|
|||||||
self, model: str | None = None
|
self, model: str | None = None
|
||||||
) -> tuple["ProviderConfig | None", str | None]:
|
) -> tuple["ProviderConfig | None", str | None]:
|
||||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||||
|
|
||||||
forced = self.agents.defaults.provider
|
forced = self.agents.defaults.provider
|
||||||
if forced != "auto":
|
if forced != "auto":
|
||||||
p = getattr(self.providers, forced, None)
|
spec = find_by_name(forced)
|
||||||
return (p, forced) if p else (None, None)
|
if spec:
|
||||||
|
p = getattr(self.providers, spec.name, None)
|
||||||
|
return (p, spec.name) if p else (None, None)
|
||||||
|
return None, None
|
||||||
|
|
||||||
model_lower = (model or self.agents.defaults.model).lower()
|
model_lower = (model or self.agents.defaults.model).lower()
|
||||||
model_normalized = model_lower.replace("-", "_")
|
model_normalized = model_lower.replace("-", "_")
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic.alias_generators import to_snake
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ProviderSpec:
|
class ProviderSpec:
|
||||||
@@ -545,7 +547,8 @@ def find_gateway(
|
|||||||
|
|
||||||
def find_by_name(name: str) -> ProviderSpec | None:
|
def find_by_name(name: str) -> ProviderSpec | None:
|
||||||
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
||||||
|
normalized = to_snake(name.replace("-", "_"))
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
if spec.name == name:
|
if spec.name == normalized:
|
||||||
return spec
|
return spec
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from nanobot.cli.commands import _make_provider, app
|
|||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||||
from nanobot.providers.registry import find_by_model
|
from nanobot.providers.registry import find_by_model, find_by_name
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
|
|
||||||
@@ -240,6 +240,34 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
|||||||
assert config.get_api_base() == "http://localhost:11434"
|
assert config.get_api_base() == "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
|
||||||
|
config = Config.model_validate(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "volcengineCodingPlan",
|
||||||
|
"model": "doubao-1-5-pro",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"providers": {
|
||||||
|
"volcengineCodingPlan": {
|
||||||
|
"apiKey": "test-key",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "volcengine_coding_plan"
|
||||||
|
assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_by_name_accepts_camel_case_and_hyphen_aliases():
|
||||||
|
assert find_by_name("volcengineCodingPlan") is not None
|
||||||
|
assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan"
|
||||||
|
assert find_by_name("github-copilot") is not None
|
||||||
|
assert find_by_name("github-copilot").name == "github_copilot"
|
||||||
|
|
||||||
|
|
||||||
def test_config_auto_detects_ollama_from_local_api_base():
|
def test_config_auto_detects_ollama_from_local_api_base():
|
||||||
config = Config.model_validate(
|
config = Config.model_validate(
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user