Merge branch 'main' into pr-1635

This commit is contained in:
Re-bin
2026-03-08 03:14:20 +00:00
40 changed files with 2182 additions and 190 deletions

115
README.md
View File

@@ -20,9 +20,20 @@
## 📢 News
- **2026-03-07** 🚀 Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish.
- **2026-03-06** 🪄 Lighter providers, smarter media handling, and sturdier memory and CLI compatibility.
- **2026-03-05** ⚡️ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes.
- **2026-03-04** 🛠️ Dependency cleanup, safer file reads, and another round of test and Cron fixes.
- **2026-03-03** 🧠 Cleaner user-message merging, safer multimodal saves, and stronger Cron guards.
- **2026-03-02** 🛡️ Safer default access control, sturdier Cron reloads, and cleaner Matrix media handling.
- **2026-03-01** 🌐 Web proxy support, smarter Cron reminders, and Feishu rich-text parsing improvements.
- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
<details>
<summary>Earlier news</summary>
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
@@ -30,10 +41,6 @@
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
<details>
<summary>Earlier news</summary>
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
@@ -420,6 +427,10 @@ nanobot channels login
nanobot gateway
```
> WhatsApp bridge updates are not applied automatically for existing installations.
> If you upgrade nanobot and need the latest WhatsApp bridge, run:
> `rm -rf ~/.nanobot/bridge && nanobot channels login`
</details>
<details>
@@ -671,6 +682,7 @@ Config file: `~/.nanobot/config.json`
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
@@ -894,31 +906,26 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
## Multiple Instances
## 🧩 Multiple Instances
Run multiple nanobot instances simultaneously, each with its own workspace and configuration.
Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint, and optionally use `--workspace` to override the workspace for a specific run.
### Quick Start
```bash
# Instance A - Telegram bot
nanobot gateway -w ~/.nanobot/botA -p 18791
nanobot gateway --config ~/.nanobot-telegram/config.json
# Instance B - Discord bot
nanobot gateway -w ~/.nanobot/botB -p 18792
# Instance B - Discord bot
nanobot gateway --config ~/.nanobot-discord/config.json
# Instance C - Using custom config file
nanobot gateway -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json -p 18793
# Instance C - Feishu bot with custom port
nanobot gateway --config ~/.nanobot-feishu/config.json --port 18792
```
| Option | Short | Description |
|--------|-------|-------------|
| `--workspace` | `-w` | Workspace directory (default: `~/.nanobot/workspace`) |
| `--config` | `-c` | Config file path (default: `~/.nanobot/config.json`) |
| `--port` | `-p` | Gateway port (default: `18790`) |
### Path Resolution
Each instance has its own:
- Workspace directory (MEMORY.md, HEARTBEAT.md, session files)
- Cron jobs storage (`workspace/cron/jobs.json`)
- Configuration (if using `--config`)
When using `--config`, nanobot derives its runtime data directory from the config file location. The workspace still comes from `agents.defaults.workspace` unless you override it with `--workspace`.
To open a CLI session against one of these instances locally:
@@ -928,9 +935,75 @@ nanobot agent -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json
```
> `nanobot agent` starts a local CLI agent using the selected workspace/config. It does not attach to or proxy through an already running `nanobot gateway` process.
| Component | Resolved From | Example |
|-----------|---------------|---------|
| **Config** | `--config` path | `~/.nanobot-A/config.json` |
| **Workspace** | `--workspace` or config | `~/.nanobot-A/workspace/` |
| **Cron Jobs** | config directory | `~/.nanobot-A/cron/` |
| **Media / runtime state** | config directory | `~/.nanobot-A/media/` |
### How It Works
## CLI Reference
- `--config` selects which config file to load
- By default, the workspace comes from `agents.defaults.workspace` in that config
- If you pass `--workspace`, it overrides the workspace from the config file
### Minimal Setup
1. Copy your base config into a new instance directory.
2. Set a different `agents.defaults.workspace` for that instance.
3. Start the instance with `--config`.
Example config:
```json
{
"agents": {
"defaults": {
"workspace": "~/.nanobot-telegram/workspace",
"model": "anthropic/claude-sonnet-4-6"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "YOUR_TELEGRAM_BOT_TOKEN"
}
},
"gateway": {
"port": 18790
}
}
```
Start separate instances:
```bash
nanobot gateway --config ~/.nanobot-telegram/config.json
nanobot gateway --config ~/.nanobot-discord/config.json
```
Override workspace for one-off runs when needed:
```bash
nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobot-telegram-test
```
### Common Use Cases
- Run separate bots for Telegram, Discord, Feishu, and other platforms
- Keep testing and production instances isolated
- Use different models or providers for different teams
- Serve multiple tenants with separate configs and runtime data
### Notes
- Each instance must use a different port if they run at the same time
- Use a different workspace per instance if you want isolated memory, sessions, and skills
- `--workspace` overrides the workspace defined in the config file
- Cron jobs and runtime media/state are derived from the config directory
## 💻 CLI Reference
| Command | Description |
|---------|-------------|

View File

@@ -9,11 +9,16 @@ import makeWASocket, {
useMultiFileAuthState,
fetchLatestBaileysVersion,
makeCacheableSignalKeyStore,
downloadMediaMessage,
extractMessageContent as baileysExtractMessageContent,
} from '@whiskeysockets/baileys';
import { Boom } from '@hapi/boom';
import qrcode from 'qrcode-terminal';
import pino from 'pino';
import { writeFile, mkdir } from 'fs/promises';
import { join } from 'path';
import { randomBytes } from 'crypto';
const VERSION = '0.1.0';
@@ -24,6 +29,7 @@ export interface InboundMessage {
content: string;
timestamp: number;
isGroup: boolean;
media?: string[];
}
export interface WhatsAppClientOptions {
@@ -110,14 +116,33 @@ export class WhatsAppClient {
if (type !== 'notify') return;
for (const msg of messages) {
// Skip own messages
if (msg.key.fromMe) continue;
// Skip status updates
if (msg.key.remoteJid === 'status@broadcast') continue;
const content = this.extractMessageContent(msg);
if (!content) continue;
const unwrapped = baileysExtractMessageContent(msg.message);
if (!unwrapped) continue;
const content = this.getTextContent(unwrapped);
let fallbackContent: string | null = null;
const mediaPaths: string[] = [];
if (unwrapped.imageMessage) {
fallbackContent = '[Image]';
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.documentMessage) {
fallbackContent = '[Document]';
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
unwrapped.documentMessage.fileName ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.videoMessage) {
fallbackContent = '[Video]';
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
}
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
if (!finalContent && mediaPaths.length === 0) continue;
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
@@ -125,18 +150,45 @@ export class WhatsAppClient {
id: msg.key.id || '',
sender: msg.key.remoteJid || '',
pn: msg.key.remoteJidAlt || '',
content,
content: finalContent,
timestamp: msg.messageTimestamp as number,
isGroup,
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
});
}
});
}
private extractMessageContent(msg: any): string | null {
const message = msg.message;
if (!message) return null;
private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
try {
const mediaDir = join(this.options.authDir, '..', 'media');
await mkdir(mediaDir, { recursive: true });
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
let outFilename: string;
if (fileName) {
// Documents have a filename — use it with a unique prefix to avoid collisions
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
outFilename = prefix + fileName;
} else {
const mime = mimetype || 'application/octet-stream';
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
}
const filepath = join(mediaDir, outFilename);
await writeFile(filepath, buffer);
return filepath;
} catch (err) {
console.error('Failed to download media:', err);
return null;
}
}
private getTextContent(message: any): string | null {
// Text message
if (message.conversation) {
return message.conversation;
@@ -147,19 +199,19 @@ export class WhatsAppClient {
return message.extendedTextMessage.text;
}
// Image with caption
if (message.imageMessage?.caption) {
return `[Image] ${message.imageMessage.caption}`;
// Image with optional caption
if (message.imageMessage) {
return message.imageMessage.caption || '';
}
// Video with caption
if (message.videoMessage?.caption) {
return `[Video] ${message.videoMessage.caption}`;
// Video with optional caption
if (message.videoMessage) {
return message.videoMessage.caption || '';
}
// Document with caption
if (message.documentMessage?.caption) {
return `[Document] ${message.documentMessage.caption}`;
// Document with optional caption
if (message.documentMessage) {
return message.documentMessage.caption || '';
}
// Voice/Audio message

View File

@@ -202,18 +202,9 @@ class AgentLoop:
if response.has_tool_calls:
if on_progress:
thoughts = [
self._strip_think(response.content),
response.reasoning_content,
*(
f"Thinking [{b.get('signature', '...')}]:\n{b.get('thought', '...')}"
for b in (response.thinking_blocks or [])
if isinstance(b, dict) and "signature" in b
),
]
combined_thoughts = "\n\n".join(filter(None, thoughts))
if combined_thoughts:
await on_progress(combined_thoughts)
thought = self._strip_think(response.content)
if thought:
await on_progress(thought)
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [

View File

@@ -52,6 +52,75 @@ class Tool(ABC):
"""
pass
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Apply safe schema-driven casts before validation."""
schema = self.parameters or {}
if schema.get("type", "object") != "object":
return params
return self._cast_object(params, schema)
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
"""Cast an object (dict) according to schema."""
if not isinstance(obj, dict):
return obj
props = schema.get("properties", {})
result = {}
for key, value in obj.items():
if key in props:
result[key] = self._cast_value(value, props[key])
else:
result[key] = value
return result
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
"""Cast a single value according to schema."""
target_type = schema.get("type")
if target_type == "boolean" and isinstance(val, bool):
return val
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
return val
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
expected = self._TYPE_MAP[target_type]
if isinstance(val, expected):
return val
if target_type == "integer" and isinstance(val, str):
try:
return int(val)
except ValueError:
return val
if target_type == "number" and isinstance(val, str):
try:
return float(val)
except ValueError:
return val
if target_type == "string":
return val if val is None else str(val)
if target_type == "boolean" and isinstance(val, str):
val_lower = val.lower()
if val_lower in ("true", "1", "yes"):
return True
if val_lower in ("false", "0", "no"):
return False
return val
if target_type == "array" and isinstance(val, list):
item_schema = schema.get("items")
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
if target_type == "object" and isinstance(val, dict):
return self._cast_object(val, schema)
return val
def validate_params(self, params: dict[str, Any]) -> list[str]:
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
if not isinstance(params, dict):
@@ -63,7 +132,13 @@ class Tool(ABC):
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
return [f"{label} should be integer"]
if t == "number" and (
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
):
return [f"{label} should be number"]
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"]
errors = []

View File

@@ -96,7 +96,7 @@ class MessageTool(Tool):
media=media or [],
metadata={
"message_id": message_id,
}
},
)
try:

View File

@@ -44,6 +44,10 @@ class ToolRegistry:
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
try:
# Attempt to cast parameters to match schema types
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params)
if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT

View File

@@ -66,10 +66,7 @@ class BaseChannel(ABC):
return False
if "*" in allow_list:
return True
sender_str = str(sender_id)
return sender_str in allow_list or any(
p in allow_list for p in sender_str.split("|") if p
)
return str(sender_id) in allow_list
async def _handle_message(
self,

View File

@@ -70,12 +70,24 @@ class NanobotDingTalkHandler(CallbackHandler):
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
sender_name = chatbot_msg.sender_nick or "Unknown"
conversation_type = message.data.get("conversationType")
conversation_id = (
message.data.get("conversationId")
or message.data.get("openConversationId")
)
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
# Forward to Nanobot via _on_message (non-blocking).
# Store reference to prevent GC before task completes.
task = asyncio.create_task(
self.channel._on_message(content, sender_id, sender_name)
self.channel._on_message(
content,
sender_id,
sender_name,
conversation_type,
conversation_id,
)
)
self.channel._background_tasks.add(task)
task.add_done_callback(self.channel._background_tasks.discard)
@@ -95,8 +107,8 @@ class DingTalkChannel(BaseChannel):
Uses WebSocket to receive events via `dingtalk-stream` SDK.
Uses direct HTTP API to send messages (SDK is mainly for receiving).
Note: Currently only supports private (1:1) chat. Group messages are
received but replies are sent back as private messages to the sender.
Supports both private (1:1) and group chats.
Group chat_id is stored with a "group:" prefix to route replies back.
"""
name = "dingtalk"
@@ -301,14 +313,25 @@ class DingTalkChannel(BaseChannel):
logger.warning("DingTalk HTTP client not initialized, cannot send")
return False
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
headers = {"x-acs-dingtalk-access-token": token}
payload = {
"robotCode": self.config.client_id,
"userIds": [chat_id],
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
if chat_id.startswith("group:"):
# Group chat
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
payload = {
"robotCode": self.config.client_id,
"openConversationId": chat_id[6:], # Remove "group:" prefix,
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
else:
# Private chat
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
payload = {
"robotCode": self.config.client_id,
"userIds": [chat_id],
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
try:
resp = await self._http.post(url, json=payload, headers=headers)
@@ -417,7 +440,14 @@ class DingTalkChannel(BaseChannel):
f"[Attachment send failed: {filename}]",
)
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
async def _on_message(
self,
content: str,
sender_id: str,
sender_name: str,
conversation_type: str | None = None,
conversation_id: str | None = None,
) -> None:
"""Handle incoming message (called by NanobotDingTalkHandler).
Delegates to BaseChannel._handle_message() which enforces allow_from
@@ -425,13 +455,16 @@ class DingTalkChannel(BaseChannel):
"""
try:
logger.info("DingTalk inbound: {} from {}", content, sender_name)
is_group = conversation_type == "2" and conversation_id
chat_id = f"group:{conversation_id}" if is_group else sender_id
await self._handle_message(
sender_id=sender_id,
chat_id=sender_id, # For private chat, chat_id == sender_id
chat_id=chat_id,
content=str(content),
metadata={
"sender_name": sender_name,
"platform": "dingtalk",
"conversation_type": conversation_type,
},
)
except Exception as e:

View File

@@ -12,6 +12,7 @@ from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import DiscordConfig
from nanobot.utils.helpers import split_message
@@ -75,7 +76,7 @@ class DiscordChannel(BaseChannel):
self._http = None
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Discord REST API."""
"""Send a message through Discord REST API, including file attachments."""
if not self._http:
logger.warning("Discord HTTP client not initialized")
return
@@ -84,15 +85,31 @@ class DiscordChannel(BaseChannel):
headers = {"Authorization": f"Bot {self.config.token}"}
try:
sent_media = False
failed_media: list[str] = []
# Send file attachments first
for media_path in msg.media or []:
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
sent_media = True
else:
failed_media.append(Path(media_path).name)
# Send text content
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks and failed_media and not sent_media:
chunks = split_message(
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
MAX_MESSAGE_LEN,
)
if not chunks:
return
for i, chunk in enumerate(chunks):
payload: dict[str, Any] = {"content": chunk}
# Only set reply reference on the first chunk
if i == 0 and msg.reply_to:
# Let the first successful attachment carry the reply if present.
if i == 0 and msg.reply_to and not sent_media:
payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False}
@@ -123,6 +140,54 @@ class DiscordChannel(BaseChannel):
await asyncio.sleep(1)
return False
async def _send_file(
self,
url: str,
headers: dict[str, str],
file_path: str,
reply_to: str | None = None,
) -> bool:
"""Send a file attachment via Discord REST API using multipart/form-data."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
payload_json: dict[str, Any] = {}
if reply_to:
payload_json["message_reference"] = {"message_id": reply_to}
payload_json["allowed_mentions"] = {"replied_user": False}
for attempt in range(3):
try:
with open(path, "rb") as f:
files = {"files[0]": (path.name, f, "application/octet-stream")}
data: dict[str, Any] = {}
if payload_json:
data["payload_json"] = json.dumps(payload_json)
response = await self._http.post(
url, headers=headers, files=files, data=data
)
if response.status_code == 429:
resp_data = response.json()
retry_after = float(resp_data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord file {}: {}", path.name, e)
else:
await asyncio.sleep(1)
return False
async def _gateway_loop(self) -> None:
"""Main gateway loop: identify, heartbeat, dispatch events."""
if not self._ws:
@@ -225,7 +290,7 @@ class DiscordChannel(BaseChannel):
content_parts = [content] if content else []
media_paths: list[str] = []
media_dir = Path.home() / ".nanobot" / "media"
media_dir = get_media_dir("discord")
for attachment in payload.get("attachments") or []:
url = attachment.get("url")

View File

@@ -14,6 +14,7 @@ from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import FeishuConfig
import importlib.util
@@ -244,15 +245,22 @@ class FeishuChannel(BaseChannel):
name = "feishu"
def __init__(self, config: FeishuConfig, bus: MessageBus):
def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""):
super().__init__(config, bus)
self.config: FeishuConfig = config
self.groq_api_key = groq_api_key
self._client: Any = None
self._ws_client: Any = None
self._ws_thread: threading.Thread | None = None
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
self._loop: asyncio.AbstractEventLoop | None = None
@staticmethod
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
"""Register an event handler only when the SDK supports it."""
method = getattr(builder, method_name, None)
return method(handler) if callable(method) else builder
async def start(self) -> None:
"""Start the Feishu bot with WebSocket long connection."""
if not FEISHU_AVAILABLE:
@@ -273,14 +281,24 @@ class FeishuChannel(BaseChannel):
.app_secret(self.config.app_secret) \
.log_level(lark.LogLevel.INFO) \
.build()
# Create event handler (only register message receive, ignore other events)
event_handler = lark.EventDispatcherHandler.builder(
builder = lark.EventDispatcherHandler.builder(
self.config.encrypt_key or "",
self.config.verification_token or "",
).register_p2_im_message_receive_v1(
self._on_message_sync
).build()
)
builder = self._register_optional_event(
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
)
builder = self._register_optional_event(
builder, "register_p2_im_message_message_read_v1", self._on_message_read
)
builder = self._register_optional_event(
builder,
"register_p2_im_chat_access_event_bot_p2p_chat_entered_v1",
self._on_bot_p2p_chat_entered,
)
event_handler = builder.build()
# Create WebSocket client for long connection
self._ws_client = lark.ws.Client(
@@ -715,8 +733,7 @@ class FeishuChannel(BaseChannel):
(file_path, content_text) - file_path is None if download failed
"""
loop = asyncio.get_running_loop()
media_dir = Path.home() / ".nanobot" / "media"
media_dir.mkdir(parents=True, exist_ok=True)
media_dir = get_media_dir("feishu")
data, filename = None, None
@@ -841,7 +858,7 @@ class FeishuChannel(BaseChannel):
except Exception as e:
logger.error("Error sending Feishu message: {}", e)
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
def _on_message_sync(self, data: Any) -> None:
"""
Sync handler for incoming messages (called from WebSocket thread).
Schedules async handling in the main event loop.
@@ -849,7 +866,7 @@ class FeishuChannel(BaseChannel):
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
async def _on_message(self, data: Any) -> None:
"""Handle incoming message from Feishu."""
try:
event = data.event
@@ -909,6 +926,18 @@ class FeishuChannel(BaseChannel):
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
if file_path:
media_paths.append(file_path)
# Transcribe audio using Groq Whisper
if msg_type == "audio" and file_path and self.groq_api_key:
try:
from nanobot.providers.transcription import GroqTranscriptionProvider
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
transcription = await transcriber.transcribe(file_path)
if transcription:
content_text = f"[transcription: {transcription}]"
except Exception as e:
logger.warning("Failed to transcribe audio: {}", e)
content_parts.append(content_text)
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
@@ -941,3 +970,16 @@ class FeishuChannel(BaseChannel):
except Exception as e:
logger.error("Error processing Feishu message: {}", e)
def _on_reaction_created(self, data: Any) -> None:
"""Ignore reaction events so they do not generate SDK noise."""
pass
def _on_message_read(self, data: Any) -> None:
"""Ignore read events so they do not generate SDK noise."""
pass
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
"""Ignore p2p-enter events when a user opens a bot chat."""
logger.debug("Bot entered p2p chat (user opened chat window)")
pass

View File

@@ -74,7 +74,8 @@ class ChannelManager:
try:
from nanobot.channels.feishu import FeishuChannel
self.channels["feishu"] = FeishuChannel(
self.config.channels.feishu, self.bus
self.config.channels.feishu, self.bus,
groq_api_key=self.config.providers.groq.api_key,
)
logger.info("Feishu channel enabled")
except ImportError as e:

View File

@@ -38,7 +38,7 @@ except ImportError as e:
from nanobot.bus.events import OutboundMessage
from nanobot.channels.base import BaseChannel
from nanobot.config.loader import get_data_dir
from nanobot.config.paths import get_data_dir, get_media_dir
from nanobot.utils.helpers import safe_filename
TYPING_NOTICE_TIMEOUT_MS = 30_000
@@ -490,9 +490,7 @@ class MatrixChannel(BaseChannel):
return False
def _media_dir(self) -> Path:
d = get_data_dir() / "media" / "matrix"
d.mkdir(parents=True, exist_ok=True)
return d
return get_media_dir("matrix")
@staticmethod
def _event_source_content(event: RoomMessage) -> dict[str, Any]:

View File

@@ -15,8 +15,8 @@ from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_runtime_subdir
from nanobot.config.schema import MochatConfig
from nanobot.utils.helpers import get_data_path
try:
import socketio
@@ -224,7 +224,7 @@ class MochatChannel(BaseChannel):
self._socket: Any = None
self._ws_connected = self._ws_ready = False
self._state_dir = get_data_path() / "mochat"
self._state_dir = get_runtime_subdir("mochat")
self._cursor_path = self._state_dir / "session_cursors.json"
self._session_cursor: dict[str, int] = {}
self._cursor_save_task: asyncio.Task | None = None

View File

@@ -13,16 +13,17 @@ from nanobot.config.schema import QQConfig
try:
import botpy
from botpy.message import C2CMessage
from botpy.message import C2CMessage, GroupMessage
QQ_AVAILABLE = True
except ImportError:
QQ_AVAILABLE = False
botpy = None
C2CMessage = None
GroupMessage = None
if TYPE_CHECKING:
from botpy.message import C2CMessage
from botpy.message import C2CMessage, GroupMessage
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
@@ -38,10 +39,13 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
logger.info("QQ bot ready: {}", self.robot.name)
async def on_c2c_message_create(self, message: "C2CMessage"):
await channel._on_message(message)
await channel._on_message(message, is_group=False)
async def on_group_at_message_create(self, message: "GroupMessage"):
await channel._on_message(message, is_group=True)
async def on_direct_message_create(self, message):
await channel._on_message(message)
await channel._on_message(message, is_group=False)
return _Bot
@@ -57,6 +61,7 @@ class QQChannel(BaseChannel):
self._client: "botpy.Client | None" = None
self._processed_ids: deque = deque(maxlen=1000)
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
self._chat_type_cache: dict[str, str] = {}
async def start(self) -> None:
"""Start the QQ bot."""
@@ -71,8 +76,7 @@ class QQChannel(BaseChannel):
self._running = True
BotClass = _make_bot_class(self)
self._client = BotClass()
logger.info("QQ bot started (C2C private message)")
logger.info("QQ bot started (C2C & Group supported)")
await self._run_bot()
async def _run_bot(self) -> None:
@@ -101,20 +105,31 @@ class QQChannel(BaseChannel):
if not self._client:
logger.warning("QQ client not initialized")
return
try:
msg_id = msg.metadata.get("message_id")
self._msg_seq += 1 # 递增序列号
await self._client.api.post_c2c_message(
openid=msg.chat_id,
msg_type=0,
content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq, # 添加序列号避免去重
)
self._msg_seq += 1
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
if msg_type == "group":
await self._client.api.post_group_message(
group_openid=msg.chat_id,
msg_type=0,
content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq,
)
else:
await self._client.api.post_c2c_message(
openid=msg.chat_id,
msg_type=0,
content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq,
)
except Exception as e:
logger.error("Error sending QQ message: {}", e)
async def _on_message(self, data: "C2CMessage") -> None:
async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
"""Handle incoming message from QQ."""
try:
# Dedup by message ID
@@ -122,18 +137,24 @@ class QQChannel(BaseChannel):
return
self._processed_ids.append(data.id)
author = data.author
user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
content = (data.content or "").strip()
if not content:
return
if is_group:
chat_id = data.group_openid
user_id = data.author.member_openid
self._chat_type_cache[chat_id] = "group"
else:
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
user_id = chat_id
self._chat_type_cache[chat_id] = "c2c"
await self._handle_message(
sender_id=user_id,
chat_id=user_id,
chat_id=chat_id,
content=content,
metadata={"message_id": data.id},
)
except Exception:
logger.exception("Error handling QQ message")

View File

@@ -82,13 +82,14 @@ class SlackChannel(BaseChannel):
thread_ts = slack_meta.get("thread_ts")
channel_type = slack_meta.get("channel_type")
# Only reply in thread for channel/group messages; DMs don't use threads
use_thread = thread_ts and channel_type != "im"
thread_ts_param = thread_ts if use_thread else None
if msg.content:
# Slack rejects empty text payloads. Keep media-only messages media-only,
# but send a single blank message when the bot has no text or files to send.
if msg.content or not (msg.media or []):
await self._web_client.chat_postMessage(
channel=msg.chat_id,
text=self._to_mrkdwn(msg.content),
text=self._to_mrkdwn(msg.content) if msg.content else " ",
thread_ts=thread_ts_param,
)

View File

@@ -15,6 +15,7 @@ from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import TelegramConfig
from nanobot.utils.helpers import split_message
@@ -177,6 +178,26 @@ class TelegramChannel(BaseChannel):
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {}
self._message_threads: dict[tuple[str, int], int] = {}
def is_allowed(self, sender_id: str) -> bool:
"""Preserve Telegram's legacy id|username allowlist matching."""
if super().is_allowed(sender_id):
return True
allow_list = getattr(self.config, "allow_from", [])
if not allow_list or "*" in allow_list:
return False
sender_str = str(sender_id)
if sender_str.count("|") != 1:
return False
sid, username = sender_str.split("|", 1)
if not sid.isdigit() or not username:
return False
return sid in allow_list or username in allow_list
async def start(self) -> None:
"""Start the Telegram bot with long polling."""
@@ -187,16 +208,21 @@ class TelegramChannel(BaseChannel):
self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
req = HTTPXRequest(
connection_pool_size=16,
pool_timeout=5.0,
connect_timeout=30.0,
read_timeout=30.0,
proxy=self.config.proxy if self.config.proxy else None,
)
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
if self.config.proxy:
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build()
self._app.add_error_handler(self._on_error)
# Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("stop", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents
@@ -281,10 +307,16 @@ class TelegramChannel(BaseChannel):
except ValueError:
logger.error("Invalid chat_id: {}", msg.chat_id)
return
reply_to_message_id = msg.metadata.get("message_id")
message_thread_id = msg.metadata.get("message_thread_id")
if message_thread_id is None and reply_to_message_id is not None:
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
thread_kwargs = {}
if message_thread_id is not None:
thread_kwargs["message_thread_id"] = message_thread_id
reply_params = None
if self.config.reply_to_message:
reply_to_message_id = msg.metadata.get("message_id")
if reply_to_message_id:
reply_params = ReplyParameters(
message_id=reply_to_message_id,
@@ -305,7 +337,8 @@ class TelegramChannel(BaseChannel):
await sender(
chat_id=chat_id,
**{param: f},
reply_parameters=reply_params
reply_parameters=reply_params,
**thread_kwargs,
)
except Exception as e:
filename = media_path.rsplit("/", 1)[-1]
@@ -313,7 +346,8 @@ class TelegramChannel(BaseChannel):
await self._app.bot.send_message(
chat_id=chat_id,
text=f"[Failed to send: {filename}]",
reply_parameters=reply_params
reply_parameters=reply_params,
**thread_kwargs,
)
# Send text content
@@ -323,28 +357,44 @@ class TelegramChannel(BaseChannel):
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
# Final response: simulate streaming via draft, then persist
if not is_progress:
await self._send_with_streaming(chat_id, chunk, reply_params)
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
else:
await self._send_text(chat_id, chunk, reply_params)
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
async def _send_text(self, chat_id: int, text: str, reply_params=None) -> None:
async def _send_text(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Send a plain text message with HTML fallback."""
try:
html = _markdown_to_telegram_html(text)
await self._app.bot.send_message(
chat_id=chat_id, text=html, parse_mode="HTML",
reply_parameters=reply_params,
**(thread_kwargs or {}),
)
except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e)
try:
await self._app.bot.send_message(
chat_id=chat_id, text=text, reply_parameters=reply_params,
chat_id=chat_id,
text=text,
reply_parameters=reply_params,
**(thread_kwargs or {}),
)
except Exception as e2:
logger.error("Error sending Telegram message: {}", e2)
async def _send_with_streaming(self, chat_id: int, text: str, reply_params=None) -> None:
async def _send_with_streaming(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Simulate streaming via send_message_draft, then persist with send_message."""
draft_id = int(time.time() * 1000) % (2**31)
try:
@@ -360,7 +410,7 @@ class TelegramChannel(BaseChannel):
await asyncio.sleep(0.15)
except Exception:
pass
await self._send_text(chat_id, text, reply_params)
await self._send_text(chat_id, text, reply_params, thread_kwargs)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command."""
@@ -391,14 +441,50 @@ class TelegramChannel(BaseChannel):
sid = str(user.id)
return f"{sid}|{user.username}" if user.username else sid
@staticmethod
def _derive_topic_session_key(message) -> str | None:
"""Derive topic-scoped session key for non-private Telegram chats."""
message_thread_id = getattr(message, "message_thread_id", None)
if message.chat.type == "private" or message_thread_id is None:
return None
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
@staticmethod
def _build_message_metadata(message, user) -> dict:
"""Build common Telegram inbound metadata payload."""
return {
"message_id": message.message_id,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private",
"message_thread_id": getattr(message, "message_thread_id", None),
"is_forum": bool(getattr(message.chat, "is_forum", False)),
}
def _remember_thread_context(self, message) -> None:
"""Cache topic thread id by chat/message id for follow-up replies."""
message_thread_id = getattr(message, "message_thread_id", None)
if message_thread_id is None:
return
key = (str(message.chat_id), message.message_id)
self._message_threads[key] = message_thread_id
if len(self._message_threads) > 1000:
self._message_threads.pop(next(iter(self._message_threads)))
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Forward slash commands to the bus for unified handling in AgentLoop."""
if not update.message or not update.effective_user:
return
message = update.message
user = update.effective_user
self._remember_thread_context(message)
await self._handle_message(
sender_id=self._sender_id(update.effective_user),
chat_id=str(update.message.chat_id),
content=update.message.text,
sender_id=self._sender_id(user),
chat_id=str(message.chat_id),
content=message.text,
metadata=self._build_message_metadata(message, user),
session_key=self._derive_topic_session_key(message),
)
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
@@ -410,6 +496,7 @@ class TelegramChannel(BaseChannel):
user = update.effective_user
chat_id = message.chat_id
sender_id = self._sender_id(user)
self._remember_thread_context(message)
# Store chat_id for replies
self._chat_ids[sender_id] = chat_id
@@ -445,12 +532,12 @@ class TelegramChannel(BaseChannel):
if media_file and self._app:
try:
file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
# Save to workspace/media/
from pathlib import Path
media_dir = Path.home() / ".nanobot" / "media"
media_dir.mkdir(parents=True, exist_ok=True)
ext = self._get_extension(
media_type,
getattr(media_file, 'mime_type', None),
getattr(media_file, 'file_name', None),
)
media_dir = get_media_dir("telegram")
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
await file.download_to_drive(str(file_path))
@@ -480,6 +567,8 @@ class TelegramChannel(BaseChannel):
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id)
metadata = self._build_message_metadata(message, user)
session_key = self._derive_topic_session_key(message)
# Telegram media groups: buffer briefly, forward as one aggregated turn.
if media_group_id := getattr(message, "media_group_id", None):
@@ -488,11 +577,8 @@ class TelegramChannel(BaseChannel):
self._media_group_buffers[key] = {
"sender_id": sender_id, "chat_id": str_chat_id,
"contents": [], "media": [],
"metadata": {
"message_id": message.message_id, "user_id": user.id,
"username": user.username, "first_name": user.first_name,
"is_group": message.chat.type != "private",
},
"metadata": metadata,
"session_key": session_key,
}
self._start_typing(str_chat_id)
buf = self._media_group_buffers[key]
@@ -512,13 +598,8 @@ class TelegramChannel(BaseChannel):
chat_id=str_chat_id,
content=content,
media=media_paths,
metadata={
"message_id": message.message_id,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private"
}
metadata=metadata,
session_key=session_key,
)
async def _flush_media_group(self, key: str) -> None:
@@ -532,6 +613,7 @@ class TelegramChannel(BaseChannel):
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
content=content, media=list(dict.fromkeys(buf["media"])),
metadata=buf["metadata"],
session_key=buf.get("session_key"),
)
finally:
self._media_group_tasks.pop(key, None)
@@ -563,8 +645,13 @@ class TelegramChannel(BaseChannel):
"""Log polling / handler errors instead of silently swallowing them."""
logger.error("Telegram error: {}", context.error)
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
"""Get file extension based on media type."""
def _get_extension(
self,
media_type: str,
mime_type: str | None,
filename: str | None = None,
) -> str:
"""Get file extension based on media type or original filename."""
if mime_type:
ext_map = {
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
@@ -574,4 +661,12 @@ class TelegramChannel(BaseChannel):
return ext_map[mime_type]
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
return type_map.get(media_type, "")
if ext := type_map.get(media_type, ""):
return ext
if filename:
from pathlib import Path
return "".join(Path(filename).suffixes)
return ""

View File

@@ -2,6 +2,7 @@
import asyncio
import json
import mimetypes
from collections import OrderedDict
from loguru import logger
@@ -128,10 +129,22 @@ class WhatsAppChannel(BaseChannel):
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]"
# Extract media paths (images/documents/videos downloaded by the bridge)
media_paths = data.get("media") or []
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
if media_paths:
for p in media_paths:
mime, _ = mimetypes.guess_type(p)
media_type = "image" if mime and mime.startswith("image/") else "file"
media_tag = f"[{media_type}: {p}]"
content = f"{content}\n{media_tag}" if content else media_tag
await self._handle_message(
sender_id=sender_id,
chat_id=sender, # Use full LID for replies
content=content,
media=media_paths,
metadata={
"message_id": message_id,
"timestamp": data.get("timestamp"),

View File

@@ -29,6 +29,7 @@ from rich.table import Table
from rich.text import Text
from nanobot import __logo__, __version__
from nanobot.config.paths import get_workspace_path
from nanobot.config.schema import Config
from nanobot.utils.helpers import sync_workspace_templates
@@ -98,7 +99,9 @@ def _init_prompt_session() -> None:
except Exception:
pass
history_file = Path.home() / ".nanobot" / "history" / "cli_history"
from nanobot.config.paths import get_cli_history_path
history_file = get_cli_history_path()
history_file.parent.mkdir(parents=True, exist_ok=True)
_PROMPT_SESSION = PromptSession(
@@ -169,7 +172,6 @@ def onboard():
"""Initialize nanobot configuration and workspace."""
from nanobot.config.loader import get_config_path, load_config, save_config
from nanobot.config.schema import Config
from nanobot.utils.helpers import get_workspace_path
config_path = get_config_path()
@@ -212,6 +214,7 @@ def onboard():
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
@@ -230,6 +233,20 @@ def _make_provider(config: Config):
default_model=model,
)
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
if provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1)
return AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name)
@@ -267,13 +284,24 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
def gateway(
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Start the nanobot gateway."""
# Set config path if provided (must be done before any imports that use get_data_dir)
if config:
from nanobot.config.loader import set_config_path
config_path = Path(config).expanduser().resolve()
if not config_path.exists():
console.print(f"[red]Error: Config file not found: {config_path}[/red]")
raise typer.Exit(1)
set_config_path(config_path)
console.print(f"[dim]Using config: {config_path}[/dim]")
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager
from nanobot.config.paths import get_cron_dir
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService
@@ -292,8 +320,7 @@ def gateway(
session_manager = SessionManager(config.workspace_path)
# Create cron service first (callback set after agent creation)
# Use workspace path for per-instance cron store
cron_store_path = config.workspace_path / "cron" / "jobs.json"
cron_store_path = get_cron_dir() / "jobs.json"
cron = CronService(cron_store_path)
# Create agent with cron service
@@ -464,7 +491,7 @@ def agent(
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import get_data_dir
from nanobot.config.paths import get_cron_dir
from nanobot.cron.service import CronService
config = _load_runtime_config(config, workspace)
@@ -474,7 +501,7 @@ def agent(
provider = _make_provider(config)
# Create cron service for tool usage (no callback needed for CLI unless running)
cron_store_path = get_data_dir() / "cron" / "jobs.json"
cron_store_path = get_cron_dir() / "jobs.json"
cron = CronService(cron_store_path)
if logs:
@@ -740,7 +767,9 @@ def _get_bridge_dir() -> Path:
import subprocess
# User's bridge location
user_bridge = Path.home() / ".nanobot" / "bridge"
from nanobot.config.paths import get_bridge_install_dir
user_bridge = get_bridge_install_dir()
# Check if already built
if (user_bridge / "dist" / "index.js").exists():
@@ -798,6 +827,7 @@ def channels_login():
import subprocess
from nanobot.config.loader import load_config
from nanobot.config.paths import get_runtime_subdir
config = load_config()
bridge_dir = _get_bridge_dir()
@@ -808,6 +838,7 @@ def channels_login():
env = {**os.environ}
if config.channels.whatsapp.bridge_token:
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
try:
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)

View File

@@ -1,6 +1,30 @@
"""Configuration module for nanobot."""
from nanobot.config.loader import get_config_path, load_config
from nanobot.config.paths import (
get_bridge_install_dir,
get_cli_history_path,
get_cron_dir,
get_data_dir,
get_legacy_sessions_dir,
get_logs_dir,
get_media_dir,
get_runtime_subdir,
get_workspace_path,
)
from nanobot.config.schema import Config
__all__ = ["Config", "load_config", "get_config_path"]
__all__ = [
"Config",
"load_config",
"get_config_path",
"get_data_dir",
"get_runtime_subdir",
"get_media_dir",
"get_cron_dir",
"get_logs_dir",
"get_workspace_path",
"get_cli_history_path",
"get_bridge_install_dir",
"get_legacy_sessions_dir",
]

View File

@@ -6,17 +6,23 @@ from pathlib import Path
from nanobot.config.schema import Config
# Global variable to store current config path (for multi-instance support)
_current_config_path: Path | None = None
def set_config_path(path: Path) -> None:
"""Set the current config path (used to derive data directory)."""
global _current_config_path
_current_config_path = path
def get_config_path() -> Path:
"""Get the default configuration file path."""
"""Get the configuration file path."""
if _current_config_path:
return _current_config_path
return Path.home() / ".nanobot" / "config.json"
def get_data_dir() -> Path:
"""Get the nanobot data directory."""
from nanobot.utils.helpers import get_data_path
return get_data_path()
def load_config(config_path: Path | None = None) -> Config:
"""
Load configuration from file or create default.

55
nanobot/config/paths.py Normal file
View File

@@ -0,0 +1,55 @@
"""Runtime path helpers derived from the active config context."""
from __future__ import annotations
from pathlib import Path
from nanobot.config.loader import get_config_path
from nanobot.utils.helpers import ensure_dir
def get_data_dir() -> Path:
"""Return the instance-level runtime data directory."""
return ensure_dir(get_config_path().parent)
def get_runtime_subdir(name: str) -> Path:
"""Return a named runtime subdirectory under the instance data dir."""
return ensure_dir(get_data_dir() / name)
def get_media_dir(channel: str | None = None) -> Path:
"""Return the media directory, optionally namespaced per channel."""
base = get_runtime_subdir("media")
return ensure_dir(base / channel) if channel else base
def get_cron_dir() -> Path:
"""Return the cron storage directory."""
return get_runtime_subdir("cron")
def get_logs_dir() -> Path:
"""Return the logs directory."""
return get_runtime_subdir("logs")
def get_workspace_path(workspace: str | None = None) -> Path:
"""Resolve and ensure the agent workspace path."""
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
return ensure_dir(path)
def get_cli_history_path() -> Path:
"""Return the shared CLI history file path."""
return Path.home() / ".nanobot" / "history" / "cli_history"
def get_bridge_install_dir() -> Path:
"""Return the shared WhatsApp bridge installation directory."""
return Path.home() / ".nanobot" / "bridge"
def get_legacy_sessions_dir() -> Path:
"""Return the legacy global session directory used for migration fallback."""
return Path.home() / ".nanobot" / "sessions"

View File

@@ -251,6 +251,7 @@ class ProvidersConfig(Base):
"""Configuration for LLM providers."""
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)

View File

@@ -3,5 +3,6 @@
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]

View File

@@ -0,0 +1,210 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
from __future__ import annotations
import uuid
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
"""
def __init__(
self,
api_key: str = "",
api_base: str = "",
default_model: str = "gpt-5.2-chat",
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
return payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@@ -87,6 +87,20 @@ class LLMProvider(ABC):
result.append(msg)
return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod
async def chat(
self,

View File

@@ -1,5 +1,6 @@
"""LiteLLM provider implementation for multi-provider support."""
import hashlib
import os
import secrets
import string
@@ -166,17 +167,43 @@ class LiteLLMProvider(LLMProvider):
return _ANTHROPIC_EXTRA_KEYS
return frozenset()
@staticmethod
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
if not isinstance(tool_call_id, str):
return tool_call_id
if len(tool_call_id) == 9 and tool_call_id.isalnum():
return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed}
# Strict providers require "content" even when assistant only has tool_calls
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
id_map: dict[str, str] = {}
def map_id(value: Any) -> Any:
if not isinstance(value, str):
return value
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for clean in sanitized:
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list):
normalized_tool_calls = []
for tc in clean["tool_calls"]:
if not isinstance(tc, dict):
normalized_tool_calls.append(tc)
continue
tc_clean = dict(tc)
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized_tool_calls.append(tc_clean)
clean["tool_calls"] = normalized_tool_calls
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
async def chat(

View File

@@ -79,6 +79,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
litellm_prefix="",
is_direct=True,
),
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
ProviderSpec(
name="azure_openai",
keywords=("azure", "azure-openai"),
env_key="",
display_name="Azure OpenAI",
litellm_prefix="",
is_direct=True,
),
# === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-"

View File

@@ -9,6 +9,7 @@ from typing import Any
from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ensure_dir, safe_filename
@@ -79,7 +80,7 @@ class SessionManager:
def __init__(self, workspace: Path):
self.workspace = workspace
self.sessions_dir = ensure_dir(self.workspace / "sessions")
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
self.legacy_sessions_dir = get_legacy_sessions_dir()
self._cache: dict[str, Session] = {}
def _get_session_path(self, key: str) -> Path:

View File

@@ -1,5 +1,5 @@
"""Utility functions for nanobot."""
from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path
from nanobot.utils.helpers import ensure_dir
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]
__all__ = ["ensure_dir"]

View File

@@ -24,17 +24,6 @@ def ensure_dir(path: Path) -> Path:
return path
def get_data_path() -> Path:
"""~/.nanobot data directory."""
return ensure_dir(Path.home() / ".nanobot")
def get_workspace_path(workspace: str | None = None) -> Path:
"""Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace."""
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
return ensure_dir(path)
def timestamp() -> str:
"""Current ISO timestamp."""
return datetime.now().isoformat()

View File

@@ -0,0 +1,399 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init():
"""Test AzureOpenAIProvider initialization without deployment_name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21"
def test_azure_openai_provider_init_validation():
"""Test AzureOpenAIProvider initialization validation."""
# Missing api_key
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url():
"""Test Azure OpenAI URL building with different deployment names."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test various deployment names
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
def test_build_chat_url_api_base_without_slash():
"""Test URL building when api_base doesn't end with slash."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com", # No trailing slash
default_model="gpt-4o",
)
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers():
"""Test Azure OpenAI header building with api-key authentication."""
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
)
assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio
async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
# Mock response data
mock_response_data = {
"choices": [{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided():
"""Test that chat uses default_model when no model is specified."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
)
mock_response_data = {
"choices": [{
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified
# Verify URL was built with default model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Mock response with tool calls
mock_response_data = {
"choices": [{
"message": {
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
@pytest.mark.asyncio
async def test_chat_api_error():
"""Test chat request API error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content
assert result.finish_reason == "error"
def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
)
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")

View File

@@ -0,0 +1,25 @@
from types import SimpleNamespace
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
class _DummyChannel(BaseChannel):
name = "dummy"
async def start(self) -> None:
return None
async def stop(self) -> None:
return None
async def send(self, msg: OutboundMessage) -> None:
return None
def test_is_allowed_requires_exact_match() -> None:
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
assert channel.is_allowed("allow@email.com") is True
assert channel.is_allowed("attacker|allow@email.com") is False

View File

@@ -14,13 +14,17 @@ from nanobot.providers.registry import find_by_model
runner = CliRunner()
class _StopGateway(RuntimeError):
pass
@pytest.fixture
def mock_paths():
"""Mock config/workspace paths for test isolation."""
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
patch("nanobot.config.loader.save_config") as mock_sc, \
patch("nanobot.config.loader.load_config"), \
patch("nanobot.utils.helpers.get_workspace_path") as mock_ws:
patch("nanobot.config.loader.load_config") as mock_lc, \
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
base_dir = Path("./test_onboard_data")
if base_dir.exists():
@@ -135,10 +139,10 @@ def mock_agent_runtime(tmp_path):
"""Mock agent command dependencies for focused CLI tests."""
config = Config()
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
data_dir = tmp_path / "data"
cron_dir = tmp_path / "data" / "cron"
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
patch("nanobot.config.loader.get_data_dir", return_value=data_dir), \
patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
patch("nanobot.cli.commands._make_provider", return_value=object()), \
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
@@ -221,3 +225,94 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime)
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert seen["config_path"] == config_file.resolve()
assert seen["workspace"] == Path(config.agents.defaults.workspace)
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
override = tmp_path / "override-workspace"
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(
app,
["gateway", "--config", str(config_file), "--workspace", str(override)],
)
assert isinstance(result.exception, _StopGateway)
assert seen["workspace"] == override
assert config.workspace_path == override
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGateway("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"

View File

@@ -0,0 +1,42 @@
from pathlib import Path
from nanobot.config.paths import (
get_bridge_install_dir,
get_cli_history_path,
get_cron_dir,
get_data_dir,
get_legacy_sessions_dir,
get_logs_dir,
get_media_dir,
get_runtime_subdir,
get_workspace_path,
)
def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance-a" / "config.json"
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
assert get_data_dir() == config_file.parent
assert get_runtime_subdir("cron") == config_file.parent / "cron"
assert get_cron_dir() == config_file.parent / "cron"
assert get_logs_dir() == config_file.parent / "logs"
def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance-b" / "config.json"
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
assert get_media_dir() == config_file.parent / "media"
assert get_media_dir("telegram") == config_file.parent / "media" / "telegram"
def test_shared_and_legacy_paths_remain_global() -> None:
assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history"
assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge"
assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions"
def test_workspace_path_is_explicitly_resolved() -> None:
assert get_workspace_path() == Path.home() / ".nanobot" / "workspace"
assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace"

View File

@@ -0,0 +1,66 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.dingtalk import DingTalkChannel
from nanobot.config.schema import DingTalkConfig
class _FakeResponse:
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
self.status_code = status_code
self._json_body = json_body or {}
self.text = "{}"
def json(self) -> dict:
return self._json_body
class _FakeHttp:
def __init__(self) -> None:
self.calls: list[dict] = []
async def post(self, url: str, json=None, headers=None):
self.calls.append({"url": url, "json": json, "headers": headers})
return _FakeResponse()
@pytest.mark.asyncio
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
bus = MessageBus()
channel = DingTalkChannel(config, bus)
await channel._on_message(
"hello",
sender_id="user1",
sender_name="Alice",
conversation_type="2",
conversation_id="conv123",
)
msg = await bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"
assert msg.metadata["conversation_type"] == "2"
@pytest.mark.asyncio
async def test_group_send_uses_group_messages_api() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
channel._http = _FakeHttp()
ok = await channel._send_batch_message(
"token",
"group:conv123",
"sampleMarkdown",
{"text": "hello", "title": "Nanobot Reply"},
)
assert ok is True
call = channel._http.calls[0]
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
assert call["json"]["openConversationId"] == "conv123"
assert call["json"]["msgKey"] == "sampleMarkdown"

View File

@@ -1,4 +1,4 @@
from nanobot.channels.feishu import _extract_post_content
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
def test_extract_post_content_supports_post_wrapper_shape() -> None:
@@ -38,3 +38,28 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
assert text == "Daily report"
assert image_keys == ["img_a", "img_b"]
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
class Builder:
pass
builder = Builder()
same = FeishuChannel._register_optional_event(builder, "missing", object())
assert same is builder
def test_register_optional_event_calls_supported_method() -> None:
called = []
class Builder:
def register_event(self, handler):
called.append(handler)
return self
builder = Builder()
handler = object()
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
assert same is builder
assert called == [handler]

View File

@@ -86,6 +86,35 @@ class TestMessageToolSuppressLogic:
assert result is not None
assert "Hello" in result.content
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
calls = iter([
LLMResponse(
content="Visible<think>hidden</think>",
tool_calls=[tool_call],
reasoning_content="secret reasoning",
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
progress: list[tuple[str, bool]] = []
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [
("Visible", False),
('read_file("foo.txt")', True),
]
class TestMessageToolTurnTracking:

66
tests/test_qq_channel.py Normal file
View File

@@ -0,0 +1,66 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel
from nanobot.config.schema import QQConfig
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeClient:
def __init__(self) -> None:
self.api = _FakeApi()
@pytest.mark.asyncio
async def test_on_group_message_routes_to_group_chat_id() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
data = SimpleNamespace(
id="msg1",
content="hello",
group_openid="group123",
author=SimpleNamespace(member_openid="user1"),
)
await channel._on_message(data, is_group=True)
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group123"
@pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
await channel.send(
OutboundMessage(
channel="qq",
chat_id="group123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123"
assert call["msg_id"] == "msg1"
assert call["msg_seq"] == 2
assert not channel._client.api.c2c_calls

View File

@@ -0,0 +1,184 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TelegramChannel
from nanobot.config.schema import TelegramConfig
class _FakeHTTPXRequest:
instances: list["_FakeHTTPXRequest"] = []
def __init__(self, **kwargs) -> None:
self.kwargs = kwargs
self.__class__.instances.append(self)
class _FakeUpdater:
def __init__(self, on_start_polling) -> None:
self._on_start_polling = on_start_polling
async def start_polling(self, **kwargs) -> None:
self._on_start_polling()
class _FakeBot:
def __init__(self) -> None:
self.sent_messages: list[dict] = []
async def get_me(self):
return SimpleNamespace(username="nanobot_test")
async def set_my_commands(self, commands) -> None:
self.commands = commands
async def send_message(self, **kwargs) -> None:
self.sent_messages.append(kwargs)
class _FakeApp:
def __init__(self, on_start_polling) -> None:
self.bot = _FakeBot()
self.updater = _FakeUpdater(on_start_polling)
self.handlers = []
self.error_handlers = []
def add_error_handler(self, handler) -> None:
self.error_handlers.append(handler)
def add_handler(self, handler) -> None:
self.handlers.append(handler)
async def initialize(self) -> None:
pass
async def start(self) -> None:
pass
class _FakeBuilder:
def __init__(self, app: _FakeApp) -> None:
self.app = app
self.token_value = None
self.request_value = None
self.get_updates_request_value = None
def token(self, token: str):
self.token_value = token
return self
def request(self, request):
self.request_value = request
return self
def get_updates_request(self, request):
self.get_updates_request_value = request
return self
def proxy(self, _proxy):
raise AssertionError("builder.proxy should not be called when request is set")
def get_updates_proxy(self, _proxy):
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
def build(self):
return self.app
@pytest.mark.asyncio
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
config = TelegramConfig(
enabled=True,
token="123:abc",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
)
bus = MessageBus()
channel = TelegramChannel(config, bus)
app = _FakeApp(lambda: setattr(channel, "_running", False))
builder = _FakeBuilder(app)
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
monkeypatch.setattr(
"nanobot.channels.telegram.Application",
SimpleNamespace(builder=lambda: builder),
)
await channel.start()
assert len(_FakeHTTPXRequest.instances) == 1
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
assert builder.request_value is _FakeHTTPXRequest.instances[0]
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
def test_derive_topic_session_key_uses_thread_id() -> None:
message = SimpleNamespace(
chat=SimpleNamespace(type="supergroup"),
chat_id=-100123,
message_thread_id=42,
)
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
def test_get_extension_falls_back_to_original_filename() -> None:
channel = TelegramChannel(TelegramConfig(), MessageBus())
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
assert channel.is_allowed("12345|carol") is True
assert channel.is_allowed("99999|alice") is True
assert channel.is_allowed("67890|bob") is True
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
assert channel.is_allowed("attacker|alice|extra") is False
assert channel.is_allowed("not-a-number|alice") is False
@pytest.mark.asyncio
async def test_send_progress_keeps_message_in_topic() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"_progress": True, "message_thread_id": 42},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
@pytest.mark.asyncio
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
channel._message_threads[("123", 10)] = 42
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"message_id": 10},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10

View File

@@ -106,3 +106,234 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "/tmp/out.txt" in paths
# --- cast_params tests ---
class CastTestTool(Tool):
"""Minimal tool for testing cast_params."""
def __init__(self, schema: dict[str, Any]) -> None:
self._schema = schema
@property
def name(self) -> str:
return "cast_test"
@property
def description(self) -> str:
return "test tool for casting"
@property
def parameters(self) -> dict[str, Any]:
return self._schema
async def execute(self, **kwargs: Any) -> str:
return "ok"
def test_cast_params_string_to_int() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "42"})
assert result["count"] == 42
assert isinstance(result["count"], int)
def test_cast_params_string_to_number() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "3.14"})
assert result["rate"] == 3.14
assert isinstance(result["rate"], float)
def test_cast_params_string_to_bool() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"enabled": {"type": "boolean"}},
}
)
assert tool.cast_params({"enabled": "true"})["enabled"] is True
assert tool.cast_params({"enabled": "false"})["enabled"] is False
assert tool.cast_params({"enabled": "1"})["enabled"] is True
def test_cast_params_array_items() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"nums": {"type": "array", "items": {"type": "integer"}},
},
}
)
result = tool.cast_params({"nums": ["1", "2", "3"]})
assert result["nums"] == [1, 2, 3]
def test_cast_params_nested_object() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"port": {"type": "integer"},
"debug": {"type": "boolean"},
},
},
},
}
)
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
assert result["config"]["port"] == 8080
assert result["config"]["debug"] is True
def test_cast_params_bool_not_cast_to_int() -> None:
"""Booleans should not be silently cast to integers."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": True})
assert result["count"] is True
errors = tool.validate_params(result)
assert any("count should be integer" in e for e in errors)
def test_cast_params_preserves_empty_string() -> None:
"""Empty strings should be preserved for string type."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string"}},
}
)
result = tool.cast_params({"name": ""})
assert result["name"] == ""
def test_cast_params_bool_string_false() -> None:
"""Test that 'false', '0', 'no' strings convert to False."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
assert tool.cast_params({"flag": "false"})["flag"] is False
assert tool.cast_params({"flag": "False"})["flag"] is False
assert tool.cast_params({"flag": "0"})["flag"] is False
assert tool.cast_params({"flag": "no"})["flag"] is False
assert tool.cast_params({"flag": "NO"})["flag"] is False
def test_cast_params_bool_string_invalid() -> None:
"""Invalid boolean strings should not be cast."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
# Invalid strings should be preserved (validation will catch them)
result = tool.cast_params({"flag": "random"})
assert result["flag"] == "random"
result = tool.cast_params({"flag": "maybe"})
assert result["flag"] == "maybe"
def test_cast_params_invalid_string_to_int() -> None:
"""Invalid strings should not be cast to integer."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "abc"})
assert result["count"] == "abc" # Original value preserved
result = tool.cast_params({"count": "12.5.7"})
assert result["count"] == "12.5.7"
def test_cast_params_invalid_string_to_number() -> None:
"""Invalid strings should not be cast to number."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "not_a_number"})
assert result["rate"] == "not_a_number"
def test_validate_params_bool_not_accepted_as_number() -> None:
"""Booleans should not pass number validation."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
errors = tool.validate_params({"rate": False})
assert any("rate should be number" in e for e in errors)
def test_cast_params_none_values() -> None:
"""Test None handling for different types."""
tool = CastTestTool(
{
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "integer"},
"items": {"type": "array"},
"config": {"type": "object"},
},
}
)
result = tool.cast_params(
{
"name": None,
"count": None,
"items": None,
"config": None,
}
)
# None should be preserved for all types
assert result["name"] is None
assert result["count"] is None
assert result["items"] is None
assert result["config"] is None
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
"""Single values should NOT be automatically wrapped into arrays."""
tool = CastTestTool(
{
"type": "object",
"properties": {"items": {"type": "array"}},
}
)
# Non-array values should be preserved (validation will catch them)
result = tool.cast_params({"items": 5})
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]