feat(agent): add streaming groundwork for future TUI
Preserve the provider and agent-loop streaming primitives plus the CLI experiment scaffolding so this work can be resumed later without blocking urgent bug fixes on main. Made-with: Cursor
This commit is contained in:
@@ -212,8 +212,16 @@ class AgentLoop:
|
|||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
) -> tuple[str | None, list[str], list[dict]]:
|
||||||
"""Run the agent iteration loop."""
|
"""Run the agent iteration loop.
|
||||||
|
|
||||||
|
*on_stream*: called with each content delta during streaming.
|
||||||
|
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||||
|
``resuming=True`` means tool calls follow (spinner should restart);
|
||||||
|
``resuming=False`` means this is the final response.
|
||||||
|
"""
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_content = None
|
final_content = None
|
||||||
@@ -224,11 +232,20 @@ class AgentLoop:
|
|||||||
|
|
||||||
tool_defs = self.tools.get_definitions()
|
tool_defs = self.tools.get_definitions()
|
||||||
|
|
||||||
|
if on_stream:
|
||||||
|
response = await self.provider.chat_stream_with_retry(
|
||||||
|
messages=messages,
|
||||||
|
tools=tool_defs,
|
||||||
|
model=self.model,
|
||||||
|
on_content_delta=on_stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
response = await self.provider.chat_with_retry(
|
response = await self.provider.chat_with_retry(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tool_defs,
|
tools=tool_defs,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
usage = response.usage or {}
|
usage = response.usage or {}
|
||||||
self._last_usage = {
|
self._last_usage = {
|
||||||
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
||||||
@@ -236,7 +253,11 @@ class AgentLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
if on_stream and on_stream_end:
|
||||||
|
await on_stream_end(resuming=True)
|
||||||
|
|
||||||
if on_progress:
|
if on_progress:
|
||||||
|
if not on_stream:
|
||||||
thought = self._strip_think(response.content)
|
thought = self._strip_think(response.content)
|
||||||
if thought:
|
if thought:
|
||||||
await on_progress(thought)
|
await on_progress(thought)
|
||||||
@@ -263,9 +284,10 @@ class AgentLoop:
|
|||||||
messages, tool_call.id, tool_call.name, result
|
messages, tool_call.id, tool_call.name, result
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if on_stream and on_stream_end:
|
||||||
|
await on_stream_end(resuming=False)
|
||||||
|
|
||||||
clean = self._strip_think(response.content)
|
clean = self._strip_think(response.content)
|
||||||
# Don't persist error responses to session history — they can
|
|
||||||
# poison the context and cause permanent 400 loops (#1303).
|
|
||||||
if response.finish_reason == "error":
|
if response.finish_reason == "error":
|
||||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||||
@@ -400,6 +422,8 @@ class AgentLoop:
|
|||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
session_key: str | None = None,
|
session_key: str | None = None,
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""Process a single inbound message and return the response."""
|
"""Process a single inbound message and return the response."""
|
||||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||||
@@ -412,7 +436,6 @@ class AgentLoop:
|
|||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
# Subagent results should be assistant role, other system messages use user role
|
|
||||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
@@ -486,7 +509,10 @@ class AgentLoop:
|
|||||||
))
|
))
|
||||||
|
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
initial_messages, on_progress=on_progress or _bus_progress,
|
initial_messages,
|
||||||
|
on_progress=on_progress or _bus_progress,
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
@@ -501,9 +527,13 @@ class AgentLoop:
|
|||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
|
|
||||||
|
meta = dict(msg.metadata or {})
|
||||||
|
if on_stream is not None:
|
||||||
|
meta["_streamed"] = True
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||||
metadata=msg.metadata or {},
|
metadata=meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -592,8 +622,13 @@ class AgentLoop:
|
|||||||
channel: str = "cli",
|
channel: str = "cli",
|
||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""Process a message directly and return the outbound payload."""
|
"""Process a message directly and return the outbound payload."""
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||||
return await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
return await self._process_message(
|
||||||
|
msg, session_key=session_key, on_progress=on_progress,
|
||||||
|
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||||
|
)
|
||||||
|
|||||||
@@ -207,6 +207,10 @@ class _ThinkingSpinner:
|
|||||||
self._active = False
|
self._active = False
|
||||||
if self._spinner:
|
if self._spinner:
|
||||||
self._spinner.stop()
|
self._spinner.stop()
|
||||||
|
# Force-clear the spinner line: Rich Live's transient cleanup
|
||||||
|
# occasionally loses a race with its own render thread.
|
||||||
|
console.file.write("\033[2K\r")
|
||||||
|
console.file.flush()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -214,6 +218,8 @@ class _ThinkingSpinner:
|
|||||||
"""Temporarily stop spinner while printing progress."""
|
"""Temporarily stop spinner while printing progress."""
|
||||||
if self._spinner and self._active:
|
if self._spinner and self._active:
|
||||||
self._spinner.stop()
|
self._spinner.stop()
|
||||||
|
console.file.write("\033[2K\r")
|
||||||
|
console.file.flush()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@@ -770,16 +776,25 @@ def agent(
|
|||||||
async def run_once():
|
async def run_once():
|
||||||
nonlocal _thinking
|
nonlocal _thinking
|
||||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||||
with _thinking:
|
|
||||||
|
with _thinking or nullcontext():
|
||||||
response = await agent_loop.process_direct(
|
response = await agent_loop.process_direct(
|
||||||
message, session_id, on_progress=_cli_progress,
|
message, session_id,
|
||||||
|
on_progress=_cli_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _thinking:
|
||||||
|
_thinking.__exit__(None, None, None)
|
||||||
_thinking = None
|
_thinking = None
|
||||||
|
|
||||||
|
if response and response.content:
|
||||||
_print_agent_response(
|
_print_agent_response(
|
||||||
response.content if response else "",
|
response.content,
|
||||||
render_markdown=markdown,
|
render_markdown=markdown,
|
||||||
metadata=response.metadata if response else None,
|
metadata=response.metadata,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
console.print()
|
||||||
await agent_loop.close_mcp()
|
await agent_loop.close_mcp()
|
||||||
|
|
||||||
asyncio.run(run_once())
|
asyncio.run(run_once())
|
||||||
@@ -820,6 +835,7 @@ def agent(
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
|
||||||
if msg.metadata.get("_progress"):
|
if msg.metadata.get("_progress"):
|
||||||
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||||
ch = agent_loop.channels_config
|
ch = agent_loop.channels_config
|
||||||
@@ -834,6 +850,7 @@ def agent(
|
|||||||
if msg.content:
|
if msg.content:
|
||||||
turn_response.append((msg.content, dict(msg.metadata or {})))
|
turn_response.append((msg.content, dict(msg.metadata or {})))
|
||||||
turn_done.set()
|
turn_done.set()
|
||||||
|
|
||||||
elif msg.content:
|
elif msg.content:
|
||||||
await _print_interactive_response(
|
await _print_interactive_response(
|
||||||
msg.content,
|
msg.content,
|
||||||
@@ -872,11 +889,7 @@ def agent(
|
|||||||
content=user_input,
|
content=user_input,
|
||||||
))
|
))
|
||||||
|
|
||||||
nonlocal _thinking
|
|
||||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
|
||||||
with _thinking:
|
|
||||||
await turn_done.wait()
|
await turn_done.wait()
|
||||||
_thinking = None
|
|
||||||
|
|
||||||
if turn_response:
|
if turn_response:
|
||||||
content, meta = turn_response[0]
|
content, meta = turn_response[0]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -223,6 +224,90 @@ class LLMProvider(ABC):
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
|
||||||
|
|
||||||
|
Returns the same ``LLMResponse`` as :meth:`chat`. The default
|
||||||
|
implementation falls back to a non-streaming call and delivers the
|
||||||
|
full content as a single delta. Providers that support native
|
||||||
|
streaming should override this method.
|
||||||
|
"""
|
||||||
|
response = await self.chat(
|
||||||
|
messages=messages, tools=tools, model=model,
|
||||||
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
if on_content_delta and response.content:
|
||||||
|
await on_content_delta(response.content)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
"""Call chat_stream() and convert unexpected exceptions to error responses."""
|
||||||
|
try:
|
||||||
|
return await self.chat_stream(**kwargs)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: object = _SENTINEL,
|
||||||
|
temperature: object = _SENTINEL,
|
||||||
|
reasoning_effort: object = _SENTINEL,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Call chat_stream() with retry on transient provider failures."""
|
||||||
|
if max_tokens is self._SENTINEL:
|
||||||
|
max_tokens = self.generation.max_tokens
|
||||||
|
if temperature is self._SENTINEL:
|
||||||
|
temperature = self.generation.temperature
|
||||||
|
if reasoning_effort is self._SENTINEL:
|
||||||
|
reasoning_effort = self.generation.reasoning_effort
|
||||||
|
|
||||||
|
kw: dict[str, Any] = dict(
|
||||||
|
messages=messages, tools=tools, model=model,
|
||||||
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
|
on_content_delta=on_content_delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||||
|
response = await self._safe_chat_stream(**kw)
|
||||||
|
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
return response
|
||||||
|
|
||||||
|
if not self._is_transient_error(response.content):
|
||||||
|
stripped = self._strip_image_content(messages)
|
||||||
|
if stripped is not None:
|
||||||
|
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||||
|
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
||||||
|
return response
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||||
|
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||||
|
(response.content or "")[:120].lower(),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
return await self._safe_chat_stream(**kw)
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import hashlib
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
@@ -223,43 +224,35 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
async def chat(
|
def _build_chat_kwargs(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None,
|
||||||
model: str | None = None,
|
model: str | None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int,
|
||||||
temperature: float = 0.7,
|
temperature: float,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None,
|
||||||
) -> LLMResponse:
|
) -> tuple[dict[str, Any], str]:
|
||||||
"""
|
"""Build the kwargs dict for ``acompletion``.
|
||||||
Send a chat completion request via LiteLLM.
|
|
||||||
|
|
||||||
Args:
|
Returns ``(kwargs, original_model)`` so callers can reuse the
|
||||||
messages: List of message dicts with 'role' and 'content'.
|
original model string for downstream logic.
|
||||||
tools: Optional list of tool definitions in OpenAI format.
|
|
||||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
|
||||||
max_tokens: Maximum tokens in response.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLMResponse with content and/or tool calls.
|
|
||||||
"""
|
"""
|
||||||
original_model = model or self.default_model
|
original_model = model or self.default_model
|
||||||
model = self._resolve_model(original_model)
|
resolved = self._resolve_model(original_model)
|
||||||
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
extra_msg_keys = self._extra_msg_keys(original_model, resolved)
|
||||||
|
|
||||||
if self._supports_cache_control(original_model):
|
if self._supports_cache_control(original_model):
|
||||||
messages, tools = self._apply_cache_control(messages, tools)
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
|
|
||||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
|
||||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
|
||||||
max_tokens = max(1, max_tokens)
|
max_tokens = max(1, max_tokens)
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model,
|
"model": resolved,
|
||||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
|
"messages": self._sanitize_messages(
|
||||||
|
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
|
||||||
|
),
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
@@ -267,21 +260,15 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if self._gateway:
|
if self._gateway:
|
||||||
kwargs.update(self._gateway.litellm_kwargs)
|
kwargs.update(self._gateway.litellm_kwargs)
|
||||||
|
|
||||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
self._apply_model_overrides(resolved, kwargs)
|
||||||
self._apply_model_overrides(model, kwargs)
|
|
||||||
|
|
||||||
if self._langsmith_enabled:
|
if self._langsmith_enabled:
|
||||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
kwargs.setdefault("callbacks", []).append("langsmith")
|
||||||
|
|
||||||
# Pass api_key directly — more reliable than env vars alone
|
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
kwargs["api_key"] = self.api_key
|
kwargs["api_key"] = self.api_key
|
||||||
|
|
||||||
# Pass api_base for custom endpoints
|
|
||||||
if self.api_base:
|
if self.api_base:
|
||||||
kwargs["api_base"] = self.api_base
|
kwargs["api_base"] = self.api_base
|
||||||
|
|
||||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
|
||||||
if self.extra_headers:
|
if self.extra_headers:
|
||||||
kwargs["extra_headers"] = self.extra_headers
|
kwargs["extra_headers"] = self.extra_headers
|
||||||
|
|
||||||
@@ -293,11 +280,66 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
kwargs["tool_choice"] = tool_choice or "auto"
|
kwargs["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
|
return kwargs, original_model
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Send a chat completion request via LiteLLM."""
|
||||||
|
kwargs, _ = self._build_chat_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = await acompletion(**kwargs)
|
response = await acompletion(**kwargs)
|
||||||
return self._parse_response(response)
|
return self._parse_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Return error as content for graceful handling
|
return LLMResponse(
|
||||||
|
content=f"Error calling LLM: {str(e)}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
|
||||||
|
kwargs, _ = self._build_chat_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
kwargs["stream"] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await acompletion(**kwargs)
|
||||||
|
chunks: list[Any] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
if on_content_delta:
|
||||||
|
delta = chunk.choices[0].delta if chunk.choices else None
|
||||||
|
text = getattr(delta, "content", None) if delta else None
|
||||||
|
if text:
|
||||||
|
await on_content_delta(text)
|
||||||
|
|
||||||
|
full_response = litellm.stream_chunk_builder(
|
||||||
|
chunks, messages=kwargs["messages"],
|
||||||
|
)
|
||||||
|
return self._parse_response(full_response)
|
||||||
|
except Exception as e:
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=f"Error calling LLM: {str(e)}",
|
content=f"Error calling LLM: {str(e)}",
|
||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
|
|||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
_response = LLMResponse(content="ok", tool_calls=[])
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||||
|
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
|
||||||
|
|
||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=MessageBus(),
|
bus=MessageBus(),
|
||||||
@@ -167,6 +169,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
|
|||||||
order.append("llm")
|
order.append("llm")
|
||||||
return LLMResponse(content="ok", tool_calls=[])
|
return LLMResponse(content="ok", tool_calls=[])
|
||||||
loop.provider.chat_with_retry = track_llm
|
loop.provider.chat_with_retry = track_llm
|
||||||
|
loop.provider.chat_stream_with_retry = track_llm
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
session.messages = [
|
session.messages = [
|
||||||
|
|||||||
Reference in New Issue
Block a user