"""Base LLM provider interface.""" import asyncio import json from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any from loguru import logger @dataclass class ToolCallRequest: """A tool call request from the LLM.""" id: str name: str arguments: dict[str, Any] provider_specific_fields: dict[str, Any] | None = None function_provider_specific_fields: dict[str, Any] | None = None def to_openai_tool_call(self) -> dict[str, Any]: """Serialize to an OpenAI-style tool_call payload.""" tool_call = { "id": self.id, "type": "function", "function": { "name": self.name, "arguments": json.dumps(self.arguments, ensure_ascii=False), }, } if self.provider_specific_fields: tool_call["provider_specific_fields"] = self.provider_specific_fields if self.function_provider_specific_fields: tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields return tool_call @dataclass class LLMResponse: """Response from an LLM provider.""" content: str | None tool_calls: list[ToolCallRequest] = field(default_factory=list) finish_reason: str = "stop" usage: dict[str, int] = field(default_factory=dict) reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. thinking_blocks: list[dict] | None = None # Anthropic extended thinking @property def has_tool_calls(self) -> bool: """Check if response contains tool calls.""" return len(self.tool_calls) > 0 @dataclass(frozen=True) class GenerationSettings: """Default generation parameters for LLM calls. Stored on the provider so every call site inherits the same defaults without having to pass temperature / max_tokens / reasoning_effort through every layer. Individual call sites can still override by passing explicit keyword arguments to chat() / chat_with_retry(). """ temperature: float = 0.7 max_tokens: int = 4096 reasoning_effort: str | None = None class LLMProvider(ABC): """ Abstract base class for LLM providers. Implementations should handle the specifics of each provider's API while maintaining a consistent interface. """ _CHAT_RETRY_DELAYS = (1, 2, 4) _TRANSIENT_ERROR_MARKERS = ( "429", "rate limit", "500", "502", "503", "504", "overloaded", "timeout", "timed out", "connection", "server error", "temporarily unavailable", ) _SENTINEL = object() def __init__(self, api_key: str | None = None, api_base: str | None = None): self.api_key = api_key self.api_base = api_base self.generation: GenerationSettings = GenerationSettings() @staticmethod def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Sanitize message content: fix empty blocks, strip internal _meta fields.""" result: list[dict[str, Any]] = [] for msg in messages: content = msg.get("content") if isinstance(content, str) and not content: clean = dict(msg) clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)" result.append(clean) continue if isinstance(content, list): new_items: list[Any] = [] changed = False for item in content: if ( isinstance(item, dict) and item.get("type") in ("text", "input_text", "output_text") and not item.get("text") ): changed = True continue if isinstance(item, dict) and "_meta" in item: new_items.append({k: v for k, v in item.items() if k != "_meta"}) changed = True else: new_items.append(item) if changed: clean = dict(msg) if new_items: clean["content"] = new_items elif msg.get("role") == "assistant" and msg.get("tool_calls"): clean["content"] = None else: clean["content"] = "(empty)" result.append(clean) continue if isinstance(content, dict): clean = dict(msg) clean["content"] = [content] result.append(clean) continue result.append(msg) return result @staticmethod def _sanitize_request_messages( messages: list[dict[str, Any]], allowed_keys: frozenset[str], ) -> list[dict[str, Any]]: """Keep only provider-safe message keys and normalize assistant content.""" sanitized = [] for msg in messages: clean = {k: v for k, v in msg.items() if k in allowed_keys} if clean.get("role") == "assistant" and "content" not in clean: clean["content"] = None sanitized.append(clean) return sanitized @abstractmethod async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: """ Send a chat completion request. Args: messages: List of message dicts with 'role' and 'content'. tools: Optional list of tool definitions. model: Model identifier (provider-specific). max_tokens: Maximum tokens in response. temperature: Sampling temperature. tool_choice: Tool selection strategy ("auto", "required", or specific tool dict). Returns: LLMResponse with content and/or tool calls. """ pass @classmethod def _is_transient_error(cls, content: str | None) -> bool: err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) @staticmethod def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: """Replace image_url blocks with text placeholder. Returns None if no images found.""" found = False result = [] for msg in messages: content = msg.get("content") if isinstance(content, list): new_content = [] for b in content: if isinstance(b, dict) and b.get("type") == "image_url": path = (b.get("_meta") or {}).get("path", "") placeholder = f"[image: {path}]" if path else "[image omitted]" new_content.append({"type": "text", "text": placeholder}) found = True else: new_content.append(b) result.append({**msg, "content": new_content}) else: result.append(msg) return result if found else None async def _safe_chat(self, **kwargs: Any) -> LLMResponse: """Call chat() and convert unexpected exceptions to error responses.""" try: return await self.chat(**kwargs) except asyncio.CancelledError: raise except Exception as exc: return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") async def chat_with_retry( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: object = _SENTINEL, temperature: object = _SENTINEL, reasoning_effort: object = _SENTINEL, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: """Call chat() with retry on transient provider failures. Parameters default to ``self.generation`` when not explicitly passed, so callers no longer need to thread temperature / max_tokens / reasoning_effort through every layer. """ if max_tokens is self._SENTINEL: max_tokens = self.generation.max_tokens if temperature is self._SENTINEL: temperature = self.generation.temperature if reasoning_effort is self._SENTINEL: reasoning_effort = self.generation.reasoning_effort kw: dict[str, Any] = dict( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, reasoning_effort=reasoning_effort, tool_choice=tool_choice, ) for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): response = await self._safe_chat(**kw) if response.finish_reason != "error": return response if not self._is_transient_error(response.content): stripped = self._strip_image_content(messages) if stripped is not None: logger.warning("Non-transient LLM error with image content, retrying without images") return await self._safe_chat(**{**kw, "messages": stripped}) return response logger.warning( "LLM transient error (attempt {}/{}), retrying in {}s: {}", attempt, len(self._CHAT_RETRY_DELAYS), delay, (response.content or "")[:120].lower(), ) await asyncio.sleep(delay) return await self._safe_chat(**kw) @abstractmethod def get_default_model(self) -> str: """Get the default model for this provider.""" pass