feat(agent): add streaming groundwork for future TUI
Preserve the provider and agent-loop streaming primitives plus the CLI experiment scaffolding so this work can be resumed later without blocking urgent bug fixes on main. Made-with: Cursor
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -223,6 +224,90 @@ class LLMProvider(ABC):
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
|
||||
|
||||
Returns the same ``LLMResponse`` as :meth:`chat`. The default
|
||||
implementation falls back to a non-streaming call and delivers the
|
||||
full content as a single delta. Providers that support native
|
||||
streaming should override this method.
|
||||
"""
|
||||
response = await self.chat(
|
||||
messages=messages, tools=tools, model=model,
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
)
|
||||
if on_content_delta and response.content:
|
||||
await on_content_delta(response.content)
|
||||
return response
|
||||
|
||||
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||
"""Call chat_stream() and convert unexpected exceptions to error responses."""
|
||||
try:
|
||||
return await self.chat_stream(**kwargs)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||
|
||||
async def chat_stream_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = _SENTINEL,
|
||||
temperature: object = _SENTINEL,
|
||||
reasoning_effort: object = _SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat_stream() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
temperature = self.generation.temperature
|
||||
if reasoning_effort is self._SENTINEL:
|
||||
reasoning_effort = self.generation.reasoning_effort
|
||||
|
||||
kw: dict[str, Any] = dict(
|
||||
messages=messages, tools=tools, model=model,
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
response = await self._safe_chat_stream(**kw)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
stripped = self._strip_image_content(messages)
|
||||
if stripped is not None:
|
||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
||||
return response
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
return await self._safe_chat_stream(**kw)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -4,6 +4,7 @@ import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
@@ -223,6 +224,64 @@ class LiteLLMProvider(LLMProvider):
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
|
||||
def _build_chat_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> tuple[dict[str, Any], str]:
|
||||
"""Build the kwargs dict for ``acompletion``.
|
||||
|
||||
Returns ``(kwargs, original_model)`` so callers can reuse the
|
||||
original model string for downstream logic.
|
||||
"""
|
||||
original_model = model or self.default_model
|
||||
resolved = self._resolve_model(original_model)
|
||||
extra_msg_keys = self._extra_msg_keys(original_model, resolved)
|
||||
|
||||
if self._supports_cache_control(original_model):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": resolved,
|
||||
"messages": self._sanitize_messages(
|
||||
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
|
||||
),
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if self._gateway:
|
||||
kwargs.update(self._gateway.litellm_kwargs)
|
||||
|
||||
self._apply_model_overrides(resolved, kwargs)
|
||||
|
||||
if self._langsmith_enabled:
|
||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
kwargs["drop_params"] = True
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return kwargs, original_model
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -233,71 +292,54 @@ class LiteLLMProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request via LiteLLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
original_model = model or self.default_model
|
||||
model = self._resolve_model(original_model)
|
||||
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
||||
|
||||
if self._supports_cache_control(original_model):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if self._gateway:
|
||||
kwargs.update(self._gateway.litellm_kwargs)
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
if self._langsmith_enabled:
|
||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
||||
|
||||
# Pass api_key directly — more reliable than env vars alone
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
# Pass api_base for custom endpoints
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
kwargs["drop_params"] = True
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
"""Send a chat completion request via LiteLLM."""
|
||||
kwargs, _ = self._build_chat_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
# Return error as content for graceful handling
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
|
||||
kwargs, _ = self._build_chat_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
|
||||
try:
|
||||
stream = await acompletion(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
if on_content_delta:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
text = getattr(delta, "content", None) if delta else None
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
|
||||
full_response = litellm.stream_chunk_builder(
|
||||
chunks, messages=kwargs["messages"],
|
||||
)
|
||||
return self._parse_response(full_response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
|
||||
Reference in New Issue
Block a user