style: unify code formatting and import order
- Remove trailing whitespace and normalize blank lines - Unify string quotes and line breaks for long lines - Sort imports alphabetically across modules
This commit is contained in:
@@ -21,7 +21,7 @@ class LLMResponse:
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
@@ -35,7 +35,7 @@ class LLMProvider(ABC):
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
@@ -79,7 +79,7 @@ class LLMProvider(ABC):
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
@@ -103,7 +103,7 @@ class LLMProvider(ABC):
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
import json
|
||||
import json_repair
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.registry import find_by_model, find_gateway
|
||||
|
||||
|
||||
# Standard OpenAI chat-completion message keys plus reasoning_content for
|
||||
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||
@@ -32,10 +30,10 @@ class LiteLLMProvider(LLMProvider):
|
||||
a unified interface. Provider-specific logic is driven by the registry
|
||||
(see providers/registry.py) — no if-elif chains needed here.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
@@ -44,24 +42,24 @@ class LiteLLMProvider(LLMProvider):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
|
||||
# Detect gateway / local deployment.
|
||||
# provider_name (from config key) is the primary signal;
|
||||
# api_key / api_base are fallback for auto-detection.
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
|
||||
|
||||
# Configure environment variables
|
||||
if api_key:
|
||||
self._setup_env(api_key, api_base, default_model)
|
||||
|
||||
|
||||
if api_base:
|
||||
litellm.api_base = api_base
|
||||
|
||||
|
||||
# Disable LiteLLM logging noise
|
||||
litellm.suppress_debug_info = True
|
||||
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
||||
litellm.drop_params = True
|
||||
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||
"""Set environment variables based on detected provider."""
|
||||
spec = self._gateway or find_by_model(model)
|
||||
@@ -85,7 +83,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
resolved = env_val.replace("{api_key}", api_key)
|
||||
resolved = resolved.replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""Resolve model name by applying provider/gateway prefixes."""
|
||||
if self._gateway:
|
||||
@@ -96,7 +94,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
if prefix and not model.startswith(f"{prefix}/"):
|
||||
model = f"{prefix}/{model}"
|
||||
return model
|
||||
|
||||
|
||||
# Standard mode: auto-prefix for known providers
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
@@ -115,7 +113,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
|
||||
|
||||
def _supports_cache_control(self, model: str) -> bool:
|
||||
"""Return True when the provider supports cache_control on content blocks."""
|
||||
if self._gateway is not None:
|
||||
@@ -158,7 +156,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||
@@ -181,14 +179,14 @@ class LiteLLMProvider(LLMProvider):
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request via LiteLLM.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
@@ -201,33 +199,33 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
|
||||
# Pass api_key directly — more reliable than env vars alone
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
|
||||
# Pass api_base for custom endpoints
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
|
||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
@@ -237,12 +235,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
|
||||
def _parse_response(self, response: Any) -> LLMResponse:
|
||||
"""Parse LiteLLM response into our standard format."""
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
|
||||
tool_calls = []
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
@@ -250,13 +248,13 @@ class LiteLLMProvider(LLMProvider):
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
|
||||
usage = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = {
|
||||
@@ -264,9 +262,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||
|
||||
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
tool_calls=tool_calls,
|
||||
@@ -274,7 +272,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
|
||||
@@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
@@ -11,33 +10,33 @@ from loguru import logger
|
||||
class GroqTranscriptionProvider:
|
||||
"""
|
||||
Voice transcription provider using Groq's Whisper API.
|
||||
|
||||
|
||||
Groq offers extremely fast transcription with a generous free tier.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
||||
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
|
||||
|
||||
|
||||
async def transcribe(self, file_path: str | Path) -> str:
|
||||
"""
|
||||
Transcribe an audio file using Groq.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the audio file.
|
||||
|
||||
|
||||
Returns:
|
||||
Transcribed text.
|
||||
"""
|
||||
if not self.api_key:
|
||||
logger.warning("Groq API key not configured for transcription")
|
||||
return ""
|
||||
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.error("Audio file not found: {}", file_path)
|
||||
return ""
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(path, "rb") as f:
|
||||
@@ -48,18 +47,18 @@ class GroqTranscriptionProvider:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
files=files,
|
||||
timeout=60.0
|
||||
)
|
||||
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("text", "")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Groq transcription error: {}", e)
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user