feat(agent): add streaming groundwork for future TUI
Preserve the provider and agent-loop streaming primitives plus the CLI experiment scaffolding so this work can be resumed later without blocking urgent bug fixes on main. Made-with: Cursor
This commit is contained in:
@@ -212,8 +212,16 @@ class AgentLoop:
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop."""
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
*on_stream*: called with each content delta during streaming.
|
||||
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
@@ -224,11 +232,20 @@ class AgentLoop:
|
||||
|
||||
tool_defs = self.tools.get_definitions()
|
||||
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
)
|
||||
if on_stream:
|
||||
response = await self.provider.chat_stream_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
on_content_delta=on_stream,
|
||||
)
|
||||
else:
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
usage = response.usage or {}
|
||||
self._last_usage = {
|
||||
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
||||
@@ -236,10 +253,14 @@ class AgentLoop:
|
||||
}
|
||||
|
||||
if response.has_tool_calls:
|
||||
if on_stream and on_stream_end:
|
||||
await on_stream_end(resuming=True)
|
||||
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
if not on_stream:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = self._tool_hint(response.tool_calls)
|
||||
tool_hint = self._strip_think(tool_hint)
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
@@ -263,9 +284,10 @@ class AgentLoop:
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
if on_stream and on_stream_end:
|
||||
await on_stream_end(resuming=False)
|
||||
|
||||
clean = self._strip_think(response.content)
|
||||
# Don't persist error responses to session history — they can
|
||||
# poison the context and cause permanent 400 loops (#1303).
|
||||
if response.finish_reason == "error":
|
||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
@@ -400,6 +422,8 @@ class AgentLoop:
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
@@ -412,7 +436,6 @@ class AgentLoop:
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
# Subagent results should be assistant role, other system messages use user role
|
||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
@@ -486,7 +509,10 @@ class AgentLoop:
|
||||
))
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
initial_messages, on_progress=on_progress or _bus_progress,
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
@@ -501,9 +527,13 @@ class AgentLoop:
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
if on_stream is not None:
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
metadata=msg.metadata or {},
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -592,8 +622,13 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a message directly and return the outbound payload."""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
return await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||
return await self._process_message(
|
||||
msg, session_key=session_key, on_progress=on_progress,
|
||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
@@ -207,6 +207,10 @@ class _ThinkingSpinner:
|
||||
self._active = False
|
||||
if self._spinner:
|
||||
self._spinner.stop()
|
||||
# Force-clear the spinner line: Rich Live's transient cleanup
|
||||
# occasionally loses a race with its own render thread.
|
||||
console.file.write("\033[2K\r")
|
||||
console.file.flush()
|
||||
return False
|
||||
|
||||
@contextmanager
|
||||
@@ -214,6 +218,8 @@ class _ThinkingSpinner:
|
||||
"""Temporarily stop spinner while printing progress."""
|
||||
if self._spinner and self._active:
|
||||
self._spinner.stop()
|
||||
console.file.write("\033[2K\r")
|
||||
console.file.flush()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -770,16 +776,25 @@ def agent(
|
||||
async def run_once():
|
||||
nonlocal _thinking
|
||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||
with _thinking:
|
||||
|
||||
with _thinking or nullcontext():
|
||||
response = await agent_loop.process_direct(
|
||||
message, session_id, on_progress=_cli_progress,
|
||||
message, session_id,
|
||||
on_progress=_cli_progress,
|
||||
)
|
||||
_thinking = None
|
||||
_print_agent_response(
|
||||
response.content if response else "",
|
||||
render_markdown=markdown,
|
||||
metadata=response.metadata if response else None,
|
||||
)
|
||||
|
||||
if _thinking:
|
||||
_thinking.__exit__(None, None, None)
|
||||
_thinking = None
|
||||
|
||||
if response and response.content:
|
||||
_print_agent_response(
|
||||
response.content,
|
||||
render_markdown=markdown,
|
||||
metadata=response.metadata,
|
||||
)
|
||||
else:
|
||||
console.print()
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_once())
|
||||
@@ -820,6 +835,7 @@ def agent(
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||
ch = agent_loop.channels_config
|
||||
@@ -834,6 +850,7 @@ def agent(
|
||||
if msg.content:
|
||||
turn_response.append((msg.content, dict(msg.metadata or {})))
|
||||
turn_done.set()
|
||||
|
||||
elif msg.content:
|
||||
await _print_interactive_response(
|
||||
msg.content,
|
||||
@@ -872,11 +889,7 @@ def agent(
|
||||
content=user_input,
|
||||
))
|
||||
|
||||
nonlocal _thinking
|
||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||
with _thinking:
|
||||
await turn_done.wait()
|
||||
_thinking = None
|
||||
await turn_done.wait()
|
||||
|
||||
if turn_response:
|
||||
content, meta = turn_response[0]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -223,6 +224,90 @@ class LLMProvider(ABC):
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
|
||||
|
||||
Returns the same ``LLMResponse`` as :meth:`chat`. The default
|
||||
implementation falls back to a non-streaming call and delivers the
|
||||
full content as a single delta. Providers that support native
|
||||
streaming should override this method.
|
||||
"""
|
||||
response = await self.chat(
|
||||
messages=messages, tools=tools, model=model,
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
)
|
||||
if on_content_delta and response.content:
|
||||
await on_content_delta(response.content)
|
||||
return response
|
||||
|
||||
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||
"""Call chat_stream() and convert unexpected exceptions to error responses."""
|
||||
try:
|
||||
return await self.chat_stream(**kwargs)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||
|
||||
async def chat_stream_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = _SENTINEL,
|
||||
temperature: object = _SENTINEL,
|
||||
reasoning_effort: object = _SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat_stream() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
temperature = self.generation.temperature
|
||||
if reasoning_effort is self._SENTINEL:
|
||||
reasoning_effort = self.generation.reasoning_effort
|
||||
|
||||
kw: dict[str, Any] = dict(
|
||||
messages=messages, tools=tools, model=model,
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
response = await self._safe_chat_stream(**kw)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
stripped = self._strip_image_content(messages)
|
||||
if stripped is not None:
|
||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
||||
return response
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
return await self._safe_chat_stream(**kw)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -4,6 +4,7 @@ import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
@@ -223,6 +224,64 @@ class LiteLLMProvider(LLMProvider):
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
|
||||
def _build_chat_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> tuple[dict[str, Any], str]:
|
||||
"""Build the kwargs dict for ``acompletion``.
|
||||
|
||||
Returns ``(kwargs, original_model)`` so callers can reuse the
|
||||
original model string for downstream logic.
|
||||
"""
|
||||
original_model = model or self.default_model
|
||||
resolved = self._resolve_model(original_model)
|
||||
extra_msg_keys = self._extra_msg_keys(original_model, resolved)
|
||||
|
||||
if self._supports_cache_control(original_model):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": resolved,
|
||||
"messages": self._sanitize_messages(
|
||||
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
|
||||
),
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if self._gateway:
|
||||
kwargs.update(self._gateway.litellm_kwargs)
|
||||
|
||||
self._apply_model_overrides(resolved, kwargs)
|
||||
|
||||
if self._langsmith_enabled:
|
||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
kwargs["drop_params"] = True
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return kwargs, original_model
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -233,71 +292,54 @@ class LiteLLMProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request via LiteLLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
original_model = model or self.default_model
|
||||
model = self._resolve_model(original_model)
|
||||
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
||||
|
||||
if self._supports_cache_control(original_model):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if self._gateway:
|
||||
kwargs.update(self._gateway.litellm_kwargs)
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
if self._langsmith_enabled:
|
||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
||||
|
||||
# Pass api_key directly — more reliable than env vars alone
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
# Pass api_base for custom endpoints
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
kwargs["drop_params"] = True
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
"""Send a chat completion request via LiteLLM."""
|
||||
kwargs, _ = self._build_chat_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
# Return error as content for graceful handling
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
|
||||
kwargs, _ = self._build_chat_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
|
||||
try:
|
||||
stream = await acompletion(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
if on_content_delta:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
text = getattr(delta, "content", None) if delta else None
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
|
||||
full_response = litellm.stream_chunk_builder(
|
||||
chunks, messages=kwargs["messages"],
|
||||
)
|
||||
return self._parse_response(full_response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
|
||||
@@ -12,7 +12,9 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
_response = LLMResponse(content="ok", tool_calls=[])
|
||||
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
@@ -167,6 +169,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
|
||||
order.append("llm")
|
||||
return LLMResponse(content="ok", tool_calls=[])
|
||||
loop.provider.chat_with_retry = track_llm
|
||||
loop.provider.chat_stream_with_retry = track_llm
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
|
||||
Reference in New Issue
Block a user