From e79b9f4a831ab265639cfc95dbbbb5a6152d5cfc Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 22 Mar 2026 02:38:34 +0000 Subject: [PATCH] feat(agent): add streaming groundwork for future TUI Preserve the provider and agent-loop streaming primitives plus the CLI experiment scaffolding so this work can be resumed later without blocking urgent bug fixes on main. Made-with: Cursor --- nanobot/agent/loop.py | 65 +++++++--- nanobot/cli/commands.py | 39 ++++-- nanobot/providers/base.py | 85 ++++++++++++ nanobot/providers/litellm_provider.py | 164 +++++++++++++++--------- tests/test_loop_consolidation_tokens.py | 5 +- 5 files changed, 268 insertions(+), 90 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b8d1647..093f0e2 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -212,8 +212,16 @@ class AgentLoop: self, initial_messages: list[dict], on_progress: Callable[..., Awaitable[None]] | None = None, + on_stream: Callable[[str], Awaitable[None]] | None = None, + on_stream_end: Callable[..., Awaitable[None]] | None = None, ) -> tuple[str | None, list[str], list[dict]]: - """Run the agent iteration loop.""" + """Run the agent iteration loop. + + *on_stream*: called with each content delta during streaming. + *on_stream_end(resuming)*: called when a streaming session finishes. + ``resuming=True`` means tool calls follow (spinner should restart); + ``resuming=False`` means this is the final response. + """ messages = initial_messages iteration = 0 final_content = None @@ -224,11 +232,20 @@ class AgentLoop: tool_defs = self.tools.get_definitions() - response = await self.provider.chat_with_retry( - messages=messages, - tools=tool_defs, - model=self.model, - ) + if on_stream: + response = await self.provider.chat_stream_with_retry( + messages=messages, + tools=tool_defs, + model=self.model, + on_content_delta=on_stream, + ) + else: + response = await self.provider.chat_with_retry( + messages=messages, + tools=tool_defs, + model=self.model, + ) + usage = response.usage or {} self._last_usage = { "prompt_tokens": int(usage.get("prompt_tokens", 0) or 0), @@ -236,10 +253,14 @@ class AgentLoop: } if response.has_tool_calls: + if on_stream and on_stream_end: + await on_stream_end(resuming=True) + if on_progress: - thought = self._strip_think(response.content) - if thought: - await on_progress(thought) + if not on_stream: + thought = self._strip_think(response.content) + if thought: + await on_progress(thought) tool_hint = self._tool_hint(response.tool_calls) tool_hint = self._strip_think(tool_hint) await on_progress(tool_hint, tool_hint=True) @@ -263,9 +284,10 @@ class AgentLoop: messages, tool_call.id, tool_call.name, result ) else: + if on_stream and on_stream_end: + await on_stream_end(resuming=False) + clean = self._strip_think(response.content) - # Don't persist error responses to session history — they can - # poison the context and cause permanent 400 loops (#1303). if response.finish_reason == "error": logger.error("LLM returned error: {}", (clean or "")[:200]) final_content = clean or "Sorry, I encountered an error calling the AI model." @@ -400,6 +422,8 @@ class AgentLoop: msg: InboundMessage, session_key: str | None = None, on_progress: Callable[[str], Awaitable[None]] | None = None, + on_stream: Callable[[str], Awaitable[None]] | None = None, + on_stream_end: Callable[..., Awaitable[None]] | None = None, ) -> OutboundMessage | None: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") @@ -412,7 +436,6 @@ class AgentLoop: await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) history = session.get_history(max_messages=0) - # Subagent results should be assistant role, other system messages use user role current_role = "assistant" if msg.sender_id == "subagent" else "user" messages = self.context.build_messages( history=history, @@ -486,7 +509,10 @@ class AgentLoop: )) final_content, _, all_msgs = await self._run_agent_loop( - initial_messages, on_progress=on_progress or _bus_progress, + initial_messages, + on_progress=on_progress or _bus_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, ) if final_content is None: @@ -501,9 +527,13 @@ class AgentLoop: preview = final_content[:120] + "..." if len(final_content) > 120 else final_content logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) + + meta = dict(msg.metadata or {}) + if on_stream is not None: + meta["_streamed"] = True return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=final_content, - metadata=msg.metadata or {}, + metadata=meta, ) @staticmethod @@ -592,8 +622,13 @@ class AgentLoop: channel: str = "cli", chat_id: str = "direct", on_progress: Callable[[str], Awaitable[None]] | None = None, + on_stream: Callable[[str], Awaitable[None]] | None = None, + on_stream_end: Callable[..., Awaitable[None]] | None = None, ) -> OutboundMessage | None: """Process a message directly and return the outbound payload.""" await self._connect_mcp() msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) - return await self._process_message(msg, session_key=session_key, on_progress=on_progress) + return await self._process_message( + msg, session_key=session_key, on_progress=on_progress, + on_stream=on_stream, on_stream_end=on_stream_end, + ) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index ea06acb..7639b3d 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -207,6 +207,10 @@ class _ThinkingSpinner: self._active = False if self._spinner: self._spinner.stop() + # Force-clear the spinner line: Rich Live's transient cleanup + # occasionally loses a race with its own render thread. + console.file.write("\033[2K\r") + console.file.flush() return False @contextmanager @@ -214,6 +218,8 @@ class _ThinkingSpinner: """Temporarily stop spinner while printing progress.""" if self._spinner and self._active: self._spinner.stop() + console.file.write("\033[2K\r") + console.file.flush() try: yield finally: @@ -770,16 +776,25 @@ def agent( async def run_once(): nonlocal _thinking _thinking = _ThinkingSpinner(enabled=not logs) - with _thinking: + + with _thinking or nullcontext(): response = await agent_loop.process_direct( - message, session_id, on_progress=_cli_progress, + message, session_id, + on_progress=_cli_progress, ) - _thinking = None - _print_agent_response( - response.content if response else "", - render_markdown=markdown, - metadata=response.metadata if response else None, - ) + + if _thinking: + _thinking.__exit__(None, None, None) + _thinking = None + + if response and response.content: + _print_agent_response( + response.content, + render_markdown=markdown, + metadata=response.metadata, + ) + else: + console.print() await agent_loop.close_mcp() asyncio.run(run_once()) @@ -820,6 +835,7 @@ def agent( while True: try: msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + if msg.metadata.get("_progress"): is_tool_hint = msg.metadata.get("_tool_hint", False) ch = agent_loop.channels_config @@ -834,6 +850,7 @@ def agent( if msg.content: turn_response.append((msg.content, dict(msg.metadata or {}))) turn_done.set() + elif msg.content: await _print_interactive_response( msg.content, @@ -872,11 +889,7 @@ def agent( content=user_input, )) - nonlocal _thinking - _thinking = _ThinkingSpinner(enabled=not logs) - with _thinking: - await turn_done.wait() - _thinking = None + await turn_done.wait() if turn_response: content, meta = turn_response[0] diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 8f9b2ba..046458d 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -3,6 +3,7 @@ import asyncio import json from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import Any @@ -223,6 +224,90 @@ class LLMProvider(ABC): except Exception as exc: return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + """Stream a chat completion, calling *on_content_delta* for each text chunk. + + Returns the same ``LLMResponse`` as :meth:`chat`. The default + implementation falls back to a non-streaming call and delivers the + full content as a single delta. Providers that support native + streaming should override this method. + """ + response = await self.chat( + messages=messages, tools=tools, model=model, + max_tokens=max_tokens, temperature=temperature, + reasoning_effort=reasoning_effort, tool_choice=tool_choice, + ) + if on_content_delta and response.content: + await on_content_delta(response.content) + return response + + async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse: + """Call chat_stream() and convert unexpected exceptions to error responses.""" + try: + return await self.chat_stream(**kwargs) + except asyncio.CancelledError: + raise + except Exception as exc: + return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") + + async def chat_stream_with_retry( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: object = _SENTINEL, + temperature: object = _SENTINEL, + reasoning_effort: object = _SENTINEL, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + """Call chat_stream() with retry on transient provider failures.""" + if max_tokens is self._SENTINEL: + max_tokens = self.generation.max_tokens + if temperature is self._SENTINEL: + temperature = self.generation.temperature + if reasoning_effort is self._SENTINEL: + reasoning_effort = self.generation.reasoning_effort + + kw: dict[str, Any] = dict( + messages=messages, tools=tools, model=model, + max_tokens=max_tokens, temperature=temperature, + reasoning_effort=reasoning_effort, tool_choice=tool_choice, + on_content_delta=on_content_delta, + ) + + for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): + response = await self._safe_chat_stream(**kw) + + if response.finish_reason != "error": + return response + + if not self._is_transient_error(response.content): + stripped = self._strip_image_content(messages) + if stripped is not None: + logger.warning("Non-transient LLM error with image content, retrying without images") + return await self._safe_chat_stream(**{**kw, "messages": stripped}) + return response + + logger.warning( + "LLM transient error (attempt {}/{}), retrying in {}s: {}", + attempt, len(self._CHAT_RETRY_DELAYS), delay, + (response.content or "")[:120].lower(), + ) + await asyncio.sleep(delay) + + return await self._safe_chat_stream(**kw) + async def chat_with_retry( self, messages: list[dict[str, Any]], diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 20c3d25..9aa0ba6 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -4,6 +4,7 @@ import hashlib import os import secrets import string +from collections.abc import Awaitable, Callable from typing import Any import json_repair @@ -223,6 +224,64 @@ class LiteLLMProvider(LLMProvider): clean["tool_call_id"] = map_id(clean["tool_call_id"]) return sanitized + def _build_chat_kwargs( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + ) -> tuple[dict[str, Any], str]: + """Build the kwargs dict for ``acompletion``. + + Returns ``(kwargs, original_model)`` so callers can reuse the + original model string for downstream logic. + """ + original_model = model or self.default_model + resolved = self._resolve_model(original_model) + extra_msg_keys = self._extra_msg_keys(original_model, resolved) + + if self._supports_cache_control(original_model): + messages, tools = self._apply_cache_control(messages, tools) + + max_tokens = max(1, max_tokens) + + kwargs: dict[str, Any] = { + "model": resolved, + "messages": self._sanitize_messages( + self._sanitize_empty_content(messages), extra_keys=extra_msg_keys, + ), + "max_tokens": max_tokens, + "temperature": temperature, + } + + if self._gateway: + kwargs.update(self._gateway.litellm_kwargs) + + self._apply_model_overrides(resolved, kwargs) + + if self._langsmith_enabled: + kwargs.setdefault("callbacks", []).append("langsmith") + + if self.api_key: + kwargs["api_key"] = self.api_key + if self.api_base: + kwargs["api_base"] = self.api_base + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort + kwargs["drop_params"] = True + + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice or "auto" + + return kwargs, original_model + async def chat( self, messages: list[dict[str, Any]], @@ -233,71 +292,54 @@ class LiteLLMProvider(LLMProvider): reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: - """ - Send a chat completion request via LiteLLM. - - Args: - messages: List of message dicts with 'role' and 'content'. - tools: Optional list of tool definitions in OpenAI format. - model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5'). - max_tokens: Maximum tokens in response. - temperature: Sampling temperature. - - Returns: - LLMResponse with content and/or tool calls. - """ - original_model = model or self.default_model - model = self._resolve_model(original_model) - extra_msg_keys = self._extra_msg_keys(original_model, model) - - if self._supports_cache_control(original_model): - messages, tools = self._apply_cache_control(messages, tools) - - # Clamp max_tokens to at least 1 — negative or zero values cause - # LiteLLM to reject the request with "max_tokens must be at least 1". - max_tokens = max(1, max_tokens) - - kwargs: dict[str, Any] = { - "model": model, - "messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys), - "max_tokens": max_tokens, - "temperature": temperature, - } - - if self._gateway: - kwargs.update(self._gateway.litellm_kwargs) - - # Apply model-specific overrides (e.g. kimi-k2.5 temperature) - self._apply_model_overrides(model, kwargs) - - if self._langsmith_enabled: - kwargs.setdefault("callbacks", []).append("langsmith") - - # Pass api_key directly — more reliable than env vars alone - if self.api_key: - kwargs["api_key"] = self.api_key - - # Pass api_base for custom endpoints - if self.api_base: - kwargs["api_base"] = self.api_base - - # Pass extra headers (e.g. APP-Code for AiHubMix) - if self.extra_headers: - kwargs["extra_headers"] = self.extra_headers - - if reasoning_effort: - kwargs["reasoning_effort"] = reasoning_effort - kwargs["drop_params"] = True - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = tool_choice or "auto" - + """Send a chat completion request via LiteLLM.""" + kwargs, _ = self._build_chat_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) try: response = await acompletion(**kwargs) return self._parse_response(response) except Exception as e: - # Return error as content for graceful handling + return LLMResponse( + content=f"Error calling LLM: {str(e)}", + finish_reason="error", + ) + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + """Stream a chat completion via LiteLLM, forwarding text deltas.""" + kwargs, _ = self._build_chat_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + kwargs["stream"] = True + + try: + stream = await acompletion(**kwargs) + chunks: list[Any] = [] + async for chunk in stream: + chunks.append(chunk) + if on_content_delta: + delta = chunk.choices[0].delta if chunk.choices else None + text = getattr(delta, "content", None) if delta else None + if text: + await on_content_delta(text) + + full_response = litellm.stream_chunk_builder( + chunks, messages=kwargs["messages"], + ) + return self._parse_response(full_response) + except Exception as e: return LLMResponse( content=f"Error calling LLM: {str(e)}", finish_reason="error", diff --git a/tests/test_loop_consolidation_tokens.py b/tests/test_loop_consolidation_tokens.py index b0f3dda..87d8d29 100644 --- a/tests/test_loop_consolidation_tokens.py +++ b/tests/test_loop_consolidation_tokens.py @@ -12,7 +12,9 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) - provider = MagicMock() provider.get_default_model.return_value = "test-model" provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") - provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) + _response = LLMResponse(content="ok", tool_calls=[]) + provider.chat_with_retry = AsyncMock(return_value=_response) + provider.chat_stream_with_retry = AsyncMock(return_value=_response) loop = AgentLoop( bus=MessageBus(), @@ -167,6 +169,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> order.append("llm") return LLMResponse(content="ok", tool_calls=[]) loop.provider.chat_with_retry = track_llm + loop.provider.chat_stream_with_retry = track_llm session = loop.sessions.get_or_create("cli:test") session.messages = [