Merge remote-tracking branch 'origin/main' into pr-1109

Resolve conflict in context.py: keep main's build_messages which already
merges runtime context into user message (achieving the same cache goal).
The real value-add from this PR is the second cache breakpoint in
litellm_provider.py.

Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-22 06:14:18 +00:00
119 changed files with 17320 additions and 2647 deletions

View File

@@ -1,7 +1,30 @@
"""LLM provider abstraction module."""
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from __future__ import annotations
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
from importlib import import_module
from typing import TYPE_CHECKING
from nanobot.providers.base import LLMProvider, LLMResponse
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
_LAZY_IMPORTS = {
"LiteLLMProvider": ".litellm_provider",
"OpenAICodexProvider": ".openai_codex_provider",
"AzureOpenAIProvider": ".azure_openai_provider",
}
if TYPE_CHECKING:
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
def __getattr__(name: str):
"""Lazily expose provider implementations without importing all backends up front."""
module_name = _LAZY_IMPORTS.get(name)
if module_name is None:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module = import_module(module_name, __name__)
return getattr(module, name)

View File

@@ -0,0 +1,213 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
from __future__ import annotations
import uuid
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
"""
def __init__(
self,
api_key: str = "",
api_base: str = "",
default_model: str = "gpt-5.2-chat",
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
if tools:
payload["tools"] = tools
payload["tool_choice"] = tool_choice or "auto"
return payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
tool_choice=tool_choice,
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@@ -1,9 +1,13 @@
"""Base LLM provider interface."""
import asyncio
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from loguru import logger
@dataclass
class ToolCallRequest:
@@ -11,6 +15,24 @@ class ToolCallRequest:
id: str
name: str
arguments: dict[str, Any]
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
def to_openai_tool_call(self) -> dict[str, Any]:
"""Serialize to an OpenAI-style tool_call payload."""
tool_call = {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
if self.provider_specific_fields:
tool_call["provider_specific_fields"] = self.provider_specific_fields
if self.function_provider_specific_fields:
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
return tool_call
@dataclass
@@ -21,6 +43,7 @@ class LLMResponse:
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property
def has_tool_calls(self) -> bool:
@@ -28,6 +51,21 @@ class LLMResponse:
return len(self.tool_calls) > 0
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation parameters for LLM calls.
Stored on the provider so every call site inherits the same defaults
without having to pass temperature / max_tokens / reasoning_effort
through every layer. Individual call sites can still override by
passing explicit keyword arguments to chat() / chat_with_retry().
"""
temperature: float = 0.7
max_tokens: int = 4096
reasoning_effort: str | None = None
class LLMProvider(ABC):
"""
Abstract base class for LLM providers.
@@ -35,18 +73,33 @@ class LLMProvider(ABC):
Implementations should handle the specifics of each provider's API
while maintaining a consistent interface.
"""
_CHAT_RETRY_DELAYS = (1, 2, 4)
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
"500",
"502",
"503",
"504",
"overloaded",
"timeout",
"timed out",
"connection",
"server error",
"temporarily unavailable",
)
_SENTINEL = object()
def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key
self.api_base = api_base
self.generation: GenerationSettings = GenerationSettings()
@staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Replace empty text content that causes provider 400 errors.
Empty content can appear when MCP tools return nothing. Most providers
reject empty-string content or empty text blocks in list content.
"""
"""Sanitize message content: fix empty blocks, strip internal _meta fields."""
result: list[dict[str, Any]] = []
for msg in messages:
content = msg.get("content")
@@ -58,18 +111,25 @@ class LLMProvider(ABC):
continue
if isinstance(content, list):
filtered = [
item for item in content
if not (
new_items: list[Any] = []
changed = False
for item in content:
if (
isinstance(item, dict)
and item.get("type") in ("text", "input_text", "output_text")
and not item.get("text")
)
]
if len(filtered) != len(content):
):
changed = True
continue
if isinstance(item, dict) and "_meta" in item:
new_items.append({k: v for k, v in item.items() if k != "_meta"})
changed = True
else:
new_items.append(item)
if changed:
clean = dict(msg)
if filtered:
clean["content"] = filtered
if new_items:
clean["content"] = new_items
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
clean["content"] = None
else:
@@ -77,9 +137,29 @@ class LLMProvider(ABC):
result.append(clean)
continue
if isinstance(content, dict):
clean = dict(msg)
clean["content"] = [content]
result.append(clean)
continue
result.append(msg)
return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod
async def chat(
self,
@@ -88,6 +168,8 @@ class LLMProvider(ABC):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request.
@@ -98,12 +180,100 @@ class LLMProvider(ABC):
model: Model identifier (provider-specific).
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns:
LLMResponse with content and/or tool calls.
"""
pass
@classmethod
def _is_transient_error(cls, content: str | None) -> bool:
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
@staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
found = False
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
new_content = []
for b in content:
if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "")
placeholder = f"[image: {path}]" if path else "[image omitted]"
new_content.append({"type": "text", "text": placeholder})
found = True
else:
new_content.append(b)
result.append({**msg, "content": new_content})
else:
result.append(msg)
return result if found else None
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
"""Call chat() and convert unexpected exceptions to error responses."""
try:
return await self.chat(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures.
Parameters default to ``self.generation`` when not explicitly passed,
so callers no longer need to thread temperature / max_tokens /
reasoning_effort through every layer.
"""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
response = await self._safe_chat(**kw)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat(**{**kw, "messages": stripped})
return response
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
return await self._safe_chat(**kw)
@abstractmethod
def get_default_model(self) -> str:
"""Get the default model for this provider."""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import uuid
from typing import Any
import json_repair
@@ -12,27 +13,58 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class CustomProvider(LLMProvider):
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
def __init__(
self,
api_key: str = "no-key",
api_base: str = "http://localhost:8000/v1",
default_model: str = "default",
extra_headers: dict[str, str] | None = None,
):
super().__init__(api_key, api_base)
self.default_model = default_model
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
# Keep affinity stable for this provider instance to improve backend cache locality,
# while still letting users attach provider-specific headers for custom gateways.
default_headers = {
"x-session-affinity": uuid.uuid4().hex,
**(extra_headers or {}),
}
self._client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
default_headers=default_headers,
)
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"messages": self._sanitize_empty_content(messages),
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
if tools:
kwargs.update(tools=tools, tool_choice="auto")
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
try:
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
# JSONDecodeError.doc / APIError.response.text may carry the raw body
# (e.g. "unsupported model: xxx") which is far more useful than the
# generic "Expecting value …" message. Truncate to avoid huge HTML pages.
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
if body and body.strip():
return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error")
return LLMResponse(content=f"Error: {e}", finish_reason="error")
def _parse(self, response: Any) -> LLMResponse:
if not response.choices:
return LLMResponse(
content="Error: API returned empty choices. This may indicate a temporary service issue or an invalid model response.",
finish_reason="error"
)
choice = response.choices[0]
msg = choice.message
tool_calls = [

View File

@@ -1,22 +1,22 @@
"""LiteLLM provider implementation for multi-provider support."""
import json
import json_repair
import hashlib
import os
import secrets
import string
from typing import Any
import json_repair
import litellm
from litellm import acompletion
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway
# Standard OpenAI chat-completion message keys plus reasoning_content for
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
# Standard chat-completion message keys.
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
_ALNUM = string.ascii_letters + string.digits
def _short_tool_id() -> str:
@@ -32,10 +32,10 @@ class LiteLLMProvider(LLMProvider):
a unified interface. Provider-specific logic is driven by the registry
(see providers/registry.py) — no if-elif chains needed here.
"""
def __init__(
self,
api_key: str | None = None,
self,
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None,
@@ -44,24 +44,26 @@ class LiteLLMProvider(LLMProvider):
super().__init__(api_key, api_base)
self.default_model = default_model
self.extra_headers = extra_headers or {}
# Detect gateway / local deployment.
# provider_name (from config key) is the primary signal;
# api_key / api_base are fallback for auto-detection.
self._gateway = find_gateway(provider_name, api_key, api_base)
# Configure environment variables
if api_key:
self._setup_env(api_key, api_base, default_model)
if api_base:
litellm.api_base = api_base
# Disable LiteLLM logging noise
litellm.suppress_debug_info = True
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
litellm.drop_params = True
self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY"))
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
"""Set environment variables based on detected provider."""
spec = self._gateway or find_by_model(model)
@@ -85,18 +87,17 @@ class LiteLLMProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key)
resolved = resolved.replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
def _resolve_model(self, model: str) -> str:
"""Resolve model name by applying provider/gateway prefixes."""
if self._gateway:
# Gateway mode: apply gateway prefix, skip provider-specific prefixes
prefix = self._gateway.litellm_prefix
if self._gateway.strip_model_prefix:
model = model.split("/")[-1]
if prefix and not model.startswith(f"{prefix}/"):
if prefix:
model = f"{prefix}/{model}"
return model
# Standard mode: auto-prefix for known providers
spec = find_by_model(model)
if spec and spec.litellm_prefix:
@@ -115,7 +116,7 @@ class LiteLLMProvider(LLMProvider):
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"
def _supports_cache_control(self, model: str) -> bool:
"""Return True when the provider supports cache_control on content blocks."""
if self._gateway is not None:
@@ -174,17 +175,52 @@ class LiteLLMProvider(LLMProvider):
if pattern in model_lower:
kwargs.update(overrides)
return
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
"""Return provider-specific extra keys to preserve in request messages."""
spec = find_by_model(original_model) or find_by_model(resolved_model)
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
return _ANTHROPIC_EXTRA_KEYS
return frozenset()
@staticmethod
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
if not isinstance(tool_call_id, str):
return tool_call_id
if len(tool_call_id) == 9 and tool_call_id.isalnum():
return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in _ALLOWED_MSG_KEYS}
# Strict providers require "content" even when assistant only has tool_calls
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
id_map: dict[str, str] = {}
def map_id(value: Any) -> Any:
if not isinstance(value, str):
return value
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for clean in sanitized:
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list):
normalized_tool_calls = []
for tc in clean["tool_calls"]:
if not isinstance(tc, dict):
normalized_tool_calls.append(tc)
continue
tc_clean = dict(tc)
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized_tool_calls.append(tc_clean)
clean["tool_calls"] = normalized_tool_calls
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
async def chat(
@@ -194,22 +230,25 @@ class LiteLLMProvider(LLMProvider):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request via LiteLLM.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
Returns:
LLMResponse with content and/or tool calls.
"""
original_model = model or self.default_model
model = self._resolve_model(original_model)
extra_msg_keys = self._extra_msg_keys(original_model, model)
if self._supports_cache_control(original_model):
messages, tools = self._apply_cache_control(messages, tools)
@@ -217,33 +256,43 @@ class LiteLLMProvider(LLMProvider):
# Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1".
max_tokens = max(1, max_tokens)
kwargs: dict[str, Any] = {
"model": model,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
"max_tokens": max_tokens,
"temperature": temperature,
}
if self._gateway:
kwargs.update(self._gateway.litellm_kwargs)
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
self._apply_model_overrides(model, kwargs)
if self._langsmith_enabled:
kwargs.setdefault("callbacks", []).append("langsmith")
# Pass api_key directly — more reliable than env vars alone
if self.api_key:
kwargs["api_key"] = self.api_key
# Pass api_base for custom endpoints
if self.api_base:
kwargs["api_base"] = self.api_base
# Pass extra headers (e.g. APP-Code for AiHubMix)
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
kwargs["drop_params"] = True
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
kwargs["tool_choice"] = tool_choice or "auto"
try:
response = await acompletion(**kwargs)
return self._parse_response(response)
@@ -253,26 +302,50 @@ class LiteLLMProvider(LLMProvider):
content=f"Error calling LLM: {str(e)}",
finish_reason="error",
)
def _parse_response(self, response: Any) -> LLMResponse:
"""Parse LiteLLM response into our standard format."""
choice = response.choices[0]
message = choice.message
content = message.content
finish_reason = choice.finish_reason
# Some providers (e.g. GitHub Copilot) split content and tool_calls
# across multiple choices. Merge them so tool_calls are not lost.
raw_tool_calls = []
for ch in response.choices:
msg = ch.message
if hasattr(msg, "tool_calls") and msg.tool_calls:
raw_tool_calls.extend(msg.tool_calls)
if ch.finish_reason in ("tool_calls", "stop"):
finish_reason = ch.finish_reason
if not content and msg.content:
content = msg.content
if len(response.choices) > 1:
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
len(response.choices), len(raw_tool_calls))
tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls:
for tc in message.tool_calls:
# Parse arguments from JSON string if needed
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
))
for tc in raw_tool_calls:
# Parse arguments from JSON string if needed
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
function_provider_specific_fields = (
getattr(tc.function, "provider_specific_fields", None) or None
)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
provider_specific_fields=provider_specific_fields,
function_provider_specific_fields=function_provider_specific_fields,
))
usage = {}
if hasattr(response, "usage") and response.usage:
usage = {
@@ -280,17 +353,19 @@ class LiteLLMProvider(LLMProvider):
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
reasoning_content = getattr(message, "reasoning_content", None) or None
thinking_blocks = getattr(message, "thinking_blocks", None) or None
return LLMResponse(
content=message.content,
content=content,
tool_calls=tool_calls,
finish_reason=choice.finish_reason or "stop",
finish_reason=finish_reason or "stop",
usage=usage,
reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
)
def get_default_model(self) -> str:
"""Get the default model."""
return self.default_model

View File

@@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
import httpx
from loguru import logger
from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
@@ -31,6 +31,8 @@ class OpenAICodexProvider(LLMProvider):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
@@ -47,10 +49,13 @@ class OpenAICodexProvider(LLMProvider):
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"prompt_cache_key": _prompt_cache_key(messages),
"tool_choice": "auto",
"tool_choice": tool_choice or "auto",
"parallel_tool_calls": True,
}
if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort}
if tools:
body["tools"] = _convert_tools(tools)

View File

@@ -12,7 +12,7 @@ Every entry writes out all fields so you can copy-paste as a template.
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
@@ -26,33 +26,34 @@ class ProviderSpec:
"""
# identity
name: str # config field name, e.g. "dashscope"
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
display_name: str = "" # shown in `nanobot status`
name: str # config field name, e.g. "dashscope"
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
display_name: str = "" # shown in `nanobot status`
# model prefixing
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
env_extras: tuple[tuple[str, str], ...] = ()
# gateway / local detection
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
is_local: bool = False # local deployment (vLLM, Ollama)
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
detect_by_base_keyword: str = "" # match substring in api_base URL
default_api_base: str = "" # fallback base URL
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
is_local: bool = False # local deployment (vLLM, Ollama)
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
detect_by_base_keyword: str = "" # match substring in api_base URL
default_api_base: str = "" # fallback base URL
# gateway behavior
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
litellm_kwargs: dict[str, Any] = field(default_factory=dict) # extra kwargs passed to LiteLLM
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
is_oauth: bool = False # if True, uses OAuth flow instead of API key
is_oauth: bool = False # if True, uses OAuth flow instead of API key
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
is_direct: bool = False
@@ -70,7 +71,6 @@ class ProviderSpec:
# ---------------------------------------------------------------------------
PROVIDERS: tuple[ProviderSpec, ...] = (
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
ProviderSpec(
name="custom",
@@ -81,16 +81,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
is_direct=True,
),
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
ProviderSpec(
name="azure_openai",
keywords=("azure", "azure-openai"),
env_key="",
display_name="Azure OpenAI",
litellm_prefix="",
is_direct=True,
),
# === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
litellm_prefix="openrouter", # anthropic/claude-3 → openrouter/anthropic/claude-3
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -102,16 +110,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
# AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
env_key="OPENAI_API_KEY", # OpenAI-compatible
env_key="OPENAI_API_KEY", # OpenAI-compatible
display_name="AiHubMix",
litellm_prefix="openai", # → openai/{model}
litellm_prefix="openai", # → openai/{model}
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -119,10 +126,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
detect_by_key_prefix="",
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
model_overrides=(),
),
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
ProviderSpec(
name="siliconflow",
@@ -141,7 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
),
# VolcEngine (火山引擎): OpenAI-compatible gateway
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
@@ -159,8 +165,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) ===============
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
ProviderSpec(
name="volcengine_coding_plan",
keywords=("volcengine-plan",),
env_key="OPENAI_API_KEY",
display_name="VolcEngine Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus: VolcEngine international, pay-per-use models
ProviderSpec(
name="byteplus",
keywords=("byteplus",),
env_key="OPENAI_API_KEY",
display_name="BytePlus",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="bytepluses",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus Coding Plan: same key as byteplus
ProviderSpec(
name="byteplus_coding_plan",
keywords=("byteplus-plan",),
env_key="OPENAI_API_KEY",
display_name="BytePlus Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec(
name="anthropic",
@@ -179,7 +239,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec(
name="openai",
@@ -197,14 +256,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# OpenAI Codex: uses OAuth, not API key.
ProviderSpec(
name="openai_codex",
keywords=("openai-codex",),
env_key="", # OAuth-based, no API key
env_key="", # OAuth-based, no API key
display_name="OpenAI Codex",
litellm_prefix="", # Not routed through LiteLLM
litellm_prefix="", # Not routed through LiteLLM
skip_prefixes=(),
env_extras=(),
is_gateway=False,
@@ -214,16 +272,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://chatgpt.com/backend-api",
strip_model_prefix=False,
model_overrides=(),
is_oauth=True, # OAuth-based authentication
is_oauth=True, # OAuth-based authentication
),
# Github Copilot: uses OAuth, not API key.
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
env_key="", # OAuth-based, no API key
env_key="", # OAuth-based, no API key
display_name="Github Copilot",
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
skip_prefixes=("github_copilot/",),
env_extras=(),
is_gateway=False,
@@ -233,17 +290,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
is_oauth=True, # OAuth-based authentication
is_oauth=True, # OAuth-based authentication
),
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek",
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
skip_prefixes=("deepseek/",), # avoid double-prefix
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
skip_prefixes=("deepseek/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -253,15 +309,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
skip_prefixes=("gemini/",), # avoid double-prefix
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
skip_prefixes=("gemini/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -271,7 +326,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway.
@@ -280,11 +334,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
litellm_prefix="zai", # glm-4 → zai/glm-4
litellm_prefix="zai", # glm-4 → zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=(
("ZHIPUAI_API_KEY", "{api_key}"),
),
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
@@ -293,14 +345,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
skip_prefixes=("dashscope/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -311,7 +362,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0.
@@ -320,22 +370,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"),
env_extras=(
("MOONSHOT_API_BASE", "{api_base}"),
),
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False,
model_overrides=(
("kimi-k2.5", {"temperature": 1.0}),
),
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
),
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec(
@@ -343,7 +388,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
skip_prefixes=("minimax/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -354,9 +399,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
# === Local deployment (matched by config key, NOT by api_base) =========
# vLLM / any OpenAI-compatible local server.
# Detected when config key is "vllm" (provider_name="vllm").
ProviderSpec(
@@ -364,20 +407,35 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local",
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="", # user must provide in config
default_api_base="", # user must provide in config
strip_model_prefix=False,
model_overrides=(),
),
# === Ollama (local, OpenAI-compatible) ===================================
ProviderSpec(
name="ollama",
keywords=("ollama", "nemotron"),
env_key="OLLAMA_API_KEY",
display_name="Ollama",
litellm_prefix="ollama_chat", # model → ollama_chat/model
skip_prefixes=("ollama/", "ollama_chat/"),
env_extras=(),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="11434",
default_api_base="http://localhost:11434",
strip_model_prefix=False,
model_overrides=(),
),
# === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
ProviderSpec(
@@ -385,8 +443,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
skip_prefixes=("groq/",), # avoid double-prefix
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
skip_prefixes=("groq/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -403,6 +461,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# Lookup helpers
# ---------------------------------------------------------------------------
def find_by_model(model: str) -> ProviderSpec | None:
"""Match a standard provider by model-name keyword (case-insensitive).
Skips gateways/local — those are matched by api_key/api_base instead."""
@@ -418,7 +477,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
return spec
for spec in std_specs:
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
if any(
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
):
return spec
return None

View File

@@ -2,7 +2,6 @@
import os
from pathlib import Path
from typing import Any
import httpx
from loguru import logger
@@ -11,33 +10,33 @@ from loguru import logger
class GroqTranscriptionProvider:
"""
Voice transcription provider using Groq's Whisper API.
Groq offers extremely fast transcription with a generous free tier.
"""
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
async def transcribe(self, file_path: str | Path) -> str:
"""
Transcribe an audio file using Groq.
Args:
file_path: Path to the audio file.
Returns:
Transcribed text.
"""
if not self.api_key:
logger.warning("Groq API key not configured for transcription")
return ""
path = Path(file_path)
if not path.exists():
logger.error("Audio file not found: {}", file_path)
return ""
try:
async with httpx.AsyncClient() as client:
with open(path, "rb") as f:
@@ -48,18 +47,18 @@ class GroqTranscriptionProvider:
headers = {
"Authorization": f"Bearer {self.api_key}",
}
response = await client.post(
self.api_url,
headers=headers,
files=files,
timeout=60.0
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except Exception as e:
logger.error("Groq transcription error: {}", e)
return ""