diff --git a/README.md b/README.md index 86869a2..bc11cc8 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,20 @@ ## πŸ“’ News +- **2026-03-07** πŸš€ Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish. +- **2026-03-06** πŸͺ„ Lighter providers, smarter media handling, and sturdier memory and CLI compatibility. +- **2026-03-05** ⚑️ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes. +- **2026-03-04** πŸ› οΈ Dependency cleanup, safer file reads, and another round of test and Cron fixes. +- **2026-03-03** 🧠 Cleaner user-message merging, safer multimodal saves, and stronger Cron guards. +- **2026-03-02** πŸ›‘οΈ Safer default access control, sturdier Cron reloads, and cleaner Matrix media handling. +- **2026-03-01** 🌐 Web proxy support, smarter Cron reminders, and Feishu rich-text parsing improvements. - **2026-02-28** πŸš€ Released **v0.1.4.post3** β€” cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details. - **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes. - **2026-02-26** πŸ›‘οΈ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility. + +
+Earlier news + - **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync. - **2026-02-24** πŸš€ Released **v0.1.4.post2** β€” a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details. - **2026-02-23** πŸ”§ Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes. @@ -30,10 +41,6 @@ - **2026-02-21** πŸŽ‰ Released **v0.1.4.post1** β€” new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details. - **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood. - **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode. - -
-Earlier news - - **2026-02-18** ⚑️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching. - **2026-02-17** πŸŽ‰ Released **v0.1.4** β€” MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details. - **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill β€” search and install public agent skills. @@ -420,6 +427,10 @@ nanobot channels login nanobot gateway ``` +> WhatsApp bridge updates are not applied automatically for existing installations. +> If you upgrade nanobot and need the latest WhatsApp bridge, run: +> `rm -rf ~/.nanobot/bridge && nanobot channels login` +
@@ -671,6 +682,7 @@ Config file: `~/.nanobot/config.json` | `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | β€” | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | +| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | | `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | @@ -894,31 +906,26 @@ MCP tools are automatically discovered and registered on startup. The LLM can us | `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. | -## Multiple Instances +## 🧩 Multiple Instances -Run multiple nanobot instances simultaneously, each with its own workspace and configuration. +Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint, and optionally use `--workspace` to override the workspace for a specific run. + +### Quick Start ```bash # Instance A - Telegram bot -nanobot gateway -w ~/.nanobot/botA -p 18791 +nanobot gateway --config ~/.nanobot-telegram/config.json -# Instance B - Discord bot -nanobot gateway -w ~/.nanobot/botB -p 18792 +# Instance B - Discord bot +nanobot gateway --config ~/.nanobot-discord/config.json -# Instance C - Using custom config file -nanobot gateway -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json -p 18793 +# Instance C - Feishu bot with custom port +nanobot gateway --config ~/.nanobot-feishu/config.json --port 18792 ``` -| Option | Short | Description | -|--------|-------|-------------| -| `--workspace` | `-w` | Workspace directory (default: `~/.nanobot/workspace`) | -| `--config` | `-c` | Config file path (default: `~/.nanobot/config.json`) | -| `--port` | `-p` | Gateway port (default: `18790`) | +### Path Resolution -Each instance has its own: -- Workspace directory (MEMORY.md, HEARTBEAT.md, session files) -- Cron jobs storage (`workspace/cron/jobs.json`) -- Configuration (if using `--config`) +When using `--config`, nanobot derives its runtime data directory from the config file location. The workspace still comes from `agents.defaults.workspace` unless you override it with `--workspace`. To open a CLI session against one of these instances locally: @@ -928,9 +935,75 @@ nanobot agent -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json ``` > `nanobot agent` starts a local CLI agent using the selected workspace/config. It does not attach to or proxy through an already running `nanobot gateway` process. +| Component | Resolved From | Example | +|-----------|---------------|---------| +| **Config** | `--config` path | `~/.nanobot-A/config.json` | +| **Workspace** | `--workspace` or config | `~/.nanobot-A/workspace/` | +| **Cron Jobs** | config directory | `~/.nanobot-A/cron/` | +| **Media / runtime state** | config directory | `~/.nanobot-A/media/` | +### How It Works -## CLI Reference +- `--config` selects which config file to load +- By default, the workspace comes from `agents.defaults.workspace` in that config +- If you pass `--workspace`, it overrides the workspace from the config file + +### Minimal Setup + +1. Copy your base config into a new instance directory. +2. Set a different `agents.defaults.workspace` for that instance. +3. Start the instance with `--config`. + +Example config: + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.nanobot-telegram/workspace", + "model": "anthropic/claude-sonnet-4-6" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_TELEGRAM_BOT_TOKEN" + } + }, + "gateway": { + "port": 18790 + } +} +``` + +Start separate instances: + +```bash +nanobot gateway --config ~/.nanobot-telegram/config.json +nanobot gateway --config ~/.nanobot-discord/config.json +``` + +Override workspace for one-off runs when needed: + +```bash +nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobot-telegram-test +``` + +### Common Use Cases + +- Run separate bots for Telegram, Discord, Feishu, and other platforms +- Keep testing and production instances isolated +- Use different models or providers for different teams +- Serve multiple tenants with separate configs and runtime data + +### Notes + +- Each instance must use a different port if they run at the same time +- Use a different workspace per instance if you want isolated memory, sessions, and skills +- `--workspace` overrides the workspace defined in the config file +- Cron jobs and runtime media/state are derived from the config directory + +## πŸ’» CLI Reference | Command | Description | |---------|-------------| diff --git a/bridge/src/whatsapp.ts b/bridge/src/whatsapp.ts index 069d72b..f0485bd 100644 --- a/bridge/src/whatsapp.ts +++ b/bridge/src/whatsapp.ts @@ -9,11 +9,16 @@ import makeWASocket, { useMultiFileAuthState, fetchLatestBaileysVersion, makeCacheableSignalKeyStore, + downloadMediaMessage, + extractMessageContent as baileysExtractMessageContent, } from '@whiskeysockets/baileys'; import { Boom } from '@hapi/boom'; import qrcode from 'qrcode-terminal'; import pino from 'pino'; +import { writeFile, mkdir } from 'fs/promises'; +import { join } from 'path'; +import { randomBytes } from 'crypto'; const VERSION = '0.1.0'; @@ -24,6 +29,7 @@ export interface InboundMessage { content: string; timestamp: number; isGroup: boolean; + media?: string[]; } export interface WhatsAppClientOptions { @@ -110,14 +116,33 @@ export class WhatsAppClient { if (type !== 'notify') return; for (const msg of messages) { - // Skip own messages if (msg.key.fromMe) continue; - - // Skip status updates if (msg.key.remoteJid === 'status@broadcast') continue; - const content = this.extractMessageContent(msg); - if (!content) continue; + const unwrapped = baileysExtractMessageContent(msg.message); + if (!unwrapped) continue; + + const content = this.getTextContent(unwrapped); + let fallbackContent: string | null = null; + const mediaPaths: string[] = []; + + if (unwrapped.imageMessage) { + fallbackContent = '[Image]'; + const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined); + if (path) mediaPaths.push(path); + } else if (unwrapped.documentMessage) { + fallbackContent = '[Document]'; + const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined, + unwrapped.documentMessage.fileName ?? undefined); + if (path) mediaPaths.push(path); + } else if (unwrapped.videoMessage) { + fallbackContent = '[Video]'; + const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined); + if (path) mediaPaths.push(path); + } + + const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || ''; + if (!finalContent && mediaPaths.length === 0) continue; const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false; @@ -125,18 +150,45 @@ export class WhatsAppClient { id: msg.key.id || '', sender: msg.key.remoteJid || '', pn: msg.key.remoteJidAlt || '', - content, + content: finalContent, timestamp: msg.messageTimestamp as number, isGroup, + ...(mediaPaths.length > 0 ? { media: mediaPaths } : {}), }); } }); } - private extractMessageContent(msg: any): string | null { - const message = msg.message; - if (!message) return null; + private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise { + try { + const mediaDir = join(this.options.authDir, '..', 'media'); + await mkdir(mediaDir, { recursive: true }); + const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer; + + let outFilename: string; + if (fileName) { + // Documents have a filename β€” use it with a unique prefix to avoid collisions + const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`; + outFilename = prefix + fileName; + } else { + const mime = mimetype || 'application/octet-stream'; + // Derive extension from mimetype subtype (e.g. "image/png" β†’ ".png", "application/pdf" β†’ ".pdf") + const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin'); + outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`; + } + + const filepath = join(mediaDir, outFilename); + await writeFile(filepath, buffer); + + return filepath; + } catch (err) { + console.error('Failed to download media:', err); + return null; + } + } + + private getTextContent(message: any): string | null { // Text message if (message.conversation) { return message.conversation; @@ -147,19 +199,19 @@ export class WhatsAppClient { return message.extendedTextMessage.text; } - // Image with caption - if (message.imageMessage?.caption) { - return `[Image] ${message.imageMessage.caption}`; + // Image with optional caption + if (message.imageMessage) { + return message.imageMessage.caption || ''; } - // Video with caption - if (message.videoMessage?.caption) { - return `[Video] ${message.videoMessage.caption}`; + // Video with optional caption + if (message.videoMessage) { + return message.videoMessage.caption || ''; } - // Document with caption - if (message.documentMessage?.caption) { - return `[Document] ${message.documentMessage.caption}`; + // Document with optional caption + if (message.documentMessage) { + return message.documentMessage.caption || ''; } // Voice/Audio message diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 7f129a2..ca9a06e 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -202,18 +202,9 @@ class AgentLoop: if response.has_tool_calls: if on_progress: - thoughts = [ - self._strip_think(response.content), - response.reasoning_content, - *( - f"Thinking [{b.get('signature', '...')}]:\n{b.get('thought', '...')}" - for b in (response.thinking_blocks or []) - if isinstance(b, dict) and "signature" in b - ), - ] - combined_thoughts = "\n\n".join(filter(None, thoughts)) - if combined_thoughts: - await on_progress(combined_thoughts) + thought = self._strip_think(response.content) + if thought: + await on_progress(thought) await on_progress(self._tool_hint(response.tool_calls), tool_hint=True) tool_call_dicts = [ diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 051fc9a..06f5bdd 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -52,6 +52,75 @@ class Tool(ABC): """ pass + def cast_params(self, params: dict[str, Any]) -> dict[str, Any]: + """Apply safe schema-driven casts before validation.""" + schema = self.parameters or {} + if schema.get("type", "object") != "object": + return params + + return self._cast_object(params, schema) + + def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]: + """Cast an object (dict) according to schema.""" + if not isinstance(obj, dict): + return obj + + props = schema.get("properties", {}) + result = {} + + for key, value in obj.items(): + if key in props: + result[key] = self._cast_value(value, props[key]) + else: + result[key] = value + + return result + + def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any: + """Cast a single value according to schema.""" + target_type = schema.get("type") + + if target_type == "boolean" and isinstance(val, bool): + return val + if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool): + return val + if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"): + expected = self._TYPE_MAP[target_type] + if isinstance(val, expected): + return val + + if target_type == "integer" and isinstance(val, str): + try: + return int(val) + except ValueError: + return val + + if target_type == "number" and isinstance(val, str): + try: + return float(val) + except ValueError: + return val + + if target_type == "string": + return val if val is None else str(val) + + if target_type == "boolean" and isinstance(val, str): + val_lower = val.lower() + if val_lower in ("true", "1", "yes"): + return True + if val_lower in ("false", "0", "no"): + return False + return val + + if target_type == "array" and isinstance(val, list): + item_schema = schema.get("items") + return [self._cast_value(item, item_schema) for item in val] if item_schema else val + + if target_type == "object" and isinstance(val, dict): + return self._cast_object(val, schema) + + return val + def validate_params(self, params: dict[str, Any]) -> list[str]: """Validate tool parameters against JSON schema. Returns error list (empty if valid).""" if not isinstance(params, dict): @@ -63,7 +132,13 @@ class Tool(ABC): def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]: t, label = schema.get("type"), path or "parameter" - if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]): + if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)): + return [f"{label} should be integer"] + if t == "number" and ( + not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool) + ): + return [f"{label} should be number"] + if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]): return [f"{label} should be {t}"] errors = [] diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 35e519a..0a52427 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -96,7 +96,7 @@ class MessageTool(Tool): media=media or [], metadata={ "message_id": message_id, - } + }, ) try: diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 5d36e52..896491f 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -44,6 +44,10 @@ class ToolRegistry: return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" try: + # Attempt to cast parameters to match schema types + params = tool.cast_params(params) + + # Validate parameters errors = tool.validate_params(params) if errors: return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index b38fcaf..dc53ba4 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -66,10 +66,7 @@ class BaseChannel(ABC): return False if "*" in allow_list: return True - sender_str = str(sender_id) - return sender_str in allow_list or any( - p in allow_list for p in sender_str.split("|") if p - ) + return str(sender_id) in allow_list async def _handle_message( self, diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 8d02fa6..3c301a9 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -70,12 +70,24 @@ class NanobotDingTalkHandler(CallbackHandler): sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id sender_name = chatbot_msg.sender_nick or "Unknown" + conversation_type = message.data.get("conversationType") + conversation_id = ( + message.data.get("conversationId") + or message.data.get("openConversationId") + ) + logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content) # Forward to Nanobot via _on_message (non-blocking). # Store reference to prevent GC before task completes. task = asyncio.create_task( - self.channel._on_message(content, sender_id, sender_name) + self.channel._on_message( + content, + sender_id, + sender_name, + conversation_type, + conversation_id, + ) ) self.channel._background_tasks.add(task) task.add_done_callback(self.channel._background_tasks.discard) @@ -95,8 +107,8 @@ class DingTalkChannel(BaseChannel): Uses WebSocket to receive events via `dingtalk-stream` SDK. Uses direct HTTP API to send messages (SDK is mainly for receiving). - Note: Currently only supports private (1:1) chat. Group messages are - received but replies are sent back as private messages to the sender. + Supports both private (1:1) and group chats. + Group chat_id is stored with a "group:" prefix to route replies back. """ name = "dingtalk" @@ -301,14 +313,25 @@ class DingTalkChannel(BaseChannel): logger.warning("DingTalk HTTP client not initialized, cannot send") return False - url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend" headers = {"x-acs-dingtalk-access-token": token} - payload = { - "robotCode": self.config.client_id, - "userIds": [chat_id], - "msgKey": msg_key, - "msgParam": json.dumps(msg_param, ensure_ascii=False), - } + if chat_id.startswith("group:"): + # Group chat + url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send" + payload = { + "robotCode": self.config.client_id, + "openConversationId": chat_id[6:], # Remove "group:" prefix, + "msgKey": msg_key, + "msgParam": json.dumps(msg_param, ensure_ascii=False), + } + else: + # Private chat + url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend" + payload = { + "robotCode": self.config.client_id, + "userIds": [chat_id], + "msgKey": msg_key, + "msgParam": json.dumps(msg_param, ensure_ascii=False), + } try: resp = await self._http.post(url, json=payload, headers=headers) @@ -417,7 +440,14 @@ class DingTalkChannel(BaseChannel): f"[Attachment send failed: {filename}]", ) - async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None: + async def _on_message( + self, + content: str, + sender_id: str, + sender_name: str, + conversation_type: str | None = None, + conversation_id: str | None = None, + ) -> None: """Handle incoming message (called by NanobotDingTalkHandler). Delegates to BaseChannel._handle_message() which enforces allow_from @@ -425,13 +455,16 @@ class DingTalkChannel(BaseChannel): """ try: logger.info("DingTalk inbound: {} from {}", content, sender_name) + is_group = conversation_type == "2" and conversation_id + chat_id = f"group:{conversation_id}" if is_group else sender_id await self._handle_message( sender_id=sender_id, - chat_id=sender_id, # For private chat, chat_id == sender_id + chat_id=chat_id, content=str(content), metadata={ "sender_name": sender_name, "platform": "dingtalk", + "conversation_type": conversation_type, }, ) except Exception as e: diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index c868bbf..2ee4f77 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -12,6 +12,7 @@ from loguru import logger from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir from nanobot.config.schema import DiscordConfig from nanobot.utils.helpers import split_message @@ -75,7 +76,7 @@ class DiscordChannel(BaseChannel): self._http = None async def send(self, msg: OutboundMessage) -> None: - """Send a message through Discord REST API.""" + """Send a message through Discord REST API, including file attachments.""" if not self._http: logger.warning("Discord HTTP client not initialized") return @@ -84,15 +85,31 @@ class DiscordChannel(BaseChannel): headers = {"Authorization": f"Bot {self.config.token}"} try: + sent_media = False + failed_media: list[str] = [] + + # Send file attachments first + for media_path in msg.media or []: + if await self._send_file(url, headers, media_path, reply_to=msg.reply_to): + sent_media = True + else: + failed_media.append(Path(media_path).name) + + # Send text content chunks = split_message(msg.content or "", MAX_MESSAGE_LEN) + if not chunks and failed_media and not sent_media: + chunks = split_message( + "\n".join(f"[attachment: {name} - send failed]" for name in failed_media), + MAX_MESSAGE_LEN, + ) if not chunks: return for i, chunk in enumerate(chunks): payload: dict[str, Any] = {"content": chunk} - # Only set reply reference on the first chunk - if i == 0 and msg.reply_to: + # Let the first successful attachment carry the reply if present. + if i == 0 and msg.reply_to and not sent_media: payload["message_reference"] = {"message_id": msg.reply_to} payload["allowed_mentions"] = {"replied_user": False} @@ -123,6 +140,54 @@ class DiscordChannel(BaseChannel): await asyncio.sleep(1) return False + async def _send_file( + self, + url: str, + headers: dict[str, str], + file_path: str, + reply_to: str | None = None, + ) -> bool: + """Send a file attachment via Discord REST API using multipart/form-data.""" + path = Path(file_path) + if not path.is_file(): + logger.warning("Discord file not found, skipping: {}", file_path) + return False + + if path.stat().st_size > MAX_ATTACHMENT_BYTES: + logger.warning("Discord file too large (>20MB), skipping: {}", path.name) + return False + + payload_json: dict[str, Any] = {} + if reply_to: + payload_json["message_reference"] = {"message_id": reply_to} + payload_json["allowed_mentions"] = {"replied_user": False} + + for attempt in range(3): + try: + with open(path, "rb") as f: + files = {"files[0]": (path.name, f, "application/octet-stream")} + data: dict[str, Any] = {} + if payload_json: + data["payload_json"] = json.dumps(payload_json) + response = await self._http.post( + url, headers=headers, files=files, data=data + ) + if response.status_code == 429: + resp_data = response.json() + retry_after = float(resp_data.get("retry_after", 1.0)) + logger.warning("Discord rate limited, retrying in {}s", retry_after) + await asyncio.sleep(retry_after) + continue + response.raise_for_status() + logger.info("Discord file sent: {}", path.name) + return True + except Exception as e: + if attempt == 2: + logger.error("Error sending Discord file {}: {}", path.name, e) + else: + await asyncio.sleep(1) + return False + async def _gateway_loop(self) -> None: """Main gateway loop: identify, heartbeat, dispatch events.""" if not self._ws: @@ -225,7 +290,7 @@ class DiscordChannel(BaseChannel): content_parts = [content] if content else [] media_paths: list[str] = [] - media_dir = Path.home() / ".nanobot" / "media" + media_dir = get_media_dir("discord") for attachment in payload.get("attachments") or []: url = attachment.get("url") diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 8f69c09..a637025 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -14,6 +14,7 @@ from loguru import logger from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir from nanobot.config.schema import FeishuConfig import importlib.util @@ -244,15 +245,22 @@ class FeishuChannel(BaseChannel): name = "feishu" - def __init__(self, config: FeishuConfig, bus: MessageBus): + def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""): super().__init__(config, bus) self.config: FeishuConfig = config + self.groq_api_key = groq_api_key self._client: Any = None self._ws_client: Any = None self._ws_thread: threading.Thread | None = None self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache self._loop: asyncio.AbstractEventLoop | None = None + @staticmethod + def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any: + """Register an event handler only when the SDK supports it.""" + method = getattr(builder, method_name, None) + return method(handler) if callable(method) else builder + async def start(self) -> None: """Start the Feishu bot with WebSocket long connection.""" if not FEISHU_AVAILABLE: @@ -273,14 +281,24 @@ class FeishuChannel(BaseChannel): .app_secret(self.config.app_secret) \ .log_level(lark.LogLevel.INFO) \ .build() - - # Create event handler (only register message receive, ignore other events) - event_handler = lark.EventDispatcherHandler.builder( + builder = lark.EventDispatcherHandler.builder( self.config.encrypt_key or "", self.config.verification_token or "", ).register_p2_im_message_receive_v1( self._on_message_sync - ).build() + ) + builder = self._register_optional_event( + builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created + ) + builder = self._register_optional_event( + builder, "register_p2_im_message_message_read_v1", self._on_message_read + ) + builder = self._register_optional_event( + builder, + "register_p2_im_chat_access_event_bot_p2p_chat_entered_v1", + self._on_bot_p2p_chat_entered, + ) + event_handler = builder.build() # Create WebSocket client for long connection self._ws_client = lark.ws.Client( @@ -715,8 +733,7 @@ class FeishuChannel(BaseChannel): (file_path, content_text) - file_path is None if download failed """ loop = asyncio.get_running_loop() - media_dir = Path.home() / ".nanobot" / "media" - media_dir.mkdir(parents=True, exist_ok=True) + media_dir = get_media_dir("feishu") data, filename = None, None @@ -841,7 +858,7 @@ class FeishuChannel(BaseChannel): except Exception as e: logger.error("Error sending Feishu message: {}", e) - def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None: + def _on_message_sync(self, data: Any) -> None: """ Sync handler for incoming messages (called from WebSocket thread). Schedules async handling in the main event loop. @@ -849,7 +866,7 @@ class FeishuChannel(BaseChannel): if self._loop and self._loop.is_running(): asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop) - async def _on_message(self, data: "P2ImMessageReceiveV1") -> None: + async def _on_message(self, data: Any) -> None: """Handle incoming message from Feishu.""" try: event = data.event @@ -909,6 +926,18 @@ class FeishuChannel(BaseChannel): file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id) if file_path: media_paths.append(file_path) + + # Transcribe audio using Groq Whisper + if msg_type == "audio" and file_path and self.groq_api_key: + try: + from nanobot.providers.transcription import GroqTranscriptionProvider + transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key) + transcription = await transcriber.transcribe(file_path) + if transcription: + content_text = f"[transcription: {transcription}]" + except Exception as e: + logger.warning("Failed to transcribe audio: {}", e) + content_parts.append(content_text) elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"): @@ -941,3 +970,16 @@ class FeishuChannel(BaseChannel): except Exception as e: logger.error("Error processing Feishu message: {}", e) + + def _on_reaction_created(self, data: Any) -> None: + """Ignore reaction events so they do not generate SDK noise.""" + pass + + def _on_message_read(self, data: Any) -> None: + """Ignore read events so they do not generate SDK noise.""" + pass + + def _on_bot_p2p_chat_entered(self, data: Any) -> None: + """Ignore p2p-enter events when a user opens a bot chat.""" + logger.debug("Bot entered p2p chat (user opened chat window)") + pass diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 7d7d110..51539dd 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -74,7 +74,8 @@ class ChannelManager: try: from nanobot.channels.feishu import FeishuChannel self.channels["feishu"] = FeishuChannel( - self.config.channels.feishu, self.bus + self.config.channels.feishu, self.bus, + groq_api_key=self.config.providers.groq.api_key, ) logger.info("Feishu channel enabled") except ImportError as e: diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index 4967ac1..63cb0ca 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -38,7 +38,7 @@ except ImportError as e: from nanobot.bus.events import OutboundMessage from nanobot.channels.base import BaseChannel -from nanobot.config.loader import get_data_dir +from nanobot.config.paths import get_data_dir, get_media_dir from nanobot.utils.helpers import safe_filename TYPING_NOTICE_TIMEOUT_MS = 30_000 @@ -490,9 +490,7 @@ class MatrixChannel(BaseChannel): return False def _media_dir(self) -> Path: - d = get_data_dir() / "media" / "matrix" - d.mkdir(parents=True, exist_ok=True) - return d + return get_media_dir("matrix") @staticmethod def _event_source_content(event: RoomMessage) -> dict[str, Any]: diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py index e762dfd..09e31c3 100644 --- a/nanobot/channels/mochat.py +++ b/nanobot/channels/mochat.py @@ -15,8 +15,8 @@ from loguru import logger from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_runtime_subdir from nanobot.config.schema import MochatConfig -from nanobot.utils.helpers import get_data_path try: import socketio @@ -224,7 +224,7 @@ class MochatChannel(BaseChannel): self._socket: Any = None self._ws_connected = self._ws_ready = False - self._state_dir = get_data_path() / "mochat" + self._state_dir = get_runtime_subdir("mochat") self._cursor_path = self._state_dir / "session_cursors.json" self._session_cursor: dict[str, int] = {} self._cursor_save_task: asyncio.Task | None = None diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 6c58049..4809fd3 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -13,16 +13,17 @@ from nanobot.config.schema import QQConfig try: import botpy - from botpy.message import C2CMessage + from botpy.message import C2CMessage, GroupMessage QQ_AVAILABLE = True except ImportError: QQ_AVAILABLE = False botpy = None C2CMessage = None + GroupMessage = None if TYPE_CHECKING: - from botpy.message import C2CMessage + from botpy.message import C2CMessage, GroupMessage def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": @@ -38,10 +39,13 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": logger.info("QQ bot ready: {}", self.robot.name) async def on_c2c_message_create(self, message: "C2CMessage"): - await channel._on_message(message) + await channel._on_message(message, is_group=False) + + async def on_group_at_message_create(self, message: "GroupMessage"): + await channel._on_message(message, is_group=True) async def on_direct_message_create(self, message): - await channel._on_message(message) + await channel._on_message(message, is_group=False) return _Bot @@ -57,6 +61,7 @@ class QQChannel(BaseChannel): self._client: "botpy.Client | None" = None self._processed_ids: deque = deque(maxlen=1000) self._msg_seq: int = 1 # ζΆˆζ―εΊεˆ—ε·οΌŒιΏε…θ’« QQ API εŽ»ι‡ + self._chat_type_cache: dict[str, str] = {} async def start(self) -> None: """Start the QQ bot.""" @@ -71,8 +76,7 @@ class QQChannel(BaseChannel): self._running = True BotClass = _make_bot_class(self) self._client = BotClass() - - logger.info("QQ bot started (C2C private message)") + logger.info("QQ bot started (C2C & Group supported)") await self._run_bot() async def _run_bot(self) -> None: @@ -101,20 +105,31 @@ class QQChannel(BaseChannel): if not self._client: logger.warning("QQ client not initialized") return + try: msg_id = msg.metadata.get("message_id") - self._msg_seq += 1 # ι€’ε’žεΊεˆ—ε· - await self._client.api.post_c2c_message( - openid=msg.chat_id, - msg_type=0, - content=msg.content, - msg_id=msg_id, - msg_seq=self._msg_seq, # ζ·»εŠ εΊεˆ—ε·ιΏε…εŽ»ι‡ - ) + self._msg_seq += 1 + msg_type = self._chat_type_cache.get(msg.chat_id, "c2c") + if msg_type == "group": + await self._client.api.post_group_message( + group_openid=msg.chat_id, + msg_type=0, + content=msg.content, + msg_id=msg_id, + msg_seq=self._msg_seq, + ) + else: + await self._client.api.post_c2c_message( + openid=msg.chat_id, + msg_type=0, + content=msg.content, + msg_id=msg_id, + msg_seq=self._msg_seq, + ) except Exception as e: logger.error("Error sending QQ message: {}", e) - async def _on_message(self, data: "C2CMessage") -> None: + async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None: """Handle incoming message from QQ.""" try: # Dedup by message ID @@ -122,18 +137,24 @@ class QQChannel(BaseChannel): return self._processed_ids.append(data.id) - author = data.author - user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown')) content = (data.content or "").strip() if not content: return + if is_group: + chat_id = data.group_openid + user_id = data.author.member_openid + self._chat_type_cache[chat_id] = "group" + else: + chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown')) + user_id = chat_id + self._chat_type_cache[chat_id] = "c2c" + await self._handle_message( sender_id=user_id, - chat_id=user_id, + chat_id=chat_id, content=content, metadata={"message_id": data.id}, ) except Exception: logger.exception("Error handling QQ message") - diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index afd1d2d..a4e7324 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -82,13 +82,14 @@ class SlackChannel(BaseChannel): thread_ts = slack_meta.get("thread_ts") channel_type = slack_meta.get("channel_type") # Only reply in thread for channel/group messages; DMs don't use threads - use_thread = thread_ts and channel_type != "im" thread_ts_param = thread_ts if use_thread else None - if msg.content: + # Slack rejects empty text payloads. Keep media-only messages media-only, + # but send a single blank message when the bot has no text or files to send. + if msg.content or not (msg.media or []): await self._web_client.chat_postMessage( channel=msg.chat_id, - text=self._to_mrkdwn(msg.content), + text=self._to_mrkdwn(msg.content) if msg.content else " ", thread_ts=thread_ts_param, ) diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index aaa24e7..ecb1440 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -15,6 +15,7 @@ from telegram.request import HTTPXRequest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir from nanobot.config.schema import TelegramConfig from nanobot.utils.helpers import split_message @@ -177,6 +178,26 @@ class TelegramChannel(BaseChannel): self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._media_group_buffers: dict[str, dict] = {} self._media_group_tasks: dict[str, asyncio.Task] = {} + self._message_threads: dict[tuple[str, int], int] = {} + + def is_allowed(self, sender_id: str) -> bool: + """Preserve Telegram's legacy id|username allowlist matching.""" + if super().is_allowed(sender_id): + return True + + allow_list = getattr(self.config, "allow_from", []) + if not allow_list or "*" in allow_list: + return False + + sender_str = str(sender_id) + if sender_str.count("|") != 1: + return False + + sid, username = sender_str.split("|", 1) + if not sid.isdigit() or not username: + return False + + return sid in allow_list or username in allow_list async def start(self) -> None: """Start the Telegram bot with long polling.""" @@ -187,16 +208,21 @@ class TelegramChannel(BaseChannel): self._running = True # Build the application with larger connection pool to avoid pool-timeout on long runs - req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) + req = HTTPXRequest( + connection_pool_size=16, + pool_timeout=5.0, + connect_timeout=30.0, + read_timeout=30.0, + proxy=self.config.proxy if self.config.proxy else None, + ) builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) - if self.config.proxy: - builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy) self._app = builder.build() self._app.add_error_handler(self._on_error) # Add command handlers self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("new", self._forward_command)) + self._app.add_handler(CommandHandler("stop", self._forward_command)) self._app.add_handler(CommandHandler("help", self._on_help)) # Add message handler for text, photos, voice, documents @@ -281,10 +307,16 @@ class TelegramChannel(BaseChannel): except ValueError: logger.error("Invalid chat_id: {}", msg.chat_id) return + reply_to_message_id = msg.metadata.get("message_id") + message_thread_id = msg.metadata.get("message_thread_id") + if message_thread_id is None and reply_to_message_id is not None: + message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id)) + thread_kwargs = {} + if message_thread_id is not None: + thread_kwargs["message_thread_id"] = message_thread_id reply_params = None if self.config.reply_to_message: - reply_to_message_id = msg.metadata.get("message_id") if reply_to_message_id: reply_params = ReplyParameters( message_id=reply_to_message_id, @@ -305,7 +337,8 @@ class TelegramChannel(BaseChannel): await sender( chat_id=chat_id, **{param: f}, - reply_parameters=reply_params + reply_parameters=reply_params, + **thread_kwargs, ) except Exception as e: filename = media_path.rsplit("/", 1)[-1] @@ -313,7 +346,8 @@ class TelegramChannel(BaseChannel): await self._app.bot.send_message( chat_id=chat_id, text=f"[Failed to send: {filename}]", - reply_parameters=reply_params + reply_parameters=reply_params, + **thread_kwargs, ) # Send text content @@ -323,28 +357,44 @@ class TelegramChannel(BaseChannel): for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN): # Final response: simulate streaming via draft, then persist if not is_progress: - await self._send_with_streaming(chat_id, chunk, reply_params) + await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs) else: - await self._send_text(chat_id, chunk, reply_params) + await self._send_text(chat_id, chunk, reply_params, thread_kwargs) - async def _send_text(self, chat_id: int, text: str, reply_params=None) -> None: + async def _send_text( + self, + chat_id: int, + text: str, + reply_params=None, + thread_kwargs: dict | None = None, + ) -> None: """Send a plain text message with HTML fallback.""" try: html = _markdown_to_telegram_html(text) await self._app.bot.send_message( chat_id=chat_id, text=html, parse_mode="HTML", reply_parameters=reply_params, + **(thread_kwargs or {}), ) except Exception as e: logger.warning("HTML parse failed, falling back to plain text: {}", e) try: await self._app.bot.send_message( - chat_id=chat_id, text=text, reply_parameters=reply_params, + chat_id=chat_id, + text=text, + reply_parameters=reply_params, + **(thread_kwargs or {}), ) except Exception as e2: logger.error("Error sending Telegram message: {}", e2) - async def _send_with_streaming(self, chat_id: int, text: str, reply_params=None) -> None: + async def _send_with_streaming( + self, + chat_id: int, + text: str, + reply_params=None, + thread_kwargs: dict | None = None, + ) -> None: """Simulate streaming via send_message_draft, then persist with send_message.""" draft_id = int(time.time() * 1000) % (2**31) try: @@ -360,7 +410,7 @@ class TelegramChannel(BaseChannel): await asyncio.sleep(0.15) except Exception: pass - await self._send_text(chat_id, text, reply_params) + await self._send_text(chat_id, text, reply_params, thread_kwargs) async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /start command.""" @@ -391,14 +441,50 @@ class TelegramChannel(BaseChannel): sid = str(user.id) return f"{sid}|{user.username}" if user.username else sid + @staticmethod + def _derive_topic_session_key(message) -> str | None: + """Derive topic-scoped session key for non-private Telegram chats.""" + message_thread_id = getattr(message, "message_thread_id", None) + if message.chat.type == "private" or message_thread_id is None: + return None + return f"telegram:{message.chat_id}:topic:{message_thread_id}" + + @staticmethod + def _build_message_metadata(message, user) -> dict: + """Build common Telegram inbound metadata payload.""" + return { + "message_id": message.message_id, + "user_id": user.id, + "username": user.username, + "first_name": user.first_name, + "is_group": message.chat.type != "private", + "message_thread_id": getattr(message, "message_thread_id", None), + "is_forum": bool(getattr(message.chat, "is_forum", False)), + } + + def _remember_thread_context(self, message) -> None: + """Cache topic thread id by chat/message id for follow-up replies.""" + message_thread_id = getattr(message, "message_thread_id", None) + if message_thread_id is None: + return + key = (str(message.chat_id), message.message_id) + self._message_threads[key] = message_thread_id + if len(self._message_threads) > 1000: + self._message_threads.pop(next(iter(self._message_threads))) + async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Forward slash commands to the bus for unified handling in AgentLoop.""" if not update.message or not update.effective_user: return + message = update.message + user = update.effective_user + self._remember_thread_context(message) await self._handle_message( - sender_id=self._sender_id(update.effective_user), - chat_id=str(update.message.chat_id), - content=update.message.text, + sender_id=self._sender_id(user), + chat_id=str(message.chat_id), + content=message.text, + metadata=self._build_message_metadata(message, user), + session_key=self._derive_topic_session_key(message), ) async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: @@ -410,6 +496,7 @@ class TelegramChannel(BaseChannel): user = update.effective_user chat_id = message.chat_id sender_id = self._sender_id(user) + self._remember_thread_context(message) # Store chat_id for replies self._chat_ids[sender_id] = chat_id @@ -445,12 +532,12 @@ class TelegramChannel(BaseChannel): if media_file and self._app: try: file = await self._app.bot.get_file(media_file.file_id) - ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None)) - - # Save to workspace/media/ - from pathlib import Path - media_dir = Path.home() / ".nanobot" / "media" - media_dir.mkdir(parents=True, exist_ok=True) + ext = self._get_extension( + media_type, + getattr(media_file, 'mime_type', None), + getattr(media_file, 'file_name', None), + ) + media_dir = get_media_dir("telegram") file_path = media_dir / f"{media_file.file_id[:16]}{ext}" await file.download_to_drive(str(file_path)) @@ -480,6 +567,8 @@ class TelegramChannel(BaseChannel): logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) str_chat_id = str(chat_id) + metadata = self._build_message_metadata(message, user) + session_key = self._derive_topic_session_key(message) # Telegram media groups: buffer briefly, forward as one aggregated turn. if media_group_id := getattr(message, "media_group_id", None): @@ -488,11 +577,8 @@ class TelegramChannel(BaseChannel): self._media_group_buffers[key] = { "sender_id": sender_id, "chat_id": str_chat_id, "contents": [], "media": [], - "metadata": { - "message_id": message.message_id, "user_id": user.id, - "username": user.username, "first_name": user.first_name, - "is_group": message.chat.type != "private", - }, + "metadata": metadata, + "session_key": session_key, } self._start_typing(str_chat_id) buf = self._media_group_buffers[key] @@ -512,13 +598,8 @@ class TelegramChannel(BaseChannel): chat_id=str_chat_id, content=content, media=media_paths, - metadata={ - "message_id": message.message_id, - "user_id": user.id, - "username": user.username, - "first_name": user.first_name, - "is_group": message.chat.type != "private" - } + metadata=metadata, + session_key=session_key, ) async def _flush_media_group(self, key: str) -> None: @@ -532,6 +613,7 @@ class TelegramChannel(BaseChannel): sender_id=buf["sender_id"], chat_id=buf["chat_id"], content=content, media=list(dict.fromkeys(buf["media"])), metadata=buf["metadata"], + session_key=buf.get("session_key"), ) finally: self._media_group_tasks.pop(key, None) @@ -563,8 +645,13 @@ class TelegramChannel(BaseChannel): """Log polling / handler errors instead of silently swallowing them.""" logger.error("Telegram error: {}", context.error) - def _get_extension(self, media_type: str, mime_type: str | None) -> str: - """Get file extension based on media type.""" + def _get_extension( + self, + media_type: str, + mime_type: str | None, + filename: str | None = None, + ) -> str: + """Get file extension based on media type or original filename.""" if mime_type: ext_map = { "image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", @@ -574,4 +661,12 @@ class TelegramChannel(BaseChannel): return ext_map[mime_type] type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""} - return type_map.get(media_type, "") + if ext := type_map.get(media_type, ""): + return ext + + if filename: + from pathlib import Path + + return "".join(Path(filename).suffixes) + + return "" diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index 0d1ec7e..1307716 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -2,6 +2,7 @@ import asyncio import json +import mimetypes from collections import OrderedDict from loguru import logger @@ -128,10 +129,22 @@ class WhatsAppChannel(BaseChannel): logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id) content = "[Voice Message: Transcription not available for WhatsApp yet]" + # Extract media paths (images/documents/videos downloaded by the bridge) + media_paths = data.get("media") or [] + + # Build content tags matching Telegram's pattern: [image: /path] or [file: /path] + if media_paths: + for p in media_paths: + mime, _ = mimetypes.guess_type(p) + media_type = "image" if mime and mime.startswith("image/") else "file" + media_tag = f"[{media_type}: {p}]" + content = f"{content}\n{media_tag}" if content else media_tag + await self._handle_message( sender_id=sender_id, chat_id=sender, # Use full LID for replies content=content, + media=media_paths, metadata={ "message_id": message_id, "timestamp": data.get("timestamp"), diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 5987796..d03ef93 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -29,6 +29,7 @@ from rich.table import Table from rich.text import Text from nanobot import __logo__, __version__ +from nanobot.config.paths import get_workspace_path from nanobot.config.schema import Config from nanobot.utils.helpers import sync_workspace_templates @@ -98,7 +99,9 @@ def _init_prompt_session() -> None: except Exception: pass - history_file = Path.home() / ".nanobot" / "history" / "cli_history" + from nanobot.config.paths import get_cli_history_path + + history_file = get_cli_history_path() history_file.parent.mkdir(parents=True, exist_ok=True) _PROMPT_SESSION = PromptSession( @@ -169,7 +172,6 @@ def onboard(): """Initialize nanobot configuration and workspace.""" from nanobot.config.loader import get_config_path, load_config, save_config from nanobot.config.schema import Config - from nanobot.utils.helpers import get_workspace_path config_path = get_config_path() @@ -212,6 +214,7 @@ def onboard(): def _make_provider(config: Config): """Create the appropriate LLM provider from config.""" from nanobot.providers.openai_codex_provider import OpenAICodexProvider + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider model = config.agents.defaults.model provider_name = config.get_provider_name(model) @@ -230,6 +233,20 @@ def _make_provider(config: Config): default_model=model, ) + # Azure OpenAI: direct Azure OpenAI endpoint with deployment name + if provider_name == "azure_openai": + if not p or not p.api_key or not p.api_base: + console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") + console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") + console.print("Use the model field to specify the deployment name.") + raise typer.Exit(1) + + return AzureOpenAIProvider( + api_key=p.api_key, + api_base=p.api_base, + default_model=model, + ) + from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.registry import find_by_name spec = find_by_name(provider_name) @@ -267,13 +284,24 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None def gateway( port: int = typer.Option(18790, "--port", "-p", help="Gateway port"), workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), - config: str | None = typer.Option(None, "--config", "-c", help="Config file path"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), + config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), ): """Start the nanobot gateway.""" + # Set config path if provided (must be done before any imports that use get_data_dir) + if config: + from nanobot.config.loader import set_config_path + config_path = Path(config).expanduser().resolve() + if not config_path.exists(): + console.print(f"[red]Error: Config file not found: {config_path}[/red]") + raise typer.Exit(1) + set_config_path(config_path) + console.print(f"[dim]Using config: {config_path}[/dim]") + from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus from nanobot.channels.manager import ChannelManager + from nanobot.config.paths import get_cron_dir from nanobot.cron.service import CronService from nanobot.cron.types import CronJob from nanobot.heartbeat.service import HeartbeatService @@ -292,8 +320,7 @@ def gateway( session_manager = SessionManager(config.workspace_path) # Create cron service first (callback set after agent creation) - # Use workspace path for per-instance cron store - cron_store_path = config.workspace_path / "cron" / "jobs.json" + cron_store_path = get_cron_dir() / "jobs.json" cron = CronService(cron_store_path) # Create agent with cron service @@ -464,7 +491,7 @@ def agent( from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus - from nanobot.config.loader import get_data_dir + from nanobot.config.paths import get_cron_dir from nanobot.cron.service import CronService config = _load_runtime_config(config, workspace) @@ -474,7 +501,7 @@ def agent( provider = _make_provider(config) # Create cron service for tool usage (no callback needed for CLI unless running) - cron_store_path = get_data_dir() / "cron" / "jobs.json" + cron_store_path = get_cron_dir() / "jobs.json" cron = CronService(cron_store_path) if logs: @@ -740,7 +767,9 @@ def _get_bridge_dir() -> Path: import subprocess # User's bridge location - user_bridge = Path.home() / ".nanobot" / "bridge" + from nanobot.config.paths import get_bridge_install_dir + + user_bridge = get_bridge_install_dir() # Check if already built if (user_bridge / "dist" / "index.js").exists(): @@ -798,6 +827,7 @@ def channels_login(): import subprocess from nanobot.config.loader import load_config + from nanobot.config.paths import get_runtime_subdir config = load_config() bridge_dir = _get_bridge_dir() @@ -808,6 +838,7 @@ def channels_login(): env = {**os.environ} if config.channels.whatsapp.bridge_token: env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token + env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth")) try: subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env) diff --git a/nanobot/config/__init__.py b/nanobot/config/__init__.py index 6c59668..e2c24f8 100644 --- a/nanobot/config/__init__.py +++ b/nanobot/config/__init__.py @@ -1,6 +1,30 @@ """Configuration module for nanobot.""" from nanobot.config.loader import get_config_path, load_config +from nanobot.config.paths import ( + get_bridge_install_dir, + get_cli_history_path, + get_cron_dir, + get_data_dir, + get_legacy_sessions_dir, + get_logs_dir, + get_media_dir, + get_runtime_subdir, + get_workspace_path, +) from nanobot.config.schema import Config -__all__ = ["Config", "load_config", "get_config_path"] +__all__ = [ + "Config", + "load_config", + "get_config_path", + "get_data_dir", + "get_runtime_subdir", + "get_media_dir", + "get_cron_dir", + "get_logs_dir", + "get_workspace_path", + "get_cli_history_path", + "get_bridge_install_dir", + "get_legacy_sessions_dir", +] diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py index c789efd..7d309e5 100644 --- a/nanobot/config/loader.py +++ b/nanobot/config/loader.py @@ -6,17 +6,23 @@ from pathlib import Path from nanobot.config.schema import Config +# Global variable to store current config path (for multi-instance support) +_current_config_path: Path | None = None + + +def set_config_path(path: Path) -> None: + """Set the current config path (used to derive data directory).""" + global _current_config_path + _current_config_path = path + + def get_config_path() -> Path: - """Get the default configuration file path.""" + """Get the configuration file path.""" + if _current_config_path: + return _current_config_path return Path.home() / ".nanobot" / "config.json" -def get_data_dir() -> Path: - """Get the nanobot data directory.""" - from nanobot.utils.helpers import get_data_path - return get_data_path() - - def load_config(config_path: Path | None = None) -> Config: """ Load configuration from file or create default. diff --git a/nanobot/config/paths.py b/nanobot/config/paths.py new file mode 100644 index 0000000..f4dfbd9 --- /dev/null +++ b/nanobot/config/paths.py @@ -0,0 +1,55 @@ +"""Runtime path helpers derived from the active config context.""" + +from __future__ import annotations + +from pathlib import Path + +from nanobot.config.loader import get_config_path +from nanobot.utils.helpers import ensure_dir + + +def get_data_dir() -> Path: + """Return the instance-level runtime data directory.""" + return ensure_dir(get_config_path().parent) + + +def get_runtime_subdir(name: str) -> Path: + """Return a named runtime subdirectory under the instance data dir.""" + return ensure_dir(get_data_dir() / name) + + +def get_media_dir(channel: str | None = None) -> Path: + """Return the media directory, optionally namespaced per channel.""" + base = get_runtime_subdir("media") + return ensure_dir(base / channel) if channel else base + + +def get_cron_dir() -> Path: + """Return the cron storage directory.""" + return get_runtime_subdir("cron") + + +def get_logs_dir() -> Path: + """Return the logs directory.""" + return get_runtime_subdir("logs") + + +def get_workspace_path(workspace: str | None = None) -> Path: + """Resolve and ensure the agent workspace path.""" + path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace" + return ensure_dir(path) + + +def get_cli_history_path() -> Path: + """Return the shared CLI history file path.""" + return Path.home() / ".nanobot" / "history" / "cli_history" + + +def get_bridge_install_dir() -> Path: + """Return the shared WhatsApp bridge installation directory.""" + return Path.home() / ".nanobot" / "bridge" + + +def get_legacy_sessions_dir() -> Path: + """Return the legacy global session directory used for migration fallback.""" + return Path.home() / ".nanobot" / "sessions" diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 2073eeb..803cb61 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -251,6 +251,7 @@ class ProvidersConfig(Base): """Configuration for LLM providers.""" custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint + azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name) anthropic: ProviderConfig = Field(default_factory=ProviderConfig) openai: ProviderConfig = Field(default_factory=ProviderConfig) openrouter: ProviderConfig = Field(default_factory=ProviderConfig) diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py index b2bb2b9..5bd06f9 100644 --- a/nanobot/providers/__init__.py +++ b/nanobot/providers/__init__.py @@ -3,5 +3,6 @@ from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider +from nanobot.providers.azure_openai_provider import AzureOpenAIProvider -__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"] +__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"] diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py new file mode 100644 index 0000000..bd79b00 --- /dev/null +++ b/nanobot/providers/azure_openai_provider.py @@ -0,0 +1,210 @@ +"""Azure OpenAI provider implementation with API version 2024-10-21.""" + +from __future__ import annotations + +import uuid +from typing import Any +from urllib.parse import urljoin + +import httpx +import json_repair + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"}) + + +class AzureOpenAIProvider(LLMProvider): + """ + Azure OpenAI provider with API version 2024-10-21 compliance. + + Features: + - Hardcoded API version 2024-10-21 + - Uses model field as Azure deployment name in URL path + - Uses api-key header instead of Authorization Bearer + - Uses max_completion_tokens instead of max_tokens + - Direct HTTP calls, bypasses LiteLLM + """ + + def __init__( + self, + api_key: str = "", + api_base: str = "", + default_model: str = "gpt-5.2-chat", + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.api_version = "2024-10-21" + + # Validate required parameters + if not api_key: + raise ValueError("Azure OpenAI api_key is required") + if not api_base: + raise ValueError("Azure OpenAI api_base is required") + + # Ensure api_base ends with / + if not api_base.endswith('/'): + api_base += '/' + self.api_base = api_base + + def _build_chat_url(self, deployment_name: str) -> str: + """Build the Azure OpenAI chat completions URL.""" + # Azure OpenAI URL format: + # https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} + base_url = self.api_base + if not base_url.endswith('/'): + base_url += '/' + + url = urljoin( + base_url, + f"openai/deployments/{deployment_name}/chat/completions" + ) + return f"{url}?api-version={self.api_version}" + + def _build_headers(self) -> dict[str, str]: + """Build headers for Azure OpenAI API with api-key header.""" + return { + "Content-Type": "application/json", + "api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization + "x-session-affinity": uuid.uuid4().hex, # For cache locality + } + + @staticmethod + def _supports_temperature( + deployment_name: str, + reasoning_effort: str | None = None, + ) -> bool: + """Return True when temperature is likely supported for this deployment.""" + if reasoning_effort: + return False + name = deployment_name.lower() + return not any(token in name for token in ("gpt-5", "o1", "o3", "o4")) + + def _prepare_request_payload( + self, + deployment_name: str, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + ) -> dict[str, Any]: + """Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" + payload: dict[str, Any] = { + "messages": self._sanitize_request_messages( + self._sanitize_empty_content(messages), + _AZURE_MSG_KEYS, + ), + "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens + } + + if self._supports_temperature(deployment_name, reasoning_effort): + payload["temperature"] = temperature + + if reasoning_effort: + payload["reasoning_effort"] = reasoning_effort + + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + + return payload + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + ) -> LLMResponse: + """ + Send a chat completion request to Azure OpenAI. + + Args: + messages: List of message dicts with 'role' and 'content'. + tools: Optional list of tool definitions in OpenAI format. + model: Model identifier (used as deployment name). + max_tokens: Maximum tokens in response (mapped to max_completion_tokens). + temperature: Sampling temperature. + reasoning_effort: Optional reasoning effort parameter. + + Returns: + LLMResponse with content and/or tool calls. + """ + deployment_name = model or self.default_model + url = self._build_chat_url(deployment_name) + headers = self._build_headers() + payload = self._prepare_request_payload( + deployment_name, messages, tools, max_tokens, temperature, reasoning_effort + ) + + try: + async with httpx.AsyncClient(timeout=60.0, verify=True) as client: + response = await client.post(url, headers=headers, json=payload) + if response.status_code != 200: + return LLMResponse( + content=f"Azure OpenAI API Error {response.status_code}: {response.text}", + finish_reason="error", + ) + + response_data = response.json() + return self._parse_response(response_data) + + except Exception as e: + return LLMResponse( + content=f"Error calling Azure OpenAI: {repr(e)}", + finish_reason="error", + ) + + def _parse_response(self, response: dict[str, Any]) -> LLMResponse: + """Parse Azure OpenAI response into our standard format.""" + try: + choice = response["choices"][0] + message = choice["message"] + + tool_calls = [] + if message.get("tool_calls"): + for tc in message["tool_calls"]: + # Parse arguments from JSON string if needed + args = tc["function"]["arguments"] + if isinstance(args, str): + args = json_repair.loads(args) + + tool_calls.append( + ToolCallRequest( + id=tc["id"], + name=tc["function"]["name"], + arguments=args, + ) + ) + + usage = {} + if response.get("usage"): + usage_data = response["usage"] + usage = { + "prompt_tokens": usage_data.get("prompt_tokens", 0), + "completion_tokens": usage_data.get("completion_tokens", 0), + "total_tokens": usage_data.get("total_tokens", 0), + } + + reasoning_content = message.get("reasoning_content") or None + + return LLMResponse( + content=message.get("content"), + tool_calls=tool_calls, + finish_reason=choice.get("finish_reason", "stop"), + usage=usage, + reasoning_content=reasoning_content, + ) + + except (KeyError, IndexError) as e: + return LLMResponse( + content=f"Error parsing Azure OpenAI response: {str(e)}", + finish_reason="error", + ) + + def get_default_model(self) -> str: + """Get the default model (also used as default deployment name).""" + return self.default_model \ No newline at end of file diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 55bd805..0f73544 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -87,6 +87,20 @@ class LLMProvider(ABC): result.append(msg) return result + @staticmethod + def _sanitize_request_messages( + messages: list[dict[str, Any]], + allowed_keys: frozenset[str], + ) -> list[dict[str, Any]]: + """Keep only provider-safe message keys and normalize assistant content.""" + sanitized = [] + for msg in messages: + clean = {k: v for k, v in msg.items() if k in allowed_keys} + if clean.get("role") == "assistant" and "content" not in clean: + clean["content"] = None + sanitized.append(clean) + return sanitized + @abstractmethod async def chat( self, diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 620424e..cb67635 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -1,5 +1,6 @@ """LiteLLM provider implementation for multi-provider support.""" +import hashlib import os import secrets import string @@ -166,17 +167,43 @@ class LiteLLMProvider(LLMProvider): return _ANTHROPIC_EXTRA_KEYS return frozenset() + @staticmethod + def _normalize_tool_call_id(tool_call_id: Any) -> Any: + """Normalize tool_call_id to a provider-safe 9-char alphanumeric form.""" + if not isinstance(tool_call_id, str): + return tool_call_id + if len(tool_call_id) == 9 and tool_call_id.isalnum(): + return tool_call_id + return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] + @staticmethod def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: """Strip non-standard keys and ensure assistant messages have a content key.""" allowed = _ALLOWED_MSG_KEYS | extra_keys - sanitized = [] - for msg in messages: - clean = {k: v for k, v in msg.items() if k in allowed} - # Strict providers require "content" even when assistant only has tool_calls - if clean.get("role") == "assistant" and "content" not in clean: - clean["content"] = None - sanitized.append(clean) + sanitized = LLMProvider._sanitize_request_messages(messages, allowed) + id_map: dict[str, str] = {} + + def map_id(value: Any) -> Any: + if not isinstance(value, str): + return value + return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) + + for clean in sanitized: + # Keep assistant tool_calls[].id and tool tool_call_id in sync after + # shortening, otherwise strict providers reject the broken linkage. + if isinstance(clean.get("tool_calls"), list): + normalized_tool_calls = [] + for tc in clean["tool_calls"]: + if not isinstance(tc, dict): + normalized_tool_calls.append(tc) + continue + tc_clean = dict(tc) + tc_clean["id"] = map_id(tc_clean.get("id")) + normalized_tool_calls.append(tc_clean) + clean["tool_calls"] = normalized_tool_calls + + if "tool_call_id" in clean and clean["tool_call_id"]: + clean["tool_call_id"] = map_id(clean["tool_call_id"]) return sanitized async def chat( diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 59ba31a..3ba1a0e 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -79,6 +79,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( litellm_prefix="", is_direct=True, ), + + # === Azure OpenAI (direct API calls with API version 2024-10-21) ===== + ProviderSpec( + name="azure_openai", + keywords=("azure", "azure-openai"), + env_key="", + display_name="Azure OpenAI", + litellm_prefix="", + is_direct=True, + ), # === Gateways (detected by api_key / api_base, not model name) ========= # Gateways can route any model, so they win in fallback. # OpenRouter: global gateway, keys start with "sk-or-" diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index dce4b2e..f0a6484 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -9,6 +9,7 @@ from typing import Any from loguru import logger +from nanobot.config.paths import get_legacy_sessions_dir from nanobot.utils.helpers import ensure_dir, safe_filename @@ -79,7 +80,7 @@ class SessionManager: def __init__(self, workspace: Path): self.workspace = workspace self.sessions_dir = ensure_dir(self.workspace / "sessions") - self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions" + self.legacy_sessions_dir = get_legacy_sessions_dir() self._cache: dict[str, Session] = {} def _get_session_path(self, key: str) -> Path: diff --git a/nanobot/utils/__init__.py b/nanobot/utils/__init__.py index 9163e38..46f02ac 100644 --- a/nanobot/utils/__init__.py +++ b/nanobot/utils/__init__.py @@ -1,5 +1,5 @@ """Utility functions for nanobot.""" -from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path +from nanobot.utils.helpers import ensure_dir -__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"] +__all__ = ["ensure_dir"] diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index c57c365..57c60dc 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -24,17 +24,6 @@ def ensure_dir(path: Path) -> Path: return path -def get_data_path() -> Path: - """~/.nanobot data directory.""" - return ensure_dir(Path.home() / ".nanobot") - - -def get_workspace_path(workspace: str | None = None) -> Path: - """Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace.""" - path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace" - return ensure_dir(path) - - def timestamp() -> str: """Current ISO timestamp.""" return datetime.now().isoformat() diff --git a/tests/test_azure_openai_provider.py b/tests/test_azure_openai_provider.py new file mode 100644 index 0000000..77f36d4 --- /dev/null +++ b/tests/test_azure_openai_provider.py @@ -0,0 +1,399 @@ +"""Test Azure OpenAI provider implementation (updated for model-based deployment names).""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from nanobot.providers.azure_openai_provider import AzureOpenAIProvider +from nanobot.providers.base import LLMResponse + + +def test_azure_openai_provider_init(): + """Test AzureOpenAIProvider initialization without deployment_name.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o-deployment", + ) + + assert provider.api_key == "test-key" + assert provider.api_base == "https://test-resource.openai.azure.com/" + assert provider.default_model == "gpt-4o-deployment" + assert provider.api_version == "2024-10-21" + + +def test_azure_openai_provider_init_validation(): + """Test AzureOpenAIProvider initialization validation.""" + # Missing api_key + with pytest.raises(ValueError, match="Azure OpenAI api_key is required"): + AzureOpenAIProvider(api_key="", api_base="https://test.com") + + # Missing api_base + with pytest.raises(ValueError, match="Azure OpenAI api_base is required"): + AzureOpenAIProvider(api_key="test", api_base="") + + +def test_build_chat_url(): + """Test Azure OpenAI URL building with different deployment names.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + # Test various deployment names + test_cases = [ + ("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"), + ("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"), + ("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"), + ] + + for deployment_name, expected_url in test_cases: + url = provider._build_chat_url(deployment_name) + assert url == expected_url + + +def test_build_chat_url_api_base_without_slash(): + """Test URL building when api_base doesn't end with slash.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", # No trailing slash + default_model="gpt-4o", + ) + + url = provider._build_chat_url("test-deployment") + expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21" + assert url == expected + + +def test_build_headers(): + """Test Azure OpenAI header building with api-key authentication.""" + provider = AzureOpenAIProvider( + api_key="test-api-key-123", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + headers = provider._build_headers() + assert headers["Content-Type"] == "application/json" + assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header + assert "x-session-affinity" in headers + + +def test_prepare_request_payload(): + """Test request payload preparation with Azure OpenAI 2024-10-21 compliance.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + messages = [{"role": "user", "content": "Hello"}] + payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8) + + assert payload["messages"] == messages + assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens + assert payload["temperature"] == 0.8 + assert "tools" not in payload + + # Test with tools + tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] + payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools) + assert payload_with_tools["tools"] == tools + assert payload_with_tools["tool_choice"] == "auto" + + # Test with reasoning_effort + payload_with_reasoning = provider._prepare_request_payload( + "gpt-5-chat", messages, reasoning_effort="medium" + ) + assert payload_with_reasoning["reasoning_effort"] == "medium" + assert "temperature" not in payload_with_reasoning + + +def test_prepare_request_payload_sanitizes_messages(): + """Test Azure payload strips non-standard message keys before sending.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + messages = [ + { + "role": "assistant", + "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], + "reasoning_content": "hidden chain-of-thought", + }, + { + "role": "tool", + "tool_call_id": "call_123", + "name": "x", + "content": "ok", + "extra_field": "should be removed", + }, + ] + + payload = provider._prepare_request_payload("gpt-4o", messages) + + assert payload["messages"] == [ + { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], + }, + { + "role": "tool", + "tool_call_id": "call_123", + "name": "x", + "content": "ok", + }, + ] + + +@pytest.mark.asyncio +async def test_chat_success(): + """Test successful chat request using model as deployment name.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o-deployment", + ) + + # Mock response data + mock_response_data = { + "choices": [{ + "message": { + "content": "Hello! How can I help you today?", + "role": "assistant" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 12, + "completion_tokens": 18, + "total_tokens": 30 + } + } + + with patch("httpx.AsyncClient") as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = Mock(return_value=mock_response_data) + + mock_context = AsyncMock() + mock_context.post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value = mock_context + + # Test with specific model (deployment name) + messages = [{"role": "user", "content": "Hello"}] + result = await provider.chat(messages, model="custom-deployment") + + assert isinstance(result, LLMResponse) + assert result.content == "Hello! How can I help you today?" + assert result.finish_reason == "stop" + assert result.usage["prompt_tokens"] == 12 + assert result.usage["completion_tokens"] == 18 + assert result.usage["total_tokens"] == 30 + + # Verify URL was built with the provided model as deployment name + call_args = mock_context.post.call_args + expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21" + assert call_args[0][0] == expected_url + + +@pytest.mark.asyncio +async def test_chat_uses_default_model_when_no_model_provided(): + """Test that chat uses default_model when no model is specified.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="default-deployment", + ) + + mock_response_data = { + "choices": [{ + "message": {"content": "Response", "role": "assistant"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10} + } + + with patch("httpx.AsyncClient") as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = Mock(return_value=mock_response_data) + + mock_context = AsyncMock() + mock_context.post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value = mock_context + + messages = [{"role": "user", "content": "Test"}] + await provider.chat(messages) # No model specified + + # Verify URL was built with default model as deployment name + call_args = mock_context.post.call_args + expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21" + assert call_args[0][0] == expected_url + + +@pytest.mark.asyncio +async def test_chat_with_tool_calls(): + """Test chat request with tool calls in response.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + # Mock response with tool calls + mock_response_data = { + "choices": [{ + "message": { + "content": None, + "role": "assistant", + "tool_calls": [{ + "id": "call_12345", + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}' + } + }] + }, + "finish_reason": "tool_calls" + }], + "usage": { + "prompt_tokens": 20, + "completion_tokens": 15, + "total_tokens": 35 + } + } + + with patch("httpx.AsyncClient") as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = Mock(return_value=mock_response_data) + + mock_context = AsyncMock() + mock_context.post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value = mock_context + + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] + result = await provider.chat(messages, tools=tools, model="weather-model") + + assert isinstance(result, LLMResponse) + assert result.content is None + assert result.finish_reason == "tool_calls" + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "get_weather" + assert result.tool_calls[0].arguments == {"location": "San Francisco"} + + +@pytest.mark.asyncio +async def test_chat_api_error(): + """Test chat request API error handling.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + with patch("httpx.AsyncClient") as mock_client: + mock_response = AsyncMock() + mock_response.status_code = 401 + mock_response.text = "Invalid authentication credentials" + + mock_context = AsyncMock() + mock_context.post = AsyncMock(return_value=mock_response) + mock_client.return_value.__aenter__.return_value = mock_context + + messages = [{"role": "user", "content": "Hello"}] + result = await provider.chat(messages) + + assert isinstance(result, LLMResponse) + assert "Azure OpenAI API Error 401" in result.content + assert "Invalid authentication credentials" in result.content + assert result.finish_reason == "error" + + +@pytest.mark.asyncio +async def test_chat_connection_error(): + """Test chat request connection error handling.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + with patch("httpx.AsyncClient") as mock_client: + mock_context = AsyncMock() + mock_context.post = AsyncMock(side_effect=Exception("Connection failed")) + mock_client.return_value.__aenter__.return_value = mock_context + + messages = [{"role": "user", "content": "Hello"}] + result = await provider.chat(messages) + + assert isinstance(result, LLMResponse) + assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content + assert result.finish_reason == "error" + + +def test_parse_response_malformed(): + """Test response parsing with malformed data.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + # Test with missing choices + malformed_response = {"usage": {"prompt_tokens": 10}} + result = provider._parse_response(malformed_response) + + assert isinstance(result, LLMResponse) + assert "Error parsing Azure OpenAI response" in result.content + assert result.finish_reason == "error" + + +def test_get_default_model(): + """Test get_default_model method.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="my-custom-deployment", + ) + + assert provider.get_default_model() == "my-custom-deployment" + + +if __name__ == "__main__": + # Run basic tests + print("Running basic Azure OpenAI provider tests...") + + # Test initialization + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o-deployment", + ) + print("βœ… Provider initialization successful") + + # Test URL building + url = provider._build_chat_url("my-deployment") + expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21" + assert url == expected + print("βœ… URL building works correctly") + + # Test headers + headers = provider._build_headers() + assert headers["api-key"] == "test-key" + assert headers["Content-Type"] == "application/json" + print("βœ… Header building works correctly") + + # Test payload preparation + messages = [{"role": "user", "content": "Test"}] + payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000) + assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format + print("βœ… Payload preparation works correctly") + + print("βœ… All basic tests passed! Updated test file is working correctly.") \ No newline at end of file diff --git a/tests/test_base_channel.py b/tests/test_base_channel.py new file mode 100644 index 0000000..5d10d4e --- /dev/null +++ b/tests/test_base_channel.py @@ -0,0 +1,25 @@ +from types import SimpleNamespace + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel + + +class _DummyChannel(BaseChannel): + name = "dummy" + + async def start(self) -> None: + return None + + async def stop(self) -> None: + return None + + async def send(self, msg: OutboundMessage) -> None: + return None + + +def test_is_allowed_requires_exact_match() -> None: + channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus()) + + assert channel.is_allowed("allow@email.com") is True + assert channel.is_allowed("attacker|allow@email.com") is False diff --git a/tests/test_commands.py b/tests/test_commands.py index 46ee7d0..e3709da 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -14,13 +14,17 @@ from nanobot.providers.registry import find_by_model runner = CliRunner() +class _StopGateway(RuntimeError): + pass + + @pytest.fixture def mock_paths(): """Mock config/workspace paths for test isolation.""" with patch("nanobot.config.loader.get_config_path") as mock_cp, \ patch("nanobot.config.loader.save_config") as mock_sc, \ - patch("nanobot.config.loader.load_config"), \ - patch("nanobot.utils.helpers.get_workspace_path") as mock_ws: + patch("nanobot.config.loader.load_config") as mock_lc, \ + patch("nanobot.cli.commands.get_workspace_path") as mock_ws: base_dir = Path("./test_onboard_data") if base_dir.exists(): @@ -135,10 +139,10 @@ def mock_agent_runtime(tmp_path): """Mock agent command dependencies for focused CLI tests.""" config = Config() config.agents.defaults.workspace = str(tmp_path / "default-workspace") - data_dir = tmp_path / "data" + cron_dir = tmp_path / "data" / "cron" with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \ - patch("nanobot.config.loader.get_data_dir", return_value=data_dir), \ + patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \ patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \ patch("nanobot.cli.commands._make_provider", return_value=object()), \ patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \ @@ -221,3 +225,94 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime) assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path) assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,) assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path + + +def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + seen: dict[str, Path] = {} + + monkeypatch.setattr( + "nanobot.config.loader.set_config_path", + lambda path: seen.__setitem__("config_path", path), + ) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr( + "nanobot.cli.commands.sync_workspace_templates", + lambda path: seen.__setitem__("workspace", path), + ) + monkeypatch.setattr( + "nanobot.cli.commands._make_provider", + lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + ) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + + assert isinstance(result.exception, _StopGateway) + assert seen["config_path"] == config_file.resolve() + assert seen["workspace"] == Path(config.agents.defaults.workspace) + + +def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + override = tmp_path / "override-workspace" + seen: dict[str, Path] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr( + "nanobot.cli.commands.sync_workspace_templates", + lambda path: seen.__setitem__("workspace", path), + ) + monkeypatch.setattr( + "nanobot.cli.commands._make_provider", + lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + ) + + result = runner.invoke( + app, + ["gateway", "--config", str(config_file), "--workspace", str(override)], + ) + + assert isinstance(result.exception, _StopGateway) + assert seen["workspace"] == override + assert config.workspace_path == override + + +def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + seen: dict[str, Path] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron") + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) + monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) + + class _StopCron: + def __init__(self, store_path: Path) -> None: + seen["cron_store"] = store_path + raise _StopGateway("stop") + + monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + + assert isinstance(result.exception, _StopGateway) + assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json" diff --git a/tests/test_config_paths.py b/tests/test_config_paths.py new file mode 100644 index 0000000..473a6c8 --- /dev/null +++ b/tests/test_config_paths.py @@ -0,0 +1,42 @@ +from pathlib import Path + +from nanobot.config.paths import ( + get_bridge_install_dir, + get_cli_history_path, + get_cron_dir, + get_data_dir, + get_legacy_sessions_dir, + get_logs_dir, + get_media_dir, + get_runtime_subdir, + get_workspace_path, +) + + +def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance-a" / "config.json" + monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file) + + assert get_data_dir() == config_file.parent + assert get_runtime_subdir("cron") == config_file.parent / "cron" + assert get_cron_dir() == config_file.parent / "cron" + assert get_logs_dir() == config_file.parent / "logs" + + +def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance-b" / "config.json" + monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file) + + assert get_media_dir() == config_file.parent / "media" + assert get_media_dir("telegram") == config_file.parent / "media" / "telegram" + + +def test_shared_and_legacy_paths_remain_global() -> None: + assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history" + assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge" + assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions" + + +def test_workspace_path_is_explicitly_resolved() -> None: + assert get_workspace_path() == Path.home() / ".nanobot" / "workspace" + assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace" diff --git a/tests/test_dingtalk_channel.py b/tests/test_dingtalk_channel.py new file mode 100644 index 0000000..7595a33 --- /dev/null +++ b/tests/test_dingtalk_channel.py @@ -0,0 +1,66 @@ +from types import SimpleNamespace + +import pytest + +from nanobot.bus.queue import MessageBus +from nanobot.channels.dingtalk import DingTalkChannel +from nanobot.config.schema import DingTalkConfig + + +class _FakeResponse: + def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None: + self.status_code = status_code + self._json_body = json_body or {} + self.text = "{}" + + def json(self) -> dict: + return self._json_body + + +class _FakeHttp: + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def post(self, url: str, json=None, headers=None): + self.calls.append({"url": url, "json": json, "headers": headers}) + return _FakeResponse() + + +@pytest.mark.asyncio +async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None: + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]) + bus = MessageBus() + channel = DingTalkChannel(config, bus) + + await channel._on_message( + "hello", + sender_id="user1", + sender_name="Alice", + conversation_type="2", + conversation_id="conv123", + ) + + msg = await bus.consume_inbound() + assert msg.sender_id == "user1" + assert msg.chat_id == "group:conv123" + assert msg.metadata["conversation_type"] == "2" + + +@pytest.mark.asyncio +async def test_group_send_uses_group_messages_api() -> None: + config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]) + channel = DingTalkChannel(config, MessageBus()) + channel._http = _FakeHttp() + + ok = await channel._send_batch_message( + "token", + "group:conv123", + "sampleMarkdown", + {"text": "hello", "title": "Nanobot Reply"}, + ) + + assert ok is True + call = channel._http.calls[0] + assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send" + assert call["json"]["openConversationId"] == "conv123" + assert call["json"]["msgKey"] == "sampleMarkdown" diff --git a/tests/test_feishu_post_content.py b/tests/test_feishu_post_content.py index bf1ea82..7b1cb9d 100644 --- a/tests/test_feishu_post_content.py +++ b/tests/test_feishu_post_content.py @@ -1,4 +1,4 @@ -from nanobot.channels.feishu import _extract_post_content +from nanobot.channels.feishu import FeishuChannel, _extract_post_content def test_extract_post_content_supports_post_wrapper_shape() -> None: @@ -38,3 +38,28 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None: assert text == "Daily report" assert image_keys == ["img_a", "img_b"] + + +def test_register_optional_event_keeps_builder_when_method_missing() -> None: + class Builder: + pass + + builder = Builder() + same = FeishuChannel._register_optional_event(builder, "missing", object()) + assert same is builder + + +def test_register_optional_event_calls_supported_method() -> None: + called = [] + + class Builder: + def register_event(self, handler): + called.append(handler) + return self + + builder = Builder() + handler = object() + same = FeishuChannel._register_optional_event(builder, "register_event", handler) + + assert same is builder + assert called == [handler] diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py index 26b8a16..63b0fd1 100644 --- a/tests/test_message_tool_suppress.py +++ b/tests/test_message_tool_suppress.py @@ -86,6 +86,35 @@ class TestMessageToolSuppressLogic: assert result is not None assert "Hello" in result.content + async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"}) + calls = iter([ + LLMResponse( + content="Visiblehidden", + tool_calls=[tool_call], + reasoning_content="secret reasoning", + thinking_blocks=[{"signature": "sig", "thought": "secret thought"}], + ), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.execute = AsyncMock(return_value="ok") + + progress: list[tuple[str, bool]] = [] + + async def on_progress(content: str, *, tool_hint: bool = False) -> None: + progress.append((content, tool_hint)) + + final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + + assert final_content == "Done" + assert progress == [ + ("Visible", False), + ('read_file("foo.txt")', True), + ] + class TestMessageToolTurnTracking: diff --git a/tests/test_qq_channel.py b/tests/test_qq_channel.py new file mode 100644 index 0000000..90b4e60 --- /dev/null +++ b/tests/test_qq_channel.py @@ -0,0 +1,66 @@ +from types import SimpleNamespace + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.qq import QQChannel +from nanobot.config.schema import QQConfig + + +class _FakeApi: + def __init__(self) -> None: + self.c2c_calls: list[dict] = [] + self.group_calls: list[dict] = [] + + async def post_c2c_message(self, **kwargs) -> None: + self.c2c_calls.append(kwargs) + + async def post_group_message(self, **kwargs) -> None: + self.group_calls.append(kwargs) + + +class _FakeClient: + def __init__(self) -> None: + self.api = _FakeApi() + + +@pytest.mark.asyncio +async def test_on_group_message_routes_to_group_chat_id() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus()) + + data = SimpleNamespace( + id="msg1", + content="hello", + group_openid="group123", + author=SimpleNamespace(member_openid="user1"), + ) + + await channel._on_message(data, is_group=True) + + msg = await channel.bus.consume_inbound() + assert msg.sender_id == "user1" + assert msg.chat_id == "group123" + + +@pytest.mark.asyncio +async def test_send_group_message_uses_group_api_with_msg_seq() -> None: + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + channel._chat_type_cache["group123"] = "group" + + await channel.send( + OutboundMessage( + channel="qq", + chat_id="group123", + content="hello", + metadata={"message_id": "msg1"}, + ) + ) + + assert len(channel._client.api.group_calls) == 1 + call = channel._client.api.group_calls[0] + assert call["group_openid"] == "group123" + assert call["msg_id"] == "msg1" + assert call["msg_seq"] == 2 + assert not channel._client.api.c2c_calls diff --git a/tests/test_telegram_channel.py b/tests/test_telegram_channel.py new file mode 100644 index 0000000..88c3f54 --- /dev/null +++ b/tests/test_telegram_channel.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.telegram import TelegramChannel +from nanobot.config.schema import TelegramConfig + + +class _FakeHTTPXRequest: + instances: list["_FakeHTTPXRequest"] = [] + + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.__class__.instances.append(self) + + +class _FakeUpdater: + def __init__(self, on_start_polling) -> None: + self._on_start_polling = on_start_polling + + async def start_polling(self, **kwargs) -> None: + self._on_start_polling() + + +class _FakeBot: + def __init__(self) -> None: + self.sent_messages: list[dict] = [] + + async def get_me(self): + return SimpleNamespace(username="nanobot_test") + + async def set_my_commands(self, commands) -> None: + self.commands = commands + + async def send_message(self, **kwargs) -> None: + self.sent_messages.append(kwargs) + + +class _FakeApp: + def __init__(self, on_start_polling) -> None: + self.bot = _FakeBot() + self.updater = _FakeUpdater(on_start_polling) + self.handlers = [] + self.error_handlers = [] + + def add_error_handler(self, handler) -> None: + self.error_handlers.append(handler) + + def add_handler(self, handler) -> None: + self.handlers.append(handler) + + async def initialize(self) -> None: + pass + + async def start(self) -> None: + pass + + +class _FakeBuilder: + def __init__(self, app: _FakeApp) -> None: + self.app = app + self.token_value = None + self.request_value = None + self.get_updates_request_value = None + + def token(self, token: str): + self.token_value = token + return self + + def request(self, request): + self.request_value = request + return self + + def get_updates_request(self, request): + self.get_updates_request_value = request + return self + + def proxy(self, _proxy): + raise AssertionError("builder.proxy should not be called when request is set") + + def get_updates_proxy(self, _proxy): + raise AssertionError("builder.get_updates_proxy should not be called when request is set") + + def build(self): + return self.app + + +@pytest.mark.asyncio +async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None: + config = TelegramConfig( + enabled=True, + token="123:abc", + allow_from=["*"], + proxy="http://127.0.0.1:7890", + ) + bus = MessageBus() + channel = TelegramChannel(config, bus) + app = _FakeApp(lambda: setattr(channel, "_running", False)) + builder = _FakeBuilder(app) + + monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest) + monkeypatch.setattr( + "nanobot.channels.telegram.Application", + SimpleNamespace(builder=lambda: builder), + ) + + await channel.start() + + assert len(_FakeHTTPXRequest.instances) == 1 + assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy + assert builder.request_value is _FakeHTTPXRequest.instances[0] + assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0] + + +def test_derive_topic_session_key_uses_thread_id() -> None: + message = SimpleNamespace( + chat=SimpleNamespace(type="supergroup"), + chat_id=-100123, + message_thread_id=42, + ) + + assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42" + + +def test_get_extension_falls_back_to_original_filename() -> None: + channel = TelegramChannel(TelegramConfig(), MessageBus()) + + assert channel._get_extension("file", None, "report.pdf") == ".pdf" + assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz" + + +def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None: + channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus()) + + assert channel.is_allowed("12345|carol") is True + assert channel.is_allowed("99999|alice") is True + assert channel.is_allowed("67890|bob") is True + + +def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None: + channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus()) + + assert channel.is_allowed("attacker|alice|extra") is False + assert channel.is_allowed("not-a-number|alice") is False + + +@pytest.mark.asyncio +async def test_send_progress_keeps_message_in_topic() -> None: + config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]) + channel = TelegramChannel(config, MessageBus()) + channel._app = _FakeApp(lambda: None) + + await channel.send( + OutboundMessage( + channel="telegram", + chat_id="123", + content="hello", + metadata={"_progress": True, "message_thread_id": 42}, + ) + ) + + assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42 + + +@pytest.mark.asyncio +async def test_send_reply_infers_topic_from_message_id_cache() -> None: + config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True) + channel = TelegramChannel(config, MessageBus()) + channel._app = _FakeApp(lambda: None) + channel._message_threads[("123", 10)] = 42 + + await channel.send( + OutboundMessage( + channel="telegram", + chat_id="123", + content="hello", + metadata={"message_id": 10}, + ) + ) + + assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42 + assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10 diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py index cb50fb0..c2b4b6a 100644 --- a/tests/test_tool_validation.py +++ b/tests/test_tool_validation.py @@ -106,3 +106,234 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None: paths = ExecTool._extract_absolute_paths(cmd) assert "/tmp/data.txt" in paths assert "/tmp/out.txt" in paths + + +# --- cast_params tests --- + + +class CastTestTool(Tool): + """Minimal tool for testing cast_params.""" + + def __init__(self, schema: dict[str, Any]) -> None: + self._schema = schema + + @property + def name(self) -> str: + return "cast_test" + + @property + def description(self) -> str: + return "test tool for casting" + + @property + def parameters(self) -> dict[str, Any]: + return self._schema + + async def execute(self, **kwargs: Any) -> str: + return "ok" + + +def test_cast_params_string_to_int() -> None: + tool = CastTestTool( + { + "type": "object", + "properties": {"count": {"type": "integer"}}, + } + ) + result = tool.cast_params({"count": "42"}) + assert result["count"] == 42 + assert isinstance(result["count"], int) + + +def test_cast_params_string_to_number() -> None: + tool = CastTestTool( + { + "type": "object", + "properties": {"rate": {"type": "number"}}, + } + ) + result = tool.cast_params({"rate": "3.14"}) + assert result["rate"] == 3.14 + assert isinstance(result["rate"], float) + + +def test_cast_params_string_to_bool() -> None: + tool = CastTestTool( + { + "type": "object", + "properties": {"enabled": {"type": "boolean"}}, + } + ) + assert tool.cast_params({"enabled": "true"})["enabled"] is True + assert tool.cast_params({"enabled": "false"})["enabled"] is False + assert tool.cast_params({"enabled": "1"})["enabled"] is True + + +def test_cast_params_array_items() -> None: + tool = CastTestTool( + { + "type": "object", + "properties": { + "nums": {"type": "array", "items": {"type": "integer"}}, + }, + } + ) + result = tool.cast_params({"nums": ["1", "2", "3"]}) + assert result["nums"] == [1, 2, 3] + + +def test_cast_params_nested_object() -> None: + tool = CastTestTool( + { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "port": {"type": "integer"}, + "debug": {"type": "boolean"}, + }, + }, + }, + } + ) + result = tool.cast_params({"config": {"port": "8080", "debug": "true"}}) + assert result["config"]["port"] == 8080 + assert result["config"]["debug"] is True + + +def test_cast_params_bool_not_cast_to_int() -> None: + """Booleans should not be silently cast to integers.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"count": {"type": "integer"}}, + } + ) + result = tool.cast_params({"count": True}) + assert result["count"] is True + errors = tool.validate_params(result) + assert any("count should be integer" in e for e in errors) + + +def test_cast_params_preserves_empty_string() -> None: + """Empty strings should be preserved for string type.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + ) + result = tool.cast_params({"name": ""}) + assert result["name"] == "" + + +def test_cast_params_bool_string_false() -> None: + """Test that 'false', '0', 'no' strings convert to False.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"flag": {"type": "boolean"}}, + } + ) + assert tool.cast_params({"flag": "false"})["flag"] is False + assert tool.cast_params({"flag": "False"})["flag"] is False + assert tool.cast_params({"flag": "0"})["flag"] is False + assert tool.cast_params({"flag": "no"})["flag"] is False + assert tool.cast_params({"flag": "NO"})["flag"] is False + + +def test_cast_params_bool_string_invalid() -> None: + """Invalid boolean strings should not be cast.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"flag": {"type": "boolean"}}, + } + ) + # Invalid strings should be preserved (validation will catch them) + result = tool.cast_params({"flag": "random"}) + assert result["flag"] == "random" + result = tool.cast_params({"flag": "maybe"}) + assert result["flag"] == "maybe" + + +def test_cast_params_invalid_string_to_int() -> None: + """Invalid strings should not be cast to integer.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"count": {"type": "integer"}}, + } + ) + result = tool.cast_params({"count": "abc"}) + assert result["count"] == "abc" # Original value preserved + result = tool.cast_params({"count": "12.5.7"}) + assert result["count"] == "12.5.7" + + +def test_cast_params_invalid_string_to_number() -> None: + """Invalid strings should not be cast to number.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"rate": {"type": "number"}}, + } + ) + result = tool.cast_params({"rate": "not_a_number"}) + assert result["rate"] == "not_a_number" + + +def test_validate_params_bool_not_accepted_as_number() -> None: + """Booleans should not pass number validation.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"rate": {"type": "number"}}, + } + ) + errors = tool.validate_params({"rate": False}) + assert any("rate should be number" in e for e in errors) + + +def test_cast_params_none_values() -> None: + """Test None handling for different types.""" + tool = CastTestTool( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "count": {"type": "integer"}, + "items": {"type": "array"}, + "config": {"type": "object"}, + }, + } + ) + result = tool.cast_params( + { + "name": None, + "count": None, + "items": None, + "config": None, + } + ) + # None should be preserved for all types + assert result["name"] is None + assert result["count"] is None + assert result["items"] is None + assert result["config"] is None + + +def test_cast_params_single_value_not_auto_wrapped_to_array() -> None: + """Single values should NOT be automatically wrapped into arrays.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"items": {"type": "array"}}, + } + ) + # Non-array values should be preserved (validation will catch them) + result = tool.cast_params({"items": 5}) + assert result["items"] == 5 # Not wrapped to [5] + result = tool.cast_params({"items": "text"}) + assert result["items"] == "text" # Not wrapped to ["text"]