feat(agent): add streaming groundwork for future TUI

Preserve the provider and agent-loop streaming primitives plus the CLI experiment scaffolding so this work can be resumed later without blocking urgent bug fixes on main.

Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-22 02:38:34 +00:00
committed by Xubin Ren
parent 5fd66cae5c
commit e79b9f4a83
5 changed files with 268 additions and 90 deletions

View File

@@ -212,8 +212,16 @@ class AgentLoop:
self,
initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop."""
"""Run the agent iteration loop.
*on_stream*: called with each content delta during streaming.
*on_stream_end(resuming)*: called when a streaming session finishes.
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
"""
messages = initial_messages
iteration = 0
final_content = None
@@ -224,11 +232,20 @@ class AgentLoop:
tool_defs = self.tools.get_definitions()
response = await self.provider.chat_with_retry(
messages=messages,
tools=tool_defs,
model=self.model,
)
if on_stream:
response = await self.provider.chat_stream_with_retry(
messages=messages,
tools=tool_defs,
model=self.model,
on_content_delta=on_stream,
)
else:
response = await self.provider.chat_with_retry(
messages=messages,
tools=tool_defs,
model=self.model,
)
usage = response.usage or {}
self._last_usage = {
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
@@ -236,10 +253,14 @@ class AgentLoop:
}
if response.has_tool_calls:
if on_stream and on_stream_end:
await on_stream_end(resuming=True)
if on_progress:
thought = self._strip_think(response.content)
if thought:
await on_progress(thought)
if not on_stream:
thought = self._strip_think(response.content)
if thought:
await on_progress(thought)
tool_hint = self._tool_hint(response.tool_calls)
tool_hint = self._strip_think(tool_hint)
await on_progress(tool_hint, tool_hint=True)
@@ -263,9 +284,10 @@ class AgentLoop:
messages, tool_call.id, tool_call.name, result
)
else:
if on_stream and on_stream_end:
await on_stream_end(resuming=False)
clean = self._strip_think(response.content)
# Don't persist error responses to session history — they can
# poison the context and cause permanent 400 loops (#1303).
if response.finish_reason == "error":
logger.error("LLM returned error: {}", (clean or "")[:200])
final_content = clean or "Sorry, I encountered an error calling the AI model."
@@ -400,6 +422,8 @@ class AgentLoop:
msg: InboundMessage,
session_key: str | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
@@ -412,7 +436,6 @@ class AgentLoop:
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0)
# Subagent results should be assistant role, other system messages use user role
current_role = "assistant" if msg.sender_id == "subagent" else "user"
messages = self.context.build_messages(
history=history,
@@ -486,7 +509,10 @@ class AgentLoop:
))
final_content, _, all_msgs = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress,
initial_messages,
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
)
if final_content is None:
@@ -501,9 +527,13 @@ class AgentLoop:
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {})
if on_stream is not None:
meta["_streamed"] = True
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
metadata=msg.metadata or {},
metadata=meta,
)
@staticmethod
@@ -592,8 +622,13 @@ class AgentLoop:
channel: str = "cli",
chat_id: str = "direct",
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> OutboundMessage | None:
"""Process a message directly and return the outbound payload."""
await self._connect_mcp()
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
return await self._process_message(msg, session_key=session_key, on_progress=on_progress)
return await self._process_message(
msg, session_key=session_key, on_progress=on_progress,
on_stream=on_stream, on_stream_end=on_stream_end,
)

View File

@@ -207,6 +207,10 @@ class _ThinkingSpinner:
self._active = False
if self._spinner:
self._spinner.stop()
# Force-clear the spinner line: Rich Live's transient cleanup
# occasionally loses a race with its own render thread.
console.file.write("\033[2K\r")
console.file.flush()
return False
@contextmanager
@@ -214,6 +218,8 @@ class _ThinkingSpinner:
"""Temporarily stop spinner while printing progress."""
if self._spinner and self._active:
self._spinner.stop()
console.file.write("\033[2K\r")
console.file.flush()
try:
yield
finally:
@@ -770,16 +776,25 @@ def agent(
async def run_once():
nonlocal _thinking
_thinking = _ThinkingSpinner(enabled=not logs)
with _thinking:
with _thinking or nullcontext():
response = await agent_loop.process_direct(
message, session_id, on_progress=_cli_progress,
message, session_id,
on_progress=_cli_progress,
)
_thinking = None
_print_agent_response(
response.content if response else "",
render_markdown=markdown,
metadata=response.metadata if response else None,
)
if _thinking:
_thinking.__exit__(None, None, None)
_thinking = None
if response and response.content:
_print_agent_response(
response.content,
render_markdown=markdown,
metadata=response.metadata,
)
else:
console.print()
await agent_loop.close_mcp()
asyncio.run(run_once())
@@ -820,6 +835,7 @@ def agent(
while True:
try:
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
if msg.metadata.get("_progress"):
is_tool_hint = msg.metadata.get("_tool_hint", False)
ch = agent_loop.channels_config
@@ -834,6 +850,7 @@ def agent(
if msg.content:
turn_response.append((msg.content, dict(msg.metadata or {})))
turn_done.set()
elif msg.content:
await _print_interactive_response(
msg.content,
@@ -872,11 +889,7 @@ def agent(
content=user_input,
))
nonlocal _thinking
_thinking = _ThinkingSpinner(enabled=not logs)
with _thinking:
await turn_done.wait()
_thinking = None
await turn_done.wait()
if turn_response:
content, meta = turn_response[0]

View File

@@ -3,6 +3,7 @@
import asyncio
import json
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
@@ -223,6 +224,90 @@ class LLMProvider(ABC):
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
Returns the same ``LLMResponse`` as :meth:`chat`. The default
implementation falls back to a non-streaming call and delivers the
full content as a single delta. Providers that support native
streaming should override this method.
"""
response = await self.chat(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
if on_content_delta and response.content:
await on_content_delta(response.content)
return response
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
"""Call chat_stream() and convert unexpected exceptions to error responses."""
try:
return await self.chat_stream(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat_stream() with retry on transient provider failures."""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
on_content_delta=on_content_delta,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
response = await self._safe_chat_stream(**kw)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat_stream(**{**kw, "messages": stripped})
return response
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
return await self._safe_chat_stream(**kw)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],

View File

@@ -4,6 +4,7 @@ import hashlib
import os
import secrets
import string
from collections.abc import Awaitable, Callable
from typing import Any
import json_repair
@@ -223,43 +224,35 @@ class LiteLLMProvider(LLMProvider):
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
async def chat(
def _build_chat_kwargs(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request via LiteLLM.
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> tuple[dict[str, Any], str]:
"""Build the kwargs dict for ``acompletion``.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
Returns:
LLMResponse with content and/or tool calls.
Returns ``(kwargs, original_model)`` so callers can reuse the
original model string for downstream logic.
"""
original_model = model or self.default_model
model = self._resolve_model(original_model)
extra_msg_keys = self._extra_msg_keys(original_model, model)
resolved = self._resolve_model(original_model)
extra_msg_keys = self._extra_msg_keys(original_model, resolved)
if self._supports_cache_control(original_model):
messages, tools = self._apply_cache_control(messages, tools)
# Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1".
max_tokens = max(1, max_tokens)
kwargs: dict[str, Any] = {
"model": model,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
"model": resolved,
"messages": self._sanitize_messages(
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
),
"max_tokens": max_tokens,
"temperature": temperature,
}
@@ -267,21 +260,15 @@ class LiteLLMProvider(LLMProvider):
if self._gateway:
kwargs.update(self._gateway.litellm_kwargs)
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
self._apply_model_overrides(model, kwargs)
self._apply_model_overrides(resolved, kwargs)
if self._langsmith_enabled:
kwargs.setdefault("callbacks", []).append("langsmith")
# Pass api_key directly — more reliable than env vars alone
if self.api_key:
kwargs["api_key"] = self.api_key
# Pass api_base for custom endpoints
if self.api_base:
kwargs["api_base"] = self.api_base
# Pass extra headers (e.g. APP-Code for AiHubMix)
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
@@ -293,11 +280,66 @@ class LiteLLMProvider(LLMProvider):
kwargs["tools"] = tools
kwargs["tool_choice"] = tool_choice or "auto"
return kwargs, original_model
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""Send a chat completion request via LiteLLM."""
kwargs, _ = self._build_chat_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
try:
response = await acompletion(**kwargs)
return self._parse_response(response)
except Exception as e:
# Return error as content for graceful handling
return LLMResponse(
content=f"Error calling LLM: {str(e)}",
finish_reason="error",
)
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
kwargs, _ = self._build_chat_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
kwargs["stream"] = True
try:
stream = await acompletion(**kwargs)
chunks: list[Any] = []
async for chunk in stream:
chunks.append(chunk)
if on_content_delta:
delta = chunk.choices[0].delta if chunk.choices else None
text = getattr(delta, "content", None) if delta else None
if text:
await on_content_delta(text)
full_response = litellm.stream_chunk_builder(
chunks, messages=kwargs["messages"],
)
return self._parse_response(full_response)
except Exception as e:
return LLMResponse(
content=f"Error calling LLM: {str(e)}",
finish_reason="error",

View File

@@ -12,7 +12,9 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
_response = LLMResponse(content="ok", tool_calls=[])
provider.chat_with_retry = AsyncMock(return_value=_response)
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
loop = AgentLoop(
bus=MessageBus(),
@@ -167,6 +169,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
order.append("llm")
return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm
loop.provider.chat_stream_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test")
session.messages = [