feat(agent): add streaming groundwork for future TUI

Preserve the provider and agent-loop streaming primitives plus the CLI experiment scaffolding so this work can be resumed later without blocking urgent bug fixes on main.

Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-22 02:38:34 +00:00
committed by Xubin Ren
parent 5fd66cae5c
commit e79b9f4a83
5 changed files with 268 additions and 90 deletions

View File

@@ -212,8 +212,16 @@ class AgentLoop:
self, self,
initial_messages: list[dict], initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None, on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> tuple[str | None, list[str], list[dict]]: ) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop.""" """Run the agent iteration loop.
*on_stream*: called with each content delta during streaming.
*on_stream_end(resuming)*: called when a streaming session finishes.
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
"""
messages = initial_messages messages = initial_messages
iteration = 0 iteration = 0
final_content = None final_content = None
@@ -224,11 +232,20 @@ class AgentLoop:
tool_defs = self.tools.get_definitions() tool_defs = self.tools.get_definitions()
response = await self.provider.chat_with_retry( if on_stream:
messages=messages, response = await self.provider.chat_stream_with_retry(
tools=tool_defs, messages=messages,
model=self.model, tools=tool_defs,
) model=self.model,
on_content_delta=on_stream,
)
else:
response = await self.provider.chat_with_retry(
messages=messages,
tools=tool_defs,
model=self.model,
)
usage = response.usage or {} usage = response.usage or {}
self._last_usage = { self._last_usage = {
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0), "prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
@@ -236,10 +253,14 @@ class AgentLoop:
} }
if response.has_tool_calls: if response.has_tool_calls:
if on_stream and on_stream_end:
await on_stream_end(resuming=True)
if on_progress: if on_progress:
thought = self._strip_think(response.content) if not on_stream:
if thought: thought = self._strip_think(response.content)
await on_progress(thought) if thought:
await on_progress(thought)
tool_hint = self._tool_hint(response.tool_calls) tool_hint = self._tool_hint(response.tool_calls)
tool_hint = self._strip_think(tool_hint) tool_hint = self._strip_think(tool_hint)
await on_progress(tool_hint, tool_hint=True) await on_progress(tool_hint, tool_hint=True)
@@ -263,9 +284,10 @@ class AgentLoop:
messages, tool_call.id, tool_call.name, result messages, tool_call.id, tool_call.name, result
) )
else: else:
if on_stream and on_stream_end:
await on_stream_end(resuming=False)
clean = self._strip_think(response.content) clean = self._strip_think(response.content)
# Don't persist error responses to session history — they can
# poison the context and cause permanent 400 loops (#1303).
if response.finish_reason == "error": if response.finish_reason == "error":
logger.error("LLM returned error: {}", (clean or "")[:200]) logger.error("LLM returned error: {}", (clean or "")[:200])
final_content = clean or "Sorry, I encountered an error calling the AI model." final_content = clean or "Sorry, I encountered an error calling the AI model."
@@ -400,6 +422,8 @@ class AgentLoop:
msg: InboundMessage, msg: InboundMessage,
session_key: str | None = None, session_key: str | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None, on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> OutboundMessage | None: ) -> OutboundMessage | None:
"""Process a single inbound message and return the response.""" """Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id") # System messages: parse origin from chat_id ("channel:chat_id")
@@ -412,7 +436,6 @@ class AgentLoop:
await self.memory_consolidator.maybe_consolidate_by_tokens(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0) history = session.get_history(max_messages=0)
# Subagent results should be assistant role, other system messages use user role
current_role = "assistant" if msg.sender_id == "subagent" else "user" current_role = "assistant" if msg.sender_id == "subagent" else "user"
messages = self.context.build_messages( messages = self.context.build_messages(
history=history, history=history,
@@ -486,7 +509,10 @@ class AgentLoop:
)) ))
final_content, _, all_msgs = await self._run_agent_loop( final_content, _, all_msgs = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress, initial_messages,
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
) )
if final_content is None: if final_content is None:
@@ -501,9 +527,13 @@ class AgentLoop:
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {})
if on_stream is not None:
meta["_streamed"] = True
return OutboundMessage( return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content, channel=msg.channel, chat_id=msg.chat_id, content=final_content,
metadata=msg.metadata or {}, metadata=meta,
) )
@staticmethod @staticmethod
@@ -592,8 +622,13 @@ class AgentLoop:
channel: str = "cli", channel: str = "cli",
chat_id: str = "direct", chat_id: str = "direct",
on_progress: Callable[[str], Awaitable[None]] | None = None, on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> OutboundMessage | None: ) -> OutboundMessage | None:
"""Process a message directly and return the outbound payload.""" """Process a message directly and return the outbound payload."""
await self._connect_mcp() await self._connect_mcp()
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
return await self._process_message(msg, session_key=session_key, on_progress=on_progress) return await self._process_message(
msg, session_key=session_key, on_progress=on_progress,
on_stream=on_stream, on_stream_end=on_stream_end,
)

View File

@@ -207,6 +207,10 @@ class _ThinkingSpinner:
self._active = False self._active = False
if self._spinner: if self._spinner:
self._spinner.stop() self._spinner.stop()
# Force-clear the spinner line: Rich Live's transient cleanup
# occasionally loses a race with its own render thread.
console.file.write("\033[2K\r")
console.file.flush()
return False return False
@contextmanager @contextmanager
@@ -214,6 +218,8 @@ class _ThinkingSpinner:
"""Temporarily stop spinner while printing progress.""" """Temporarily stop spinner while printing progress."""
if self._spinner and self._active: if self._spinner and self._active:
self._spinner.stop() self._spinner.stop()
console.file.write("\033[2K\r")
console.file.flush()
try: try:
yield yield
finally: finally:
@@ -770,16 +776,25 @@ def agent(
async def run_once(): async def run_once():
nonlocal _thinking nonlocal _thinking
_thinking = _ThinkingSpinner(enabled=not logs) _thinking = _ThinkingSpinner(enabled=not logs)
with _thinking:
with _thinking or nullcontext():
response = await agent_loop.process_direct( response = await agent_loop.process_direct(
message, session_id, on_progress=_cli_progress, message, session_id,
on_progress=_cli_progress,
) )
_thinking = None
_print_agent_response( if _thinking:
response.content if response else "", _thinking.__exit__(None, None, None)
render_markdown=markdown, _thinking = None
metadata=response.metadata if response else None,
) if response and response.content:
_print_agent_response(
response.content,
render_markdown=markdown,
metadata=response.metadata,
)
else:
console.print()
await agent_loop.close_mcp() await agent_loop.close_mcp()
asyncio.run(run_once()) asyncio.run(run_once())
@@ -820,6 +835,7 @@ def agent(
while True: while True:
try: try:
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
if msg.metadata.get("_progress"): if msg.metadata.get("_progress"):
is_tool_hint = msg.metadata.get("_tool_hint", False) is_tool_hint = msg.metadata.get("_tool_hint", False)
ch = agent_loop.channels_config ch = agent_loop.channels_config
@@ -834,6 +850,7 @@ def agent(
if msg.content: if msg.content:
turn_response.append((msg.content, dict(msg.metadata or {}))) turn_response.append((msg.content, dict(msg.metadata or {})))
turn_done.set() turn_done.set()
elif msg.content: elif msg.content:
await _print_interactive_response( await _print_interactive_response(
msg.content, msg.content,
@@ -872,11 +889,7 @@ def agent(
content=user_input, content=user_input,
)) ))
nonlocal _thinking await turn_done.wait()
_thinking = _ThinkingSpinner(enabled=not logs)
with _thinking:
await turn_done.wait()
_thinking = None
if turn_response: if turn_response:
content, meta = turn_response[0] content, meta = turn_response[0]

View File

@@ -3,6 +3,7 @@
import asyncio import asyncio
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -223,6 +224,90 @@ class LLMProvider(ABC):
except Exception as exc: except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
Returns the same ``LLMResponse`` as :meth:`chat`. The default
implementation falls back to a non-streaming call and delivers the
full content as a single delta. Providers that support native
streaming should override this method.
"""
response = await self.chat(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
if on_content_delta and response.content:
await on_content_delta(response.content)
return response
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
"""Call chat_stream() and convert unexpected exceptions to error responses."""
try:
return await self.chat_stream(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat_stream() with retry on transient provider failures."""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
on_content_delta=on_content_delta,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
response = await self._safe_chat_stream(**kw)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat_stream(**{**kw, "messages": stripped})
return response
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
return await self._safe_chat_stream(**kw)
async def chat_with_retry( async def chat_with_retry(
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],

View File

@@ -4,6 +4,7 @@ import hashlib
import os import os
import secrets import secrets
import string import string
from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
import json_repair import json_repair
@@ -223,43 +224,35 @@ class LiteLLMProvider(LLMProvider):
clean["tool_call_id"] = map_id(clean["tool_call_id"]) clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized return sanitized
async def chat( def _build_chat_kwargs(
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None,
model: str | None = None, model: str | None,
max_tokens: int = 4096, max_tokens: int,
temperature: float = 0.7, temperature: float,
reasoning_effort: str | None = None, reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None = None, tool_choice: str | dict[str, Any] | None,
) -> LLMResponse: ) -> tuple[dict[str, Any], str]:
""" """Build the kwargs dict for ``acompletion``.
Send a chat completion request via LiteLLM.
Args: Returns ``(kwargs, original_model)`` so callers can reuse the
messages: List of message dicts with 'role' and 'content'. original model string for downstream logic.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
Returns:
LLMResponse with content and/or tool calls.
""" """
original_model = model or self.default_model original_model = model or self.default_model
model = self._resolve_model(original_model) resolved = self._resolve_model(original_model)
extra_msg_keys = self._extra_msg_keys(original_model, model) extra_msg_keys = self._extra_msg_keys(original_model, resolved)
if self._supports_cache_control(original_model): if self._supports_cache_control(original_model):
messages, tools = self._apply_cache_control(messages, tools) messages, tools = self._apply_cache_control(messages, tools)
# Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1".
max_tokens = max(1, max_tokens) max_tokens = max(1, max_tokens)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model, "model": resolved,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys), "messages": self._sanitize_messages(
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
),
"max_tokens": max_tokens, "max_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
} }
@@ -267,21 +260,15 @@ class LiteLLMProvider(LLMProvider):
if self._gateway: if self._gateway:
kwargs.update(self._gateway.litellm_kwargs) kwargs.update(self._gateway.litellm_kwargs)
# Apply model-specific overrides (e.g. kimi-k2.5 temperature) self._apply_model_overrides(resolved, kwargs)
self._apply_model_overrides(model, kwargs)
if self._langsmith_enabled: if self._langsmith_enabled:
kwargs.setdefault("callbacks", []).append("langsmith") kwargs.setdefault("callbacks", []).append("langsmith")
# Pass api_key directly — more reliable than env vars alone
if self.api_key: if self.api_key:
kwargs["api_key"] = self.api_key kwargs["api_key"] = self.api_key
# Pass api_base for custom endpoints
if self.api_base: if self.api_base:
kwargs["api_base"] = self.api_base kwargs["api_base"] = self.api_base
# Pass extra headers (e.g. APP-Code for AiHubMix)
if self.extra_headers: if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers kwargs["extra_headers"] = self.extra_headers
@@ -293,11 +280,66 @@ class LiteLLMProvider(LLMProvider):
kwargs["tools"] = tools kwargs["tools"] = tools
kwargs["tool_choice"] = tool_choice or "auto" kwargs["tool_choice"] = tool_choice or "auto"
return kwargs, original_model
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""Send a chat completion request via LiteLLM."""
kwargs, _ = self._build_chat_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
try: try:
response = await acompletion(**kwargs) response = await acompletion(**kwargs)
return self._parse_response(response) return self._parse_response(response)
except Exception as e: except Exception as e:
# Return error as content for graceful handling return LLMResponse(
content=f"Error calling LLM: {str(e)}",
finish_reason="error",
)
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
kwargs, _ = self._build_chat_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
kwargs["stream"] = True
try:
stream = await acompletion(**kwargs)
chunks: list[Any] = []
async for chunk in stream:
chunks.append(chunk)
if on_content_delta:
delta = chunk.choices[0].delta if chunk.choices else None
text = getattr(delta, "content", None) if delta else None
if text:
await on_content_delta(text)
full_response = litellm.stream_chunk_builder(
chunks, messages=kwargs["messages"],
)
return self._parse_response(full_response)
except Exception as e:
return LLMResponse( return LLMResponse(
content=f"Error calling LLM: {str(e)}", content=f"Error calling LLM: {str(e)}",
finish_reason="error", finish_reason="error",

View File

@@ -12,7 +12,9 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) _response = LLMResponse(content="ok", tool_calls=[])
provider.chat_with_retry = AsyncMock(return_value=_response)
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
loop = AgentLoop( loop = AgentLoop(
bus=MessageBus(), bus=MessageBus(),
@@ -167,6 +169,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
order.append("llm") order.append("llm")
return LLMResponse(content="ok", tool_calls=[]) return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm loop.provider.chat_with_retry = track_llm
loop.provider.chat_stream_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
session.messages = [ session.messages = [