Merge remote-tracking branch 'upstream/main' into feature/codex-oauth
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
@@ -7,6 +8,7 @@ import litellm
|
||||
from litellm import acompletion
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.registry import find_by_model, find_gateway
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
@@ -14,7 +16,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
LLM provider using LiteLLM for multi-provider support.
|
||||
|
||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, and many other providers through
|
||||
a unified interface.
|
||||
a unified interface. Provider-specific logic is driven by the registry
|
||||
(see providers/registry.py) — no if-elif chains needed here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -28,47 +31,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
# Detect OpenRouter by api_key prefix or explicit api_base
|
||||
self.is_openrouter = (
|
||||
(api_key and api_key.startswith("sk-or-")) or
|
||||
(api_base and "openrouter" in api_base)
|
||||
)
|
||||
# Detect gateway / local deployment from api_key and api_base
|
||||
self._gateway = find_gateway(api_key, api_base)
|
||||
|
||||
# Detect AiHubMix by api_base
|
||||
self.is_aihubmix = bool(api_base and "aihubmix" in api_base)
|
||||
# Backwards-compatible flags (used by tests and possibly external code)
|
||||
self.is_openrouter = bool(self._gateway and self._gateway.name == "openrouter")
|
||||
self.is_aihubmix = bool(self._gateway and self._gateway.name == "aihubmix")
|
||||
self.is_vllm = bool(self._gateway and self._gateway.is_local)
|
||||
|
||||
# Track if using custom endpoint (vLLM, etc.)
|
||||
self.is_vllm = bool(api_base) and not self.is_openrouter and not self.is_aihubmix
|
||||
|
||||
# Configure LiteLLM based on provider
|
||||
# Configure environment variables
|
||||
if api_key:
|
||||
if self.is_openrouter:
|
||||
# OpenRouter mode - set key
|
||||
os.environ["OPENROUTER_API_KEY"] = api_key
|
||||
elif self.is_aihubmix:
|
||||
# AiHubMix gateway - OpenAI-compatible
|
||||
os.environ["OPENAI_API_KEY"] = api_key
|
||||
elif self.is_vllm:
|
||||
# vLLM/custom endpoint - uses OpenAI-compatible API
|
||||
os.environ["HOSTED_VLLM_API_KEY"] = api_key
|
||||
elif "deepseek" in default_model:
|
||||
os.environ.setdefault("DEEPSEEK_API_KEY", api_key)
|
||||
elif "anthropic" in default_model:
|
||||
os.environ.setdefault("ANTHROPIC_API_KEY", api_key)
|
||||
elif "openai" in default_model or "gpt" in default_model:
|
||||
os.environ.setdefault("OPENAI_API_KEY", api_key)
|
||||
elif "gemini" in default_model.lower():
|
||||
os.environ.setdefault("GEMINI_API_KEY", api_key)
|
||||
elif "zhipu" in default_model or "glm" in default_model or "zai" in default_model:
|
||||
os.environ.setdefault("ZAI_API_KEY", api_key)
|
||||
os.environ.setdefault("ZHIPUAI_API_KEY", api_key)
|
||||
elif "dashscope" in default_model or "qwen" in default_model.lower():
|
||||
os.environ.setdefault("DASHSCOPE_API_KEY", api_key)
|
||||
elif "groq" in default_model:
|
||||
os.environ.setdefault("GROQ_API_KEY", api_key)
|
||||
elif "moonshot" in default_model or "kimi" in default_model:
|
||||
os.environ.setdefault("MOONSHOT_API_KEY", api_key)
|
||||
os.environ.setdefault("MOONSHOT_API_BASE", api_base or "https://api.moonshot.cn/v1")
|
||||
self._setup_env(api_key, api_base, default_model)
|
||||
|
||||
if api_base:
|
||||
litellm.api_base = api_base
|
||||
@@ -76,6 +49,55 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Disable LiteLLM logging noise
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||
"""Set environment variables based on detected provider."""
|
||||
if self._gateway:
|
||||
# Gateway / local: direct set (not setdefault)
|
||||
os.environ[self._gateway.env_key] = api_key
|
||||
return
|
||||
|
||||
# Standard provider: match by model name
|
||||
spec = find_by_model(model)
|
||||
if spec:
|
||||
os.environ.setdefault(spec.env_key, api_key)
|
||||
# Resolve env_extras placeholders:
|
||||
# {api_key} → user's API key
|
||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_val in spec.env_extras:
|
||||
resolved = env_val.replace("{api_key}", api_key)
|
||||
resolved = resolved.replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""Resolve model name by applying provider/gateway prefixes."""
|
||||
if self._gateway:
|
||||
# Gateway mode: apply gateway prefix, skip provider-specific prefixes
|
||||
prefix = self._gateway.litellm_prefix
|
||||
if self._gateway.strip_model_prefix:
|
||||
model = model.split("/")[-1]
|
||||
if prefix and not model.startswith(f"{prefix}/"):
|
||||
model = f"{prefix}/{model}"
|
||||
return model
|
||||
|
||||
# Standard mode: auto-prefix for known providers
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
if not any(model.startswith(s) for s in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
|
||||
return model
|
||||
|
||||
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
||||
"""Apply model-specific parameter overrides from the registry."""
|
||||
model_lower = model.lower()
|
||||
spec = find_by_model(model)
|
||||
if spec:
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -97,35 +119,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
model = model or self.default_model
|
||||
model = self._resolve_model(model or self.default_model)
|
||||
|
||||
# Auto-prefix model names for known providers
|
||||
# (keywords, target_prefix, skip_if_starts_with)
|
||||
_prefix_rules = [
|
||||
(("glm", "zhipu"), "zai", ("zhipu/", "zai/", "openrouter/", "hosted_vllm/")),
|
||||
(("qwen", "dashscope"), "dashscope", ("dashscope/", "openrouter/")),
|
||||
(("moonshot", "kimi"), "moonshot", ("moonshot/", "openrouter/")),
|
||||
(("gemini",), "gemini", ("gemini/",)),
|
||||
]
|
||||
if not (self.is_vllm or self.is_openrouter or self.is_aihubmix):
|
||||
model_lower = model.lower()
|
||||
for keywords, prefix, skip in _prefix_rules:
|
||||
if any(kw in model_lower for kw in keywords) and not any(model.startswith(s) for s in skip):
|
||||
model = f"{prefix}/{model}"
|
||||
break
|
||||
|
||||
# Gateway/endpoint-specific prefixes (detected by api_base/api_key, not model name)
|
||||
if self.is_openrouter and not model.startswith("openrouter/"):
|
||||
model = f"openrouter/{model}"
|
||||
elif self.is_aihubmix:
|
||||
model = f"openai/{model.split('/')[-1]}"
|
||||
elif self.is_vllm:
|
||||
model = f"hosted_vllm/{model}"
|
||||
|
||||
# kimi-k2.5 only supports temperature=1.0
|
||||
if "kimi-k2.5" in model.lower():
|
||||
temperature = 1.0
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -133,6 +128,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
# Pass api_base directly for custom endpoints (vLLM, etc.)
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
@@ -166,7 +164,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
import json
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
|
||||
323
nanobot/providers/registry.py
Normal file
323
nanobot/providers/registry.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Provider Registry — single source of truth for LLM provider metadata.
|
||||
|
||||
Adding a new provider:
|
||||
1. Add a ProviderSpec to PROVIDERS below.
|
||||
2. Add a field to ProvidersConfig in config/schema.py.
|
||||
Done. Env vars, prefixing, config matching, status display all derive from here.
|
||||
|
||||
Order matters — it controls match priority and fallback. Gateways first.
|
||||
Every entry writes out all fields so you can copy-paste as a template.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
"""One LLM provider's metadata. See PROVIDERS below for real examples.
|
||||
|
||||
Placeholders in env_extras values:
|
||||
{api_key} — the user's API key
|
||||
{api_base} — api_base from config, or this spec's default_api_base
|
||||
"""
|
||||
|
||||
# identity
|
||||
name: str # config field name, e.g. "dashscope"
|
||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
|
||||
# model prefixing
|
||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||
|
||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
|
||||
# gateway / local detection
|
||||
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||
default_api_base: str = "" # fallback base URL
|
||||
|
||||
# gateway behavior
|
||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||
|
||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PROVIDERS — the registry. Order = priority. Copy any entry as template.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
# Gateways can route any model, so they win in fallback.
|
||||
|
||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
|
||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(
|
||||
("ZHIPUAI_API_KEY", "{api_key}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(
|
||||
("MOONSHOT_API_BASE", "{api_base}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(
|
||||
("kimi-k2.5", {"temperature": 1.0}),
|
||||
),
|
||||
),
|
||||
|
||||
# === Local deployment (fallback: unknown api_base → assume local) ======
|
||||
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
# If api_base is set but doesn't match a known gateway, we land here.
|
||||
# Placed before Groq so vLLM wins the fallback when both are configured.
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="", # user must provide in config
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Auxiliary (not a primary LLM provider) ============================
|
||||
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
||||
model_lower = model.lower()
|
||||
for spec in PROVIDERS:
|
||||
if spec.is_gateway or spec.is_local:
|
||||
continue
|
||||
if any(kw in model_lower for kw in spec.keywords):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(api_key: str | None, api_base: str | None) -> ProviderSpec | None:
|
||||
"""Detect gateway/local by api_key prefix or api_base substring.
|
||||
Fallback: unknown api_base → treat as local (vLLM)."""
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
if api_base:
|
||||
return next((s for s in PROVIDERS if s.is_local), None)
|
||||
return None
|
||||
|
||||
|
||||
def find_by_name(name: str) -> ProviderSpec | None:
|
||||
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
||||
for spec in PROVIDERS:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
return None
|
||||
Reference in New Issue
Block a user