feat: add streaming channel support with automatic fallback
Provider layer: add chat_stream / chat_stream_with_retry to all providers (base fallback, litellm, custom, azure, codex). Refactor shared kwargs building in each provider. Channel layer: BaseChannel gains send_delta (no-op) and supports_streaming (checks config + method override). ChannelManager routes _stream_delta / _stream_end to send_delta, skips _streamed final messages. AgentLoop._dispatch builds bus-backed on_stream/on_stream_end callbacks when _wants_stream metadata is set. Non-streaming path unchanged. CLI: clean up spinner ANSI workarounds, simplify commands.py flow. Made-with: Cursor
This commit is contained in:
@@ -376,7 +376,23 @@ class AgentLoop:
|
|||||||
"""Process a message under the global lock."""
|
"""Process a message under the global lock."""
|
||||||
async with self._processing_lock:
|
async with self._processing_lock:
|
||||||
try:
|
try:
|
||||||
response = await self._process_message(msg)
|
on_stream = on_stream_end = None
|
||||||
|
if msg.metadata.get("_wants_stream"):
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content=delta, metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
|
||||||
|
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="", metadata={"_stream_end": True, "_resuming": resuming},
|
||||||
|
))
|
||||||
|
|
||||||
|
response = await self._process_message(
|
||||||
|
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||||
|
)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
await self.bus.publish_outbound(response)
|
await self.bus.publish_outbound(response)
|
||||||
elif msg.channel == "cli":
|
elif msg.channel == "cli":
|
||||||
|
|||||||
@@ -76,6 +76,17 @@ class BaseChannel(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
"""Deliver a streaming text chunk. Override in subclass to enable streaming."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_streaming(self) -> bool:
|
||||||
|
"""True when config enables streaming AND this subclass implements send_delta."""
|
||||||
|
cfg = self.config
|
||||||
|
streaming = cfg.get("streaming", False) if isinstance(cfg, dict) else getattr(cfg, "streaming", False)
|
||||||
|
return bool(streaming) and type(self).send_delta is not BaseChannel.send_delta
|
||||||
|
|
||||||
def is_allowed(self, sender_id: str) -> bool:
|
def is_allowed(self, sender_id: str) -> bool:
|
||||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||||
allow_list = getattr(self.config, "allow_from", [])
|
allow_list = getattr(self.config, "allow_from", [])
|
||||||
@@ -116,13 +127,17 @@ class BaseChannel(ABC):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
meta = metadata or {}
|
||||||
|
if self.supports_streaming:
|
||||||
|
meta = {**meta, "_wants_stream": True}
|
||||||
|
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel=self.name,
|
channel=self.name,
|
||||||
sender_id=str(sender_id),
|
sender_id=str(sender_id),
|
||||||
chat_id=str(chat_id),
|
chat_id=str(chat_id),
|
||||||
content=content,
|
content=content,
|
||||||
media=media or [],
|
media=media or [],
|
||||||
metadata=metadata or {},
|
metadata=meta,
|
||||||
session_key_override=session_key,
|
session_key_override=session_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -130,6 +130,11 @@ class ChannelManager:
|
|||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
try:
|
try:
|
||||||
|
if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
|
||||||
|
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||||
|
elif msg.metadata.get("_streamed"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
await channel.send(msg)
|
await channel.send(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||||
|
|||||||
@@ -207,10 +207,6 @@ class _ThinkingSpinner:
|
|||||||
self._active = False
|
self._active = False
|
||||||
if self._spinner:
|
if self._spinner:
|
||||||
self._spinner.stop()
|
self._spinner.stop()
|
||||||
# Force-clear the spinner line: Rich Live's transient cleanup
|
|
||||||
# occasionally loses a race with its own render thread.
|
|
||||||
console.file.write("\033[2K\r")
|
|
||||||
console.file.flush()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -218,8 +214,6 @@ class _ThinkingSpinner:
|
|||||||
"""Temporarily stop spinner while printing progress."""
|
"""Temporarily stop spinner while printing progress."""
|
||||||
if self._spinner and self._active:
|
if self._spinner and self._active:
|
||||||
self._spinner.stop()
|
self._spinner.stop()
|
||||||
console.file.write("\033[2K\r")
|
|
||||||
console.file.flush()
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@@ -776,25 +770,16 @@ def agent(
|
|||||||
async def run_once():
|
async def run_once():
|
||||||
nonlocal _thinking
|
nonlocal _thinking
|
||||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||||
|
with _thinking:
|
||||||
with _thinking or nullcontext():
|
|
||||||
response = await agent_loop.process_direct(
|
response = await agent_loop.process_direct(
|
||||||
message, session_id,
|
message, session_id, on_progress=_cli_progress,
|
||||||
on_progress=_cli_progress,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if _thinking:
|
|
||||||
_thinking.__exit__(None, None, None)
|
|
||||||
_thinking = None
|
_thinking = None
|
||||||
|
|
||||||
if response and response.content:
|
|
||||||
_print_agent_response(
|
_print_agent_response(
|
||||||
response.content,
|
response.content if response else "",
|
||||||
render_markdown=markdown,
|
render_markdown=markdown,
|
||||||
metadata=response.metadata,
|
metadata=response.metadata if response else None,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
console.print()
|
|
||||||
await agent_loop.close_mcp()
|
await agent_loop.close_mcp()
|
||||||
|
|
||||||
asyncio.run(run_once())
|
asyncio.run(run_once())
|
||||||
@@ -835,7 +820,6 @@ def agent(
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
|
||||||
if msg.metadata.get("_progress"):
|
if msg.metadata.get("_progress"):
|
||||||
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||||
ch = agent_loop.channels_config
|
ch = agent_loop.channels_config
|
||||||
@@ -850,7 +834,6 @@ def agent(
|
|||||||
if msg.content:
|
if msg.content:
|
||||||
turn_response.append((msg.content, dict(msg.metadata or {})))
|
turn_response.append((msg.content, dict(msg.metadata or {})))
|
||||||
turn_done.set()
|
turn_done.set()
|
||||||
|
|
||||||
elif msg.content:
|
elif msg.content:
|
||||||
await _print_interactive_response(
|
await _print_interactive_response(
|
||||||
msg.content,
|
msg.content,
|
||||||
@@ -889,7 +872,11 @@ def agent(
|
|||||||
content=user_input,
|
content=user_input,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
nonlocal _thinking
|
||||||
|
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||||
|
with _thinking:
|
||||||
await turn_done.wait()
|
await turn_done.wait()
|
||||||
|
_thinking = None
|
||||||
|
|
||||||
if turn_response:
|
if turn_response:
|
||||||
content, meta = turn_response[0]
|
content, meta = turn_response[0]
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class ChannelsConfig(Base):
|
|||||||
|
|
||||||
Built-in and plugin channel configs are stored as extra fields (dicts).
|
Built-in and plugin channel configs are stored as extra fields (dicts).
|
||||||
Each channel parses its own config in __init__.
|
Each channel parses its own config in __init__.
|
||||||
|
Per-channel "streaming": true enables streaming output (requires send_delta impl).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
@@ -208,6 +210,100 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Stream a chat completion via Azure OpenAI SSE."""
|
||||||
|
deployment_name = model or self.default_model
|
||||||
|
url = self._build_chat_url(deployment_name)
|
||||||
|
headers = self._build_headers()
|
||||||
|
payload = self._prepare_request_payload(
|
||||||
|
deployment_name, messages, tools, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
payload["stream"] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||||
|
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
text = await response.aread()
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
return await self._consume_stream(response, on_content_delta)
|
||||||
|
except Exception as e:
|
||||||
|
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
||||||
|
|
||||||
|
async def _consume_stream(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tool_call_buffers: dict[int, dict[str, str]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:].strip()
|
||||||
|
if data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choices = chunk.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
choice = choices[0]
|
||||||
|
if choice.get("finish_reason"):
|
||||||
|
finish_reason = choice["finish_reason"]
|
||||||
|
delta = choice.get("delta") or {}
|
||||||
|
|
||||||
|
text = delta.get("content")
|
||||||
|
if text:
|
||||||
|
content_parts.append(text)
|
||||||
|
if on_content_delta:
|
||||||
|
await on_content_delta(text)
|
||||||
|
|
||||||
|
for tc in delta.get("tool_calls") or []:
|
||||||
|
idx = tc.get("index", 0)
|
||||||
|
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
||||||
|
if tc.get("id"):
|
||||||
|
buf["id"] = tc["id"]
|
||||||
|
fn = tc.get("function") or {}
|
||||||
|
if fn.get("name"):
|
||||||
|
buf["name"] = fn["name"]
|
||||||
|
if fn.get("arguments"):
|
||||||
|
buf["arguments"] += fn["arguments"]
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
ToolCallRequest(
|
||||||
|
id=buf["id"], name=buf["name"],
|
||||||
|
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
||||||
|
)
|
||||||
|
for buf in tool_call_buffers.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
"""Get the default model (also used as default deployment name)."""
|
"""Get the default model (also used as default deployment name)."""
|
||||||
return self.default_model
|
return self.default_model
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
@@ -22,22 +23,20 @@ class CustomProvider(LLMProvider):
|
|||||||
):
|
):
|
||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
# Keep affinity stable for this provider instance to improve backend cache locality,
|
|
||||||
# while still letting users attach provider-specific headers for custom gateways.
|
|
||||||
default_headers = {
|
|
||||||
"x-session-affinity": uuid.uuid4().hex,
|
|
||||||
**(extra_headers or {}),
|
|
||||||
}
|
|
||||||
self._client = AsyncOpenAI(
|
self._client = AsyncOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
default_headers=default_headers,
|
default_headers={
|
||||||
|
"x-session-affinity": uuid.uuid4().hex,
|
||||||
|
**(extra_headers or {}),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
def _build_kwargs(
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None,
|
||||||
reasoning_effort: str | None = None,
|
model: str | None, max_tokens: int, temperature: float,
|
||||||
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model or self.default_model,
|
"model": model or self.default_model,
|
||||||
"messages": self._sanitize_empty_content(messages),
|
"messages": self._sanitize_empty_content(messages),
|
||||||
@@ -48,37 +47,106 @@ class CustomProvider(LLMProvider):
|
|||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
if tools:
|
if tools:
|
||||||
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _handle_error(self, e: Exception) -> LLMResponse:
|
||||||
|
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||||
|
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}"
|
||||||
|
return LLMResponse(content=msg, finish_reason="error")
|
||||||
|
|
||||||
|
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||||
try:
|
try:
|
||||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# JSONDecodeError.doc / APIError.response.text may carry the raw body
|
return self._handle_error(e)
|
||||||
# (e.g. "unsupported model: xxx") which is far more useful than the
|
|
||||||
# generic "Expecting value …" message. Truncate to avoid huge HTML pages.
|
async def chat_stream(
|
||||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
if body and body.strip():
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error")
|
reasoning_effort: str | None = None,
|
||||||
return LLMResponse(content=f"Error: {e}", finish_reason="error")
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||||
|
kwargs["stream"] = True
|
||||||
|
try:
|
||||||
|
stream = await self._client.chat.completions.create(**kwargs)
|
||||||
|
chunks: list[Any] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
if on_content_delta and chunk.choices:
|
||||||
|
text = getattr(chunk.choices[0].delta, "content", None)
|
||||||
|
if text:
|
||||||
|
await on_content_delta(text)
|
||||||
|
return self._parse_chunks(chunks)
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_error(e)
|
||||||
|
|
||||||
def _parse(self, response: Any) -> LLMResponse:
|
def _parse(self, response: Any) -> LLMResponse:
|
||||||
if not response.choices:
|
if not response.choices:
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content="Error: API returned empty choices. This may indicate a temporary service issue or an invalid model response.",
|
content="Error: API returned empty choices.",
|
||||||
finish_reason="error"
|
finish_reason="error",
|
||||||
)
|
)
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
msg = choice.message
|
msg = choice.message
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCallRequest(id=tc.id, name=tc.function.name,
|
ToolCallRequest(
|
||||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments)
|
id=tc.id, name=tc.function.name,
|
||||||
|
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
|
||||||
|
)
|
||||||
for tc in (msg.tool_calls or [])
|
for tc in (msg.tool_calls or [])
|
||||||
]
|
]
|
||||||
u = response.usage
|
u = response.usage
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
content=msg.content, tool_calls=tool_calls,
|
||||||
|
finish_reason=choice.finish_reason or "stop",
|
||||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_chunks(self, chunks: list[Any]) -> LLMResponse:
|
||||||
|
"""Reassemble streamed chunks into a single LLMResponse."""
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tc_bufs: dict[int, dict[str, str]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
usage: dict[str, int] = {}
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
if not chunk.choices:
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
u = chunk.usage
|
||||||
|
usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0,
|
||||||
|
"total_tokens": u.total_tokens or 0}
|
||||||
|
continue
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if choice.finish_reason:
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
delta = choice.delta
|
||||||
|
if delta and delta.content:
|
||||||
|
content_parts.append(delta.content)
|
||||||
|
for tc in (delta.tool_calls or []) if delta else []:
|
||||||
|
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
|
||||||
|
if tc.id:
|
||||||
|
buf["id"] = tc.id
|
||||||
|
if tc.function and tc.function.name:
|
||||||
|
buf["name"] = tc.function.name
|
||||||
|
if tc.function and tc.function.arguments:
|
||||||
|
buf["arguments"] += tc.function.arguments
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {})
|
||||||
|
for b in tc_bufs.values()
|
||||||
|
],
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -24,16 +25,16 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
super().__init__(api_key=None, api_base=None)
|
super().__init__(api_key=None, api_base=None)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
|
|
||||||
async def chat(
|
async def _call_codex(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None,
|
||||||
model: str | None = None,
|
model: str | None,
|
||||||
max_tokens: int = 4096,
|
reasoning_effort: str | None,
|
||||||
temperature: float = 0.7,
|
tool_choice: str | dict[str, Any] | None,
|
||||||
reasoning_effort: str | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
"""Shared request logic for both chat() and chat_stream()."""
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
system_prompt, input_items = _convert_messages(messages)
|
system_prompt, input_items = _convert_messages(messages)
|
||||||
|
|
||||||
@@ -52,33 +53,45 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
"tool_choice": tool_choice or "auto",
|
"tool_choice": tool_choice or "auto",
|
||||||
"parallel_tool_calls": True,
|
"parallel_tool_calls": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
body["reasoning"] = {"effort": reasoning_effort}
|
body["reasoning"] = {"effort": reasoning_effort}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
body["tools"] = _convert_tools(tools)
|
body["tools"] = _convert_tools(tools)
|
||||||
|
|
||||||
url = DEFAULT_CODEX_URL
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True)
|
content, tool_calls, finish_reason = await _request_codex(
|
||||||
|
DEFAULT_CODEX_URL, headers, body, verify=True,
|
||||||
|
on_content_delta=on_content_delta,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
||||||
raise
|
raise
|
||||||
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False")
|
logger.warning("SSL verification failed for Codex API; retrying with verify=False")
|
||||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False)
|
content, tool_calls, finish_reason = await _request_codex(
|
||||||
return LLMResponse(
|
DEFAULT_CODEX_URL, headers, body, verify=False,
|
||||||
content=content,
|
on_content_delta=on_content_delta,
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
)
|
||||||
|
return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return LLMResponse(
|
return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error")
|
||||||
content=f"Error calling Codex: {str(e)}",
|
|
||||||
finish_reason="error",
|
async def chat(
|
||||||
)
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice, on_content_delta)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
@@ -107,13 +120,14 @@ async def _request_codex(
|
|||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
body: dict[str, Any],
|
body: dict[str, Any],
|
||||||
verify: bool,
|
verify: bool,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> tuple[str, list[ToolCallRequest], str]:
|
) -> tuple[str, list[ToolCallRequest], str]:
|
||||||
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
|
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
|
||||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
text = await response.aread()
|
text = await response.aread()
|
||||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||||
return await _consume_sse(response)
|
return await _consume_sse(response, on_content_delta)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
@@ -151,45 +165,28 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
# Handle text first.
|
|
||||||
if isinstance(content, str) and content:
|
if isinstance(content, str) and content:
|
||||||
input_items.append(
|
input_items.append({
|
||||||
{
|
"type": "message", "role": "assistant",
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "output_text", "text": content}],
|
"content": [{"type": "output_text", "text": content}],
|
||||||
"status": "completed",
|
"status": "completed", "id": f"msg_{idx}",
|
||||||
"id": f"msg_{idx}",
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
# Then handle tool calls.
|
|
||||||
for tool_call in msg.get("tool_calls", []) or []:
|
for tool_call in msg.get("tool_calls", []) or []:
|
||||||
fn = tool_call.get("function") or {}
|
fn = tool_call.get("function") or {}
|
||||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||||
call_id = call_id or f"call_{idx}"
|
input_items.append({
|
||||||
item_id = item_id or f"fc_{idx}"
|
|
||||||
input_items.append(
|
|
||||||
{
|
|
||||||
"type": "function_call",
|
"type": "function_call",
|
||||||
"id": item_id,
|
"id": item_id or f"fc_{idx}",
|
||||||
"call_id": call_id,
|
"call_id": call_id or f"call_{idx}",
|
||||||
"name": fn.get("name"),
|
"name": fn.get("name"),
|
||||||
"arguments": fn.get("arguments") or "{}",
|
"arguments": fn.get("arguments") or "{}",
|
||||||
}
|
})
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||||
input_items.append(
|
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||||
{
|
|
||||||
"type": "function_call_output",
|
|
||||||
"call_id": call_id,
|
|
||||||
"output": output_text,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return system_prompt, input_items
|
return system_prompt, input_items
|
||||||
|
|
||||||
@@ -247,7 +244,10 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any],
|
|||||||
buffer.append(line)
|
buffer.append(line)
|
||||||
|
|
||||||
|
|
||||||
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
async def _consume_sse(
|
||||||
|
response: httpx.Response,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> tuple[str, list[ToolCallRequest], str]:
|
||||||
content = ""
|
content = ""
|
||||||
tool_calls: list[ToolCallRequest] = []
|
tool_calls: list[ToolCallRequest] = []
|
||||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||||
@@ -267,7 +267,10 @@ async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequ
|
|||||||
"arguments": item.get("arguments") or "",
|
"arguments": item.get("arguments") or "",
|
||||||
}
|
}
|
||||||
elif event_type == "response.output_text.delta":
|
elif event_type == "response.output_text.delta":
|
||||||
content += event.get("delta") or ""
|
delta_text = event.get("delta") or ""
|
||||||
|
content += delta_text
|
||||||
|
if on_content_delta and delta_text:
|
||||||
|
await on_content_delta(delta_text)
|
||||||
elif event_type == "response.function_call_arguments.delta":
|
elif event_type == "response.function_call_arguments.delta":
|
||||||
call_id = event.get("call_id")
|
call_id = event.get("call_id")
|
||||||
if call_id and call_id in tool_call_buffers:
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
|||||||
Reference in New Issue
Block a user