From bd621df57f7b4ab4122d57bf04d797eb1e523690 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 22 Mar 2026 15:34:15 +0000 Subject: [PATCH] feat: add streaming channel support with automatic fallback Provider layer: add chat_stream / chat_stream_with_retry to all providers (base fallback, litellm, custom, azure, codex). Refactor shared kwargs building in each provider. Channel layer: BaseChannel gains send_delta (no-op) and supports_streaming (checks config + method override). ChannelManager routes _stream_delta / _stream_end to send_delta, skips _streamed final messages. AgentLoop._dispatch builds bus-backed on_stream/on_stream_end callbacks when _wants_stream metadata is set. Non-streaming path unchanged. CLI: clean up spinner ANSI workarounds, simplify commands.py flow. Made-with: Cursor --- nanobot/agent/loop.py | 18 +++- nanobot/channels/base.py | 17 ++- nanobot/channels/manager.py | 7 +- nanobot/cli/commands.py | 39 +++---- nanobot/config/schema.py | 1 + nanobot/providers/azure_openai_provider.py | 96 +++++++++++++++++ nanobot/providers/custom_provider.py | 116 ++++++++++++++++----- nanobot/providers/openai_codex_provider.py | 115 ++++++++++---------- 8 files changed, 300 insertions(+), 109 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 093f0e2..1bbb7cf 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -376,7 +376,23 @@ class AgentLoop: """Process a message under the global lock.""" async with self._processing_lock: try: - response = await self._process_message(msg) + on_stream = on_stream_end = None + if msg.metadata.get("_wants_stream"): + async def on_stream(delta: str) -> None: + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content=delta, metadata={"_stream_delta": True}, + )) + + async def on_stream_end(*, resuming: bool = False) -> None: + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="", metadata={"_stream_end": True, "_resuming": resuming}, + )) + + response = await self._process_message( + msg, on_stream=on_stream, on_stream_end=on_stream_end, + ) if response is not None: await self.bus.publish_outbound(response) elif msg.channel == "cli": diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 81f0751..49be390 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -76,6 +76,17 @@ class BaseChannel(ABC): """ pass + async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + """Deliver a streaming text chunk. Override in subclass to enable streaming.""" + pass + + @property + def supports_streaming(self) -> bool: + """True when config enables streaming AND this subclass implements send_delta.""" + cfg = self.config + streaming = cfg.get("streaming", False) if isinstance(cfg, dict) else getattr(cfg, "streaming", False) + return bool(streaming) and type(self).send_delta is not BaseChannel.send_delta + def is_allowed(self, sender_id: str) -> bool: """Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all.""" allow_list = getattr(self.config, "allow_from", []) @@ -116,13 +127,17 @@ class BaseChannel(ABC): ) return + meta = metadata or {} + if self.supports_streaming: + meta = {**meta, "_wants_stream": True} + msg = InboundMessage( channel=self.name, sender_id=str(sender_id), chat_id=str(chat_id), content=content, media=media or [], - metadata=metadata or {}, + metadata=meta, session_key_override=session_key, ) diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 3820c10..3a53b63 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -130,7 +130,12 @@ class ChannelManager: channel = self.channels.get(msg.channel) if channel: try: - await channel.send(msg) + if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"): + await channel.send_delta(msg.chat_id, msg.content, msg.metadata) + elif msg.metadata.get("_streamed"): + pass + else: + await channel.send(msg) except Exception as e: logger.error("Error sending to {}: {}", msg.channel, e) else: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 7639b3d..ea06acb 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -207,10 +207,6 @@ class _ThinkingSpinner: self._active = False if self._spinner: self._spinner.stop() - # Force-clear the spinner line: Rich Live's transient cleanup - # occasionally loses a race with its own render thread. - console.file.write("\033[2K\r") - console.file.flush() return False @contextmanager @@ -218,8 +214,6 @@ class _ThinkingSpinner: """Temporarily stop spinner while printing progress.""" if self._spinner and self._active: self._spinner.stop() - console.file.write("\033[2K\r") - console.file.flush() try: yield finally: @@ -776,25 +770,16 @@ def agent( async def run_once(): nonlocal _thinking _thinking = _ThinkingSpinner(enabled=not logs) - - with _thinking or nullcontext(): + with _thinking: response = await agent_loop.process_direct( - message, session_id, - on_progress=_cli_progress, + message, session_id, on_progress=_cli_progress, ) - - if _thinking: - _thinking.__exit__(None, None, None) - _thinking = None - - if response and response.content: - _print_agent_response( - response.content, - render_markdown=markdown, - metadata=response.metadata, - ) - else: - console.print() + _thinking = None + _print_agent_response( + response.content if response else "", + render_markdown=markdown, + metadata=response.metadata if response else None, + ) await agent_loop.close_mcp() asyncio.run(run_once()) @@ -835,7 +820,6 @@ def agent( while True: try: msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) - if msg.metadata.get("_progress"): is_tool_hint = msg.metadata.get("_tool_hint", False) ch = agent_loop.channels_config @@ -850,7 +834,6 @@ def agent( if msg.content: turn_response.append((msg.content, dict(msg.metadata or {}))) turn_done.set() - elif msg.content: await _print_interactive_response( msg.content, @@ -889,7 +872,11 @@ def agent( content=user_input, )) - await turn_done.wait() + nonlocal _thinking + _thinking = _ThinkingSpinner(enabled=not logs) + with _thinking: + await turn_done.wait() + _thinking = None if turn_response: content, meta = turn_response[0] diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c884433..5937b2e 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -18,6 +18,7 @@ class ChannelsConfig(Base): Built-in and plugin channel configs are stored as extra fields (dicts). Each channel parses its own config in __init__. + Per-channel "streaming": true enables streaming output (requires send_delta impl). """ model_config = ConfigDict(extra="allow") diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index 05fbac4..d71dae9 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -2,7 +2,9 @@ from __future__ import annotations +import json import uuid +from collections.abc import Awaitable, Callable from typing import Any from urllib.parse import urljoin @@ -208,6 +210,100 @@ class AzureOpenAIProvider(LLMProvider): finish_reason="error", ) + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + """Stream a chat completion via Azure OpenAI SSE.""" + deployment_name = model or self.default_model + url = self._build_chat_url(deployment_name) + headers = self._build_headers() + payload = self._prepare_request_payload( + deployment_name, messages, tools, max_tokens, temperature, + reasoning_effort, tool_choice=tool_choice, + ) + payload["stream"] = True + + try: + async with httpx.AsyncClient(timeout=60.0, verify=True) as client: + async with client.stream("POST", url, headers=headers, json=payload) as response: + if response.status_code != 200: + text = await response.aread() + return LLMResponse( + content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}", + finish_reason="error", + ) + return await self._consume_stream(response, on_content_delta) + except Exception as e: + return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error") + + async def _consume_stream( + self, + response: httpx.Response, + on_content_delta: Callable[[str], Awaitable[None]] | None, + ) -> LLMResponse: + """Parse Azure OpenAI SSE stream into an LLMResponse.""" + content_parts: list[str] = [] + tool_call_buffers: dict[int, dict[str, str]] = {} + finish_reason = "stop" + + async for line in response.aiter_lines(): + if not line.startswith("data: "): + continue + data = line[6:].strip() + if data == "[DONE]": + break + try: + chunk = json.loads(data) + except Exception: + continue + + choices = chunk.get("choices") or [] + if not choices: + continue + choice = choices[0] + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + delta = choice.get("delta") or {} + + text = delta.get("content") + if text: + content_parts.append(text) + if on_content_delta: + await on_content_delta(text) + + for tc in delta.get("tool_calls") or []: + idx = tc.get("index", 0) + buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""}) + if tc.get("id"): + buf["id"] = tc["id"] + fn = tc.get("function") or {} + if fn.get("name"): + buf["name"] = fn["name"] + if fn.get("arguments"): + buf["arguments"] += fn["arguments"] + + tool_calls = [ + ToolCallRequest( + id=buf["id"], name=buf["name"], + arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {}, + ) + for buf in tool_call_buffers.values() + ] + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + ) + def get_default_model(self) -> str: """Get the default model (also used as default deployment name).""" return self.default_model \ No newline at end of file diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py index 3daa0cc..a47dae7 100644 --- a/nanobot/providers/custom_provider.py +++ b/nanobot/providers/custom_provider.py @@ -3,6 +3,7 @@ from __future__ import annotations import uuid +from collections.abc import Awaitable, Callable from typing import Any import json_repair @@ -22,22 +23,20 @@ class CustomProvider(LLMProvider): ): super().__init__(api_key, api_base) self.default_model = default_model - # Keep affinity stable for this provider instance to improve backend cache locality, - # while still letting users attach provider-specific headers for custom gateways. - default_headers = { - "x-session-affinity": uuid.uuid4().hex, - **(extra_headers or {}), - } self._client = AsyncOpenAI( api_key=api_key, base_url=api_base, - default_headers=default_headers, + default_headers={ + "x-session-affinity": uuid.uuid4().hex, + **(extra_headers or {}), + }, ) - async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None) -> LLMResponse: + def _build_kwargs( + self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None, + model: str | None, max_tokens: int, temperature: float, + reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None, + ) -> dict[str, Any]: kwargs: dict[str, Any] = { "model": model or self.default_model, "messages": self._sanitize_empty_content(messages), @@ -48,37 +47,106 @@ class CustomProvider(LLMProvider): kwargs["reasoning_effort"] = reasoning_effort if tools: kwargs.update(tools=tools, tool_choice=tool_choice or "auto") + return kwargs + + def _handle_error(self, e: Exception) -> LLMResponse: + body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) + msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}" + return LLMResponse(content=msg, finish_reason="error") + + async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, + model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None) -> LLMResponse: + kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) try: return self._parse(await self._client.chat.completions.create(**kwargs)) except Exception as e: - # JSONDecodeError.doc / APIError.response.text may carry the raw body - # (e.g. "unsupported model: xxx") which is far more useful than the - # generic "Expecting value …" message. Truncate to avoid huge HTML pages. - body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) - if body and body.strip(): - return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error") - return LLMResponse(content=f"Error: {e}", finish_reason="error") + return self._handle_error(e) + + async def chat_stream( + self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, + model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) + kwargs["stream"] = True + try: + stream = await self._client.chat.completions.create(**kwargs) + chunks: list[Any] = [] + async for chunk in stream: + chunks.append(chunk) + if on_content_delta and chunk.choices: + text = getattr(chunk.choices[0].delta, "content", None) + if text: + await on_content_delta(text) + return self._parse_chunks(chunks) + except Exception as e: + return self._handle_error(e) def _parse(self, response: Any) -> LLMResponse: if not response.choices: return LLMResponse( - content="Error: API returned empty choices. This may indicate a temporary service issue or an invalid model response.", - finish_reason="error" + content="Error: API returned empty choices.", + finish_reason="error", ) choice = response.choices[0] msg = choice.message tool_calls = [ - ToolCallRequest(id=tc.id, name=tc.function.name, - arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments) + ToolCallRequest( + id=tc.id, name=tc.function.name, + arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments, + ) for tc in (msg.tool_calls or []) ] u = response.usage return LLMResponse( - content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop", + content=msg.content, tool_calls=tool_calls, + finish_reason=choice.finish_reason or "stop", usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {}, reasoning_content=getattr(msg, "reasoning_content", None) or None, ) + def _parse_chunks(self, chunks: list[Any]) -> LLMResponse: + """Reassemble streamed chunks into a single LLMResponse.""" + content_parts: list[str] = [] + tc_bufs: dict[int, dict[str, str]] = {} + finish_reason = "stop" + usage: dict[str, int] = {} + + for chunk in chunks: + if not chunk.choices: + if hasattr(chunk, "usage") and chunk.usage: + u = chunk.usage + usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0, + "total_tokens": u.total_tokens or 0} + continue + choice = chunk.choices[0] + if choice.finish_reason: + finish_reason = choice.finish_reason + delta = choice.delta + if delta and delta.content: + content_parts.append(delta.content) + for tc in (delta.tool_calls or []) if delta else []: + buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) + if tc.id: + buf["id"] = tc.id + if tc.function and tc.function.name: + buf["name"] = tc.function.name + if tc.function and tc.function.arguments: + buf["arguments"] += tc.function.arguments + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=[ + ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}) + for b in tc_bufs.values() + ], + finish_reason=finish_reason, + usage=usage, + ) + def get_default_model(self) -> str: return self.default_model - diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index c8f2155..1c6bc70 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import hashlib import json +from collections.abc import Awaitable, Callable from typing import Any, AsyncGenerator import httpx @@ -24,16 +25,16 @@ class OpenAICodexProvider(LLMProvider): super().__init__(api_key=None, api_base=None) self.default_model = default_model - async def chat( + async def _call_codex( self, messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None, + model: str | None, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: + """Shared request logic for both chat() and chat_stream().""" model = model or self.default_model system_prompt, input_items = _convert_messages(messages) @@ -52,33 +53,45 @@ class OpenAICodexProvider(LLMProvider): "tool_choice": tool_choice or "auto", "parallel_tool_calls": True, } - if reasoning_effort: body["reasoning"] = {"effort": reasoning_effort} - if tools: body["tools"] = _convert_tools(tools) - url = DEFAULT_CODEX_URL - try: try: - content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True) + content, tool_calls, finish_reason = await _request_codex( + DEFAULT_CODEX_URL, headers, body, verify=True, + on_content_delta=on_content_delta, + ) except Exception as e: if "CERTIFICATE_VERIFY_FAILED" not in str(e): raise - logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False") - content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False) - return LLMResponse( - content=content, - tool_calls=tool_calls, - finish_reason=finish_reason, - ) + logger.warning("SSL verification failed for Codex API; retrying with verify=False") + content, tool_calls, finish_reason = await _request_codex( + DEFAULT_CODEX_URL, headers, body, verify=False, + on_content_delta=on_content_delta, + ) + return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason) except Exception as e: - return LLMResponse( - content=f"Error calling Codex: {str(e)}", - finish_reason="error", - ) + return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error") + + async def chat( + self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, + model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice) + + async def chat_stream( + self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, + model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice, on_content_delta) def get_default_model(self) -> str: return self.default_model @@ -107,13 +120,14 @@ async def _request_codex( headers: dict[str, str], body: dict[str, Any], verify: bool, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> tuple[str, list[ToolCallRequest], str]: async with httpx.AsyncClient(timeout=60.0, verify=verify) as client: async with client.stream("POST", url, headers=headers, json=body) as response: if response.status_code != 200: text = await response.aread() raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) - return await _consume_sse(response) + return await _consume_sse(response, on_content_delta) def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: @@ -151,45 +165,28 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st continue if role == "assistant": - # Handle text first. if isinstance(content, str) and content: - input_items.append( - { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": content}], - "status": "completed", - "id": f"msg_{idx}", - } - ) - # Then handle tool calls. + input_items.append({ + "type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": content}], + "status": "completed", "id": f"msg_{idx}", + }) for tool_call in msg.get("tool_calls", []) or []: fn = tool_call.get("function") or {} call_id, item_id = _split_tool_call_id(tool_call.get("id")) - call_id = call_id or f"call_{idx}" - item_id = item_id or f"fc_{idx}" - input_items.append( - { - "type": "function_call", - "id": item_id, - "call_id": call_id, - "name": fn.get("name"), - "arguments": fn.get("arguments") or "{}", - } - ) + input_items.append({ + "type": "function_call", + "id": item_id or f"fc_{idx}", + "call_id": call_id or f"call_{idx}", + "name": fn.get("name"), + "arguments": fn.get("arguments") or "{}", + }) continue if role == "tool": call_id, _ = _split_tool_call_id(msg.get("tool_call_id")) output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) - input_items.append( - { - "type": "function_call_output", - "call_id": call_id, - "output": output_text, - } - ) - continue + input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text}) return system_prompt, input_items @@ -247,7 +244,10 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], buffer.append(line) -async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]: +async def _consume_sse( + response: httpx.Response, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, +) -> tuple[str, list[ToolCallRequest], str]: content = "" tool_calls: list[ToolCallRequest] = [] tool_call_buffers: dict[str, dict[str, Any]] = {} @@ -267,7 +267,10 @@ async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequ "arguments": item.get("arguments") or "", } elif event_type == "response.output_text.delta": - content += event.get("delta") or "" + delta_text = event.get("delta") or "" + content += delta_text + if on_content_delta and delta_text: + await on_content_delta(delta_text) elif event_type == "response.function_call_arguments.delta": call_id = event.get("call_id") if call_id and call_id in tool_call_buffers: