feat: add streaming channel support with automatic fallback
Provider layer: add chat_stream / chat_stream_with_retry to all providers (base fallback, litellm, custom, azure, codex). Refactor shared kwargs building in each provider. Channel layer: BaseChannel gains send_delta (no-op) and supports_streaming (checks config + method override). ChannelManager routes _stream_delta / _stream_end to send_delta, skips _streamed final messages. AgentLoop._dispatch builds bus-backed on_stream/on_stream_end callbacks when _wants_stream metadata is set. Non-streaming path unchanged. CLI: clean up spinner ANSI workarounds, simplify commands.py flow. Made-with: Cursor
This commit is contained in:
@@ -376,7 +376,23 @@ class AgentLoop:
|
||||
"""Process a message under the global lock."""
|
||||
async with self._processing_lock:
|
||||
try:
|
||||
response = await self._process_message(msg)
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
async def on_stream(delta: str) -> None:
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta, metadata={"_stream_delta": True},
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata={"_stream_end": True, "_resuming": resuming},
|
||||
))
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
|
||||
@@ -76,6 +76,17 @@ class BaseChannel(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
"""Deliver a streaming text chunk. Override in subclass to enable streaming."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
"""True when config enables streaming AND this subclass implements send_delta."""
|
||||
cfg = self.config
|
||||
streaming = cfg.get("streaming", False) if isinstance(cfg, dict) else getattr(cfg, "streaming", False)
|
||||
return bool(streaming) and type(self).send_delta is not BaseChannel.send_delta
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
@@ -116,13 +127,17 @@ class BaseChannel(ABC):
|
||||
)
|
||||
return
|
||||
|
||||
meta = metadata or {}
|
||||
if self.supports_streaming:
|
||||
meta = {**meta, "_wants_stream": True}
|
||||
|
||||
msg = InboundMessage(
|
||||
channel=self.name,
|
||||
sender_id=str(sender_id),
|
||||
chat_id=str(chat_id),
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata=metadata or {},
|
||||
metadata=meta,
|
||||
session_key_override=session_key,
|
||||
)
|
||||
|
||||
|
||||
@@ -130,6 +130,11 @@ class ChannelManager:
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
try:
|
||||
if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
|
||||
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||
elif msg.metadata.get("_streamed"):
|
||||
pass
|
||||
else:
|
||||
await channel.send(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||
|
||||
@@ -207,10 +207,6 @@ class _ThinkingSpinner:
|
||||
self._active = False
|
||||
if self._spinner:
|
||||
self._spinner.stop()
|
||||
# Force-clear the spinner line: Rich Live's transient cleanup
|
||||
# occasionally loses a race with its own render thread.
|
||||
console.file.write("\033[2K\r")
|
||||
console.file.flush()
|
||||
return False
|
||||
|
||||
@contextmanager
|
||||
@@ -218,8 +214,6 @@ class _ThinkingSpinner:
|
||||
"""Temporarily stop spinner while printing progress."""
|
||||
if self._spinner and self._active:
|
||||
self._spinner.stop()
|
||||
console.file.write("\033[2K\r")
|
||||
console.file.flush()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -776,25 +770,16 @@ def agent(
|
||||
async def run_once():
|
||||
nonlocal _thinking
|
||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||
|
||||
with _thinking or nullcontext():
|
||||
with _thinking:
|
||||
response = await agent_loop.process_direct(
|
||||
message, session_id,
|
||||
on_progress=_cli_progress,
|
||||
message, session_id, on_progress=_cli_progress,
|
||||
)
|
||||
|
||||
if _thinking:
|
||||
_thinking.__exit__(None, None, None)
|
||||
_thinking = None
|
||||
|
||||
if response and response.content:
|
||||
_print_agent_response(
|
||||
response.content,
|
||||
response.content if response else "",
|
||||
render_markdown=markdown,
|
||||
metadata=response.metadata,
|
||||
metadata=response.metadata if response else None,
|
||||
)
|
||||
else:
|
||||
console.print()
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_once())
|
||||
@@ -835,7 +820,6 @@ def agent(
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||
ch = agent_loop.channels_config
|
||||
@@ -850,7 +834,6 @@ def agent(
|
||||
if msg.content:
|
||||
turn_response.append((msg.content, dict(msg.metadata or {})))
|
||||
turn_done.set()
|
||||
|
||||
elif msg.content:
|
||||
await _print_interactive_response(
|
||||
msg.content,
|
||||
@@ -889,7 +872,11 @@ def agent(
|
||||
content=user_input,
|
||||
))
|
||||
|
||||
nonlocal _thinking
|
||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||
with _thinking:
|
||||
await turn_done.wait()
|
||||
_thinking = None
|
||||
|
||||
if turn_response:
|
||||
content, meta = turn_response[0]
|
||||
|
||||
@@ -18,6 +18,7 @@ class ChannelsConfig(Base):
|
||||
|
||||
Built-in and plugin channel configs are stored as extra fields (dicts).
|
||||
Each channel parses its own config in __init__.
|
||||
Per-channel "streaming": true enables streaming output (requires send_delta impl).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
@@ -208,6 +210,100 @@ class AzureOpenAIProvider(LLMProvider):
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via Azure OpenAI SSE."""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice=tool_choice,
|
||||
)
|
||||
payload["stream"] = True
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return await self._consume_stream(response, on_content_delta)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
||||
content_parts: list[str] = []
|
||||
tool_call_buffers: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:].strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
choices = chunk.get("choices") or []
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice["finish_reason"]
|
||||
delta = choice.get("delta") or {}
|
||||
|
||||
text = delta.get("content")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
if on_content_delta:
|
||||
await on_content_delta(text)
|
||||
|
||||
for tc in delta.get("tool_calls") or []:
|
||||
idx = tc.get("index", 0)
|
||||
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.get("id"):
|
||||
buf["id"] = tc["id"]
|
||||
fn = tc.get("function") or {}
|
||||
if fn.get("name"):
|
||||
buf["name"] = fn["name"]
|
||||
if fn.get("arguments"):
|
||||
buf["arguments"] += fn["arguments"]
|
||||
|
||||
tool_calls = [
|
||||
ToolCallRequest(
|
||||
id=buf["id"], name=buf["name"],
|
||||
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
||||
)
|
||||
for buf in tool_call_buffers.values()
|
||||
]
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model (also used as default deployment name)."""
|
||||
return self.default_model
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
@@ -22,22 +23,20 @@ class CustomProvider(LLMProvider):
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
# Keep affinity stable for this provider instance to improve backend cache locality,
|
||||
# while still letting users attach provider-specific headers for custom gateways.
|
||||
default_headers = {
|
||||
"x-session-affinity": uuid.uuid4().hex,
|
||||
**(extra_headers or {}),
|
||||
}
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
default_headers=default_headers,
|
||||
default_headers={
|
||||
"x-session-affinity": uuid.uuid4().hex,
|
||||
**(extra_headers or {}),
|
||||
},
|
||||
)
|
||||
|
||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
||||
def _build_kwargs(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None,
|
||||
model: str | None, max_tokens: int, temperature: float,
|
||||
reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self._sanitize_empty_content(messages),
|
||||
@@ -48,37 +47,106 @@ class CustomProvider(LLMProvider):
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
||||
return kwargs
|
||||
|
||||
def _handle_error(self, e: Exception) -> LLMResponse:
|
||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}"
|
||||
return LLMResponse(content=msg, finish_reason="error")
|
||||
|
||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||
try:
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
# JSONDecodeError.doc / APIError.response.text may carry the raw body
|
||||
# (e.g. "unsupported model: xxx") which is far more useful than the
|
||||
# generic "Expecting value …" message. Truncate to avoid huge HTML pages.
|
||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
if body and body.strip():
|
||||
return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error")
|
||||
return LLMResponse(content=f"Error: {e}", finish_reason="error")
|
||||
return self._handle_error(e)
|
||||
|
||||
async def chat_stream(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||
kwargs["stream"] = True
|
||||
try:
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
if on_content_delta and chunk.choices:
|
||||
text = getattr(chunk.choices[0].delta, "content", None)
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
return self._parse_chunks(chunks)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
if not response.choices:
|
||||
return LLMResponse(
|
||||
content="Error: API returned empty choices. This may indicate a temporary service issue or an invalid model response.",
|
||||
finish_reason="error"
|
||||
content="Error: API returned empty choices.",
|
||||
finish_reason="error",
|
||||
)
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
tool_calls = [
|
||||
ToolCallRequest(id=tc.id, name=tc.function.name,
|
||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments)
|
||||
ToolCallRequest(
|
||||
id=tc.id, name=tc.function.name,
|
||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
|
||||
)
|
||||
for tc in (msg.tool_calls or [])
|
||||
]
|
||||
u = response.usage
|
||||
return LLMResponse(
|
||||
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
||||
content=msg.content, tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||
)
|
||||
|
||||
def _parse_chunks(self, chunks: list[Any]) -> LLMResponse:
|
||||
"""Reassemble streamed chunks into a single LLMResponse."""
|
||||
content_parts: list[str] = []
|
||||
tc_bufs: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
usage: dict[str, int] = {}
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
u = chunk.usage
|
||||
usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0,
|
||||
"total_tokens": u.total_tokens or 0}
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
delta = choice.delta
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
for tc in (delta.tool_calls or []) if delta else []:
|
||||
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.id:
|
||||
buf["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
buf["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
buf["arguments"] += tc.function.arguments
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {})
|
||||
for b in tc_bufs.values()
|
||||
],
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
@@ -24,16 +25,16 @@ class OpenAICodexProvider(LLMProvider):
|
||||
super().__init__(api_key=None, api_base=None)
|
||||
self.default_model = default_model
|
||||
|
||||
async def chat(
|
||||
async def _call_codex(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Shared request logic for both chat() and chat_stream()."""
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
|
||||
@@ -52,33 +53,45 @@ class OpenAICodexProvider(LLMProvider):
|
||||
"tool_choice": tool_choice or "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
|
||||
if reasoning_effort:
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
|
||||
url = DEFAULT_CODEX_URL
|
||||
|
||||
try:
|
||||
try:
|
||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True)
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
DEFAULT_CODEX_URL, headers, body, verify=True,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
except Exception as e:
|
||||
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
||||
raise
|
||||
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False")
|
||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False)
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
logger.warning("SSL verification failed for Codex API; retrying with verify=False")
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
DEFAULT_CODEX_URL, headers, body, verify=False,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Codex: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error")
|
||||
|
||||
async def chat(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice)
|
||||
|
||||
async def chat_stream(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice, on_content_delta)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
@@ -107,13 +120,14 @@ async def _request_codex(
|
||||
headers: dict[str, str],
|
||||
body: dict[str, Any],
|
||||
verify: bool,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response)
|
||||
return await _consume_sse(response, on_content_delta)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
@@ -151,45 +165,28 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# Handle text first.
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed",
|
||||
"id": f"msg_{idx}",
|
||||
}
|
||||
)
|
||||
# Then handle tool calls.
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
call_id = call_id or f"call_{idx}"
|
||||
item_id = item_id or f"fc_{idx}"
|
||||
input_items.append(
|
||||
{
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
}
|
||||
)
|
||||
})
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output_text,
|
||||
}
|
||||
)
|
||||
continue
|
||||
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
@@ -247,7 +244,10 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any],
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
||||
async def _consume_sse(
|
||||
response: httpx.Response,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
@@ -267,7 +267,10 @@ async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequ
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
content += event.get("delta") or ""
|
||||
delta_text = event.get("delta") or ""
|
||||
content += delta_text
|
||||
if on_content_delta and delta_text:
|
||||
await on_content_delta(delta_text)
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
|
||||
Reference in New Issue
Block a user