feat: add streaming channel support with automatic fallback

Provider layer: add chat_stream / chat_stream_with_retry to all providers
(base fallback, litellm, custom, azure, codex). Refactor shared kwargs
building in each provider.

Channel layer: BaseChannel gains send_delta (no-op) and supports_streaming
(checks config + method override). ChannelManager routes _stream_delta /
_stream_end to send_delta, skips _streamed final messages.

AgentLoop._dispatch builds bus-backed on_stream/on_stream_end callbacks
when _wants_stream metadata is set. Non-streaming path unchanged.

CLI: clean up spinner ANSI workarounds, simplify commands.py flow.
Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-22 15:34:15 +00:00
committed by Xubin Ren
parent e79b9f4a83
commit bd621df57f
8 changed files with 300 additions and 109 deletions

View File

@@ -376,7 +376,23 @@ class AgentLoop:
"""Process a message under the global lock.""" """Process a message under the global lock."""
async with self._processing_lock: async with self._processing_lock:
try: try:
response = await self._process_message(msg) on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
async def on_stream(delta: str) -> None:
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta, metadata={"_stream_delta": True},
))
async def on_stream_end(*, resuming: bool = False) -> None:
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata={"_stream_end": True, "_resuming": resuming},
))
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
)
if response is not None: if response is not None:
await self.bus.publish_outbound(response) await self.bus.publish_outbound(response)
elif msg.channel == "cli": elif msg.channel == "cli":

View File

@@ -76,6 +76,17 @@ class BaseChannel(ABC):
""" """
pass pass
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
"""Deliver a streaming text chunk. Override in subclass to enable streaming."""
pass
@property
def supports_streaming(self) -> bool:
"""True when config enables streaming AND this subclass implements send_delta."""
cfg = self.config
streaming = cfg.get("streaming", False) if isinstance(cfg, dict) else getattr(cfg, "streaming", False)
return bool(streaming) and type(self).send_delta is not BaseChannel.send_delta
def is_allowed(self, sender_id: str) -> bool: def is_allowed(self, sender_id: str) -> bool:
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all.""" """Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
allow_list = getattr(self.config, "allow_from", []) allow_list = getattr(self.config, "allow_from", [])
@@ -116,13 +127,17 @@ class BaseChannel(ABC):
) )
return return
meta = metadata or {}
if self.supports_streaming:
meta = {**meta, "_wants_stream": True}
msg = InboundMessage( msg = InboundMessage(
channel=self.name, channel=self.name,
sender_id=str(sender_id), sender_id=str(sender_id),
chat_id=str(chat_id), chat_id=str(chat_id),
content=content, content=content,
media=media or [], media=media or [],
metadata=metadata or {}, metadata=meta,
session_key_override=session_key, session_key_override=session_key,
) )

View File

@@ -130,6 +130,11 @@ class ChannelManager:
channel = self.channels.get(msg.channel) channel = self.channels.get(msg.channel)
if channel: if channel:
try: try:
if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
elif msg.metadata.get("_streamed"):
pass
else:
await channel.send(msg) await channel.send(msg)
except Exception as e: except Exception as e:
logger.error("Error sending to {}: {}", msg.channel, e) logger.error("Error sending to {}: {}", msg.channel, e)

View File

@@ -207,10 +207,6 @@ class _ThinkingSpinner:
self._active = False self._active = False
if self._spinner: if self._spinner:
self._spinner.stop() self._spinner.stop()
# Force-clear the spinner line: Rich Live's transient cleanup
# occasionally loses a race with its own render thread.
console.file.write("\033[2K\r")
console.file.flush()
return False return False
@contextmanager @contextmanager
@@ -218,8 +214,6 @@ class _ThinkingSpinner:
"""Temporarily stop spinner while printing progress.""" """Temporarily stop spinner while printing progress."""
if self._spinner and self._active: if self._spinner and self._active:
self._spinner.stop() self._spinner.stop()
console.file.write("\033[2K\r")
console.file.flush()
try: try:
yield yield
finally: finally:
@@ -776,25 +770,16 @@ def agent(
async def run_once(): async def run_once():
nonlocal _thinking nonlocal _thinking
_thinking = _ThinkingSpinner(enabled=not logs) _thinking = _ThinkingSpinner(enabled=not logs)
with _thinking:
with _thinking or nullcontext():
response = await agent_loop.process_direct( response = await agent_loop.process_direct(
message, session_id, message, session_id, on_progress=_cli_progress,
on_progress=_cli_progress,
) )
if _thinking:
_thinking.__exit__(None, None, None)
_thinking = None _thinking = None
if response and response.content:
_print_agent_response( _print_agent_response(
response.content, response.content if response else "",
render_markdown=markdown, render_markdown=markdown,
metadata=response.metadata, metadata=response.metadata if response else None,
) )
else:
console.print()
await agent_loop.close_mcp() await agent_loop.close_mcp()
asyncio.run(run_once()) asyncio.run(run_once())
@@ -835,7 +820,6 @@ def agent(
while True: while True:
try: try:
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
if msg.metadata.get("_progress"): if msg.metadata.get("_progress"):
is_tool_hint = msg.metadata.get("_tool_hint", False) is_tool_hint = msg.metadata.get("_tool_hint", False)
ch = agent_loop.channels_config ch = agent_loop.channels_config
@@ -850,7 +834,6 @@ def agent(
if msg.content: if msg.content:
turn_response.append((msg.content, dict(msg.metadata or {}))) turn_response.append((msg.content, dict(msg.metadata or {})))
turn_done.set() turn_done.set()
elif msg.content: elif msg.content:
await _print_interactive_response( await _print_interactive_response(
msg.content, msg.content,
@@ -889,7 +872,11 @@ def agent(
content=user_input, content=user_input,
)) ))
nonlocal _thinking
_thinking = _ThinkingSpinner(enabled=not logs)
with _thinking:
await turn_done.wait() await turn_done.wait()
_thinking = None
if turn_response: if turn_response:
content, meta = turn_response[0] content, meta = turn_response[0]

View File

@@ -18,6 +18,7 @@ class ChannelsConfig(Base):
Built-in and plugin channel configs are stored as extra fields (dicts). Built-in and plugin channel configs are stored as extra fields (dicts).
Each channel parses its own config in __init__. Each channel parses its own config in __init__.
Per-channel "streaming": true enables streaming output (requires send_delta impl).
""" """
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")

View File

@@ -2,7 +2,9 @@
from __future__ import annotations from __future__ import annotations
import json
import uuid import uuid
from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
from urllib.parse import urljoin from urllib.parse import urljoin
@@ -208,6 +210,100 @@ class AzureOpenAIProvider(LLMProvider):
finish_reason="error", finish_reason="error",
) )
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion via Azure OpenAI SSE."""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature,
reasoning_effort, tool_choice=tool_choice,
)
payload["stream"] = True
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
async with client.stream("POST", url, headers=headers, json=payload) as response:
if response.status_code != 200:
text = await response.aread()
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
finish_reason="error",
)
return await self._consume_stream(response, on_content_delta)
except Exception as e:
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
async def _consume_stream(
self,
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
content_parts: list[str] = []
tool_call_buffers: dict[int, dict[str, str]] = {}
finish_reason = "stop"
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:].strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
except Exception:
continue
choices = chunk.get("choices") or []
if not choices:
continue
choice = choices[0]
if choice.get("finish_reason"):
finish_reason = choice["finish_reason"]
delta = choice.get("delta") or {}
text = delta.get("content")
if text:
content_parts.append(text)
if on_content_delta:
await on_content_delta(text)
for tc in delta.get("tool_calls") or []:
idx = tc.get("index", 0)
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
if tc.get("id"):
buf["id"] = tc["id"]
fn = tc.get("function") or {}
if fn.get("name"):
buf["name"] = fn["name"]
if fn.get("arguments"):
buf["arguments"] += fn["arguments"]
tool_calls = [
ToolCallRequest(
id=buf["id"], name=buf["name"],
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
)
for buf in tool_call_buffers.values()
]
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name).""" """Get the default model (also used as default deployment name)."""
return self.default_model return self.default_model

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
import json_repair import json_repair
@@ -22,22 +23,20 @@ class CustomProvider(LLMProvider):
): ):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
# Keep affinity stable for this provider instance to improve backend cache locality,
# while still letting users attach provider-specific headers for custom gateways.
default_headers = {
"x-session-affinity": uuid.uuid4().hex,
**(extra_headers or {}),
}
self._client = AsyncOpenAI( self._client = AsyncOpenAI(
api_key=api_key, api_key=api_key,
base_url=api_base, base_url=api_base,
default_headers=default_headers, default_headers={
"x-session-affinity": uuid.uuid4().hex,
**(extra_headers or {}),
},
) )
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, def _build_kwargs(
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None,
reasoning_effort: str | None = None, model: str | None, max_tokens: int, temperature: float,
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse: reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model or self.default_model, "model": model or self.default_model,
"messages": self._sanitize_empty_content(messages), "messages": self._sanitize_empty_content(messages),
@@ -48,37 +47,106 @@ class CustomProvider(LLMProvider):
kwargs["reasoning_effort"] = reasoning_effort kwargs["reasoning_effort"] = reasoning_effort
if tools: if tools:
kwargs.update(tools=tools, tool_choice=tool_choice or "auto") kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
return kwargs
def _handle_error(self, e: Exception) -> LLMResponse:
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}"
return LLMResponse(content=msg, finish_reason="error")
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
try: try:
return self._parse(await self._client.chat.completions.create(**kwargs)) return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e: except Exception as e:
# JSONDecodeError.doc / APIError.response.text may carry the raw body return self._handle_error(e)
# (e.g. "unsupported model: xxx") which is far more useful than the
# generic "Expecting value …" message. Truncate to avoid huge HTML pages. async def chat_stream(
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
if body and body.strip(): model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error") reasoning_effort: str | None = None,
return LLMResponse(content=f"Error: {e}", finish_reason="error") tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
kwargs["stream"] = True
try:
stream = await self._client.chat.completions.create(**kwargs)
chunks: list[Any] = []
async for chunk in stream:
chunks.append(chunk)
if on_content_delta and chunk.choices:
text = getattr(chunk.choices[0].delta, "content", None)
if text:
await on_content_delta(text)
return self._parse_chunks(chunks)
except Exception as e:
return self._handle_error(e)
def _parse(self, response: Any) -> LLMResponse: def _parse(self, response: Any) -> LLMResponse:
if not response.choices: if not response.choices:
return LLMResponse( return LLMResponse(
content="Error: API returned empty choices. This may indicate a temporary service issue or an invalid model response.", content="Error: API returned empty choices.",
finish_reason="error" finish_reason="error",
) )
choice = response.choices[0] choice = response.choices[0]
msg = choice.message msg = choice.message
tool_calls = [ tool_calls = [
ToolCallRequest(id=tc.id, name=tc.function.name, ToolCallRequest(
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments) id=tc.id, name=tc.function.name,
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
)
for tc in (msg.tool_calls or []) for tc in (msg.tool_calls or [])
] ]
u = response.usage u = response.usage
return LLMResponse( return LLMResponse(
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop", content=msg.content, tool_calls=tool_calls,
finish_reason=choice.finish_reason or "stop",
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {}, usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
reasoning_content=getattr(msg, "reasoning_content", None) or None, reasoning_content=getattr(msg, "reasoning_content", None) or None,
) )
def _parse_chunks(self, chunks: list[Any]) -> LLMResponse:
"""Reassemble streamed chunks into a single LLMResponse."""
content_parts: list[str] = []
tc_bufs: dict[int, dict[str, str]] = {}
finish_reason = "stop"
usage: dict[str, int] = {}
for chunk in chunks:
if not chunk.choices:
if hasattr(chunk, "usage") and chunk.usage:
u = chunk.usage
usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0,
"total_tokens": u.total_tokens or 0}
continue
choice = chunk.choices[0]
if choice.finish_reason:
finish_reason = choice.finish_reason
delta = choice.delta
if delta and delta.content:
content_parts.append(delta.content)
for tc in (delta.tool_calls or []) if delta else []:
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
if tc.id:
buf["id"] = tc.id
if tc.function and tc.function.name:
buf["name"] = tc.function.name
if tc.function and tc.function.arguments:
buf["arguments"] += tc.function.arguments
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=[
ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {})
for b in tc_bufs.values()
],
finish_reason=finish_reason,
usage=usage,
)
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import hashlib import hashlib
import json import json
from collections.abc import Awaitable, Callable
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
import httpx import httpx
@@ -24,16 +25,16 @@ class OpenAICodexProvider(LLMProvider):
super().__init__(api_key=None, api_base=None) super().__init__(api_key=None, api_base=None)
self.default_model = default_model self.default_model = default_model
async def chat( async def _call_codex(
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None,
model: str | None = None, model: str | None,
max_tokens: int = 4096, reasoning_effort: str | None,
temperature: float = 0.7, tool_choice: str | dict[str, Any] | None,
reasoning_effort: str | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
"""Shared request logic for both chat() and chat_stream()."""
model = model or self.default_model model = model or self.default_model
system_prompt, input_items = _convert_messages(messages) system_prompt, input_items = _convert_messages(messages)
@@ -52,33 +53,45 @@ class OpenAICodexProvider(LLMProvider):
"tool_choice": tool_choice or "auto", "tool_choice": tool_choice or "auto",
"parallel_tool_calls": True, "parallel_tool_calls": True,
} }
if reasoning_effort: if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort} body["reasoning"] = {"effort": reasoning_effort}
if tools: if tools:
body["tools"] = _convert_tools(tools) body["tools"] = _convert_tools(tools)
url = DEFAULT_CODEX_URL
try: try:
try: try:
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True) content, tool_calls, finish_reason = await _request_codex(
DEFAULT_CODEX_URL, headers, body, verify=True,
on_content_delta=on_content_delta,
)
except Exception as e: except Exception as e:
if "CERTIFICATE_VERIFY_FAILED" not in str(e): if "CERTIFICATE_VERIFY_FAILED" not in str(e):
raise raise
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False") logger.warning("SSL verification failed for Codex API; retrying with verify=False")
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False) content, tool_calls, finish_reason = await _request_codex(
return LLMResponse( DEFAULT_CODEX_URL, headers, body, verify=False,
content=content, on_content_delta=on_content_delta,
tool_calls=tool_calls,
finish_reason=finish_reason,
) )
return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason)
except Exception as e: except Exception as e:
return LLMResponse( return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error")
content=f"Error calling Codex: {str(e)}",
finish_reason="error", async def chat(
) self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice)
async def chat_stream(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice, on_content_delta)
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model
@@ -107,13 +120,14 @@ async def _request_codex(
headers: dict[str, str], headers: dict[str, str],
body: dict[str, Any], body: dict[str, Any],
verify: bool, verify: bool,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]: ) -> tuple[str, list[ToolCallRequest], str]:
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client: async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
async with client.stream("POST", url, headers=headers, json=body) as response: async with client.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200: if response.status_code != 200:
text = await response.aread() text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response) return await _consume_sse(response, on_content_delta)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -151,45 +165,28 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
continue continue
if role == "assistant": if role == "assistant":
# Handle text first.
if isinstance(content, str) and content: if isinstance(content, str) and content:
input_items.append( input_items.append({
{ "type": "message", "role": "assistant",
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": content}], "content": [{"type": "output_text", "text": content}],
"status": "completed", "status": "completed", "id": f"msg_{idx}",
"id": f"msg_{idx}", })
}
)
# Then handle tool calls.
for tool_call in msg.get("tool_calls", []) or []: for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {} fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id")) call_id, item_id = _split_tool_call_id(tool_call.get("id"))
call_id = call_id or f"call_{idx}" input_items.append({
item_id = item_id or f"fc_{idx}"
input_items.append(
{
"type": "function_call", "type": "function_call",
"id": item_id, "id": item_id or f"fc_{idx}",
"call_id": call_id, "call_id": call_id or f"call_{idx}",
"name": fn.get("name"), "name": fn.get("name"),
"arguments": fn.get("arguments") or "{}", "arguments": fn.get("arguments") or "{}",
} })
)
continue continue
if role == "tool": if role == "tool":
call_id, _ = _split_tool_call_id(msg.get("tool_call_id")) call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append( input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
{
"type": "function_call_output",
"call_id": call_id,
"output": output_text,
}
)
continue
return system_prompt, input_items return system_prompt, input_items
@@ -247,7 +244,10 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any],
buffer.append(line) buffer.append(line)
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]: async def _consume_sse(
response: httpx.Response,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> tuple[str, list[ToolCallRequest], str]:
content = "" content = ""
tool_calls: list[ToolCallRequest] = [] tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {} tool_call_buffers: dict[str, dict[str, Any]] = {}
@@ -267,7 +267,10 @@ async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequ
"arguments": item.get("arguments") or "", "arguments": item.get("arguments") or "",
} }
elif event_type == "response.output_text.delta": elif event_type == "response.output_text.delta":
content += event.get("delta") or "" delta_text = event.get("delta") or ""
content += delta_text
if on_content_delta and delta_text:
await on_content_delta(delta_text)
elif event_type == "response.function_call_arguments.delta": elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id") call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers: if call_id and call_id in tool_call_buffers: