style: unify code formatting and import order

- Remove trailing whitespace and normalize blank lines
- Unify string quotes and line breaks for long lines
- Sort imports alphabetically across modules
This commit is contained in:
JK_Lu
2026-02-28 20:55:43 +08:00
parent bfc2fa88f3
commit 977ca725f2
33 changed files with 574 additions and 581 deletions

View File

@@ -21,7 +21,7 @@ class LLMResponse:
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
@property
def has_tool_calls(self) -> bool:
"""Check if response contains tool calls."""
@@ -35,7 +35,7 @@ class LLMProvider(ABC):
Implementations should handle the specifics of each provider's API
while maintaining a consistent interface.
"""
def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key
self.api_base = api_base
@@ -79,7 +79,7 @@ class LLMProvider(ABC):
result.append(msg)
return result
@abstractmethod
async def chat(
self,
@@ -103,7 +103,7 @@ class LLMProvider(ABC):
LLMResponse with content and/or tool calls.
"""
pass
@abstractmethod
def get_default_model(self) -> str:
"""Get the default model for this provider."""

View File

@@ -1,19 +1,17 @@
"""LiteLLM provider implementation for multi-provider support."""
import json
import json_repair
import os
import secrets
import string
from typing import Any
import json_repair
import litellm
from litellm import acompletion
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway
# Standard OpenAI chat-completion message keys plus reasoning_content for
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
@@ -32,10 +30,10 @@ class LiteLLMProvider(LLMProvider):
a unified interface. Provider-specific logic is driven by the registry
(see providers/registry.py) — no if-elif chains needed here.
"""
def __init__(
self,
api_key: str | None = None,
self,
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None,
@@ -44,24 +42,24 @@ class LiteLLMProvider(LLMProvider):
super().__init__(api_key, api_base)
self.default_model = default_model
self.extra_headers = extra_headers or {}
# Detect gateway / local deployment.
# provider_name (from config key) is the primary signal;
# api_key / api_base are fallback for auto-detection.
self._gateway = find_gateway(provider_name, api_key, api_base)
# Configure environment variables
if api_key:
self._setup_env(api_key, api_base, default_model)
if api_base:
litellm.api_base = api_base
# Disable LiteLLM logging noise
litellm.suppress_debug_info = True
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
litellm.drop_params = True
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
"""Set environment variables based on detected provider."""
spec = self._gateway or find_by_model(model)
@@ -85,7 +83,7 @@ class LiteLLMProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key)
resolved = resolved.replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
def _resolve_model(self, model: str) -> str:
"""Resolve model name by applying provider/gateway prefixes."""
if self._gateway:
@@ -96,7 +94,7 @@ class LiteLLMProvider(LLMProvider):
if prefix and not model.startswith(f"{prefix}/"):
model = f"{prefix}/{model}"
return model
# Standard mode: auto-prefix for known providers
spec = find_by_model(model)
if spec and spec.litellm_prefix:
@@ -115,7 +113,7 @@ class LiteLLMProvider(LLMProvider):
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"
def _supports_cache_control(self, model: str) -> bool:
"""Return True when the provider supports cache_control on content blocks."""
if self._gateway is not None:
@@ -158,7 +156,7 @@ class LiteLLMProvider(LLMProvider):
if pattern in model_lower:
kwargs.update(overrides)
return
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
@@ -181,14 +179,14 @@ class LiteLLMProvider(LLMProvider):
) -> LLMResponse:
"""
Send a chat completion request via LiteLLM.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
Returns:
LLMResponse with content and/or tool calls.
"""
@@ -201,33 +199,33 @@ class LiteLLMProvider(LLMProvider):
# Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1".
max_tokens = max(1, max_tokens)
kwargs: dict[str, Any] = {
"model": model,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
"max_tokens": max_tokens,
"temperature": temperature,
}
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
self._apply_model_overrides(model, kwargs)
# Pass api_key directly — more reliable than env vars alone
if self.api_key:
kwargs["api_key"] = self.api_key
# Pass api_base for custom endpoints
if self.api_base:
kwargs["api_base"] = self.api_base
# Pass extra headers (e.g. APP-Code for AiHubMix)
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
try:
response = await acompletion(**kwargs)
return self._parse_response(response)
@@ -237,12 +235,12 @@ class LiteLLMProvider(LLMProvider):
content=f"Error calling LLM: {str(e)}",
finish_reason="error",
)
def _parse_response(self, response: Any) -> LLMResponse:
"""Parse LiteLLM response into our standard format."""
choice = response.choices[0]
message = choice.message
tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls:
for tc in message.tool_calls:
@@ -250,13 +248,13 @@ class LiteLLMProvider(LLMProvider):
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
))
usage = {}
if hasattr(response, "usage") and response.usage:
usage = {
@@ -264,9 +262,9 @@ class LiteLLMProvider(LLMProvider):
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
reasoning_content = getattr(message, "reasoning_content", None) or None
return LLMResponse(
content=message.content,
tool_calls=tool_calls,
@@ -274,7 +272,7 @@ class LiteLLMProvider(LLMProvider):
usage=usage,
reasoning_content=reasoning_content,
)
def get_default_model(self) -> str:
"""Get the default model."""
return self.default_model

View File

@@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
import httpx
from loguru import logger
from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"

View File

@@ -2,7 +2,6 @@
import os
from pathlib import Path
from typing import Any
import httpx
from loguru import logger
@@ -11,33 +10,33 @@ from loguru import logger
class GroqTranscriptionProvider:
"""
Voice transcription provider using Groq's Whisper API.
Groq offers extremely fast transcription with a generous free tier.
"""
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
async def transcribe(self, file_path: str | Path) -> str:
"""
Transcribe an audio file using Groq.
Args:
file_path: Path to the audio file.
Returns:
Transcribed text.
"""
if not self.api_key:
logger.warning("Groq API key not configured for transcription")
return ""
path = Path(file_path)
if not path.exists():
logger.error("Audio file not found: {}", file_path)
return ""
try:
async with httpx.AsyncClient() as client:
with open(path, "rb") as f:
@@ -48,18 +47,18 @@ class GroqTranscriptionProvider:
headers = {
"Authorization": f"Bearer {self.api_key}",
}
response = await client.post(
self.api_url,
headers=headers,
files=files,
timeout=60.0
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except Exception as e:
logger.error("Groq transcription error: {}", e)
return ""