Merge branch 'main' into pr-1581
This commit is contained in:
21
README.md
21
README.md
@@ -20,9 +20,20 @@
|
|||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
- **2026-03-07** 🚀 Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish.
|
||||||
|
- **2026-03-06** 🪄 Lighter providers, smarter media handling, and sturdier memory and CLI compatibility.
|
||||||
|
- **2026-03-05** ⚡️ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes.
|
||||||
|
- **2026-03-04** 🛠️ Dependency cleanup, safer file reads, and another round of test and Cron fixes.
|
||||||
|
- **2026-03-03** 🧠 Cleaner user-message merging, safer multimodal saves, and stronger Cron guards.
|
||||||
|
- **2026-03-02** 🛡️ Safer default access control, sturdier Cron reloads, and cleaner Matrix media handling.
|
||||||
|
- **2026-03-01** 🌐 Web proxy support, smarter Cron reminders, and Feishu rich-text parsing improvements.
|
||||||
- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
|
- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
|
||||||
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
|
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
|
||||||
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
|
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Earlier news</summary>
|
||||||
|
|
||||||
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
|
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
|
||||||
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
||||||
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
||||||
@@ -30,10 +41,6 @@
|
|||||||
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
||||||
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
||||||
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>Earlier news</summary>
|
|
||||||
|
|
||||||
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
|
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
|
||||||
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
|
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
|
||||||
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
|
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
|
||||||
@@ -420,6 +427,10 @@ nanobot channels login
|
|||||||
nanobot gateway
|
nanobot gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> WhatsApp bridge updates are not applied automatically for existing installations.
|
||||||
|
> If you upgrade nanobot and need the latest WhatsApp bridge, run:
|
||||||
|
> `rm -rf ~/.nanobot/bridge && nanobot channels login`
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -664,12 +675,14 @@ Config file: `~/.nanobot/config.json`
|
|||||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||||
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
||||||
|
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
|
||||||
|
|
||||||
| Provider | Purpose | Get API Key |
|
| Provider | Purpose | Get API Key |
|
||||||
|----------|---------|-------------|
|
|----------|---------|-------------|
|
||||||
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
||||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||||
|
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||||
|
|||||||
@@ -9,11 +9,17 @@ import makeWASocket, {
|
|||||||
useMultiFileAuthState,
|
useMultiFileAuthState,
|
||||||
fetchLatestBaileysVersion,
|
fetchLatestBaileysVersion,
|
||||||
makeCacheableSignalKeyStore,
|
makeCacheableSignalKeyStore,
|
||||||
|
downloadMediaMessage,
|
||||||
|
extractMessageContent as baileysExtractMessageContent,
|
||||||
} from '@whiskeysockets/baileys';
|
} from '@whiskeysockets/baileys';
|
||||||
|
|
||||||
import { Boom } from '@hapi/boom';
|
import { Boom } from '@hapi/boom';
|
||||||
import qrcode from 'qrcode-terminal';
|
import qrcode from 'qrcode-terminal';
|
||||||
import pino from 'pino';
|
import pino from 'pino';
|
||||||
|
import { writeFile, mkdir } from 'fs/promises';
|
||||||
|
import { join } from 'path';
|
||||||
|
import { homedir } from 'os';
|
||||||
|
import { randomBytes } from 'crypto';
|
||||||
|
|
||||||
const VERSION = '0.1.0';
|
const VERSION = '0.1.0';
|
||||||
|
|
||||||
@@ -24,6 +30,7 @@ export interface InboundMessage {
|
|||||||
content: string;
|
content: string;
|
||||||
timestamp: number;
|
timestamp: number;
|
||||||
isGroup: boolean;
|
isGroup: boolean;
|
||||||
|
media?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface WhatsAppClientOptions {
|
export interface WhatsAppClientOptions {
|
||||||
@@ -110,14 +117,33 @@ export class WhatsAppClient {
|
|||||||
if (type !== 'notify') return;
|
if (type !== 'notify') return;
|
||||||
|
|
||||||
for (const msg of messages) {
|
for (const msg of messages) {
|
||||||
// Skip own messages
|
|
||||||
if (msg.key.fromMe) continue;
|
if (msg.key.fromMe) continue;
|
||||||
|
|
||||||
// Skip status updates
|
|
||||||
if (msg.key.remoteJid === 'status@broadcast') continue;
|
if (msg.key.remoteJid === 'status@broadcast') continue;
|
||||||
|
|
||||||
const content = this.extractMessageContent(msg);
|
const unwrapped = baileysExtractMessageContent(msg.message);
|
||||||
if (!content) continue;
|
if (!unwrapped) continue;
|
||||||
|
|
||||||
|
const content = this.getTextContent(unwrapped);
|
||||||
|
let fallbackContent: string | null = null;
|
||||||
|
const mediaPaths: string[] = [];
|
||||||
|
|
||||||
|
if (unwrapped.imageMessage) {
|
||||||
|
fallbackContent = '[Image]';
|
||||||
|
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
|
||||||
|
if (path) mediaPaths.push(path);
|
||||||
|
} else if (unwrapped.documentMessage) {
|
||||||
|
fallbackContent = '[Document]';
|
||||||
|
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
|
||||||
|
unwrapped.documentMessage.fileName ?? undefined);
|
||||||
|
if (path) mediaPaths.push(path);
|
||||||
|
} else if (unwrapped.videoMessage) {
|
||||||
|
fallbackContent = '[Video]';
|
||||||
|
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
|
||||||
|
if (path) mediaPaths.push(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
|
||||||
|
if (!finalContent && mediaPaths.length === 0) continue;
|
||||||
|
|
||||||
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
|
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
|
||||||
|
|
||||||
@@ -125,18 +151,45 @@ export class WhatsAppClient {
|
|||||||
id: msg.key.id || '',
|
id: msg.key.id || '',
|
||||||
sender: msg.key.remoteJid || '',
|
sender: msg.key.remoteJid || '',
|
||||||
pn: msg.key.remoteJidAlt || '',
|
pn: msg.key.remoteJidAlt || '',
|
||||||
content,
|
content: finalContent,
|
||||||
timestamp: msg.messageTimestamp as number,
|
timestamp: msg.messageTimestamp as number,
|
||||||
isGroup,
|
isGroup,
|
||||||
|
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private extractMessageContent(msg: any): string | null {
|
private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
|
||||||
const message = msg.message;
|
try {
|
||||||
if (!message) return null;
|
const mediaDir = join(homedir(), '.nanobot', 'media');
|
||||||
|
await mkdir(mediaDir, { recursive: true });
|
||||||
|
|
||||||
|
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
|
||||||
|
|
||||||
|
let outFilename: string;
|
||||||
|
if (fileName) {
|
||||||
|
// Documents have a filename — use it with a unique prefix to avoid collisions
|
||||||
|
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
|
||||||
|
outFilename = prefix + fileName;
|
||||||
|
} else {
|
||||||
|
const mime = mimetype || 'application/octet-stream';
|
||||||
|
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
|
||||||
|
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
|
||||||
|
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const filepath = join(mediaDir, outFilename);
|
||||||
|
await writeFile(filepath, buffer);
|
||||||
|
|
||||||
|
return filepath;
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to download media:', err);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private getTextContent(message: any): string | null {
|
||||||
// Text message
|
// Text message
|
||||||
if (message.conversation) {
|
if (message.conversation) {
|
||||||
return message.conversation;
|
return message.conversation;
|
||||||
@@ -147,19 +200,19 @@ export class WhatsAppClient {
|
|||||||
return message.extendedTextMessage.text;
|
return message.extendedTextMessage.text;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Image with caption
|
// Image with optional caption
|
||||||
if (message.imageMessage?.caption) {
|
if (message.imageMessage) {
|
||||||
return `[Image] ${message.imageMessage.caption}`;
|
return message.imageMessage.caption || '';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Video with caption
|
// Video with optional caption
|
||||||
if (message.videoMessage?.caption) {
|
if (message.videoMessage) {
|
||||||
return `[Video] ${message.videoMessage.caption}`;
|
return message.videoMessage.caption || '';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Document with caption
|
// Document with optional caption
|
||||||
if (message.documentMessage?.caption) {
|
if (message.documentMessage) {
|
||||||
return `[Document] ${message.documentMessage.caption}`;
|
return message.documentMessage.caption || '';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Voice/Audio message
|
// Voice/Audio message
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
|
from nanobot.utils.helpers import detect_image_mime
|
||||||
|
|
||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
@@ -136,10 +137,14 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
images = []
|
images = []
|
||||||
for path in media:
|
for path in media:
|
||||||
p = Path(path)
|
p = Path(path)
|
||||||
mime, _ = mimetypes.guess_type(path)
|
if not p.is_file():
|
||||||
if not p.is_file() or not mime or not mime.startswith("image/"):
|
|
||||||
continue
|
continue
|
||||||
b64 = base64.b64encode(p.read_bytes()).decode()
|
raw = p.read_bytes()
|
||||||
|
# Detect real MIME type from magic bytes; fallback to filename guess
|
||||||
|
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||||
|
if not mime or not mime.startswith("image/"):
|
||||||
|
continue
|
||||||
|
b64 = base64.b64encode(raw).decode()
|
||||||
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||||||
|
|
||||||
if not images:
|
if not images:
|
||||||
|
|||||||
@@ -202,18 +202,9 @@ class AgentLoop:
|
|||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
if on_progress:
|
if on_progress:
|
||||||
thoughts = [
|
thought = self._strip_think(response.content)
|
||||||
self._strip_think(response.content),
|
if thought:
|
||||||
response.reasoning_content,
|
await on_progress(thought)
|
||||||
*(
|
|
||||||
f"Thinking [{b.get('signature', '...')}]:\n{b.get('thought', '...')}"
|
|
||||||
for b in (response.thinking_blocks or [])
|
|
||||||
if isinstance(b, dict) and "signature" in b
|
|
||||||
),
|
|
||||||
]
|
|
||||||
combined_thoughts = "\n\n".join(filter(None, thoughts))
|
|
||||||
if combined_thoughts:
|
|
||||||
await on_progress(combined_thoughts)
|
|
||||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||||
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
|
|||||||
@@ -128,6 +128,13 @@ class MemoryStore:
|
|||||||
# Some providers return arguments as a JSON string instead of dict
|
# Some providers return arguments as a JSON string instead of dict
|
||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json.loads(args)
|
args = json.loads(args)
|
||||||
|
# Some providers return arguments as a list (handle edge case)
|
||||||
|
if isinstance(args, list):
|
||||||
|
if args and isinstance(args[0], dict):
|
||||||
|
args = args[0]
|
||||||
|
else:
|
||||||
|
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
||||||
|
return False
|
||||||
if not isinstance(args, dict):
|
if not isinstance(args, dict):
|
||||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -52,6 +52,75 @@ class Tool(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Apply safe schema-driven casts before validation."""
|
||||||
|
schema = self.parameters or {}
|
||||||
|
if schema.get("type", "object") != "object":
|
||||||
|
return params
|
||||||
|
|
||||||
|
return self._cast_object(params, schema)
|
||||||
|
|
||||||
|
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Cast an object (dict) according to schema."""
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return obj
|
||||||
|
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for key, value in obj.items():
|
||||||
|
if key in props:
|
||||||
|
result[key] = self._cast_value(value, props[key])
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
||||||
|
"""Cast a single value according to schema."""
|
||||||
|
target_type = schema.get("type")
|
||||||
|
|
||||||
|
if target_type == "boolean" and isinstance(val, bool):
|
||||||
|
return val
|
||||||
|
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
||||||
|
return val
|
||||||
|
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
|
||||||
|
expected = self._TYPE_MAP[target_type]
|
||||||
|
if isinstance(val, expected):
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "integer" and isinstance(val, str):
|
||||||
|
try:
|
||||||
|
return int(val)
|
||||||
|
except ValueError:
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "number" and isinstance(val, str):
|
||||||
|
try:
|
||||||
|
return float(val)
|
||||||
|
except ValueError:
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "string":
|
||||||
|
return val if val is None else str(val)
|
||||||
|
|
||||||
|
if target_type == "boolean" and isinstance(val, str):
|
||||||
|
val_lower = val.lower()
|
||||||
|
if val_lower in ("true", "1", "yes"):
|
||||||
|
return True
|
||||||
|
if val_lower in ("false", "0", "no"):
|
||||||
|
return False
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "array" and isinstance(val, list):
|
||||||
|
item_schema = schema.get("items")
|
||||||
|
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
|
||||||
|
|
||||||
|
if target_type == "object" and isinstance(val, dict):
|
||||||
|
return self._cast_object(val, schema)
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||||
if not isinstance(params, dict):
|
if not isinstance(params, dict):
|
||||||
@@ -63,7 +132,13 @@ class Tool(ABC):
|
|||||||
|
|
||||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||||
t, label = schema.get("type"), path or "parameter"
|
t, label = schema.get("type"), path or "parameter"
|
||||||
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
||||||
|
return [f"{label} should be integer"]
|
||||||
|
if t == "number" and (
|
||||||
|
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
|
||||||
|
):
|
||||||
|
return [f"{label} should be number"]
|
||||||
|
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
|
||||||
return [f"{label} should be {t}"]
|
return [f"{label} should be {t}"]
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class MessageTool(Tool):
|
|||||||
media=media or [],
|
media=media or [],
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ class ToolRegistry:
|
|||||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Attempt to cast parameters to match schema types
|
||||||
|
params = tool.cast_params(params)
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
errors = tool.validate_params(params)
|
errors = tool.validate_params(params)
|
||||||
if errors:
|
if errors:
|
||||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||||
|
|||||||
@@ -66,10 +66,7 @@ class BaseChannel(ABC):
|
|||||||
return False
|
return False
|
||||||
if "*" in allow_list:
|
if "*" in allow_list:
|
||||||
return True
|
return True
|
||||||
sender_str = str(sender_id)
|
return str(sender_id) in allow_list
|
||||||
return sender_str in allow_list or any(
|
|
||||||
p in allow_list for p in sender_str.split("|") if p
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_message(
|
async def _handle_message(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -70,12 +70,24 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||||
|
|
||||||
|
conversation_type = message.data.get("conversationType")
|
||||||
|
conversation_id = (
|
||||||
|
message.data.get("conversationId")
|
||||||
|
or message.data.get("openConversationId")
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||||
|
|
||||||
# Forward to Nanobot via _on_message (non-blocking).
|
# Forward to Nanobot via _on_message (non-blocking).
|
||||||
# Store reference to prevent GC before task completes.
|
# Store reference to prevent GC before task completes.
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
self.channel._on_message(content, sender_id, sender_name)
|
self.channel._on_message(
|
||||||
|
content,
|
||||||
|
sender_id,
|
||||||
|
sender_name,
|
||||||
|
conversation_type,
|
||||||
|
conversation_id,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.channel._background_tasks.add(task)
|
self.channel._background_tasks.add(task)
|
||||||
task.add_done_callback(self.channel._background_tasks.discard)
|
task.add_done_callback(self.channel._background_tasks.discard)
|
||||||
@@ -95,8 +107,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
||||||
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
||||||
|
|
||||||
Note: Currently only supports private (1:1) chat. Group messages are
|
Supports both private (1:1) and group chats.
|
||||||
received but replies are sent back as private messages to the sender.
|
Group chat_id is stored with a "group:" prefix to route replies back.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "dingtalk"
|
name = "dingtalk"
|
||||||
@@ -301,14 +313,25 @@ class DingTalkChannel(BaseChannel):
|
|||||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
|
||||||
headers = {"x-acs-dingtalk-access-token": token}
|
headers = {"x-acs-dingtalk-access-token": token}
|
||||||
payload = {
|
if chat_id.startswith("group:"):
|
||||||
"robotCode": self.config.client_id,
|
# Group chat
|
||||||
"userIds": [chat_id],
|
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||||
"msgKey": msg_key,
|
payload = {
|
||||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
"robotCode": self.config.client_id,
|
||||||
}
|
"openConversationId": chat_id[6:], # Remove "group:" prefix,
|
||||||
|
"msgKey": msg_key,
|
||||||
|
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Private chat
|
||||||
|
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||||
|
payload = {
|
||||||
|
"robotCode": self.config.client_id,
|
||||||
|
"userIds": [chat_id],
|
||||||
|
"msgKey": msg_key,
|
||||||
|
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._http.post(url, json=payload, headers=headers)
|
resp = await self._http.post(url, json=payload, headers=headers)
|
||||||
@@ -417,7 +440,14 @@ class DingTalkChannel(BaseChannel):
|
|||||||
f"[Attachment send failed: {filename}]",
|
f"[Attachment send failed: {filename}]",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
async def _on_message(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
sender_id: str,
|
||||||
|
sender_name: str,
|
||||||
|
conversation_type: str | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||||
|
|
||||||
Delegates to BaseChannel._handle_message() which enforces allow_from
|
Delegates to BaseChannel._handle_message() which enforces allow_from
|
||||||
@@ -425,13 +455,16 @@ class DingTalkChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||||
|
is_group = conversation_type == "2" and conversation_id
|
||||||
|
chat_id = f"group:{conversation_id}" if is_group else sender_id
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
chat_id=sender_id, # For private chat, chat_id == sender_id
|
chat_id=chat_id,
|
||||||
content=str(content),
|
content=str(content),
|
||||||
metadata={
|
metadata={
|
||||||
"sender_name": sender_name,
|
"sender_name": sender_name,
|
||||||
"platform": "dingtalk",
|
"platform": "dingtalk",
|
||||||
|
"conversation_type": conversation_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,34 +13,13 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import DiscordConfig
|
from nanobot.config.schema import DiscordConfig
|
||||||
|
from nanobot.utils.helpers import split_message
|
||||||
|
|
||||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||||
|
|
||||||
|
|
||||||
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
|
|
||||||
"""Split content into chunks within max_len, preferring line breaks."""
|
|
||||||
if not content:
|
|
||||||
return []
|
|
||||||
if len(content) <= max_len:
|
|
||||||
return [content]
|
|
||||||
chunks: list[str] = []
|
|
||||||
while content:
|
|
||||||
if len(content) <= max_len:
|
|
||||||
chunks.append(content)
|
|
||||||
break
|
|
||||||
cut = content[:max_len]
|
|
||||||
pos = cut.rfind('\n')
|
|
||||||
if pos <= 0:
|
|
||||||
pos = cut.rfind(' ')
|
|
||||||
if pos <= 0:
|
|
||||||
pos = max_len
|
|
||||||
chunks.append(content[:pos])
|
|
||||||
content = content[pos:].lstrip()
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(BaseChannel):
|
class DiscordChannel(BaseChannel):
|
||||||
"""Discord channel using Gateway websocket."""
|
"""Discord channel using Gateway websocket."""
|
||||||
|
|
||||||
@@ -96,7 +75,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
self._http = None
|
self._http = None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through Discord REST API."""
|
"""Send a message through Discord REST API, including file attachments."""
|
||||||
if not self._http:
|
if not self._http:
|
||||||
logger.warning("Discord HTTP client not initialized")
|
logger.warning("Discord HTTP client not initialized")
|
||||||
return
|
return
|
||||||
@@ -105,15 +84,31 @@ class DiscordChannel(BaseChannel):
|
|||||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunks = _split_message(msg.content or "")
|
sent_media = False
|
||||||
|
failed_media: list[str] = []
|
||||||
|
|
||||||
|
# Send file attachments first
|
||||||
|
for media_path in msg.media or []:
|
||||||
|
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||||
|
sent_media = True
|
||||||
|
else:
|
||||||
|
failed_media.append(Path(media_path).name)
|
||||||
|
|
||||||
|
# Send text content
|
||||||
|
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||||
|
if not chunks and failed_media and not sent_media:
|
||||||
|
chunks = split_message(
|
||||||
|
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||||
|
MAX_MESSAGE_LEN,
|
||||||
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return
|
return
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
payload: dict[str, Any] = {"content": chunk}
|
payload: dict[str, Any] = {"content": chunk}
|
||||||
|
|
||||||
# Only set reply reference on the first chunk
|
# Let the first successful attachment carry the reply if present.
|
||||||
if i == 0 and msg.reply_to:
|
if i == 0 and msg.reply_to and not sent_media:
|
||||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||||
payload["allowed_mentions"] = {"replied_user": False}
|
payload["allowed_mentions"] = {"replied_user": False}
|
||||||
|
|
||||||
@@ -144,6 +139,54 @@ class DiscordChannel(BaseChannel):
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _send_file(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
headers: dict[str, str],
|
||||||
|
file_path: str,
|
||||||
|
reply_to: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||||
|
path = Path(file_path)
|
||||||
|
if not path.is_file():
|
||||||
|
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||||
|
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
payload_json: dict[str, Any] = {}
|
||||||
|
if reply_to:
|
||||||
|
payload_json["message_reference"] = {"message_id": reply_to}
|
||||||
|
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||||
|
|
||||||
|
for attempt in range(3):
|
||||||
|
try:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
if payload_json:
|
||||||
|
data["payload_json"] = json.dumps(payload_json)
|
||||||
|
response = await self._http.post(
|
||||||
|
url, headers=headers, files=files, data=data
|
||||||
|
)
|
||||||
|
if response.status_code == 429:
|
||||||
|
resp_data = response.json()
|
||||||
|
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||||
|
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||||
|
await asyncio.sleep(retry_after)
|
||||||
|
continue
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.info("Discord file sent: {}", path.name)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == 2:
|
||||||
|
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
return False
|
||||||
|
|
||||||
async def _gateway_loop(self) -> None:
|
async def _gateway_loop(self) -> None:
|
||||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||||
if not self._ws:
|
if not self._ws:
|
||||||
|
|||||||
@@ -244,15 +244,22 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
name = "feishu"
|
name = "feishu"
|
||||||
|
|
||||||
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
self.config: FeishuConfig = config
|
self.config: FeishuConfig = config
|
||||||
|
self.groq_api_key = groq_api_key
|
||||||
self._client: Any = None
|
self._client: Any = None
|
||||||
self._ws_client: Any = None
|
self._ws_client: Any = None
|
||||||
self._ws_thread: threading.Thread | None = None
|
self._ws_thread: threading.Thread | None = None
|
||||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||||
|
"""Register an event handler only when the SDK supports it."""
|
||||||
|
method = getattr(builder, method_name, None)
|
||||||
|
return method(handler) if callable(method) else builder
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Feishu bot with WebSocket long connection."""
|
"""Start the Feishu bot with WebSocket long connection."""
|
||||||
if not FEISHU_AVAILABLE:
|
if not FEISHU_AVAILABLE:
|
||||||
@@ -273,14 +280,24 @@ class FeishuChannel(BaseChannel):
|
|||||||
.app_secret(self.config.app_secret) \
|
.app_secret(self.config.app_secret) \
|
||||||
.log_level(lark.LogLevel.INFO) \
|
.log_level(lark.LogLevel.INFO) \
|
||||||
.build()
|
.build()
|
||||||
|
builder = lark.EventDispatcherHandler.builder(
|
||||||
# Create event handler (only register message receive, ignore other events)
|
|
||||||
event_handler = lark.EventDispatcherHandler.builder(
|
|
||||||
self.config.encrypt_key or "",
|
self.config.encrypt_key or "",
|
||||||
self.config.verification_token or "",
|
self.config.verification_token or "",
|
||||||
).register_p2_im_message_receive_v1(
|
).register_p2_im_message_receive_v1(
|
||||||
self._on_message_sync
|
self._on_message_sync
|
||||||
).build()
|
)
|
||||||
|
builder = self._register_optional_event(
|
||||||
|
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
|
||||||
|
)
|
||||||
|
builder = self._register_optional_event(
|
||||||
|
builder, "register_p2_im_message_message_read_v1", self._on_message_read
|
||||||
|
)
|
||||||
|
builder = self._register_optional_event(
|
||||||
|
builder,
|
||||||
|
"register_p2_im_chat_access_event_bot_p2p_chat_entered_v1",
|
||||||
|
self._on_bot_p2p_chat_entered,
|
||||||
|
)
|
||||||
|
event_handler = builder.build()
|
||||||
|
|
||||||
# Create WebSocket client for long connection
|
# Create WebSocket client for long connection
|
||||||
self._ws_client = lark.ws.Client(
|
self._ws_client = lark.ws.Client(
|
||||||
@@ -472,8 +489,124 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
return elements or [{"tag": "markdown", "content": content}]
|
return elements or [{"tag": "markdown", "content": content}]
|
||||||
|
|
||||||
|
# ── Smart format detection ──────────────────────────────────────────
|
||||||
|
# Patterns that indicate "complex" markdown needing card rendering
|
||||||
|
_COMPLEX_MD_RE = re.compile(
|
||||||
|
r"```" # fenced code block
|
||||||
|
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
|
||||||
|
r"|^#{1,6}\s+" # headings
|
||||||
|
, re.MULTILINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simple markdown patterns (bold, italic, strikethrough)
|
||||||
|
_SIMPLE_MD_RE = re.compile(
|
||||||
|
r"\*\*.+?\*\*" # **bold**
|
||||||
|
r"|__.+?__" # __bold__
|
||||||
|
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
|
||||||
|
r"|~~.+?~~" # ~~strikethrough~~
|
||||||
|
, re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Markdown link: [text](url)
|
||||||
|
_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\((https?://[^\)]+)\)")
|
||||||
|
|
||||||
|
# Unordered list items
|
||||||
|
_LIST_RE = re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE)
|
||||||
|
|
||||||
|
# Ordered list items
|
||||||
|
_OLIST_RE = re.compile(r"^[\s]*\d+\.\s+", re.MULTILINE)
|
||||||
|
|
||||||
|
# Max length for plain text format
|
||||||
|
_TEXT_MAX_LEN = 200
|
||||||
|
|
||||||
|
# Max length for post (rich text) format; beyond this, use card
|
||||||
|
_POST_MAX_LEN = 2000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _detect_msg_format(cls, content: str) -> str:
|
||||||
|
"""Determine the optimal Feishu message format for *content*.
|
||||||
|
|
||||||
|
Returns one of:
|
||||||
|
- ``"text"`` – plain text, short and no markdown
|
||||||
|
- ``"post"`` – rich text (links only, moderate length)
|
||||||
|
- ``"interactive"`` – card with full markdown rendering
|
||||||
|
"""
|
||||||
|
stripped = content.strip()
|
||||||
|
|
||||||
|
# Complex markdown (code blocks, tables, headings) → always card
|
||||||
|
if cls._COMPLEX_MD_RE.search(stripped):
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Long content → card (better readability with card layout)
|
||||||
|
if len(stripped) > cls._POST_MAX_LEN:
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Has bold/italic/strikethrough → card (post format can't render these)
|
||||||
|
if cls._SIMPLE_MD_RE.search(stripped):
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Has list items → card (post format can't render list bullets well)
|
||||||
|
if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped):
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Has links → post format (supports <a> tags)
|
||||||
|
if cls._MD_LINK_RE.search(stripped):
|
||||||
|
return "post"
|
||||||
|
|
||||||
|
# Short plain text → text format
|
||||||
|
if len(stripped) <= cls._TEXT_MAX_LEN:
|
||||||
|
return "text"
|
||||||
|
|
||||||
|
# Medium plain text without any formatting → post format
|
||||||
|
return "post"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _markdown_to_post(cls, content: str) -> str:
|
||||||
|
"""Convert markdown content to Feishu post message JSON.
|
||||||
|
|
||||||
|
Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
|
||||||
|
Each line becomes a paragraph (row) in the post body.
|
||||||
|
"""
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
paragraphs: list[list[dict]] = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
elements: list[dict] = []
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for m in cls._MD_LINK_RE.finditer(line):
|
||||||
|
# Text before this link
|
||||||
|
before = line[last_end:m.start()]
|
||||||
|
if before:
|
||||||
|
elements.append({"tag": "text", "text": before})
|
||||||
|
elements.append({
|
||||||
|
"tag": "a",
|
||||||
|
"text": m.group(1),
|
||||||
|
"href": m.group(2),
|
||||||
|
})
|
||||||
|
last_end = m.end()
|
||||||
|
|
||||||
|
# Remaining text after last link
|
||||||
|
remaining = line[last_end:]
|
||||||
|
if remaining:
|
||||||
|
elements.append({"tag": "text", "text": remaining})
|
||||||
|
|
||||||
|
# Empty line → empty paragraph for spacing
|
||||||
|
if not elements:
|
||||||
|
elements.append({"tag": "text", "text": ""})
|
||||||
|
|
||||||
|
paragraphs.append(elements)
|
||||||
|
|
||||||
|
post_body = {
|
||||||
|
"zh_cn": {
|
||||||
|
"content": paragraphs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return json.dumps(post_body, ensure_ascii=False)
|
||||||
|
|
||||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
||||||
_AUDIO_EXTS = {".opus"}
|
_AUDIO_EXTS = {".opus"}
|
||||||
|
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
|
||||||
_FILE_TYPE_MAP = {
|
_FILE_TYPE_MAP = {
|
||||||
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
||||||
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
||||||
@@ -682,25 +815,50 @@ class FeishuChannel(BaseChannel):
|
|||||||
else:
|
else:
|
||||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||||
if key:
|
if key:
|
||||||
media_type = "audio" if ext in self._AUDIO_EXTS else "file"
|
# Use msg_type "media" for audio/video so users can play inline;
|
||||||
|
# "file" for everything else (documents, archives, etc.)
|
||||||
|
if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
|
||||||
|
media_type = "media"
|
||||||
|
else:
|
||||||
|
media_type = "file"
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None, self._send_message_sync,
|
None, self._send_message_sync,
|
||||||
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
if msg.content and msg.content.strip():
|
if msg.content and msg.content.strip():
|
||||||
elements = self._build_card_elements(msg.content)
|
fmt = self._detect_msg_format(msg.content)
|
||||||
for chunk in self._split_elements_by_table_limit(elements):
|
|
||||||
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
|
if fmt == "text":
|
||||||
|
# Short plain text – send as simple text message
|
||||||
|
text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None, self._send_message_sync,
|
None, self._send_message_sync,
|
||||||
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
receive_id_type, msg.chat_id, "text", text_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif fmt == "post":
|
||||||
|
# Medium content with links – send as rich-text post
|
||||||
|
post_body = self._markdown_to_post(msg.content)
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync,
|
||||||
|
receive_id_type, msg.chat_id, "post", post_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Complex / long content – send as interactive card
|
||||||
|
elements = self._build_card_elements(msg.content)
|
||||||
|
for chunk in self._split_elements_by_table_limit(elements):
|
||||||
|
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync,
|
||||||
|
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending Feishu message: {}", e)
|
logger.error("Error sending Feishu message: {}", e)
|
||||||
|
|
||||||
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
def _on_message_sync(self, data: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Sync handler for incoming messages (called from WebSocket thread).
|
Sync handler for incoming messages (called from WebSocket thread).
|
||||||
Schedules async handling in the main event loop.
|
Schedules async handling in the main event loop.
|
||||||
@@ -708,7 +866,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
if self._loop and self._loop.is_running():
|
if self._loop and self._loop.is_running():
|
||||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||||
|
|
||||||
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
|
async def _on_message(self, data: Any) -> None:
|
||||||
"""Handle incoming message from Feishu."""
|
"""Handle incoming message from Feishu."""
|
||||||
try:
|
try:
|
||||||
event = data.event
|
event = data.event
|
||||||
@@ -768,6 +926,18 @@ class FeishuChannel(BaseChannel):
|
|||||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||||
if file_path:
|
if file_path:
|
||||||
media_paths.append(file_path)
|
media_paths.append(file_path)
|
||||||
|
|
||||||
|
# Transcribe audio using Groq Whisper
|
||||||
|
if msg_type == "audio" and file_path and self.groq_api_key:
|
||||||
|
try:
|
||||||
|
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||||
|
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||||
|
transcription = await transcriber.transcribe(file_path)
|
||||||
|
if transcription:
|
||||||
|
content_text = f"[transcription: {transcription}]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to transcribe audio: {}", e)
|
||||||
|
|
||||||
content_parts.append(content_text)
|
content_parts.append(content_text)
|
||||||
|
|
||||||
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||||
@@ -800,3 +970,16 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error processing Feishu message: {}", e)
|
logger.error("Error processing Feishu message: {}", e)
|
||||||
|
|
||||||
|
def _on_reaction_created(self, data: Any) -> None:
|
||||||
|
"""Ignore reaction events so they do not generate SDK noise."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_message_read(self, data: Any) -> None:
|
||||||
|
"""Ignore read events so they do not generate SDK noise."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
|
||||||
|
"""Ignore p2p-enter events when a user opens a bot chat."""
|
||||||
|
logger.debug("Bot entered p2p chat (user opened chat window)")
|
||||||
|
pass
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ class ChannelManager:
|
|||||||
try:
|
try:
|
||||||
from nanobot.channels.feishu import FeishuChannel
|
from nanobot.channels.feishu import FeishuChannel
|
||||||
self.channels["feishu"] = FeishuChannel(
|
self.channels["feishu"] = FeishuChannel(
|
||||||
self.config.channels.feishu, self.bus
|
self.config.channels.feishu, self.bus,
|
||||||
|
groq_api_key=self.config.providers.groq.api_key,
|
||||||
)
|
)
|
||||||
logger.info("Feishu channel enabled")
|
logger.info("Feishu channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
|||||||
@@ -13,16 +13,17 @@ from nanobot.config.schema import QQConfig
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import botpy
|
import botpy
|
||||||
from botpy.message import C2CMessage
|
from botpy.message import C2CMessage, GroupMessage
|
||||||
|
|
||||||
QQ_AVAILABLE = True
|
QQ_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
QQ_AVAILABLE = False
|
QQ_AVAILABLE = False
|
||||||
botpy = None
|
botpy = None
|
||||||
C2CMessage = None
|
C2CMessage = None
|
||||||
|
GroupMessage = None
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from botpy.message import C2CMessage
|
from botpy.message import C2CMessage, GroupMessage
|
||||||
|
|
||||||
|
|
||||||
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||||
@@ -38,10 +39,13 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
|||||||
logger.info("QQ bot ready: {}", self.robot.name)
|
logger.info("QQ bot ready: {}", self.robot.name)
|
||||||
|
|
||||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||||
await channel._on_message(message)
|
await channel._on_message(message, is_group=False)
|
||||||
|
|
||||||
|
async def on_group_at_message_create(self, message: "GroupMessage"):
|
||||||
|
await channel._on_message(message, is_group=True)
|
||||||
|
|
||||||
async def on_direct_message_create(self, message):
|
async def on_direct_message_create(self, message):
|
||||||
await channel._on_message(message)
|
await channel._on_message(message, is_group=False)
|
||||||
|
|
||||||
return _Bot
|
return _Bot
|
||||||
|
|
||||||
@@ -57,6 +61,7 @@ class QQChannel(BaseChannel):
|
|||||||
self._client: "botpy.Client | None" = None
|
self._client: "botpy.Client | None" = None
|
||||||
self._processed_ids: deque = deque(maxlen=1000)
|
self._processed_ids: deque = deque(maxlen=1000)
|
||||||
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
|
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
|
||||||
|
self._chat_type_cache: dict[str, str] = {}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the QQ bot."""
|
"""Start the QQ bot."""
|
||||||
@@ -71,8 +76,7 @@ class QQChannel(BaseChannel):
|
|||||||
self._running = True
|
self._running = True
|
||||||
BotClass = _make_bot_class(self)
|
BotClass = _make_bot_class(self)
|
||||||
self._client = BotClass()
|
self._client = BotClass()
|
||||||
|
logger.info("QQ bot started (C2C & Group supported)")
|
||||||
logger.info("QQ bot started (C2C private message)")
|
|
||||||
await self._run_bot()
|
await self._run_bot()
|
||||||
|
|
||||||
async def _run_bot(self) -> None:
|
async def _run_bot(self) -> None:
|
||||||
@@ -101,20 +105,31 @@ class QQChannel(BaseChannel):
|
|||||||
if not self._client:
|
if not self._client:
|
||||||
logger.warning("QQ client not initialized")
|
logger.warning("QQ client not initialized")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg_id = msg.metadata.get("message_id")
|
msg_id = msg.metadata.get("message_id")
|
||||||
self._msg_seq += 1 # 递增序列号
|
self._msg_seq += 1
|
||||||
await self._client.api.post_c2c_message(
|
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||||
openid=msg.chat_id,
|
if msg_type == "group":
|
||||||
msg_type=0,
|
await self._client.api.post_group_message(
|
||||||
content=msg.content,
|
group_openid=msg.chat_id,
|
||||||
msg_id=msg_id,
|
msg_type=0,
|
||||||
msg_seq=self._msg_seq, # 添加序列号避免去重
|
content=msg.content,
|
||||||
)
|
msg_id=msg_id,
|
||||||
|
msg_seq=self._msg_seq,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._client.api.post_c2c_message(
|
||||||
|
openid=msg.chat_id,
|
||||||
|
msg_type=0,
|
||||||
|
content=msg.content,
|
||||||
|
msg_id=msg_id,
|
||||||
|
msg_seq=self._msg_seq,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending QQ message: {}", e)
|
logger.error("Error sending QQ message: {}", e)
|
||||||
|
|
||||||
async def _on_message(self, data: "C2CMessage") -> None:
|
async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
|
||||||
"""Handle incoming message from QQ."""
|
"""Handle incoming message from QQ."""
|
||||||
try:
|
try:
|
||||||
# Dedup by message ID
|
# Dedup by message ID
|
||||||
@@ -122,18 +137,24 @@ class QQChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
self._processed_ids.append(data.id)
|
self._processed_ids.append(data.id)
|
||||||
|
|
||||||
author = data.author
|
|
||||||
user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
|
|
||||||
content = (data.content or "").strip()
|
content = (data.content or "").strip()
|
||||||
if not content:
|
if not content:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if is_group:
|
||||||
|
chat_id = data.group_openid
|
||||||
|
user_id = data.author.member_openid
|
||||||
|
self._chat_type_cache[chat_id] = "group"
|
||||||
|
else:
|
||||||
|
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
|
||||||
|
user_id = chat_id
|
||||||
|
self._chat_type_cache[chat_id] = "c2c"
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=user_id,
|
sender_id=user_id,
|
||||||
chat_id=user_id,
|
chat_id=chat_id,
|
||||||
content=content,
|
content=content,
|
||||||
metadata={"message_id": data.id},
|
metadata={"message_id": data.id},
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error handling QQ message")
|
logger.exception("Error handling QQ message")
|
||||||
|
|
||||||
|
|||||||
@@ -82,13 +82,14 @@ class SlackChannel(BaseChannel):
|
|||||||
thread_ts = slack_meta.get("thread_ts")
|
thread_ts = slack_meta.get("thread_ts")
|
||||||
channel_type = slack_meta.get("channel_type")
|
channel_type = slack_meta.get("channel_type")
|
||||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
# Only reply in thread for channel/group messages; DMs don't use threads
|
||||||
use_thread = thread_ts and channel_type != "im"
|
|
||||||
thread_ts_param = thread_ts if use_thread else None
|
thread_ts_param = thread_ts if use_thread else None
|
||||||
|
|
||||||
if msg.content:
|
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||||
|
# but send a single blank message when the bot has no text or files to send.
|
||||||
|
if msg.content or not (msg.media or []):
|
||||||
await self._web_client.chat_postMessage(
|
await self._web_client.chat_postMessage(
|
||||||
channel=msg.chat_id,
|
channel=msg.chat_id,
|
||||||
text=self._to_mrkdwn(msg.content),
|
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||||
thread_ts=thread_ts_param,
|
thread_ts=thread_ts_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from telegram import BotCommand, ReplyParameters, Update
|
from telegram import BotCommand, ReplyParameters, Update
|
||||||
@@ -14,6 +16,50 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import TelegramConfig
|
from nanobot.config.schema import TelegramConfig
|
||||||
|
from nanobot.utils.helpers import split_message
|
||||||
|
|
||||||
|
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_md(s: str) -> str:
|
||||||
|
"""Strip markdown inline formatting from text."""
|
||||||
|
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
|
||||||
|
s = re.sub(r'__(.+?)__', r'\1', s)
|
||||||
|
s = re.sub(r'~~(.+?)~~', r'\1', s)
|
||||||
|
s = re.sub(r'`([^`]+)`', r'\1', s)
|
||||||
|
return s.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _render_table_box(table_lines: list[str]) -> str:
|
||||||
|
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
|
||||||
|
|
||||||
|
def dw(s: str) -> int:
|
||||||
|
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
|
||||||
|
|
||||||
|
rows: list[list[str]] = []
|
||||||
|
has_sep = False
|
||||||
|
for line in table_lines:
|
||||||
|
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
|
||||||
|
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
|
||||||
|
has_sep = True
|
||||||
|
continue
|
||||||
|
rows.append(cells)
|
||||||
|
if not rows or not has_sep:
|
||||||
|
return '\n'.join(table_lines)
|
||||||
|
|
||||||
|
ncols = max(len(r) for r in rows)
|
||||||
|
for r in rows:
|
||||||
|
r.extend([''] * (ncols - len(r)))
|
||||||
|
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
||||||
|
|
||||||
|
def dr(cells: list[str]) -> str:
|
||||||
|
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
|
||||||
|
|
||||||
|
out = [dr(rows[0])]
|
||||||
|
out.append(' '.join('─' * w for w in widths))
|
||||||
|
for row in rows[1:]:
|
||||||
|
out.append(dr(row))
|
||||||
|
return '\n'.join(out)
|
||||||
|
|
||||||
|
|
||||||
def _markdown_to_telegram_html(text: str) -> str:
|
def _markdown_to_telegram_html(text: str) -> str:
|
||||||
@@ -31,6 +77,27 @@ def _markdown_to_telegram_html(text: str) -> str:
|
|||||||
|
|
||||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||||
|
|
||||||
|
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
|
||||||
|
lines = text.split('\n')
|
||||||
|
rebuilt: list[str] = []
|
||||||
|
li = 0
|
||||||
|
while li < len(lines):
|
||||||
|
if re.match(r'^\s*\|.+\|', lines[li]):
|
||||||
|
tbl: list[str] = []
|
||||||
|
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
|
||||||
|
tbl.append(lines[li])
|
||||||
|
li += 1
|
||||||
|
box = _render_table_box(tbl)
|
||||||
|
if box != '\n'.join(tbl):
|
||||||
|
code_blocks.append(box)
|
||||||
|
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
|
||||||
|
else:
|
||||||
|
rebuilt.extend(tbl)
|
||||||
|
else:
|
||||||
|
rebuilt.append(lines[li])
|
||||||
|
li += 1
|
||||||
|
text = '\n'.join(rebuilt)
|
||||||
|
|
||||||
# 2. Extract and protect inline code
|
# 2. Extract and protect inline code
|
||||||
inline_codes: list[str] = []
|
inline_codes: list[str] = []
|
||||||
def save_inline_code(m: re.Match) -> str:
|
def save_inline_code(m: re.Match) -> str:
|
||||||
@@ -79,26 +146,6 @@ def _markdown_to_telegram_html(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _split_message(content: str, max_len: int = 4000) -> list[str]:
|
|
||||||
"""Split content into chunks within max_len, preferring line breaks."""
|
|
||||||
if len(content) <= max_len:
|
|
||||||
return [content]
|
|
||||||
chunks: list[str] = []
|
|
||||||
while content:
|
|
||||||
if len(content) <= max_len:
|
|
||||||
chunks.append(content)
|
|
||||||
break
|
|
||||||
cut = content[:max_len]
|
|
||||||
pos = cut.rfind('\n')
|
|
||||||
if pos == -1:
|
|
||||||
pos = cut.rfind(' ')
|
|
||||||
if pos == -1:
|
|
||||||
pos = max_len
|
|
||||||
chunks.append(content[:pos])
|
|
||||||
content = content[pos:].lstrip()
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramChannel(BaseChannel):
|
class TelegramChannel(BaseChannel):
|
||||||
"""
|
"""
|
||||||
Telegram channel using long polling.
|
Telegram channel using long polling.
|
||||||
@@ -130,6 +177,26 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||||
self._media_group_buffers: dict[str, dict] = {}
|
self._media_group_buffers: dict[str, dict] = {}
|
||||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._message_threads: dict[tuple[str, int], int] = {}
|
||||||
|
|
||||||
|
def is_allowed(self, sender_id: str) -> bool:
|
||||||
|
"""Preserve Telegram's legacy id|username allowlist matching."""
|
||||||
|
if super().is_allowed(sender_id):
|
||||||
|
return True
|
||||||
|
|
||||||
|
allow_list = getattr(self.config, "allow_from", [])
|
||||||
|
if not allow_list or "*" in allow_list:
|
||||||
|
return False
|
||||||
|
|
||||||
|
sender_str = str(sender_id)
|
||||||
|
if sender_str.count("|") != 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
sid, username = sender_str.split("|", 1)
|
||||||
|
if not sid.isdigit() or not username:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return sid in allow_list or username in allow_list
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Telegram bot with long polling."""
|
"""Start the Telegram bot with long polling."""
|
||||||
@@ -140,16 +207,21 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
req = HTTPXRequest(
|
||||||
|
connection_pool_size=16,
|
||||||
|
pool_timeout=5.0,
|
||||||
|
connect_timeout=30.0,
|
||||||
|
read_timeout=30.0,
|
||||||
|
proxy=self.config.proxy if self.config.proxy else None,
|
||||||
|
)
|
||||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||||
if self.config.proxy:
|
|
||||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
|
||||||
self._app = builder.build()
|
self._app = builder.build()
|
||||||
self._app.add_error_handler(self._on_error)
|
self._app.add_error_handler(self._on_error)
|
||||||
|
|
||||||
# Add command handlers
|
# Add command handlers
|
||||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||||
|
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
||||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||||
|
|
||||||
# Add message handler for text, photos, voice, documents
|
# Add message handler for text, photos, voice, documents
|
||||||
@@ -234,10 +306,16 @@ class TelegramChannel(BaseChannel):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||||
return
|
return
|
||||||
|
reply_to_message_id = msg.metadata.get("message_id")
|
||||||
|
message_thread_id = msg.metadata.get("message_thread_id")
|
||||||
|
if message_thread_id is None and reply_to_message_id is not None:
|
||||||
|
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
|
||||||
|
thread_kwargs = {}
|
||||||
|
if message_thread_id is not None:
|
||||||
|
thread_kwargs["message_thread_id"] = message_thread_id
|
||||||
|
|
||||||
reply_params = None
|
reply_params = None
|
||||||
if self.config.reply_to_message:
|
if self.config.reply_to_message:
|
||||||
reply_to_message_id = msg.metadata.get("message_id")
|
|
||||||
if reply_to_message_id:
|
if reply_to_message_id:
|
||||||
reply_params = ReplyParameters(
|
reply_params = ReplyParameters(
|
||||||
message_id=reply_to_message_id,
|
message_id=reply_to_message_id,
|
||||||
@@ -258,7 +336,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
await sender(
|
await sender(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
**{param: f},
|
**{param: f},
|
||||||
reply_parameters=reply_params
|
reply_parameters=reply_params,
|
||||||
|
**thread_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
filename = media_path.rsplit("/", 1)[-1]
|
filename = media_path.rsplit("/", 1)[-1]
|
||||||
@@ -266,48 +345,71 @@ class TelegramChannel(BaseChannel):
|
|||||||
await self._app.bot.send_message(
|
await self._app.bot.send_message(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
text=f"[Failed to send: {filename}]",
|
text=f"[Failed to send: {filename}]",
|
||||||
reply_parameters=reply_params
|
reply_parameters=reply_params,
|
||||||
|
**thread_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send text content
|
# Send text content
|
||||||
if msg.content and msg.content != "[empty message]":
|
if msg.content and msg.content != "[empty message]":
|
||||||
is_progress = msg.metadata.get("_progress", False)
|
is_progress = msg.metadata.get("_progress", False)
|
||||||
draft_id = msg.metadata.get("message_id")
|
|
||||||
|
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||||
for chunk in _split_message(msg.content):
|
# Final response: simulate streaming via draft, then persist
|
||||||
try:
|
if not is_progress:
|
||||||
html = _markdown_to_telegram_html(chunk)
|
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
|
||||||
if is_progress and draft_id:
|
else:
|
||||||
await self._app.bot.send_message_draft(
|
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||||
chat_id=chat_id,
|
|
||||||
draft_id=draft_id,
|
async def _send_text(
|
||||||
text=html,
|
self,
|
||||||
parse_mode="HTML"
|
chat_id: int,
|
||||||
)
|
text: str,
|
||||||
else:
|
reply_params=None,
|
||||||
await self._app.bot.send_message(
|
thread_kwargs: dict | None = None,
|
||||||
chat_id=chat_id,
|
) -> None:
|
||||||
text=html,
|
"""Send a plain text message with HTML fallback."""
|
||||||
parse_mode="HTML",
|
try:
|
||||||
reply_parameters=reply_params
|
html = _markdown_to_telegram_html(text)
|
||||||
)
|
await self._app.bot.send_message(
|
||||||
except Exception as e:
|
chat_id=chat_id, text=html, parse_mode="HTML",
|
||||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
reply_parameters=reply_params,
|
||||||
try:
|
**(thread_kwargs or {}),
|
||||||
if is_progress and draft_id:
|
)
|
||||||
await self._app.bot.send_message_draft(
|
except Exception as e:
|
||||||
chat_id=chat_id,
|
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||||
draft_id=draft_id,
|
try:
|
||||||
text=chunk
|
await self._app.bot.send_message(
|
||||||
)
|
chat_id=chat_id,
|
||||||
else:
|
text=text,
|
||||||
await self._app.bot.send_message(
|
reply_parameters=reply_params,
|
||||||
chat_id=chat_id,
|
**(thread_kwargs or {}),
|
||||||
text=chunk,
|
)
|
||||||
reply_parameters=reply_params
|
except Exception as e2:
|
||||||
)
|
logger.error("Error sending Telegram message: {}", e2)
|
||||||
except Exception as e2:
|
|
||||||
logger.error("Error sending Telegram message: {}", e2)
|
async def _send_with_streaming(
|
||||||
|
self,
|
||||||
|
chat_id: int,
|
||||||
|
text: str,
|
||||||
|
reply_params=None,
|
||||||
|
thread_kwargs: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
||||||
|
draft_id = int(time.time() * 1000) % (2**31)
|
||||||
|
try:
|
||||||
|
step = max(len(text) // 8, 40)
|
||||||
|
for i in range(step, len(text), step):
|
||||||
|
await self._app.bot.send_message_draft(
|
||||||
|
chat_id=chat_id, draft_id=draft_id, text=text[:i],
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.04)
|
||||||
|
await self._app.bot.send_message_draft(
|
||||||
|
chat_id=chat_id, draft_id=draft_id, text=text,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await self._send_text(chat_id, text, reply_params, thread_kwargs)
|
||||||
|
|
||||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
@@ -338,14 +440,50 @@ class TelegramChannel(BaseChannel):
|
|||||||
sid = str(user.id)
|
sid = str(user.id)
|
||||||
return f"{sid}|{user.username}" if user.username else sid
|
return f"{sid}|{user.username}" if user.username else sid
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _derive_topic_session_key(message) -> str | None:
|
||||||
|
"""Derive topic-scoped session key for non-private Telegram chats."""
|
||||||
|
message_thread_id = getattr(message, "message_thread_id", None)
|
||||||
|
if message.chat.type == "private" or message_thread_id is None:
|
||||||
|
return None
|
||||||
|
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_message_metadata(message, user) -> dict:
|
||||||
|
"""Build common Telegram inbound metadata payload."""
|
||||||
|
return {
|
||||||
|
"message_id": message.message_id,
|
||||||
|
"user_id": user.id,
|
||||||
|
"username": user.username,
|
||||||
|
"first_name": user.first_name,
|
||||||
|
"is_group": message.chat.type != "private",
|
||||||
|
"message_thread_id": getattr(message, "message_thread_id", None),
|
||||||
|
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _remember_thread_context(self, message) -> None:
|
||||||
|
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||||
|
message_thread_id = getattr(message, "message_thread_id", None)
|
||||||
|
if message_thread_id is None:
|
||||||
|
return
|
||||||
|
key = (str(message.chat_id), message.message_id)
|
||||||
|
self._message_threads[key] = message_thread_id
|
||||||
|
if len(self._message_threads) > 1000:
|
||||||
|
self._message_threads.pop(next(iter(self._message_threads)))
|
||||||
|
|
||||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||||
if not update.message or not update.effective_user:
|
if not update.message or not update.effective_user:
|
||||||
return
|
return
|
||||||
|
message = update.message
|
||||||
|
user = update.effective_user
|
||||||
|
self._remember_thread_context(message)
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=self._sender_id(update.effective_user),
|
sender_id=self._sender_id(user),
|
||||||
chat_id=str(update.message.chat_id),
|
chat_id=str(message.chat_id),
|
||||||
content=update.message.text,
|
content=message.text,
|
||||||
|
metadata=self._build_message_metadata(message, user),
|
||||||
|
session_key=self._derive_topic_session_key(message),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
@@ -357,6 +495,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
chat_id = message.chat_id
|
chat_id = message.chat_id
|
||||||
sender_id = self._sender_id(user)
|
sender_id = self._sender_id(user)
|
||||||
|
self._remember_thread_context(message)
|
||||||
|
|
||||||
# Store chat_id for replies
|
# Store chat_id for replies
|
||||||
self._chat_ids[sender_id] = chat_id
|
self._chat_ids[sender_id] = chat_id
|
||||||
@@ -392,8 +531,11 @@ class TelegramChannel(BaseChannel):
|
|||||||
if media_file and self._app:
|
if media_file and self._app:
|
||||||
try:
|
try:
|
||||||
file = await self._app.bot.get_file(media_file.file_id)
|
file = await self._app.bot.get_file(media_file.file_id)
|
||||||
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
ext = self._get_extension(
|
||||||
|
media_type,
|
||||||
|
getattr(media_file, 'mime_type', None),
|
||||||
|
getattr(media_file, 'file_name', None),
|
||||||
|
)
|
||||||
# Save to workspace/media/
|
# Save to workspace/media/
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
media_dir = Path.home() / ".nanobot" / "media"
|
media_dir = Path.home() / ".nanobot" / "media"
|
||||||
@@ -427,6 +569,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||||
|
|
||||||
str_chat_id = str(chat_id)
|
str_chat_id = str(chat_id)
|
||||||
|
metadata = self._build_message_metadata(message, user)
|
||||||
|
session_key = self._derive_topic_session_key(message)
|
||||||
|
|
||||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||||
if media_group_id := getattr(message, "media_group_id", None):
|
if media_group_id := getattr(message, "media_group_id", None):
|
||||||
@@ -435,11 +579,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._media_group_buffers[key] = {
|
self._media_group_buffers[key] = {
|
||||||
"sender_id": sender_id, "chat_id": str_chat_id,
|
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||||
"contents": [], "media": [],
|
"contents": [], "media": [],
|
||||||
"metadata": {
|
"metadata": metadata,
|
||||||
"message_id": message.message_id, "user_id": user.id,
|
"session_key": session_key,
|
||||||
"username": user.username, "first_name": user.first_name,
|
|
||||||
"is_group": message.chat.type != "private",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
self._start_typing(str_chat_id)
|
self._start_typing(str_chat_id)
|
||||||
buf = self._media_group_buffers[key]
|
buf = self._media_group_buffers[key]
|
||||||
@@ -459,13 +600,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
chat_id=str_chat_id,
|
chat_id=str_chat_id,
|
||||||
content=content,
|
content=content,
|
||||||
media=media_paths,
|
media=media_paths,
|
||||||
metadata={
|
metadata=metadata,
|
||||||
"message_id": message.message_id,
|
session_key=session_key,
|
||||||
"user_id": user.id,
|
|
||||||
"username": user.username,
|
|
||||||
"first_name": user.first_name,
|
|
||||||
"is_group": message.chat.type != "private"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _flush_media_group(self, key: str) -> None:
|
async def _flush_media_group(self, key: str) -> None:
|
||||||
@@ -479,6 +615,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||||
content=content, media=list(dict.fromkeys(buf["media"])),
|
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||||
metadata=buf["metadata"],
|
metadata=buf["metadata"],
|
||||||
|
session_key=buf.get("session_key"),
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._media_group_tasks.pop(key, None)
|
self._media_group_tasks.pop(key, None)
|
||||||
@@ -510,8 +647,13 @@ class TelegramChannel(BaseChannel):
|
|||||||
"""Log polling / handler errors instead of silently swallowing them."""
|
"""Log polling / handler errors instead of silently swallowing them."""
|
||||||
logger.error("Telegram error: {}", context.error)
|
logger.error("Telegram error: {}", context.error)
|
||||||
|
|
||||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
def _get_extension(
|
||||||
"""Get file extension based on media type."""
|
self,
|
||||||
|
media_type: str,
|
||||||
|
mime_type: str | None,
|
||||||
|
filename: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Get file extension based on media type or original filename."""
|
||||||
if mime_type:
|
if mime_type:
|
||||||
ext_map = {
|
ext_map = {
|
||||||
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
||||||
@@ -521,4 +663,12 @@ class TelegramChannel(BaseChannel):
|
|||||||
return ext_map[mime_type]
|
return ext_map[mime_type]
|
||||||
|
|
||||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||||
return type_map.get(media_type, "")
|
if ext := type_map.get(media_type, ""):
|
||||||
|
return ext
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
return "".join(Path(filename).suffixes)
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import mimetypes
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -128,10 +129,22 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||||
|
|
||||||
|
# Extract media paths (images/documents/videos downloaded by the bridge)
|
||||||
|
media_paths = data.get("media") or []
|
||||||
|
|
||||||
|
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
|
||||||
|
if media_paths:
|
||||||
|
for p in media_paths:
|
||||||
|
mime, _ = mimetypes.guess_type(p)
|
||||||
|
media_type = "image" if mime and mime.startswith("image/") else "file"
|
||||||
|
media_tag = f"[{media_type}: {p}]"
|
||||||
|
content = f"{content}\n{media_tag}" if content else media_tag
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
chat_id=sender, # Use full LID for replies
|
chat_id=sender, # Use full LID for replies
|
||||||
content=content,
|
content=content,
|
||||||
|
media=media_paths,
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"timestamp": data.get("timestamp"),
|
"timestamp": data.get("timestamp"),
|
||||||
|
|||||||
@@ -7,6 +7,18 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Force UTF-8 encoding for Windows console
|
||||||
|
if sys.platform == "win32":
|
||||||
|
import locale
|
||||||
|
if sys.stdout.encoding != "utf-8":
|
||||||
|
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||||
|
# Re-open stdout/stderr with UTF-8 encoding
|
||||||
|
try:
|
||||||
|
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from prompt_toolkit import PromptSession
|
from prompt_toolkit import PromptSession
|
||||||
from prompt_toolkit.formatted_text import HTML
|
from prompt_toolkit.formatted_text import HTML
|
||||||
@@ -200,9 +212,8 @@ def onboard():
|
|||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config."""
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
provider_name = config.get_provider_name(model)
|
provider_name = config.get_provider_name(model)
|
||||||
@@ -213,6 +224,7 @@ def _make_provider(config: Config):
|
|||||||
return OpenAICodexProvider(default_model=model)
|
return OpenAICodexProvider(default_model=model)
|
||||||
|
|
||||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||||
|
from nanobot.providers.custom_provider import CustomProvider
|
||||||
if provider_name == "custom":
|
if provider_name == "custom":
|
||||||
return CustomProvider(
|
return CustomProvider(
|
||||||
api_key=p.api_key if p else "no-key",
|
api_key=p.api_key if p else "no-key",
|
||||||
@@ -220,6 +232,21 @@ def _make_provider(config: Config):
|
|||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||||
|
if provider_name == "azure_openai":
|
||||||
|
if not p or not p.api_key or not p.api_base:
|
||||||
|
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||||
|
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||||
|
console.print("Use the model field to specify the deployment name.")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
return AzureOpenAIProvider(
|
||||||
|
api_key=p.api_key,
|
||||||
|
api_base=p.api_base,
|
||||||
|
default_model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
spec = find_by_name(provider_name)
|
spec = find_by_name(provider_name)
|
||||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
|
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
|
||||||
@@ -530,9 +557,13 @@ def agent(
|
|||||||
|
|
||||||
signal.signal(signal.SIGINT, _handle_signal)
|
signal.signal(signal.SIGINT, _handle_signal)
|
||||||
signal.signal(signal.SIGTERM, _handle_signal)
|
signal.signal(signal.SIGTERM, _handle_signal)
|
||||||
signal.signal(signal.SIGHUP, _handle_signal)
|
# SIGHUP is not available on Windows
|
||||||
|
if hasattr(signal, 'SIGHUP'):
|
||||||
|
signal.signal(signal.SIGHUP, _handle_signal)
|
||||||
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
|
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
|
||||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
# SIGPIPE is not available on Windows
|
||||||
|
if hasattr(signal, 'SIGPIPE'):
|
||||||
|
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||||
|
|
||||||
async def run_interactive():
|
async def run_interactive():
|
||||||
bus_task = asyncio.create_task(agent_loop.run())
|
bus_task = asyncio.create_task(agent_loop.run())
|
||||||
|
|||||||
@@ -199,21 +199,6 @@ class QQConfig(Base):
|
|||||||
) # Allowed user openids (empty = public access)
|
) # Allowed user openids (empty = public access)
|
||||||
|
|
||||||
|
|
||||||
class MatrixConfig(Base):
|
|
||||||
"""Matrix (Element) channel configuration."""
|
|
||||||
|
|
||||||
enabled: bool = False
|
|
||||||
homeserver: str = "https://matrix.org"
|
|
||||||
access_token: str = ""
|
|
||||||
user_id: str = "" # e.g. @bot:matrix.org
|
|
||||||
device_id: str = ""
|
|
||||||
e2ee_enabled: bool = True # end-to-end encryption support
|
|
||||||
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
|
|
||||||
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
|
|
||||||
allow_from: list[str] = Field(default_factory=list)
|
|
||||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
|
||||||
group_allow_from: list[str] = Field(default_factory=list)
|
|
||||||
allow_room_mentions: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelsConfig(Base):
|
class ChannelsConfig(Base):
|
||||||
@@ -266,6 +251,7 @@ class ProvidersConfig(Base):
|
|||||||
"""Configuration for LLM providers."""
|
"""Configuration for LLM providers."""
|
||||||
|
|
||||||
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||||
|
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
|
||||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
@@ -278,12 +264,8 @@ class ProvidersConfig(Base):
|
|||||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
siliconflow: ProviderConfig = Field(
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||||
default_factory=ProviderConfig
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||||
) # SiliconFlow (硅基流动) API gateway
|
|
||||||
volcengine: ProviderConfig = Field(
|
|
||||||
default_factory=ProviderConfig
|
|
||||||
) # VolcEngine (火山引擎) API gateway
|
|
||||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||||
|
|
||||||
|
|||||||
@@ -3,5 +3,6 @@
|
|||||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
|
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
|
||||||
|
|||||||
210
nanobot/providers/azure_openai_provider.py
Normal file
210
nanobot/providers/azure_openai_provider.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json_repair
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIProvider(LLMProvider):
|
||||||
|
"""
|
||||||
|
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Hardcoded API version 2024-10-21
|
||||||
|
- Uses model field as Azure deployment name in URL path
|
||||||
|
- Uses api-key header instead of Authorization Bearer
|
||||||
|
- Uses max_completion_tokens instead of max_tokens
|
||||||
|
- Direct HTTP calls, bypasses LiteLLM
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "",
|
||||||
|
api_base: str = "",
|
||||||
|
default_model: str = "gpt-5.2-chat",
|
||||||
|
):
|
||||||
|
super().__init__(api_key, api_base)
|
||||||
|
self.default_model = default_model
|
||||||
|
self.api_version = "2024-10-21"
|
||||||
|
|
||||||
|
# Validate required parameters
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("Azure OpenAI api_key is required")
|
||||||
|
if not api_base:
|
||||||
|
raise ValueError("Azure OpenAI api_base is required")
|
||||||
|
|
||||||
|
# Ensure api_base ends with /
|
||||||
|
if not api_base.endswith('/'):
|
||||||
|
api_base += '/'
|
||||||
|
self.api_base = api_base
|
||||||
|
|
||||||
|
def _build_chat_url(self, deployment_name: str) -> str:
|
||||||
|
"""Build the Azure OpenAI chat completions URL."""
|
||||||
|
# Azure OpenAI URL format:
|
||||||
|
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||||
|
base_url = self.api_base
|
||||||
|
if not base_url.endswith('/'):
|
||||||
|
base_url += '/'
|
||||||
|
|
||||||
|
url = urljoin(
|
||||||
|
base_url,
|
||||||
|
f"openai/deployments/{deployment_name}/chat/completions"
|
||||||
|
)
|
||||||
|
return f"{url}?api-version={self.api_version}"
|
||||||
|
|
||||||
|
def _build_headers(self) -> dict[str, str]:
|
||||||
|
"""Build headers for Azure OpenAI API with api-key header."""
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||||
|
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _supports_temperature(
|
||||||
|
deployment_name: str,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True when temperature is likely supported for this deployment."""
|
||||||
|
if reasoning_effort:
|
||||||
|
return False
|
||||||
|
name = deployment_name.lower()
|
||||||
|
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||||
|
|
||||||
|
def _prepare_request_payload(
|
||||||
|
self,
|
||||||
|
deployment_name: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"messages": self._sanitize_request_messages(
|
||||||
|
self._sanitize_empty_content(messages),
|
||||||
|
_AZURE_MSG_KEYS,
|
||||||
|
),
|
||||||
|
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||||
|
payload["temperature"] = temperature
|
||||||
|
|
||||||
|
if reasoning_effort:
|
||||||
|
payload["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = tools
|
||||||
|
payload["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
Send a chat completion request to Azure OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
tools: Optional list of tool definitions in OpenAI format.
|
||||||
|
model: Model identifier (used as deployment name).
|
||||||
|
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
reasoning_effort: Optional reasoning effort parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and/or tool calls.
|
||||||
|
"""
|
||||||
|
deployment_name = model or self.default_model
|
||||||
|
url = self._build_chat_url(deployment_name)
|
||||||
|
headers = self._build_headers()
|
||||||
|
payload = self._prepare_request_payload(
|
||||||
|
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
return self._parse_response(response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||||
|
"""Parse Azure OpenAI response into our standard format."""
|
||||||
|
try:
|
||||||
|
choice = response["choices"][0]
|
||||||
|
message = choice["message"]
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
for tc in message["tool_calls"]:
|
||||||
|
# Parse arguments from JSON string if needed
|
||||||
|
args = tc["function"]["arguments"]
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json_repair.loads(args)
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCallRequest(
|
||||||
|
id=tc["id"],
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = {}
|
||||||
|
if response.get("usage"):
|
||||||
|
usage_data = response["usage"]
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||||
|
"total_tokens": usage_data.get("total_tokens", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
reasoning_content = message.get("reasoning_content") or None
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content"),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=choice.get("finish_reason", "stop"),
|
||||||
|
usage=usage,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (KeyError, IndexError) as e:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
"""Get the default model (also used as default deployment name)."""
|
||||||
|
return self.default_model
|
||||||
@@ -87,6 +87,20 @@ class LLMProvider(ABC):
|
|||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_request_messages(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
allowed_keys: frozenset[str],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Keep only provider-safe message keys and normalize assistant content."""
|
||||||
|
sanitized = []
|
||||||
|
for msg in messages:
|
||||||
|
clean = {k: v for k, v in msg.items() if k in allowed_keys}
|
||||||
|
if clean.get("role") == "assistant" and "content" not in clean:
|
||||||
|
clean["content"] = None
|
||||||
|
sanitized.append(clean)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""LiteLLM provider implementation for multi-provider support."""
|
"""LiteLLM provider implementation for multi-provider support."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
@@ -8,6 +9,7 @@ from typing import Any
|
|||||||
import json_repair
|
import json_repair
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
from nanobot.providers.registry import find_by_model, find_gateway
|
||||||
@@ -165,17 +167,43 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
return _ANTHROPIC_EXTRA_KEYS
|
return _ANTHROPIC_EXTRA_KEYS
|
||||||
return frozenset()
|
return frozenset()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
||||||
|
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
|
||||||
|
if not isinstance(tool_call_id, str):
|
||||||
|
return tool_call_id
|
||||||
|
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
||||||
|
return tool_call_id
|
||||||
|
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
||||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||||
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
||||||
sanitized = []
|
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
|
||||||
for msg in messages:
|
id_map: dict[str, str] = {}
|
||||||
clean = {k: v for k, v in msg.items() if k in allowed}
|
|
||||||
# Strict providers require "content" even when assistant only has tool_calls
|
def map_id(value: Any) -> Any:
|
||||||
if clean.get("role") == "assistant" and "content" not in clean:
|
if not isinstance(value, str):
|
||||||
clean["content"] = None
|
return value
|
||||||
sanitized.append(clean)
|
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
||||||
|
|
||||||
|
for clean in sanitized:
|
||||||
|
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
||||||
|
# shortening, otherwise strict providers reject the broken linkage.
|
||||||
|
if isinstance(clean.get("tool_calls"), list):
|
||||||
|
normalized_tool_calls = []
|
||||||
|
for tc in clean["tool_calls"]:
|
||||||
|
if not isinstance(tc, dict):
|
||||||
|
normalized_tool_calls.append(tc)
|
||||||
|
continue
|
||||||
|
tc_clean = dict(tc)
|
||||||
|
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||||
|
normalized_tool_calls.append(tc_clean)
|
||||||
|
clean["tool_calls"] = normalized_tool_calls
|
||||||
|
|
||||||
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
@@ -255,20 +283,37 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"""Parse LiteLLM response into our standard format."""
|
"""Parse LiteLLM response into our standard format."""
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
message = choice.message
|
message = choice.message
|
||||||
|
content = message.content
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
# Some providers (e.g. GitHub Copilot) split content and tool_calls
|
||||||
|
# across multiple choices. Merge them so tool_calls are not lost.
|
||||||
|
raw_tool_calls = []
|
||||||
|
for ch in response.choices:
|
||||||
|
msg = ch.message
|
||||||
|
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||||
|
raw_tool_calls.extend(msg.tool_calls)
|
||||||
|
if ch.finish_reason in ("tool_calls", "stop"):
|
||||||
|
finish_reason = ch.finish_reason
|
||||||
|
if not content and msg.content:
|
||||||
|
content = msg.content
|
||||||
|
|
||||||
|
if len(response.choices) > 1:
|
||||||
|
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
|
||||||
|
len(response.choices), len(raw_tool_calls))
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
for tc in raw_tool_calls:
|
||||||
for tc in message.tool_calls:
|
# Parse arguments from JSON string if needed
|
||||||
# Parse arguments from JSON string if needed
|
args = tc.function.arguments
|
||||||
args = tc.function.arguments
|
if isinstance(args, str):
|
||||||
if isinstance(args, str):
|
args = json_repair.loads(args)
|
||||||
args = json_repair.loads(args)
|
|
||||||
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=_short_tool_id(),
|
id=_short_tool_id(),
|
||||||
name=tc.function.name,
|
name=tc.function.name,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
))
|
))
|
||||||
|
|
||||||
usage = {}
|
usage = {}
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
@@ -280,11 +325,11 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
|
|
||||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||||
thinking_blocks = getattr(message, "thinking_blocks", None) or None
|
thinking_blocks = getattr(message, "thinking_blocks", None) or None
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=message.content,
|
content=content,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
finish_reason=choice.finish_reason or "stop",
|
finish_reason=finish_reason or "stop",
|
||||||
usage=usage,
|
usage=usage,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
thinking_blocks=thinking_blocks,
|
thinking_blocks=thinking_blocks,
|
||||||
|
|||||||
@@ -26,33 +26,33 @@ class ProviderSpec:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# identity
|
# identity
|
||||||
name: str # config field name, e.g. "dashscope"
|
name: str # config field name, e.g. "dashscope"
|
||||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||||
display_name: str = "" # shown in `nanobot status`
|
display_name: str = "" # shown in `nanobot status`
|
||||||
|
|
||||||
# model prefixing
|
# model prefixing
|
||||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||||
|
|
||||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||||
env_extras: tuple[tuple[str, str], ...] = ()
|
env_extras: tuple[tuple[str, str], ...] = ()
|
||||||
|
|
||||||
# gateway / local detection
|
# gateway / local detection
|
||||||
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
||||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||||
default_api_base: str = "" # fallback base URL
|
default_api_base: str = "" # fallback base URL
|
||||||
|
|
||||||
# gateway behavior
|
# gateway behavior
|
||||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||||
|
|
||||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||||
|
|
||||||
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
||||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
||||||
|
|
||||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
||||||
is_direct: bool = False
|
is_direct: bool = False
|
||||||
@@ -70,7 +70,6 @@ class ProviderSpec:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||||
|
|
||||||
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="custom",
|
name="custom",
|
||||||
@@ -81,16 +80,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
is_direct=True,
|
is_direct=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
|
||||||
|
ProviderSpec(
|
||||||
|
name="azure_openai",
|
||||||
|
keywords=("azure", "azure-openai"),
|
||||||
|
env_key="",
|
||||||
|
display_name="Azure OpenAI",
|
||||||
|
litellm_prefix="",
|
||||||
|
is_direct=True,
|
||||||
|
),
|
||||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||||
# Gateways can route any model, so they win in fallback.
|
# Gateways can route any model, so they win in fallback.
|
||||||
|
|
||||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openrouter",
|
name="openrouter",
|
||||||
keywords=("openrouter",),
|
keywords=("openrouter",),
|
||||||
env_key="OPENROUTER_API_KEY",
|
env_key="OPENROUTER_API_KEY",
|
||||||
display_name="OpenRouter",
|
display_name="OpenRouter",
|
||||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
@@ -102,16 +109,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
supports_prompt_caching=True,
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="aihubmix",
|
name="aihubmix",
|
||||||
keywords=("aihubmix",),
|
keywords=("aihubmix",),
|
||||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||||
display_name="AiHubMix",
|
display_name="AiHubMix",
|
||||||
litellm_prefix="openai", # → openai/{model}
|
litellm_prefix="openai", # → openai/{model}
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
@@ -119,10 +125,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
detect_by_base_keyword="aihubmix",
|
detect_by_base_keyword="aihubmix",
|
||||||
default_api_base="https://aihubmix.com/v1",
|
default_api_base="https://aihubmix.com/v1",
|
||||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="siliconflow",
|
name="siliconflow",
|
||||||
@@ -140,7 +145,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="volcengine",
|
name="volcengine",
|
||||||
@@ -158,9 +162,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Standard providers (matched by model-name keywords) ===============
|
# === Standard providers (matched by model-name keywords) ===============
|
||||||
|
|
||||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="anthropic",
|
name="anthropic",
|
||||||
@@ -179,7 +181,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
supports_prompt_caching=True,
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openai",
|
name="openai",
|
||||||
@@ -197,14 +198,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# OpenAI Codex: uses OAuth, not API key.
|
# OpenAI Codex: uses OAuth, not API key.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openai_codex",
|
name="openai_codex",
|
||||||
keywords=("openai-codex",),
|
keywords=("openai-codex",),
|
||||||
env_key="", # OAuth-based, no API key
|
env_key="", # OAuth-based, no API key
|
||||||
display_name="OpenAI Codex",
|
display_name="OpenAI Codex",
|
||||||
litellm_prefix="", # Not routed through LiteLLM
|
litellm_prefix="", # Not routed through LiteLLM
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@@ -214,16 +214,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
default_api_base="https://chatgpt.com/backend-api",
|
default_api_base="https://chatgpt.com/backend-api",
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
is_oauth=True, # OAuth-based authentication
|
is_oauth=True, # OAuth-based authentication
|
||||||
),
|
),
|
||||||
|
|
||||||
# Github Copilot: uses OAuth, not API key.
|
# Github Copilot: uses OAuth, not API key.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="github_copilot",
|
name="github_copilot",
|
||||||
keywords=("github_copilot", "copilot"),
|
keywords=("github_copilot", "copilot"),
|
||||||
env_key="", # OAuth-based, no API key
|
env_key="", # OAuth-based, no API key
|
||||||
display_name="Github Copilot",
|
display_name="Github Copilot",
|
||||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
||||||
skip_prefixes=("github_copilot/",),
|
skip_prefixes=("github_copilot/",),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@@ -233,17 +232,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
default_api_base="",
|
default_api_base="",
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
is_oauth=True, # OAuth-based authentication
|
is_oauth=True, # OAuth-based authentication
|
||||||
),
|
),
|
||||||
|
|
||||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
keywords=("deepseek",),
|
keywords=("deepseek",),
|
||||||
env_key="DEEPSEEK_API_KEY",
|
env_key="DEEPSEEK_API_KEY",
|
||||||
display_name="DeepSeek",
|
display_name="DeepSeek",
|
||||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
@@ -253,15 +251,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="gemini",
|
name="gemini",
|
||||||
keywords=("gemini",),
|
keywords=("gemini",),
|
||||||
env_key="GEMINI_API_KEY",
|
env_key="GEMINI_API_KEY",
|
||||||
display_name="Gemini",
|
display_name="Gemini",
|
||||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
@@ -271,7 +268,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||||
@@ -280,11 +276,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("zhipu", "glm", "zai"),
|
keywords=("zhipu", "glm", "zai"),
|
||||||
env_key="ZAI_API_KEY",
|
env_key="ZAI_API_KEY",
|
||||||
display_name="Zhipu AI",
|
display_name="Zhipu AI",
|
||||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||||
env_extras=(
|
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||||
("ZHIPUAI_API_KEY", "{api_key}"),
|
|
||||||
),
|
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
@@ -293,14 +287,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="dashscope",
|
name="dashscope",
|
||||||
keywords=("qwen", "dashscope"),
|
keywords=("qwen", "dashscope"),
|
||||||
env_key="DASHSCOPE_API_KEY",
|
env_key="DASHSCOPE_API_KEY",
|
||||||
display_name="DashScope",
|
display_name="DashScope",
|
||||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||||
skip_prefixes=("dashscope/", "openrouter/"),
|
skip_prefixes=("dashscope/", "openrouter/"),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@@ -311,7 +304,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||||
@@ -320,22 +312,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("moonshot", "kimi"),
|
keywords=("moonshot", "kimi"),
|
||||||
env_key="MOONSHOT_API_KEY",
|
env_key="MOONSHOT_API_KEY",
|
||||||
display_name="Moonshot",
|
display_name="Moonshot",
|
||||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||||
skip_prefixes=("moonshot/", "openrouter/"),
|
skip_prefixes=("moonshot/", "openrouter/"),
|
||||||
env_extras=(
|
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||||
("MOONSHOT_API_BASE", "{api_base}"),
|
|
||||||
),
|
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
detect_by_base_keyword="",
|
detect_by_base_keyword="",
|
||||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(
|
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||||
("kimi-k2.5", {"temperature": 1.0}),
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
|
|
||||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@@ -343,7 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("minimax",),
|
keywords=("minimax",),
|
||||||
env_key="MINIMAX_API_KEY",
|
env_key="MINIMAX_API_KEY",
|
||||||
display_name="MiniMax",
|
display_name="MiniMax",
|
||||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||||
skip_prefixes=("minimax/", "openrouter/"),
|
skip_prefixes=("minimax/", "openrouter/"),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@@ -354,9 +341,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||||
|
|
||||||
# vLLM / any OpenAI-compatible local server.
|
# vLLM / any OpenAI-compatible local server.
|
||||||
# Detected when config key is "vllm" (provider_name="vllm").
|
# Detected when config key is "vllm" (provider_name="vllm").
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@@ -364,20 +349,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("vllm",),
|
keywords=("vllm",),
|
||||||
env_key="HOSTED_VLLM_API_KEY",
|
env_key="HOSTED_VLLM_API_KEY",
|
||||||
display_name="vLLM/Local",
|
display_name="vLLM/Local",
|
||||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=True,
|
is_local=True,
|
||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
detect_by_base_keyword="",
|
detect_by_base_keyword="",
|
||||||
default_api_base="", # user must provide in config
|
default_api_base="", # user must provide in config
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Auxiliary (not a primary LLM provider) ============================
|
# === Auxiliary (not a primary LLM provider) ============================
|
||||||
|
|
||||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@@ -385,8 +368,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("groq",),
|
keywords=("groq",),
|
||||||
env_key="GROQ_API_KEY",
|
env_key="GROQ_API_KEY",
|
||||||
display_name="Groq",
|
display_name="Groq",
|
||||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||||
skip_prefixes=("groq/",), # avoid double-prefix
|
skip_prefixes=("groq/",), # avoid double-prefix
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
@@ -403,6 +386,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
# Lookup helpers
|
# Lookup helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def find_by_model(model: str) -> ProviderSpec | None:
|
def find_by_model(model: str) -> ProviderSpec | None:
|
||||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
"""Match a standard provider by model-name keyword (case-insensitive).
|
||||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
Skips gateways/local — those are matched by api_key/api_base instead."""
|
||||||
@@ -418,7 +402,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
for spec in std_specs:
|
for spec in std_specs:
|
||||||
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
|
if any(
|
||||||
|
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
|
||||||
|
):
|
||||||
return spec
|
return spec
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,19 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def detect_image_mime(data: bytes) -> str | None:
|
||||||
|
"""Detect image MIME type from magic bytes, ignoring file extension."""
|
||||||
|
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||||
|
return "image/png"
|
||||||
|
if data[:3] == b"\xff\xd8\xff":
|
||||||
|
return "image/jpeg"
|
||||||
|
if data[:6] in (b"GIF87a", b"GIF89a"):
|
||||||
|
return "image/gif"
|
||||||
|
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
||||||
|
return "image/webp"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def ensure_dir(path: Path) -> Path:
|
def ensure_dir(path: Path) -> Path:
|
||||||
"""Ensure directory exists, return it."""
|
"""Ensure directory exists, return it."""
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -35,6 +48,38 @@ def safe_filename(name: str) -> str:
|
|||||||
return _UNSAFE_CHARS.sub("_", name).strip()
|
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
||||||
|
"""
|
||||||
|
Split content into chunks within max_len, preferring line breaks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The text content to split.
|
||||||
|
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message chunks, each within max_len.
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return []
|
||||||
|
if len(content) <= max_len:
|
||||||
|
return [content]
|
||||||
|
chunks: list[str] = []
|
||||||
|
while content:
|
||||||
|
if len(content) <= max_len:
|
||||||
|
chunks.append(content)
|
||||||
|
break
|
||||||
|
cut = content[:max_len]
|
||||||
|
# Try to break at newline first, then space, then hard break
|
||||||
|
pos = cut.rfind('\n')
|
||||||
|
if pos <= 0:
|
||||||
|
pos = cut.rfind(' ')
|
||||||
|
if pos <= 0:
|
||||||
|
pos = max_len
|
||||||
|
chunks.append(content[:pos])
|
||||||
|
content = content[pos:].lstrip()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||||
from importlib.resources import files as pkg_files
|
from importlib.resources import files as pkg_files
|
||||||
|
|||||||
399
tests/test_azure_openai_provider.py
Normal file
399
tests/test_azure_openai_provider.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_openai_provider_init():
|
||||||
|
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||||
|
assert provider.default_model == "gpt-4o-deployment"
|
||||||
|
assert provider.api_version == "2024-10-21"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_openai_provider_init_validation():
|
||||||
|
"""Test AzureOpenAIProvider initialization validation."""
|
||||||
|
# Missing api_key
|
||||||
|
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||||
|
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||||
|
|
||||||
|
# Missing api_base
|
||||||
|
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||||
|
AzureOpenAIProvider(api_key="test", api_base="")
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_chat_url():
|
||||||
|
"""Test Azure OpenAI URL building with different deployment names."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test various deployment names
|
||||||
|
test_cases = [
|
||||||
|
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||||
|
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||||
|
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for deployment_name, expected_url in test_cases:
|
||||||
|
url = provider._build_chat_url(deployment_name)
|
||||||
|
assert url == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_chat_url_api_base_without_slash():
|
||||||
|
"""Test URL building when api_base doesn't end with slash."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
url = provider._build_chat_url("test-deployment")
|
||||||
|
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert url == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_headers():
|
||||||
|
"""Test Azure OpenAI header building with api-key authentication."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-api-key-123",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = provider._build_headers()
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||||
|
assert "x-session-affinity" in headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_request_payload():
|
||||||
|
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||||
|
|
||||||
|
assert payload["messages"] == messages
|
||||||
|
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
|
assert payload["temperature"] == 0.8
|
||||||
|
assert "tools" not in payload
|
||||||
|
|
||||||
|
# Test with tools
|
||||||
|
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||||
|
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||||
|
assert payload_with_tools["tools"] == tools
|
||||||
|
assert payload_with_tools["tool_choice"] == "auto"
|
||||||
|
|
||||||
|
# Test with reasoning_effort
|
||||||
|
payload_with_reasoning = provider._prepare_request_payload(
|
||||||
|
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||||
|
)
|
||||||
|
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||||
|
assert "temperature" not in payload_with_reasoning
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_request_payload_sanitizes_messages():
|
||||||
|
"""Test Azure payload strips non-standard message keys before sending."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||||
|
"reasoning_content": "hidden chain-of-thought",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "x",
|
||||||
|
"content": "ok",
|
||||||
|
"extra_field": "should be removed",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||||
|
|
||||||
|
assert payload["messages"] == [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "x",
|
||||||
|
"content": "ok",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_success():
|
||||||
|
"""Test successful chat request using model as deployment name."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock response data
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": "Hello! How can I help you today?",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 12,
|
||||||
|
"completion_tokens": 18,
|
||||||
|
"total_tokens": 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
# Test with specific model (deployment name)
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await provider.chat(messages, model="custom-deployment")
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert result.content == "Hello! How can I help you today?"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.usage["prompt_tokens"] == 12
|
||||||
|
assert result.usage["completion_tokens"] == 18
|
||||||
|
assert result.usage["total_tokens"] == 30
|
||||||
|
|
||||||
|
# Verify URL was built with the provided model as deployment name
|
||||||
|
call_args = mock_context.post.call_args
|
||||||
|
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert call_args[0][0] == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_uses_default_model_when_no_model_provided():
|
||||||
|
"""Test that chat uses default_model when no model is specified."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="default-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {"content": "Response", "role": "assistant"},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Test"}]
|
||||||
|
await provider.chat(messages) # No model specified
|
||||||
|
|
||||||
|
# Verify URL was built with default model as deployment name
|
||||||
|
call_args = mock_context.post.call_args
|
||||||
|
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert call_args[0][0] == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_tool_calls():
|
||||||
|
"""Test chat request with tool calls in response."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock response with tool calls
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": None,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_12345",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "San Francisco"}'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 20,
|
||||||
|
"completion_tokens": 15,
|
||||||
|
"total_tokens": 35
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||||
|
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||||
|
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert result.content is None
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_api_error():
|
||||||
|
"""Test chat request API error handling."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
mock_response.text = "Invalid authentication credentials"
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await provider.chat(messages)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert "Azure OpenAI API Error 401" in result.content
|
||||||
|
assert "Invalid authentication credentials" in result.content
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_connection_error():
|
||||||
|
"""Test chat request connection error handling."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await provider.chat(messages)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_malformed():
|
||||||
|
"""Test response parsing with malformed data."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with missing choices
|
||||||
|
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||||
|
result = provider._parse_response(malformed_response)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert "Error parsing Azure OpenAI response" in result.content
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_model():
|
||||||
|
"""Test get_default_model method."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="my-custom-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.get_default_model() == "my-custom-deployment"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run basic tests
|
||||||
|
print("Running basic Azure OpenAI provider tests...")
|
||||||
|
|
||||||
|
# Test initialization
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o-deployment",
|
||||||
|
)
|
||||||
|
print("✅ Provider initialization successful")
|
||||||
|
|
||||||
|
# Test URL building
|
||||||
|
url = provider._build_chat_url("my-deployment")
|
||||||
|
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert url == expected
|
||||||
|
print("✅ URL building works correctly")
|
||||||
|
|
||||||
|
# Test headers
|
||||||
|
headers = provider._build_headers()
|
||||||
|
assert headers["api-key"] == "test-key"
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
print("✅ Header building works correctly")
|
||||||
|
|
||||||
|
# Test payload preparation
|
||||||
|
messages = [{"role": "user", "content": "Test"}]
|
||||||
|
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||||
|
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||||
|
print("✅ Payload preparation works correctly")
|
||||||
|
|
||||||
|
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||||
25
tests/test_base_channel.py
Normal file
25
tests/test_base_channel.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyChannel(BaseChannel):
|
||||||
|
name = "dummy"
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_requires_exact_match() -> None:
|
||||||
|
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
|
||||||
|
|
||||||
|
assert channel.is_allowed("allow@email.com") is True
|
||||||
|
assert channel.is_allowed("attacker|allow@email.com") is False
|
||||||
66
tests/test_dingtalk_channel.py
Normal file
66
tests/test_dingtalk_channel.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.dingtalk import DingTalkChannel
|
||||||
|
from nanobot.config.schema import DingTalkConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_body = json_body or {}
|
||||||
|
self.text = "{}"
|
||||||
|
|
||||||
|
def json(self) -> dict:
|
||||||
|
return self._json_body
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeHttp:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: list[dict] = []
|
||||||
|
|
||||||
|
async def post(self, url: str, json=None, headers=None):
|
||||||
|
self.calls.append({"url": url, "json": json, "headers": headers})
|
||||||
|
return _FakeResponse()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||||
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = DingTalkChannel(config, bus)
|
||||||
|
|
||||||
|
await channel._on_message(
|
||||||
|
"hello",
|
||||||
|
sender_id="user1",
|
||||||
|
sender_name="Alice",
|
||||||
|
conversation_type="2",
|
||||||
|
conversation_id="conv123",
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await bus.consume_inbound()
|
||||||
|
assert msg.sender_id == "user1"
|
||||||
|
assert msg.chat_id == "group:conv123"
|
||||||
|
assert msg.metadata["conversation_type"] == "2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_send_uses_group_messages_api() -> None:
|
||||||
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||||
|
channel = DingTalkChannel(config, MessageBus())
|
||||||
|
channel._http = _FakeHttp()
|
||||||
|
|
||||||
|
ok = await channel._send_batch_message(
|
||||||
|
"token",
|
||||||
|
"group:conv123",
|
||||||
|
"sampleMarkdown",
|
||||||
|
{"text": "hello", "title": "Nanobot Reply"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
call = channel._http.calls[0]
|
||||||
|
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||||
|
assert call["json"]["openConversationId"] == "conv123"
|
||||||
|
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from nanobot.channels.feishu import _extract_post_content
|
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
|
||||||
|
|
||||||
|
|
||||||
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||||
@@ -38,3 +38,28 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
|||||||
|
|
||||||
assert text == "Daily report"
|
assert text == "Daily report"
|
||||||
assert image_keys == ["img_a", "img_b"]
|
assert image_keys == ["img_a", "img_b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
|
||||||
|
class Builder:
|
||||||
|
pass
|
||||||
|
|
||||||
|
builder = Builder()
|
||||||
|
same = FeishuChannel._register_optional_event(builder, "missing", object())
|
||||||
|
assert same is builder
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_optional_event_calls_supported_method() -> None:
|
||||||
|
called = []
|
||||||
|
|
||||||
|
class Builder:
|
||||||
|
def register_event(self, handler):
|
||||||
|
called.append(handler)
|
||||||
|
return self
|
||||||
|
|
||||||
|
builder = Builder()
|
||||||
|
handler = object()
|
||||||
|
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
|
||||||
|
|
||||||
|
assert same is builder
|
||||||
|
assert called == [handler]
|
||||||
|
|||||||
@@ -145,3 +145,78 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
provider.chat.assert_not_called()
|
provider.chat.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
|
||||||
|
"""Some providers return arguments as a list - extract first element if it's a dict."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
# Simulate arguments being a list containing a dict
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=[{
|
||||||
|
"history_entry": "[2026-01-01] User discussed testing.",
|
||||||
|
"memory_update": "# Memory\nUser likes testing.",
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
|
assert "User likes testing." in store.memory_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
|
||||||
|
"""Empty list arguments should return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=[],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
|
||||||
|
"""List with non-dict content should return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=["string", "content"],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|||||||
@@ -86,6 +86,35 @@ class TestMessageToolSuppressLogic:
|
|||||||
assert result is not None
|
assert result is not None
|
||||||
assert "Hello" in result.content
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(
|
||||||
|
content="Visible<think>hidden</think>",
|
||||||
|
tool_calls=[tool_call],
|
||||||
|
reasoning_content="secret reasoning",
|
||||||
|
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
|
||||||
|
),
|
||||||
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
|
|
||||||
|
progress: list[tuple[str, bool]] = []
|
||||||
|
|
||||||
|
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
|
progress.append((content, tool_hint))
|
||||||
|
|
||||||
|
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||||
|
|
||||||
|
assert final_content == "Done"
|
||||||
|
assert progress == [
|
||||||
|
("Visible", False),
|
||||||
|
('read_file("foo.txt")', True),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestMessageToolTurnTracking:
|
class TestMessageToolTurnTracking:
|
||||||
|
|
||||||
|
|||||||
66
tests/test_qq_channel.py
Normal file
66
tests/test_qq_channel.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.qq import QQChannel
|
||||||
|
from nanobot.config.schema import QQConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeApi:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.c2c_calls: list[dict] = []
|
||||||
|
self.group_calls: list[dict] = []
|
||||||
|
|
||||||
|
async def post_c2c_message(self, **kwargs) -> None:
|
||||||
|
self.c2c_calls.append(kwargs)
|
||||||
|
|
||||||
|
async def post_group_message(self, **kwargs) -> None:
|
||||||
|
self.group_calls.append(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.api = _FakeApi()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_group_message_routes_to_group_chat_id() -> None:
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
|
||||||
|
|
||||||
|
data = SimpleNamespace(
|
||||||
|
id="msg1",
|
||||||
|
content="hello",
|
||||||
|
group_openid="group123",
|
||||||
|
author=SimpleNamespace(member_openid="user1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._on_message(data, is_group=True)
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.sender_id == "user1"
|
||||||
|
assert msg.chat_id == "group123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
channel._chat_type_cache["group123"] = "group"
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="qq",
|
||||||
|
chat_id="group123",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": "msg1"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(channel._client.api.group_calls) == 1
|
||||||
|
call = channel._client.api.group_calls[0]
|
||||||
|
assert call["group_openid"] == "group123"
|
||||||
|
assert call["msg_id"] == "msg1"
|
||||||
|
assert call["msg_seq"] == 2
|
||||||
|
assert not channel._client.api.c2c_calls
|
||||||
184
tests/test_telegram_channel.py
Normal file
184
tests/test_telegram_channel.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.telegram import TelegramChannel
|
||||||
|
from nanobot.config.schema import TelegramConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeHTTPXRequest:
|
||||||
|
instances: list["_FakeHTTPXRequest"] = []
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.__class__.instances.append(self)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeUpdater:
|
||||||
|
def __init__(self, on_start_polling) -> None:
|
||||||
|
self._on_start_polling = on_start_polling
|
||||||
|
|
||||||
|
async def start_polling(self, **kwargs) -> None:
|
||||||
|
self._on_start_polling()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeBot:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent_messages: list[dict] = []
|
||||||
|
|
||||||
|
async def get_me(self):
|
||||||
|
return SimpleNamespace(username="nanobot_test")
|
||||||
|
|
||||||
|
async def set_my_commands(self, commands) -> None:
|
||||||
|
self.commands = commands
|
||||||
|
|
||||||
|
async def send_message(self, **kwargs) -> None:
|
||||||
|
self.sent_messages.append(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeApp:
|
||||||
|
def __init__(self, on_start_polling) -> None:
|
||||||
|
self.bot = _FakeBot()
|
||||||
|
self.updater = _FakeUpdater(on_start_polling)
|
||||||
|
self.handlers = []
|
||||||
|
self.error_handlers = []
|
||||||
|
|
||||||
|
def add_error_handler(self, handler) -> None:
|
||||||
|
self.error_handlers.append(handler)
|
||||||
|
|
||||||
|
def add_handler(self, handler) -> None:
|
||||||
|
self.handlers.append(handler)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeBuilder:
|
||||||
|
def __init__(self, app: _FakeApp) -> None:
|
||||||
|
self.app = app
|
||||||
|
self.token_value = None
|
||||||
|
self.request_value = None
|
||||||
|
self.get_updates_request_value = None
|
||||||
|
|
||||||
|
def token(self, token: str):
|
||||||
|
self.token_value = token
|
||||||
|
return self
|
||||||
|
|
||||||
|
def request(self, request):
|
||||||
|
self.request_value = request
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_updates_request(self, request):
|
||||||
|
self.get_updates_request_value = request
|
||||||
|
return self
|
||||||
|
|
||||||
|
def proxy(self, _proxy):
|
||||||
|
raise AssertionError("builder.proxy should not be called when request is set")
|
||||||
|
|
||||||
|
def get_updates_proxy(self, _proxy):
|
||||||
|
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
return self.app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
||||||
|
config = TelegramConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="123:abc",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
)
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = TelegramChannel(config, bus)
|
||||||
|
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||||
|
builder = _FakeBuilder(app)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.telegram.Application",
|
||||||
|
SimpleNamespace(builder=lambda: builder),
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert len(_FakeHTTPXRequest.instances) == 1
|
||||||
|
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
|
||||||
|
assert builder.request_value is _FakeHTTPXRequest.instances[0]
|
||||||
|
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||||
|
message = SimpleNamespace(
|
||||||
|
chat=SimpleNamespace(type="supergroup"),
|
||||||
|
chat_id=-100123,
|
||||||
|
message_thread_id=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_extension_falls_back_to_original_filename() -> None:
|
||||||
|
channel = TelegramChannel(TelegramConfig(), MessageBus())
|
||||||
|
|
||||||
|
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
|
||||||
|
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
||||||
|
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
||||||
|
|
||||||
|
assert channel.is_allowed("12345|carol") is True
|
||||||
|
assert channel.is_allowed("99999|alice") is True
|
||||||
|
assert channel.is_allowed("67890|bob") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
|
||||||
|
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
|
||||||
|
|
||||||
|
assert channel.is_allowed("attacker|alice|extra") is False
|
||||||
|
assert channel.is_allowed("not-a-number|alice") is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||||
|
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||||
|
channel = TelegramChannel(config, MessageBus())
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="123",
|
||||||
|
content="hello",
|
||||||
|
metadata={"_progress": True, "message_thread_id": 42},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||||
|
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||||
|
channel = TelegramChannel(config, MessageBus())
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
channel._message_threads[("123", 10)] = 42
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="123",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": 10},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||||
|
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||||
@@ -106,3 +106,234 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
|||||||
paths = ExecTool._extract_absolute_paths(cmd)
|
paths = ExecTool._extract_absolute_paths(cmd)
|
||||||
assert "/tmp/data.txt" in paths
|
assert "/tmp/data.txt" in paths
|
||||||
assert "/tmp/out.txt" in paths
|
assert "/tmp/out.txt" in paths
|
||||||
|
|
||||||
|
|
||||||
|
# --- cast_params tests ---
|
||||||
|
|
||||||
|
|
||||||
|
class CastTestTool(Tool):
|
||||||
|
"""Minimal tool for testing cast_params."""
|
||||||
|
|
||||||
|
def __init__(self, schema: dict[str, Any]) -> None:
|
||||||
|
self._schema = schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "cast_test"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "test tool for casting"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return self._schema
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_string_to_int() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"count": {"type": "integer"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"count": "42"})
|
||||||
|
assert result["count"] == 42
|
||||||
|
assert isinstance(result["count"], int)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_string_to_number() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"rate": {"type": "number"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"rate": "3.14"})
|
||||||
|
assert result["rate"] == 3.14
|
||||||
|
assert isinstance(result["rate"], float)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_string_to_bool() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"enabled": {"type": "boolean"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||||
|
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||||
|
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_array_items() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||||
|
assert result["nums"] == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_nested_object() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"port": {"type": "integer"},
|
||||||
|
"debug": {"type": "boolean"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||||
|
assert result["config"]["port"] == 8080
|
||||||
|
assert result["config"]["debug"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||||
|
"""Booleans should not be silently cast to integers."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"count": {"type": "integer"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"count": True})
|
||||||
|
assert result["count"] is True
|
||||||
|
errors = tool.validate_params(result)
|
||||||
|
assert any("count should be integer" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_preserves_empty_string() -> None:
|
||||||
|
"""Empty strings should be preserved for string type."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"type": "string"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"name": ""})
|
||||||
|
assert result["name"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_bool_string_false() -> None:
|
||||||
|
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"flag": {"type": "boolean"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_bool_string_invalid() -> None:
|
||||||
|
"""Invalid boolean strings should not be cast."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"flag": {"type": "boolean"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Invalid strings should be preserved (validation will catch them)
|
||||||
|
result = tool.cast_params({"flag": "random"})
|
||||||
|
assert result["flag"] == "random"
|
||||||
|
result = tool.cast_params({"flag": "maybe"})
|
||||||
|
assert result["flag"] == "maybe"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_invalid_string_to_int() -> None:
|
||||||
|
"""Invalid strings should not be cast to integer."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"count": {"type": "integer"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"count": "abc"})
|
||||||
|
assert result["count"] == "abc" # Original value preserved
|
||||||
|
result = tool.cast_params({"count": "12.5.7"})
|
||||||
|
assert result["count"] == "12.5.7"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_invalid_string_to_number() -> None:
|
||||||
|
"""Invalid strings should not be cast to number."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"rate": {"type": "number"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"rate": "not_a_number"})
|
||||||
|
assert result["rate"] == "not_a_number"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||||
|
"""Booleans should not pass number validation."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"rate": {"type": "number"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
errors = tool.validate_params({"rate": False})
|
||||||
|
assert any("rate should be number" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_none_values() -> None:
|
||||||
|
"""Test None handling for different types."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"count": {"type": "integer"},
|
||||||
|
"items": {"type": "array"},
|
||||||
|
"config": {"type": "object"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params(
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"count": None,
|
||||||
|
"items": None,
|
||||||
|
"config": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# None should be preserved for all types
|
||||||
|
assert result["name"] is None
|
||||||
|
assert result["count"] is None
|
||||||
|
assert result["items"] is None
|
||||||
|
assert result["config"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||||
|
"""Single values should NOT be automatically wrapped into arrays."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"items": {"type": "array"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Non-array values should be preserved (validation will catch them)
|
||||||
|
result = tool.cast_params({"items": 5})
|
||||||
|
assert result["items"] == 5 # Not wrapped to [5]
|
||||||
|
result = tool.cast_params({"items": "text"})
|
||||||
|
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||||
|
|||||||
Reference in New Issue
Block a user