merge: sync origin/main into local main
This commit is contained in:
77
README.md
77
README.md
@@ -1028,6 +1028,8 @@ Config file: `~/.nanobot/config.json`
|
|||||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||||
| `ollama` | LLM (local, Ollama) | — |
|
| `ollama` | LLM (local, Ollama) | — |
|
||||||
|
| `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) |
|
||||||
|
| `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) |
|
||||||
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
||||||
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
|
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
|
||||||
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
|
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
|
||||||
@@ -1163,6 +1165,81 @@ ollama run llama3.2
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>OpenVINO Model Server (local / OpenAI-compatible)</b></summary>
|
||||||
|
|
||||||
|
Run LLMs locally on Intel GPUs using [OpenVINO Model Server](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html). OVMS exposes an OpenAI-compatible API at `/v3`.
|
||||||
|
|
||||||
|
> Requires Docker and an Intel GPU with driver access (`/dev/dri`).
|
||||||
|
|
||||||
|
**1. Pull the model** (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p ov/models && cd ov
|
||||||
|
|
||||||
|
docker run -d \
|
||||||
|
--rm \
|
||||||
|
--user $(id -u):$(id -g) \
|
||||||
|
-v $(pwd)/models:/models \
|
||||||
|
openvino/model_server:latest-gpu \
|
||||||
|
--pull \
|
||||||
|
--model_name openai/gpt-oss-20b \
|
||||||
|
--model_repository_path /models \
|
||||||
|
--source_model OpenVINO/gpt-oss-20b-int4-ov \
|
||||||
|
--task text_generation \
|
||||||
|
--tool_parser gptoss \
|
||||||
|
--reasoning_parser gptoss \
|
||||||
|
--enable_prefix_caching true \
|
||||||
|
--target_device GPU
|
||||||
|
```
|
||||||
|
|
||||||
|
> This downloads the model weights. Wait for the container to finish before proceeding.
|
||||||
|
|
||||||
|
**2. Start the server** (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run -d \
|
||||||
|
--rm \
|
||||||
|
--name ovms \
|
||||||
|
--user $(id -u):$(id -g) \
|
||||||
|
-p 8000:8000 \
|
||||||
|
-v $(pwd)/models:/models \
|
||||||
|
--device /dev/dri \
|
||||||
|
--group-add=$(stat -c "%g" /dev/dri/render* | head -n 1) \
|
||||||
|
openvino/model_server:latest-gpu \
|
||||||
|
--rest_port 8000 \
|
||||||
|
--model_name openai/gpt-oss-20b \
|
||||||
|
--model_repository_path /models \
|
||||||
|
--source_model OpenVINO/gpt-oss-20b-int4-ov \
|
||||||
|
--task text_generation \
|
||||||
|
--tool_parser gptoss \
|
||||||
|
--reasoning_parser gptoss \
|
||||||
|
--enable_prefix_caching true \
|
||||||
|
--target_device GPU
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Add to config** (partial — merge into `~/.nanobot/config.json`):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"ovms": {
|
||||||
|
"apiBase": "http://localhost:8000/v3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "ovms",
|
||||||
|
"model": "openai/gpt-oss-20b"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> OVMS is a local server — no API key required. Supports tool calling (`--tool_parser gptoss`), reasoning (`--reasoning_parser gptoss`), and streaming.
|
||||||
|
> See the [official OVMS docs](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) for more details.
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
|
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
|
||||||
|
|
||||||
|
|||||||
352
docs/CHANNEL_PLUGIN_GUIDE.md
Normal file
352
docs/CHANNEL_PLUGIN_GUIDE.md
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
# Channel Plugin Guide
|
||||||
|
|
||||||
|
Build a custom nanobot channel in three steps: subclass, package, install.
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
nanobot discovers channel plugins via Python [entry points](https://packaging.python.org/en/latest/specifications/entry-points/). When `nanobot gateway` starts, it scans:
|
||||||
|
|
||||||
|
1. Built-in channels in `nanobot/channels/`
|
||||||
|
2. External packages registered under the `nanobot.channels` entry point group
|
||||||
|
|
||||||
|
If a matching config section has `"enabled": true`, the channel is instantiated and started.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
We'll build a minimal webhook channel that receives messages via HTTP POST and sends replies back.
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
nanobot-channel-webhook/
|
||||||
|
├── nanobot_channel_webhook/
|
||||||
|
│ ├── __init__.py # re-export WebhookChannel
|
||||||
|
│ └── channel.py # channel implementation
|
||||||
|
└── pyproject.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1. Create Your Channel
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanobot_channel_webhook/__init__.py
|
||||||
|
from nanobot_channel_webhook.channel import WebhookChannel
|
||||||
|
|
||||||
|
__all__ = ["WebhookChannel"]
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanobot_channel_webhook/channel.py
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookChannel(BaseChannel):
|
||||||
|
name = "webhook"
|
||||||
|
display_name = "Webhook"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_config(cls) -> dict[str, Any]:
|
||||||
|
return {"enabled": False, "port": 9000, "allowFrom": []}
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start an HTTP server that listens for incoming messages.
|
||||||
|
|
||||||
|
IMPORTANT: start() must block forever (or until stop() is called).
|
||||||
|
If it returns, the channel is considered dead.
|
||||||
|
"""
|
||||||
|
self._running = True
|
||||||
|
port = self.config.get("port", 9000)
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/message", self._on_request)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "0.0.0.0", port)
|
||||||
|
await site.start()
|
||||||
|
logger.info("Webhook listening on :{}", port)
|
||||||
|
|
||||||
|
# Block until stopped
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
await runner.cleanup()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
"""Deliver an outbound message.
|
||||||
|
|
||||||
|
msg.content — markdown text (convert to platform format as needed)
|
||||||
|
msg.media — list of local file paths to attach
|
||||||
|
msg.chat_id — the recipient (same chat_id you passed to _handle_message)
|
||||||
|
msg.metadata — may contain "_progress": True for streaming chunks
|
||||||
|
"""
|
||||||
|
logger.info("[webhook] -> {}: {}", msg.chat_id, msg.content[:80])
|
||||||
|
# In a real plugin: POST to a callback URL, send via SDK, etc.
|
||||||
|
|
||||||
|
async def _on_request(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle an incoming HTTP POST."""
|
||||||
|
body = await request.json()
|
||||||
|
sender = body.get("sender", "unknown")
|
||||||
|
chat_id = body.get("chat_id", sender)
|
||||||
|
text = body.get("text", "")
|
||||||
|
media = body.get("media", []) # list of URLs
|
||||||
|
|
||||||
|
# This is the key call: validates allowFrom, then puts the
|
||||||
|
# message onto the bus for the agent to process.
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=sender,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=text,
|
||||||
|
media=media,
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response({"ok": True})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Register the Entry Point
|
||||||
|
|
||||||
|
```toml
|
||||||
|
# pyproject.toml
|
||||||
|
[project]
|
||||||
|
name = "nanobot-channel-webhook"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = ["nanobot", "aiohttp"]
|
||||||
|
|
||||||
|
[project.entry-points."nanobot.channels"]
|
||||||
|
webhook = "nanobot_channel_webhook:WebhookChannel"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools"]
|
||||||
|
build-backend = "setuptools.backends._legacy:_Backend"
|
||||||
|
```
|
||||||
|
|
||||||
|
The key (`webhook`) becomes the config section name. The value points to your `BaseChannel` subclass.
|
||||||
|
|
||||||
|
### 3. Install & Configure
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
nanobot plugins list # verify "Webhook" shows as "plugin"
|
||||||
|
nanobot onboard # auto-adds default config for detected plugins
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit `~/.nanobot/config.json`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"webhook": {
|
||||||
|
"enabled": true,
|
||||||
|
"port": 9000,
|
||||||
|
"allowFrom": ["*"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run & Test
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nanobot gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
In another terminal:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:9000/message \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"sender": "user1", "chat_id": "user1", "text": "Hello!"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
The agent receives the message and processes it. Replies arrive in your `send()` method.
|
||||||
|
|
||||||
|
## BaseChannel API
|
||||||
|
|
||||||
|
### Required (abstract)
|
||||||
|
|
||||||
|
| Method | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `async start()` | **Must block forever.** Connect to platform, listen for messages, call `_handle_message()` on each. If this returns, the channel is dead. |
|
||||||
|
| `async stop()` | Set `self._running = False` and clean up. Called when gateway shuts down. |
|
||||||
|
| `async send(msg: OutboundMessage)` | Deliver an outbound message to the platform. |
|
||||||
|
|
||||||
|
### Provided by Base
|
||||||
|
|
||||||
|
| Method / Property | Description |
|
||||||
|
|-------------------|-------------|
|
||||||
|
| `_handle_message(sender_id, chat_id, content, media?, metadata?, session_key?)` | **Call this when you receive a message.** Checks `is_allowed()`, then publishes to the bus. Automatically sets `_wants_stream` if `supports_streaming` is true. |
|
||||||
|
| `is_allowed(sender_id)` | Checks against `config["allowFrom"]`; `"*"` allows all, `[]` denies all. |
|
||||||
|
| `default_config()` (classmethod) | Returns default config dict for `nanobot onboard`. Override to declare your fields. |
|
||||||
|
| `transcribe_audio(file_path)` | Transcribes audio via Groq Whisper (if configured). |
|
||||||
|
| `supports_streaming` (property) | `True` when config has `"streaming": true` **and** subclass overrides `send_delta()`. |
|
||||||
|
| `is_running` | Returns `self._running`. |
|
||||||
|
|
||||||
|
### Optional (streaming)
|
||||||
|
|
||||||
|
| Method | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `async send_delta(chat_id, delta, metadata?)` | Override to receive streaming chunks. See [Streaming Support](#streaming-support) for details. |
|
||||||
|
|
||||||
|
### Message Types
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class OutboundMessage:
|
||||||
|
channel: str # your channel name
|
||||||
|
chat_id: str # recipient (same value you passed to _handle_message)
|
||||||
|
content: str # markdown text — convert to platform format as needed
|
||||||
|
media: list[str] # local file paths to attach (images, audio, docs)
|
||||||
|
metadata: dict # may contain: "_progress" (bool) for streaming chunks,
|
||||||
|
# "message_id" for reply threading
|
||||||
|
```
|
||||||
|
|
||||||
|
## Streaming Support
|
||||||
|
|
||||||
|
Channels can opt into real-time streaming — the agent sends content token-by-token instead of one final message. This is entirely optional; channels work fine without it.
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
When **both** conditions are met, the agent streams content through your channel:
|
||||||
|
|
||||||
|
1. Config has `"streaming": true`
|
||||||
|
2. Your subclass overrides `send_delta()`
|
||||||
|
|
||||||
|
If either is missing, the agent falls back to the normal one-shot `send()` path.
|
||||||
|
|
||||||
|
### Implementing `send_delta`
|
||||||
|
|
||||||
|
Override `send_delta` to handle two types of calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
meta = metadata or {}
|
||||||
|
|
||||||
|
if meta.get("_stream_end"):
|
||||||
|
# Streaming finished — do final formatting, cleanup, etc.
|
||||||
|
return
|
||||||
|
|
||||||
|
# Regular delta — append text, update the message on screen
|
||||||
|
# delta contains a small chunk of text (a few tokens)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Metadata flags:**
|
||||||
|
|
||||||
|
| Flag | Meaning |
|
||||||
|
|------|---------|
|
||||||
|
| `_stream_delta: True` | A content chunk (delta contains the new text) |
|
||||||
|
| `_stream_end: True` | Streaming finished (delta is empty) |
|
||||||
|
| `_resuming: True` | More streaming rounds coming (e.g. tool call then another response) |
|
||||||
|
|
||||||
|
### Example: Webhook with Streaming
|
||||||
|
|
||||||
|
```python
|
||||||
|
class WebhookChannel(BaseChannel):
|
||||||
|
name = "webhook"
|
||||||
|
display_name = "Webhook"
|
||||||
|
|
||||||
|
def __init__(self, config, bus):
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self._buffers: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
meta = metadata or {}
|
||||||
|
if meta.get("_stream_end"):
|
||||||
|
text = self._buffers.pop(chat_id, "")
|
||||||
|
# Final delivery — format and send the complete message
|
||||||
|
await self._deliver(chat_id, text, final=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._buffers.setdefault(chat_id, "")
|
||||||
|
self._buffers[chat_id] += delta
|
||||||
|
# Incremental update — push partial text to the client
|
||||||
|
await self._deliver(chat_id, self._buffers[chat_id], final=False)
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
# Non-streaming path — unchanged
|
||||||
|
await self._deliver(msg.chat_id, msg.content, final=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Enable streaming per channel:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"webhook": {
|
||||||
|
"enabled": true,
|
||||||
|
"streaming": true,
|
||||||
|
"allowFrom": ["*"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When `streaming` is `false` (default) or omitted, only `send()` is called — no streaming overhead.
|
||||||
|
|
||||||
|
### BaseChannel Streaming API
|
||||||
|
|
||||||
|
| Method / Property | Description |
|
||||||
|
|-------------------|-------------|
|
||||||
|
| `async send_delta(chat_id, delta, metadata?)` | Override to handle streaming chunks. No-op by default. |
|
||||||
|
| `supports_streaming` (property) | Returns `True` when config has `streaming: true` **and** subclass overrides `send_delta`. |
|
||||||
|
|
||||||
|
## Config
|
||||||
|
|
||||||
|
Your channel receives config as a plain `dict`. Access fields with `.get()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def start(self) -> None:
|
||||||
|
port = self.config.get("port", 9000)
|
||||||
|
token = self.config.get("token", "")
|
||||||
|
```
|
||||||
|
|
||||||
|
`allowFrom` is handled automatically by `_handle_message()` — you don't need to check it yourself.
|
||||||
|
|
||||||
|
Override `default_config()` so `nanobot onboard` auto-populates `config.json`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@classmethod
|
||||||
|
def default_config(cls) -> dict[str, Any]:
|
||||||
|
return {"enabled": False, "port": 9000, "allowFrom": []}
|
||||||
|
```
|
||||||
|
|
||||||
|
If not overridden, the base class returns `{"enabled": false}`.
|
||||||
|
|
||||||
|
## Naming Convention
|
||||||
|
|
||||||
|
| What | Format | Example |
|
||||||
|
|------|--------|---------|
|
||||||
|
| PyPI package | `nanobot-channel-{name}` | `nanobot-channel-webhook` |
|
||||||
|
| Entry point key | `{name}` | `webhook` |
|
||||||
|
| Config section | `channels.{name}` | `channels.webhook` |
|
||||||
|
| Python package | `nanobot_channel_{name}` | `nanobot_channel_webhook` |
|
||||||
|
|
||||||
|
## Local Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/you/nanobot-channel-webhook
|
||||||
|
cd nanobot-channel-webhook
|
||||||
|
pip install -e .
|
||||||
|
nanobot plugins list # should show "Webhook" as "plugin"
|
||||||
|
nanobot gateway # test end-to-end
|
||||||
|
```
|
||||||
|
|
||||||
|
## Verify
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ nanobot plugins list
|
||||||
|
|
||||||
|
Name Source Enabled
|
||||||
|
telegram builtin yes
|
||||||
|
discord builtin no
|
||||||
|
webhook plugin yes
|
||||||
|
```
|
||||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -642,7 +641,8 @@ class AgentLoop:
|
|||||||
"""Remove <think>…</think> blocks that some models embed in content."""
|
"""Remove <think>…</think> blocks that some models embed in content."""
|
||||||
if not text:
|
if not text:
|
||||||
return None
|
return None
|
||||||
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
|
from nanobot.utils.helpers import strip_think
|
||||||
|
return strip_think(text) or None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_hint(tool_calls: list) -> str:
|
def _tool_hint(tool_calls: list) -> str:
|
||||||
@@ -812,23 +812,55 @@ class AgentLoop:
|
|||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
) -> tuple[str | None, list[str], list[dict]]:
|
||||||
"""Run the agent iteration loop."""
|
"""Run the agent iteration loop.
|
||||||
|
|
||||||
|
*on_stream*: called with each content delta during streaming.
|
||||||
|
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||||
|
``resuming=True`` means tool calls follow (spinner should restart);
|
||||||
|
``resuming=False`` means this is the final response.
|
||||||
|
"""
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_content = None
|
final_content = None
|
||||||
tools_used: list[str] = []
|
tools_used: list[str] = []
|
||||||
|
|
||||||
|
# Wrap on_stream with stateful think-tag filter so downstream
|
||||||
|
# consumers (CLI, channels) never see <think> blocks.
|
||||||
|
_raw_stream = on_stream
|
||||||
|
_stream_buf = ""
|
||||||
|
|
||||||
|
async def _filtered_stream(delta: str) -> None:
|
||||||
|
nonlocal _stream_buf
|
||||||
|
from nanobot.utils.helpers import strip_think
|
||||||
|
prev_clean = strip_think(_stream_buf)
|
||||||
|
_stream_buf += delta
|
||||||
|
new_clean = strip_think(_stream_buf)
|
||||||
|
incremental = new_clean[len(prev_clean):]
|
||||||
|
if incremental and _raw_stream:
|
||||||
|
await _raw_stream(incremental)
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
while iteration < self.max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
tool_defs = self.tools.get_definitions()
|
tool_defs = self.tools.get_definitions()
|
||||||
|
|
||||||
|
if on_stream:
|
||||||
|
response = await self.provider.chat_stream_with_retry(
|
||||||
|
messages=messages,
|
||||||
|
tools=tool_defs,
|
||||||
|
model=self.model,
|
||||||
|
on_content_delta=_filtered_stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
response = await self.provider.chat_with_retry(
|
response = await self.provider.chat_with_retry(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tool_defs,
|
tools=tool_defs,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
usage = getattr(response, "usage", None) or {}
|
usage = getattr(response, "usage", None) or {}
|
||||||
self._last_usage = {
|
self._last_usage = {
|
||||||
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
||||||
@@ -836,11 +868,18 @@ class AgentLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
if on_stream and on_stream_end:
|
||||||
|
await on_stream_end(resuming=True)
|
||||||
|
_stream_buf = ""
|
||||||
|
|
||||||
if on_progress:
|
if on_progress:
|
||||||
|
if not on_stream:
|
||||||
thought = self._strip_think(response.content)
|
thought = self._strip_think(response.content)
|
||||||
if thought:
|
if thought:
|
||||||
await on_progress(thought)
|
await on_progress(thought)
|
||||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
tool_hint = self._tool_hint(response.tool_calls)
|
||||||
|
tool_hint = self._strip_think(tool_hint)
|
||||||
|
await on_progress(tool_hint, tool_hint=True)
|
||||||
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
tc.to_openai_tool_call()
|
tc.to_openai_tool_call()
|
||||||
@@ -861,9 +900,11 @@ class AgentLoop:
|
|||||||
messages, tool_call.id, tool_call.name, result
|
messages, tool_call.id, tool_call.name, result
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if on_stream and on_stream_end:
|
||||||
|
await on_stream_end(resuming=False)
|
||||||
|
_stream_buf = ""
|
||||||
|
|
||||||
clean = self._strip_think(response.content)
|
clean = self._strip_think(response.content)
|
||||||
# Don't persist error responses to session history — they can
|
|
||||||
# poison the context and cause permanent 400 loops (#1303).
|
|
||||||
if response.finish_reason == "error":
|
if response.finish_reason == "error":
|
||||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||||
@@ -956,7 +997,23 @@ class AgentLoop:
|
|||||||
"""Process a message under the global lock."""
|
"""Process a message under the global lock."""
|
||||||
async with self._processing_lock:
|
async with self._processing_lock:
|
||||||
try:
|
try:
|
||||||
response = await self._process_message(msg)
|
on_stream = on_stream_end = None
|
||||||
|
if msg.metadata.get("_wants_stream"):
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content=delta, metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
|
||||||
|
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="", metadata={"_stream_end": True, "_resuming": resuming},
|
||||||
|
))
|
||||||
|
|
||||||
|
response = await self._process_message(
|
||||||
|
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||||
|
)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
await self.bus.publish_outbound(response)
|
await self.bus.publish_outbound(response)
|
||||||
elif msg.channel == "cli":
|
elif msg.channel == "cli":
|
||||||
@@ -1173,6 +1230,8 @@ class AgentLoop:
|
|||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
session_key: str | None = None,
|
session_key: str | None = None,
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""Process a single inbound message and return the response."""
|
"""Process a single inbound message and return the response."""
|
||||||
await self._reload_runtime_config_if_needed()
|
await self._reload_runtime_config_if_needed()
|
||||||
@@ -1190,7 +1249,6 @@ class AgentLoop:
|
|||||||
await self._run_preflight_token_consolidation(session)
|
await self._run_preflight_token_consolidation(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
# Subagent results should be assistant role, other system messages use user role
|
|
||||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
@@ -1280,7 +1338,10 @@ class AgentLoop:
|
|||||||
))
|
))
|
||||||
|
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
initial_messages, on_progress=on_progress or _bus_progress,
|
initial_messages,
|
||||||
|
on_progress=on_progress or _bus_progress,
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
@@ -1295,7 +1356,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
return await self._maybe_attach_voice_reply(
|
outbound = await self._maybe_attach_voice_reply(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
@@ -1304,6 +1365,24 @@ class AgentLoop:
|
|||||||
),
|
),
|
||||||
persona=persona,
|
persona=persona,
|
||||||
)
|
)
|
||||||
|
if outbound is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
meta = dict(outbound.metadata or {})
|
||||||
|
content = outbound.content
|
||||||
|
if on_stream is not None:
|
||||||
|
if outbound.media:
|
||||||
|
content = ""
|
||||||
|
else:
|
||||||
|
meta["_streamed"] = True
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=outbound.channel,
|
||||||
|
chat_id=outbound.chat_id,
|
||||||
|
content=content,
|
||||||
|
reply_to=outbound.reply_to,
|
||||||
|
media=list(outbound.media or []),
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
|
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
|
||||||
@@ -1391,8 +1470,13 @@ class AgentLoop:
|
|||||||
channel: str = "cli",
|
channel: str = "cli",
|
||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""Process a message directly and return the outbound payload."""
|
"""Process a message directly and return the outbound payload."""
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||||
return await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
return await self._process_message(
|
||||||
|
msg, session_key=session_key, on_progress=on_progress,
|
||||||
|
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||||
|
)
|
||||||
|
|||||||
@@ -81,6 +81,17 @@ class BaseChannel(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
"""Deliver a streaming text chunk. Override in subclass to enable streaming."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_streaming(self) -> bool:
|
||||||
|
"""True when config enables streaming AND this subclass implements send_delta."""
|
||||||
|
cfg = self.config
|
||||||
|
streaming = cfg.get("streaming", False) if isinstance(cfg, dict) else getattr(cfg, "streaming", False)
|
||||||
|
return bool(streaming) and type(self).send_delta is not BaseChannel.send_delta
|
||||||
|
|
||||||
def is_allowed(self, sender_id: str) -> bool:
|
def is_allowed(self, sender_id: str) -> bool:
|
||||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||||
allow_list = getattr(self.config, "allow_from", [])
|
allow_list = getattr(self.config, "allow_from", [])
|
||||||
@@ -121,13 +132,17 @@ class BaseChannel(ABC):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
meta = metadata or {}
|
||||||
|
if self.supports_streaming:
|
||||||
|
meta = {**meta, "_wants_stream": True}
|
||||||
|
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(
|
||||||
channel=self.name,
|
channel=self.name,
|
||||||
sender_id=str(sender_id),
|
sender_id=str(sender_id),
|
||||||
chat_id=str(chat_id),
|
chat_id=str(chat_id),
|
||||||
content=content,
|
content=content,
|
||||||
media=media or [],
|
media=media or [],
|
||||||
metadata=metadata or {},
|
metadata=meta,
|
||||||
session_key_override=session_key,
|
session_key_override=session_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -190,6 +190,11 @@ class ChannelManager:
|
|||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
try:
|
try:
|
||||||
|
if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
|
||||||
|
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||||
|
elif msg.metadata.get("_streamed"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
await channel.send(msg)
|
await channel.send(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import asyncio
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from telegram import BotCommand, ReplyParameters, Update
|
from telegram import BotCommand, ReplyParameters, Update
|
||||||
@@ -157,6 +159,16 @@ def _markdown_to_telegram_html(text: str) -> str:
|
|||||||
|
|
||||||
_SEND_MAX_RETRIES = 3
|
_SEND_MAX_RETRIES = 3
|
||||||
_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
|
_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _StreamBuf:
|
||||||
|
"""Per-chat streaming accumulator for progressive message editing."""
|
||||||
|
text: str = ""
|
||||||
|
message_id: int | None = None
|
||||||
|
last_edit: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class TelegramChannel(BaseChannel):
|
class TelegramChannel(BaseChannel):
|
||||||
"""
|
"""
|
||||||
Telegram channel using long polling.
|
Telegram channel using long polling.
|
||||||
@@ -173,7 +185,11 @@ class TelegramChannel(BaseChannel):
|
|||||||
def default_config(cls) -> dict[str, object]:
|
def default_config(cls) -> dict[str, object]:
|
||||||
return TelegramConfig().model_dump(by_alias=True)
|
return TelegramConfig().model_dump(by_alias=True)
|
||||||
|
|
||||||
def __init__(self, config: TelegramConfig | TelegramInstanceConfig, bus: MessageBus):
|
_STREAM_EDIT_INTERVAL = 0.6 # min seconds between edit_message_text calls
|
||||||
|
|
||||||
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
|
if isinstance(config, dict):
|
||||||
|
config = TelegramConfig.model_validate(config)
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.config: TelegramConfig | TelegramInstanceConfig = config
|
self.config: TelegramConfig | TelegramInstanceConfig = config
|
||||||
self._app: Application | None = None
|
self._app: Application | None = None
|
||||||
@@ -184,6 +200,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._message_threads: dict[tuple[str, int], int] = {}
|
self._message_threads: dict[tuple[str, int], int] = {}
|
||||||
self._bot_user_id: int | None = None
|
self._bot_user_id: int | None = None
|
||||||
self._bot_username: str | None = None
|
self._bot_username: str | None = None
|
||||||
|
self._stream_bufs: dict[str, _StreamBuf] = {} # chat_id -> streaming state
|
||||||
|
|
||||||
def is_allowed(self, sender_id: str) -> bool:
|
def is_allowed(self, sender_id: str) -> bool:
|
||||||
"""Preserve Telegram's legacy id|username allowlist matching."""
|
"""Preserve Telegram's legacy id|username allowlist matching."""
|
||||||
@@ -410,13 +427,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
# Send text content
|
# Send text content
|
||||||
if msg.content and msg.content != "[empty message]":
|
if msg.content and msg.content != "[empty message]":
|
||||||
is_progress = msg.metadata.get("_progress", False)
|
|
||||||
|
|
||||||
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||||
# Final response: simulate streaming via draft, then persist.
|
|
||||||
if not is_progress:
|
|
||||||
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
|
|
||||||
else:
|
|
||||||
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||||
|
|
||||||
async def _call_with_retry(self, fn, *args, **kwargs):
|
async def _call_with_retry(self, fn, *args, **kwargs):
|
||||||
@@ -463,29 +474,67 @@ class TelegramChannel(BaseChannel):
|
|||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
logger.error("Error sending Telegram message: {}", e2)
|
logger.error("Error sending Telegram message: {}", e2)
|
||||||
|
|
||||||
async def _send_with_streaming(
|
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||||
self,
|
"""Progressive message editing: send on first delta, edit on subsequent ones."""
|
||||||
chat_id: int,
|
if not self._app:
|
||||||
text: str,
|
return
|
||||||
reply_params=None,
|
meta = metadata or {}
|
||||||
thread_kwargs: dict | None = None,
|
int_chat_id = int(chat_id)
|
||||||
) -> None:
|
|
||||||
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
if meta.get("_stream_end"):
|
||||||
draft_id = int(time.time() * 1000) % (2**31)
|
buf = self._stream_bufs.pop(chat_id, None)
|
||||||
|
if not buf or not buf.message_id or not buf.text:
|
||||||
|
return
|
||||||
|
self._stop_typing(chat_id)
|
||||||
try:
|
try:
|
||||||
step = max(len(text) // 8, 40)
|
html = _markdown_to_telegram_html(buf.text)
|
||||||
for i in range(step, len(text), step):
|
await self._call_with_retry(
|
||||||
await self._app.bot.send_message_draft(
|
self._app.bot.edit_message_text,
|
||||||
chat_id=chat_id, draft_id=draft_id, text=text[:i],
|
chat_id=int_chat_id, message_id=buf.message_id,
|
||||||
|
text=html, parse_mode="HTML",
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0.04)
|
except Exception as e:
|
||||||
await self._app.bot.send_message_draft(
|
logger.debug("Final stream edit failed (HTML), trying plain: {}", e)
|
||||||
chat_id=chat_id, draft_id=draft_id, text=text,
|
try:
|
||||||
|
await self._call_with_retry(
|
||||||
|
self._app.bot.edit_message_text,
|
||||||
|
chat_id=int_chat_id, message_id=buf.message_id,
|
||||||
|
text=buf.text,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0.15)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
await self._send_text(chat_id, text, reply_params, thread_kwargs)
|
return
|
||||||
|
|
||||||
|
buf = self._stream_bufs.get(chat_id)
|
||||||
|
if buf is None:
|
||||||
|
buf = _StreamBuf()
|
||||||
|
self._stream_bufs[chat_id] = buf
|
||||||
|
buf.text += delta
|
||||||
|
|
||||||
|
if not buf.text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if buf.message_id is None:
|
||||||
|
try:
|
||||||
|
sent = await self._call_with_retry(
|
||||||
|
self._app.bot.send_message,
|
||||||
|
chat_id=int_chat_id, text=buf.text,
|
||||||
|
)
|
||||||
|
buf.message_id = sent.message_id
|
||||||
|
buf.last_edit = now
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Stream initial send failed: {}", e)
|
||||||
|
elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
|
||||||
|
try:
|
||||||
|
await self._call_with_retry(
|
||||||
|
self._app.bot.edit_message_text,
|
||||||
|
chat_id=int_chat_id, message_id=buf.message_id,
|
||||||
|
text=buf.text,
|
||||||
|
)
|
||||||
|
buf.last_edit = now
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from rich.table import Table
|
|||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from nanobot import __logo__, __version__
|
from nanobot import __logo__, __version__
|
||||||
|
from nanobot.cli.stream import StreamRenderer, ThinkingSpinner
|
||||||
from nanobot.config.paths import get_workspace_path
|
from nanobot.config.paths import get_workspace_path
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.utils.helpers import sync_workspace_templates
|
from nanobot.utils.helpers import sync_workspace_templates
|
||||||
@@ -187,46 +188,13 @@ async def _print_interactive_response(
|
|||||||
await run_in_terminal(_write)
|
await run_in_terminal(_write)
|
||||||
|
|
||||||
|
|
||||||
class _ThinkingSpinner:
|
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||||
"""Spinner wrapper with pause support for clean progress output."""
|
|
||||||
|
|
||||||
def __init__(self, enabled: bool):
|
|
||||||
self._spinner = console.status(
|
|
||||||
"[dim]nanobot is thinking...[/dim]", spinner="dots"
|
|
||||||
) if enabled else None
|
|
||||||
self._active = False
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
if self._spinner:
|
|
||||||
self._spinner.start()
|
|
||||||
self._active = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
|
||||||
self._active = False
|
|
||||||
if self._spinner:
|
|
||||||
self._spinner.stop()
|
|
||||||
return False
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def pause(self):
|
|
||||||
"""Temporarily stop spinner while printing progress."""
|
|
||||||
if self._spinner and self._active:
|
|
||||||
self._spinner.stop()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if self._spinner and self._active:
|
|
||||||
self._spinner.start()
|
|
||||||
|
|
||||||
|
|
||||||
def _print_cli_progress_line(text: str, thinking: _ThinkingSpinner | None) -> None:
|
|
||||||
"""Print a CLI progress line, pausing the spinner if needed."""
|
"""Print a CLI progress line, pausing the spinner if needed."""
|
||||||
with thinking.pause() if thinking else nullcontext():
|
with thinking.pause() if thinking else nullcontext():
|
||||||
console.print(f" [dim]↳ {text}[/dim]")
|
console.print(f" [dim]↳ {text}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
async def _print_interactive_progress_line(text: str, thinking: _ThinkingSpinner | None) -> None:
|
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||||
"""Print an interactive progress line, pausing the spinner if needed."""
|
"""Print an interactive progress line, pausing the spinner if needed."""
|
||||||
with thinking.pause() if thinking else nullcontext():
|
with thinking.pause() if thinking else nullcontext():
|
||||||
await _print_interactive_line(text)
|
await _print_interactive_line(text)
|
||||||
@@ -467,6 +435,14 @@ def _make_provider(config: Config):
|
|||||||
api_base=p.api_base,
|
api_base=p.api_base,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
# OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3
|
||||||
|
elif provider_name == "ovms":
|
||||||
|
from nanobot.providers.custom_provider import CustomProvider
|
||||||
|
provider = CustomProvider(
|
||||||
|
api_key=p.api_key if p else "no-key",
|
||||||
|
api_base=config.get_api_base(model) or "http://localhost:8000/v3",
|
||||||
|
default_model=model,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
@@ -788,7 +764,7 @@ def agent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Shared reference for progress callbacks
|
# Shared reference for progress callbacks
|
||||||
_thinking: _ThinkingSpinner | None = None
|
_thinking: ThinkingSpinner | None = None
|
||||||
|
|
||||||
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
|
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
ch = agent_loop.channels_config
|
ch = agent_loop.channels_config
|
||||||
@@ -801,13 +777,14 @@ def agent(
|
|||||||
if message:
|
if message:
|
||||||
# Single message mode — direct call, no bus needed
|
# Single message mode — direct call, no bus needed
|
||||||
async def run_once():
|
async def run_once():
|
||||||
nonlocal _thinking
|
renderer = StreamRenderer(render_markdown=markdown)
|
||||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
|
||||||
with _thinking:
|
|
||||||
response = await agent_loop.process_direct(
|
response = await agent_loop.process_direct(
|
||||||
message, session_id, on_progress=_cli_progress,
|
message, session_id,
|
||||||
|
on_progress=_cli_progress,
|
||||||
|
on_stream=renderer.on_delta,
|
||||||
|
on_stream_end=renderer.on_end,
|
||||||
)
|
)
|
||||||
_thinking = None
|
if not renderer.streamed:
|
||||||
_print_agent_response(
|
_print_agent_response(
|
||||||
response.content if response else "",
|
response.content if response else "",
|
||||||
render_markdown=markdown,
|
render_markdown=markdown,
|
||||||
@@ -848,11 +825,27 @@ def agent(
|
|||||||
turn_done = asyncio.Event()
|
turn_done = asyncio.Event()
|
||||||
turn_done.set()
|
turn_done.set()
|
||||||
turn_response: list[tuple[str, dict]] = []
|
turn_response: list[tuple[str, dict]] = []
|
||||||
|
renderer: StreamRenderer | None = None
|
||||||
|
|
||||||
async def _consume_outbound():
|
async def _consume_outbound():
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
|
||||||
|
if msg.metadata.get("_stream_delta"):
|
||||||
|
if renderer:
|
||||||
|
await renderer.on_delta(msg.content)
|
||||||
|
continue
|
||||||
|
if msg.metadata.get("_stream_end"):
|
||||||
|
if renderer:
|
||||||
|
await renderer.on_end(
|
||||||
|
resuming=msg.metadata.get("_resuming", False),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if msg.metadata.get("_streamed"):
|
||||||
|
turn_done.set()
|
||||||
|
continue
|
||||||
|
|
||||||
if msg.metadata.get("_progress"):
|
if msg.metadata.get("_progress"):
|
||||||
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||||
ch = agent_loop.channels_config
|
ch = agent_loop.channels_config
|
||||||
@@ -862,8 +855,9 @@ def agent(
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
await _print_interactive_progress_line(msg.content, _thinking)
|
await _print_interactive_progress_line(msg.content, _thinking)
|
||||||
|
continue
|
||||||
|
|
||||||
elif not turn_done.is_set():
|
if not turn_done.is_set():
|
||||||
if msg.content:
|
if msg.content:
|
||||||
turn_response.append((msg.content, dict(msg.metadata or {})))
|
turn_response.append((msg.content, dict(msg.metadata or {})))
|
||||||
turn_done.set()
|
turn_done.set()
|
||||||
@@ -897,23 +891,24 @@ def agent(
|
|||||||
|
|
||||||
turn_done.clear()
|
turn_done.clear()
|
||||||
turn_response.clear()
|
turn_response.clear()
|
||||||
|
renderer = StreamRenderer(render_markdown=markdown)
|
||||||
|
|
||||||
await bus.publish_inbound(InboundMessage(
|
await bus.publish_inbound(InboundMessage(
|
||||||
channel=cli_channel,
|
channel=cli_channel,
|
||||||
sender_id="user",
|
sender_id="user",
|
||||||
chat_id=cli_chat_id,
|
chat_id=cli_chat_id,
|
||||||
content=user_input,
|
content=user_input,
|
||||||
|
metadata={"_wants_stream": True},
|
||||||
))
|
))
|
||||||
|
|
||||||
nonlocal _thinking
|
|
||||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
|
||||||
with _thinking:
|
|
||||||
await turn_done.wait()
|
await turn_done.wait()
|
||||||
_thinking = None
|
|
||||||
|
|
||||||
if turn_response:
|
if turn_response:
|
||||||
content, meta = turn_response[0]
|
content, meta = turn_response[0]
|
||||||
_print_agent_response(content, render_markdown=markdown, metadata=meta)
|
if content and not meta.get("_streamed"):
|
||||||
|
_print_agent_response(
|
||||||
|
content, render_markdown=markdown, metadata=meta,
|
||||||
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
|
|||||||
121
nanobot/cli/stream.py
Normal file
121
nanobot/cli/stream.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Streaming renderer for CLI output.
|
||||||
|
|
||||||
|
Uses Rich Live with auto_refresh=False for stable, flicker-free
|
||||||
|
markdown rendering during streaming. Ellipsis mode handles overflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from nanobot import __logo__
|
||||||
|
|
||||||
|
|
||||||
|
def _make_console() -> Console:
|
||||||
|
return Console(file=sys.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
class ThinkingSpinner:
|
||||||
|
"""Spinner that shows 'nanobot is thinking...' with pause support."""
|
||||||
|
|
||||||
|
def __init__(self, console: Console | None = None):
|
||||||
|
c = console or _make_console()
|
||||||
|
self._spinner = c.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
||||||
|
self._active = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._spinner.start()
|
||||||
|
self._active = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
self._active = False
|
||||||
|
self._spinner.stop()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def pause(self):
|
||||||
|
"""Context manager: temporarily stop spinner for clean output."""
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _ctx():
|
||||||
|
if self._spinner and self._active:
|
||||||
|
self._spinner.stop()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if self._spinner and self._active:
|
||||||
|
self._spinner.start()
|
||||||
|
|
||||||
|
return _ctx()
|
||||||
|
|
||||||
|
|
||||||
|
class StreamRenderer:
|
||||||
|
"""Rich Live streaming with markdown. auto_refresh=False avoids render races.
|
||||||
|
|
||||||
|
Deltas arrive pre-filtered (no <think> tags) from the agent loop.
|
||||||
|
|
||||||
|
Flow per round:
|
||||||
|
spinner -> first visible delta -> header + Live renders ->
|
||||||
|
on_end -> Live stops (content stays on screen)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, render_markdown: bool = True, show_spinner: bool = True):
|
||||||
|
self._md = render_markdown
|
||||||
|
self._show_spinner = show_spinner
|
||||||
|
self._buf = ""
|
||||||
|
self._live: Live | None = None
|
||||||
|
self._t = 0.0
|
||||||
|
self.streamed = False
|
||||||
|
self._spinner: ThinkingSpinner | None = None
|
||||||
|
self._start_spinner()
|
||||||
|
|
||||||
|
def _render(self):
|
||||||
|
return Markdown(self._buf) if self._md and self._buf else Text(self._buf or "")
|
||||||
|
|
||||||
|
def _start_spinner(self) -> None:
|
||||||
|
if self._show_spinner:
|
||||||
|
self._spinner = ThinkingSpinner()
|
||||||
|
self._spinner.__enter__()
|
||||||
|
|
||||||
|
def _stop_spinner(self) -> None:
|
||||||
|
if self._spinner:
|
||||||
|
self._spinner.__exit__(None, None, None)
|
||||||
|
self._spinner = None
|
||||||
|
|
||||||
|
async def on_delta(self, delta: str) -> None:
|
||||||
|
self.streamed = True
|
||||||
|
self._buf += delta
|
||||||
|
if self._live is None:
|
||||||
|
if not self._buf.strip():
|
||||||
|
return
|
||||||
|
self._stop_spinner()
|
||||||
|
c = _make_console()
|
||||||
|
c.print()
|
||||||
|
c.print(f"[cyan]{__logo__} nanobot[/cyan]")
|
||||||
|
self._live = Live(self._render(), console=c, auto_refresh=False)
|
||||||
|
self._live.start()
|
||||||
|
now = time.monotonic()
|
||||||
|
if "\n" in delta or (now - self._t) > 0.05:
|
||||||
|
self._live.update(self._render())
|
||||||
|
self._live.refresh()
|
||||||
|
self._t = now
|
||||||
|
|
||||||
|
async def on_end(self, *, resuming: bool = False) -> None:
|
||||||
|
if self._live:
|
||||||
|
self._live.update(self._render())
|
||||||
|
self._live.refresh()
|
||||||
|
self._live.stop()
|
||||||
|
self._live = None
|
||||||
|
self._stop_spinner()
|
||||||
|
if resuming:
|
||||||
|
self._buf = ""
|
||||||
|
self._start_spinner()
|
||||||
|
else:
|
||||||
|
_make_console().print()
|
||||||
@@ -49,6 +49,7 @@ class TelegramConfig(Base):
|
|||||||
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned or replied to, "open" responds to all
|
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned or replied to, "open" responds to all
|
||||||
connection_pool_size: int = 32 # Outbound Telegram API HTTP pool size
|
connection_pool_size: int = 32 # Outbound Telegram API HTTP pool size
|
||||||
pool_timeout: float = 5.0 # Shared HTTP pool timeout for bot sends and getUpdates
|
pool_timeout: float = 5.0 # Shared HTTP pool timeout for bot sends and getUpdates
|
||||||
|
streaming: bool = True # Progressive edit-based streaming for final text replies
|
||||||
|
|
||||||
|
|
||||||
class TelegramInstanceConfig(TelegramConfig):
|
class TelegramInstanceConfig(TelegramConfig):
|
||||||
@@ -387,7 +388,14 @@ def _coerce_multi_channel_config(
|
|||||||
|
|
||||||
|
|
||||||
class ChannelsConfig(Base):
|
class ChannelsConfig(Base):
|
||||||
"""Configuration for chat channels."""
|
"""Configuration for chat channels.
|
||||||
|
|
||||||
|
Built-in and plugin channel configs are stored as extra fields (dicts).
|
||||||
|
Each channel parses its own config in __init__.
|
||||||
|
Per-channel "streaming": true enables streaming output (requires send_delta impl).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
send_progress: bool = True # stream agent's text progress to the channel
|
send_progress: bool = True # stream agent's text progress to the channel
|
||||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||||
@@ -480,9 +488,11 @@ class ProvidersConfig(Base):
|
|||||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
|
dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
||||||
|
ovms: ProviderConfig = Field(default_factory=ProviderConfig) # OpenVINO Model Server (OVMS)
|
||||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
|
mistral: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
@@ -208,6 +210,100 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Stream a chat completion via Azure OpenAI SSE."""
|
||||||
|
deployment_name = model or self.default_model
|
||||||
|
url = self._build_chat_url(deployment_name)
|
||||||
|
headers = self._build_headers()
|
||||||
|
payload = self._prepare_request_payload(
|
||||||
|
deployment_name, messages, tools, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
payload["stream"] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||||
|
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
text = await response.aread()
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
return await self._consume_stream(response, on_content_delta)
|
||||||
|
except Exception as e:
|
||||||
|
return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error")
|
||||||
|
|
||||||
|
async def _consume_stream(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Parse Azure OpenAI SSE stream into an LLMResponse."""
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tool_call_buffers: dict[int, dict[str, str]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:].strip()
|
||||||
|
if data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choices = chunk.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
choice = choices[0]
|
||||||
|
if choice.get("finish_reason"):
|
||||||
|
finish_reason = choice["finish_reason"]
|
||||||
|
delta = choice.get("delta") or {}
|
||||||
|
|
||||||
|
text = delta.get("content")
|
||||||
|
if text:
|
||||||
|
content_parts.append(text)
|
||||||
|
if on_content_delta:
|
||||||
|
await on_content_delta(text)
|
||||||
|
|
||||||
|
for tc in delta.get("tool_calls") or []:
|
||||||
|
idx = tc.get("index", 0)
|
||||||
|
buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""})
|
||||||
|
if tc.get("id"):
|
||||||
|
buf["id"] = tc["id"]
|
||||||
|
fn = tc.get("function") or {}
|
||||||
|
if fn.get("name"):
|
||||||
|
buf["name"] = fn["name"]
|
||||||
|
if fn.get("arguments"):
|
||||||
|
buf["arguments"] += fn["arguments"]
|
||||||
|
|
||||||
|
tool_calls = [
|
||||||
|
ToolCallRequest(
|
||||||
|
id=buf["id"], name=buf["name"],
|
||||||
|
arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {},
|
||||||
|
)
|
||||||
|
for buf in tool_call_buffers.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
"""Get the default model (also used as default deployment name)."""
|
"""Get the default model (also used as default deployment name)."""
|
||||||
return self.default_model
|
return self.default_model
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -223,6 +224,90 @@ class LLMProvider(ABC):
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
|
||||||
|
|
||||||
|
Returns the same ``LLMResponse`` as :meth:`chat`. The default
|
||||||
|
implementation falls back to a non-streaming call and delivers the
|
||||||
|
full content as a single delta. Providers that support native
|
||||||
|
streaming should override this method.
|
||||||
|
"""
|
||||||
|
response = await self.chat(
|
||||||
|
messages=messages, tools=tools, model=model,
|
||||||
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
if on_content_delta and response.content:
|
||||||
|
await on_content_delta(response.content)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
"""Call chat_stream() and convert unexpected exceptions to error responses."""
|
||||||
|
try:
|
||||||
|
return await self.chat_stream(**kwargs)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: object = _SENTINEL,
|
||||||
|
temperature: object = _SENTINEL,
|
||||||
|
reasoning_effort: object = _SENTINEL,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Call chat_stream() with retry on transient provider failures."""
|
||||||
|
if max_tokens is self._SENTINEL:
|
||||||
|
max_tokens = self.generation.max_tokens
|
||||||
|
if temperature is self._SENTINEL:
|
||||||
|
temperature = self.generation.temperature
|
||||||
|
if reasoning_effort is self._SENTINEL:
|
||||||
|
reasoning_effort = self.generation.reasoning_effort
|
||||||
|
|
||||||
|
kw: dict[str, Any] = dict(
|
||||||
|
messages=messages, tools=tools, model=model,
|
||||||
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
|
on_content_delta=on_content_delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||||
|
response = await self._safe_chat_stream(**kw)
|
||||||
|
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
return response
|
||||||
|
|
||||||
|
if not self._is_transient_error(response.content):
|
||||||
|
stripped = self._strip_image_content(messages)
|
||||||
|
if stripped is not None:
|
||||||
|
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||||
|
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
||||||
|
return response
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||||
|
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||||
|
(response.content or "")[:120].lower(),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
return await self._safe_chat_stream(**kw)
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
@@ -22,22 +23,20 @@ class CustomProvider(LLMProvider):
|
|||||||
):
|
):
|
||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
# Keep affinity stable for this provider instance to improve backend cache locality,
|
|
||||||
# while still letting users attach provider-specific headers for custom gateways.
|
|
||||||
default_headers = {
|
|
||||||
"x-session-affinity": uuid.uuid4().hex,
|
|
||||||
**(extra_headers or {}),
|
|
||||||
}
|
|
||||||
self._client = AsyncOpenAI(
|
self._client = AsyncOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
default_headers=default_headers,
|
default_headers={
|
||||||
|
"x-session-affinity": uuid.uuid4().hex,
|
||||||
|
**(extra_headers or {}),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
def _build_kwargs(
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None,
|
||||||
reasoning_effort: str | None = None,
|
model: str | None, max_tokens: int, temperature: float,
|
||||||
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model or self.default_model,
|
"model": model or self.default_model,
|
||||||
"messages": self._sanitize_empty_content(messages),
|
"messages": self._sanitize_empty_content(messages),
|
||||||
@@ -48,37 +47,106 @@ class CustomProvider(LLMProvider):
|
|||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
if tools:
|
if tools:
|
||||||
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def _handle_error(self, e: Exception) -> LLMResponse:
|
||||||
|
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||||
|
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}"
|
||||||
|
return LLMResponse(content=msg, finish_reason="error")
|
||||||
|
|
||||||
|
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||||
try:
|
try:
|
||||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# JSONDecodeError.doc / APIError.response.text may carry the raw body
|
return self._handle_error(e)
|
||||||
# (e.g. "unsupported model: xxx") which is far more useful than the
|
|
||||||
# generic "Expecting value …" message. Truncate to avoid huge HTML pages.
|
async def chat_stream(
|
||||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
if body and body.strip():
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error")
|
reasoning_effort: str | None = None,
|
||||||
return LLMResponse(content=f"Error: {e}", finish_reason="error")
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||||
|
kwargs["stream"] = True
|
||||||
|
try:
|
||||||
|
stream = await self._client.chat.completions.create(**kwargs)
|
||||||
|
chunks: list[Any] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
if on_content_delta and chunk.choices:
|
||||||
|
text = getattr(chunk.choices[0].delta, "content", None)
|
||||||
|
if text:
|
||||||
|
await on_content_delta(text)
|
||||||
|
return self._parse_chunks(chunks)
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_error(e)
|
||||||
|
|
||||||
def _parse(self, response: Any) -> LLMResponse:
|
def _parse(self, response: Any) -> LLMResponse:
|
||||||
if not response.choices:
|
if not response.choices:
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content="Error: API returned empty choices. This may indicate a temporary service issue or an invalid model response.",
|
content="Error: API returned empty choices.",
|
||||||
finish_reason="error"
|
finish_reason="error",
|
||||||
)
|
)
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
msg = choice.message
|
msg = choice.message
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCallRequest(id=tc.id, name=tc.function.name,
|
ToolCallRequest(
|
||||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments)
|
id=tc.id, name=tc.function.name,
|
||||||
|
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
|
||||||
|
)
|
||||||
for tc in (msg.tool_calls or [])
|
for tc in (msg.tool_calls or [])
|
||||||
]
|
]
|
||||||
u = response.usage
|
u = response.usage
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
content=msg.content, tool_calls=tool_calls,
|
||||||
|
finish_reason=choice.finish_reason or "stop",
|
||||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_chunks(self, chunks: list[Any]) -> LLMResponse:
|
||||||
|
"""Reassemble streamed chunks into a single LLMResponse."""
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tc_bufs: dict[int, dict[str, str]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
usage: dict[str, int] = {}
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
if not chunk.choices:
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
u = chunk.usage
|
||||||
|
usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0,
|
||||||
|
"total_tokens": u.total_tokens or 0}
|
||||||
|
continue
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if choice.finish_reason:
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
delta = choice.delta
|
||||||
|
if delta and delta.content:
|
||||||
|
content_parts.append(delta.content)
|
||||||
|
for tc in (delta.tool_calls or []) if delta else []:
|
||||||
|
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
|
||||||
|
if tc.id:
|
||||||
|
buf["id"] = tc.id
|
||||||
|
if tc.function and tc.function.name:
|
||||||
|
buf["name"] = tc.function.name
|
||||||
|
if tc.function and tc.function.arguments:
|
||||||
|
buf["arguments"] += tc.function.arguments
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {})
|
||||||
|
for b in tc_bufs.values()
|
||||||
|
],
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import hashlib
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
@@ -222,59 +223,51 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
async def chat(
|
def _build_chat_kwargs(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None,
|
||||||
model: str | None = None,
|
model: str | None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int,
|
||||||
temperature: float = 0.7,
|
temperature: float,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None,
|
||||||
) -> LLMResponse:
|
) -> tuple[dict[str, Any], str]:
|
||||||
"""
|
"""Build the kwargs dict for ``acompletion``.
|
||||||
Send a chat completion request via LiteLLM.
|
|
||||||
|
|
||||||
Args:
|
Returns ``(kwargs, original_model)`` so callers can reuse the
|
||||||
messages: List of message dicts with 'role' and 'content'.
|
original model string for downstream logic.
|
||||||
tools: Optional list of tool definitions in OpenAI format.
|
|
||||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
|
||||||
max_tokens: Maximum tokens in response.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLMResponse with content and/or tool calls.
|
|
||||||
"""
|
"""
|
||||||
original_model = model or self.default_model
|
original_model = model or self.default_model
|
||||||
model = self._resolve_model(original_model)
|
resolved = self._resolve_model(original_model)
|
||||||
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
extra_msg_keys = self._extra_msg_keys(original_model, resolved)
|
||||||
|
|
||||||
if self._supports_cache_control(original_model):
|
if self._supports_cache_control(original_model):
|
||||||
messages, tools = self._apply_cache_control(messages, tools)
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
|
|
||||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
|
||||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
|
||||||
max_tokens = max(1, max_tokens)
|
max_tokens = max(1, max_tokens)
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model,
|
"model": resolved,
|
||||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
|
"messages": self._sanitize_messages(
|
||||||
|
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
|
||||||
|
),
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
if self._gateway:
|
||||||
self._apply_model_overrides(model, kwargs)
|
kwargs.update(self._gateway.litellm_kwargs)
|
||||||
|
|
||||||
|
self._apply_model_overrides(resolved, kwargs)
|
||||||
|
|
||||||
|
if self._langsmith_enabled:
|
||||||
|
kwargs.setdefault("callbacks", []).append("langsmith")
|
||||||
|
|
||||||
# Pass api_key directly — more reliable than env vars alone
|
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
kwargs["api_key"] = self.api_key
|
kwargs["api_key"] = self.api_key
|
||||||
|
|
||||||
# Pass api_base for custom endpoints
|
|
||||||
if self.api_base:
|
if self.api_base:
|
||||||
kwargs["api_base"] = self.api_base
|
kwargs["api_base"] = self.api_base
|
||||||
|
|
||||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
|
||||||
if self.extra_headers:
|
if self.extra_headers:
|
||||||
kwargs["extra_headers"] = self.extra_headers
|
kwargs["extra_headers"] = self.extra_headers
|
||||||
|
|
||||||
@@ -286,11 +279,66 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
kwargs["tool_choice"] = tool_choice or "auto"
|
kwargs["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
|
return kwargs, original_model
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Send a chat completion request via LiteLLM."""
|
||||||
|
kwargs, _ = self._build_chat_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response = await acompletion(**kwargs)
|
response = await acompletion(**kwargs)
|
||||||
return self._parse_response(response)
|
return self._parse_response(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Return error as content for graceful handling
|
return LLMResponse(
|
||||||
|
content=f"Error calling LLM: {str(e)}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
|
||||||
|
kwargs, _ = self._build_chat_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
kwargs["stream"] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = await acompletion(**kwargs)
|
||||||
|
chunks: list[Any] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
if on_content_delta:
|
||||||
|
delta = chunk.choices[0].delta if chunk.choices else None
|
||||||
|
text = getattr(delta, "content", None) if delta else None
|
||||||
|
if text:
|
||||||
|
await on_content_delta(text)
|
||||||
|
|
||||||
|
full_response = litellm.stream_chunk_builder(
|
||||||
|
chunks, messages=kwargs["messages"],
|
||||||
|
)
|
||||||
|
return self._parse_response(full_response)
|
||||||
|
except Exception as e:
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=f"Error calling LLM: {str(e)}",
|
content=f"Error calling LLM: {str(e)}",
|
||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -24,16 +25,16 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
super().__init__(api_key=None, api_base=None)
|
super().__init__(api_key=None, api_base=None)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
|
|
||||||
async def chat(
|
async def _call_codex(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None,
|
||||||
model: str | None = None,
|
model: str | None,
|
||||||
max_tokens: int = 4096,
|
reasoning_effort: str | None,
|
||||||
temperature: float = 0.7,
|
tool_choice: str | dict[str, Any] | None,
|
||||||
reasoning_effort: str | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
"""Shared request logic for both chat() and chat_stream()."""
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
system_prompt, input_items = _convert_messages(messages)
|
system_prompt, input_items = _convert_messages(messages)
|
||||||
|
|
||||||
@@ -52,33 +53,45 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
"tool_choice": tool_choice or "auto",
|
"tool_choice": tool_choice or "auto",
|
||||||
"parallel_tool_calls": True,
|
"parallel_tool_calls": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
body["reasoning"] = {"effort": reasoning_effort}
|
body["reasoning"] = {"effort": reasoning_effort}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
body["tools"] = _convert_tools(tools)
|
body["tools"] = _convert_tools(tools)
|
||||||
|
|
||||||
url = DEFAULT_CODEX_URL
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True)
|
content, tool_calls, finish_reason = await _request_codex(
|
||||||
|
DEFAULT_CODEX_URL, headers, body, verify=True,
|
||||||
|
on_content_delta=on_content_delta,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
||||||
raise
|
raise
|
||||||
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False")
|
logger.warning("SSL verification failed for Codex API; retrying with verify=False")
|
||||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False)
|
content, tool_calls, finish_reason = await _request_codex(
|
||||||
return LLMResponse(
|
DEFAULT_CODEX_URL, headers, body, verify=False,
|
||||||
content=content,
|
on_content_delta=on_content_delta,
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
)
|
||||||
|
return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return LLMResponse(
|
return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error")
|
||||||
content=f"Error calling Codex: {str(e)}",
|
|
||||||
finish_reason="error",
|
async def chat(
|
||||||
)
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
return await self._call_codex(messages, tools, model, reasoning_effort, tool_choice, on_content_delta)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
@@ -107,13 +120,14 @@ async def _request_codex(
|
|||||||
headers: dict[str, str],
|
headers: dict[str, str],
|
||||||
body: dict[str, Any],
|
body: dict[str, Any],
|
||||||
verify: bool,
|
verify: bool,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> tuple[str, list[ToolCallRequest], str]:
|
) -> tuple[str, list[ToolCallRequest], str]:
|
||||||
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
|
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
|
||||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
text = await response.aread()
|
text = await response.aread()
|
||||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||||
return await _consume_sse(response)
|
return await _consume_sse(response, on_content_delta)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
@@ -151,45 +165,28 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
# Handle text first.
|
|
||||||
if isinstance(content, str) and content:
|
if isinstance(content, str) and content:
|
||||||
input_items.append(
|
input_items.append({
|
||||||
{
|
"type": "message", "role": "assistant",
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "output_text", "text": content}],
|
"content": [{"type": "output_text", "text": content}],
|
||||||
"status": "completed",
|
"status": "completed", "id": f"msg_{idx}",
|
||||||
"id": f"msg_{idx}",
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
# Then handle tool calls.
|
|
||||||
for tool_call in msg.get("tool_calls", []) or []:
|
for tool_call in msg.get("tool_calls", []) or []:
|
||||||
fn = tool_call.get("function") or {}
|
fn = tool_call.get("function") or {}
|
||||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||||
call_id = call_id or f"call_{idx}"
|
input_items.append({
|
||||||
item_id = item_id or f"fc_{idx}"
|
|
||||||
input_items.append(
|
|
||||||
{
|
|
||||||
"type": "function_call",
|
"type": "function_call",
|
||||||
"id": item_id,
|
"id": item_id or f"fc_{idx}",
|
||||||
"call_id": call_id,
|
"call_id": call_id or f"call_{idx}",
|
||||||
"name": fn.get("name"),
|
"name": fn.get("name"),
|
||||||
"arguments": fn.get("arguments") or "{}",
|
"arguments": fn.get("arguments") or "{}",
|
||||||
}
|
})
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||||
input_items.append(
|
input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text})
|
||||||
{
|
|
||||||
"type": "function_call_output",
|
|
||||||
"call_id": call_id,
|
|
||||||
"output": output_text,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return system_prompt, input_items
|
return system_prompt, input_items
|
||||||
|
|
||||||
@@ -247,7 +244,10 @@ async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any],
|
|||||||
buffer.append(line)
|
buffer.append(line)
|
||||||
|
|
||||||
|
|
||||||
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
async def _consume_sse(
|
||||||
|
response: httpx.Response,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> tuple[str, list[ToolCallRequest], str]:
|
||||||
content = ""
|
content = ""
|
||||||
tool_calls: list[ToolCallRequest] = []
|
tool_calls: list[ToolCallRequest] = []
|
||||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||||
@@ -267,7 +267,10 @@ async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequ
|
|||||||
"arguments": item.get("arguments") or "",
|
"arguments": item.get("arguments") or "",
|
||||||
}
|
}
|
||||||
elif event_type == "response.output_text.delta":
|
elif event_type == "response.output_text.delta":
|
||||||
content += event.get("delta") or ""
|
delta_text = event.get("delta") or ""
|
||||||
|
content += delta_text
|
||||||
|
if on_content_delta and delta_text:
|
||||||
|
await on_content_delta(delta_text)
|
||||||
elif event_type == "response.function_call_arguments.delta":
|
elif event_type == "response.function_call_arguments.delta":
|
||||||
call_id = event.get("call_id")
|
call_id = event.get("call_id")
|
||||||
if call_id and call_id in tool_call_buffers:
|
if call_id and call_id in tool_call_buffers:
|
||||||
|
|||||||
@@ -398,6 +398,23 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
# Mistral AI: OpenAI-compatible API at api.mistral.ai/v1.
|
||||||
|
ProviderSpec(
|
||||||
|
name="mistral",
|
||||||
|
keywords=("mistral",),
|
||||||
|
env_key="MISTRAL_API_KEY",
|
||||||
|
display_name="Mistral",
|
||||||
|
litellm_prefix="mistral", # mistral-large-latest → mistral/mistral-large-latest
|
||||||
|
skip_prefixes=("mistral/",), # avoid double-prefix
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=False,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="",
|
||||||
|
default_api_base="https://api.mistral.ai/v1",
|
||||||
|
strip_model_prefix=False,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||||
# vLLM / any OpenAI-compatible local server.
|
# vLLM / any OpenAI-compatible local server.
|
||||||
# Detected when config key is "vllm" (provider_name="vllm").
|
# Detected when config key is "vllm" (provider_name="vllm").
|
||||||
@@ -434,6 +451,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
# === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) ===
|
||||||
|
ProviderSpec(
|
||||||
|
name="ovms",
|
||||||
|
keywords=("openvino", "ovms"),
|
||||||
|
env_key="",
|
||||||
|
display_name="OpenVINO Model Server",
|
||||||
|
litellm_prefix="",
|
||||||
|
is_direct=True,
|
||||||
|
is_local=True,
|
||||||
|
default_api_base="http://localhost:8000/v3",
|
||||||
|
),
|
||||||
# === Auxiliary (not a primary LLM provider) ============================
|
# === Auxiliary (not a primary LLM provider) ============================
|
||||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ from typing import Any
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
def strip_think(text: str) -> str:
|
||||||
|
"""Remove <think>…</think> blocks and any unclosed trailing <think> tag."""
|
||||||
|
text = re.sub(r"<think>[\s\S]*?</think>", "", text)
|
||||||
|
text = re.sub(r"<think>[\s\S]*$", "", text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
def detect_image_mime(data: bytes) -> str | None:
|
def detect_image_mime(data: bytes) -> str | None:
|
||||||
"""Detect image MIME type from magic bytes, ignoring file extension."""
|
"""Detect image MIME type from magic bytes, ignoring file extension."""
|
||||||
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pytest
|
|||||||
from prompt_toolkit.formatted_text import HTML
|
from prompt_toolkit.formatted_text import HTML
|
||||||
|
|
||||||
from nanobot.cli import commands
|
from nanobot.cli import commands
|
||||||
|
from nanobot.cli import stream as stream_mod
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -62,9 +63,10 @@ def test_init_prompt_session_creates_session():
|
|||||||
def test_thinking_spinner_pause_stops_and_restarts():
|
def test_thinking_spinner_pause_stops_and_restarts():
|
||||||
"""Pause should stop the active spinner and restart it afterward."""
|
"""Pause should stop the active spinner and restart it afterward."""
|
||||||
spinner = MagicMock()
|
spinner = MagicMock()
|
||||||
|
mock_console = MagicMock()
|
||||||
|
mock_console.status.return_value = spinner
|
||||||
|
|
||||||
with patch.object(commands.console, "status", return_value=spinner):
|
thinking = stream_mod.ThinkingSpinner(console=mock_console)
|
||||||
thinking = commands._ThinkingSpinner(enabled=True)
|
|
||||||
with thinking:
|
with thinking:
|
||||||
with thinking.pause():
|
with thinking.pause():
|
||||||
pass
|
pass
|
||||||
@@ -83,10 +85,11 @@ def test_print_cli_progress_line_pauses_spinner_before_printing():
|
|||||||
spinner = MagicMock()
|
spinner = MagicMock()
|
||||||
spinner.start.side_effect = lambda: order.append("start")
|
spinner.start.side_effect = lambda: order.append("start")
|
||||||
spinner.stop.side_effect = lambda: order.append("stop")
|
spinner.stop.side_effect = lambda: order.append("stop")
|
||||||
|
mock_console = MagicMock()
|
||||||
|
mock_console.status.return_value = spinner
|
||||||
|
|
||||||
with patch.object(commands.console, "status", return_value=spinner), \
|
with patch.object(commands.console, "print", side_effect=lambda *_args, **_kwargs: order.append("print")):
|
||||||
patch.object(commands.console, "print", side_effect=lambda *_args, **_kwargs: order.append("print")):
|
thinking = stream_mod.ThinkingSpinner(console=mock_console)
|
||||||
thinking = commands._ThinkingSpinner(enabled=True)
|
|
||||||
with thinking:
|
with thinking:
|
||||||
commands._print_cli_progress_line("tool running", thinking)
|
commands._print_cli_progress_line("tool running", thinking)
|
||||||
|
|
||||||
@@ -100,13 +103,14 @@ async def test_print_interactive_progress_line_pauses_spinner_before_printing():
|
|||||||
spinner = MagicMock()
|
spinner = MagicMock()
|
||||||
spinner.start.side_effect = lambda: order.append("start")
|
spinner.start.side_effect = lambda: order.append("start")
|
||||||
spinner.stop.side_effect = lambda: order.append("stop")
|
spinner.stop.side_effect = lambda: order.append("stop")
|
||||||
|
mock_console = MagicMock()
|
||||||
|
mock_console.status.return_value = spinner
|
||||||
|
|
||||||
async def fake_print(_text: str) -> None:
|
async def fake_print(_text: str) -> None:
|
||||||
order.append("print")
|
order.append("print")
|
||||||
|
|
||||||
with patch.object(commands.console, "status", return_value=spinner), \
|
with patch("nanobot.cli.commands._print_interactive_line", side_effect=fake_print):
|
||||||
patch("nanobot.cli.commands._print_interactive_line", side_effect=fake_print):
|
thinking = stream_mod.ThinkingSpinner(console=mock_console)
|
||||||
thinking = commands._ThinkingSpinner(enabled=True)
|
|
||||||
with thinking:
|
with thinking:
|
||||||
await commands._print_interactive_progress_line("tool running", thinking)
|
await commands._print_interactive_progress_line("tool running", thinking)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
|
|||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
_response = LLMResponse(content="ok", tool_calls=[])
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||||
|
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
|
||||||
|
|
||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=MessageBus(),
|
bus=MessageBus(),
|
||||||
@@ -168,6 +170,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
|
|||||||
order.append("llm")
|
order.append("llm")
|
||||||
return LLMResponse(content="ok", tool_calls=[])
|
return LLMResponse(content="ok", tool_calls=[])
|
||||||
loop.provider.chat_with_retry = track_llm
|
loop.provider.chat_with_retry = track_llm
|
||||||
|
loop.provider.chat_stream_with_retry = track_llm
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
session.messages = [
|
session.messages = [
|
||||||
|
|||||||
22
tests/test_mistral_provider.py
Normal file
22
tests/test_mistral_provider.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Tests for the Mistral provider registration."""
|
||||||
|
|
||||||
|
from nanobot.config.schema import ProvidersConfig
|
||||||
|
from nanobot.providers.registry import PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_config_field_exists():
|
||||||
|
"""ProvidersConfig should have a mistral field."""
|
||||||
|
config = ProvidersConfig()
|
||||||
|
assert hasattr(config, "mistral")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_provider_in_registry():
|
||||||
|
"""Mistral should be registered in the provider registry."""
|
||||||
|
specs = {s.name: s for s in PROVIDERS}
|
||||||
|
assert "mistral" in specs
|
||||||
|
|
||||||
|
mistral = specs["mistral"]
|
||||||
|
assert mistral.env_key == "MISTRAL_API_KEY"
|
||||||
|
assert mistral.litellm_prefix == "mistral"
|
||||||
|
assert mistral.default_api_base == "https://api.mistral.ai/v1"
|
||||||
|
assert "mistral/" in mistral.skip_prefixes
|
||||||
Reference in New Issue
Block a user