Merge branch 'main' into feat-volcengine-tuning
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,4 +20,5 @@ __pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
botpy.log
|
||||
nano.*.save
|
||||
|
||||
|
||||
265
README.md
265
README.md
@@ -20,9 +20,21 @@
|
||||
|
||||
## 📢 News
|
||||
|
||||
- **2026-03-08** 🚀 Released **v0.1.4.post4** — a reliability-packed release with safer defaults, better multi-instance support, sturdier MCP, and major channel and provider improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post4) for details.
|
||||
- **2026-03-07** 🚀 Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish.
|
||||
- **2026-03-06** 🪄 Lighter providers, smarter media handling, and sturdier memory and CLI compatibility.
|
||||
- **2026-03-05** ⚡️ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes.
|
||||
- **2026-03-04** 🛠️ Dependency cleanup, safer file reads, and another round of test and Cron fixes.
|
||||
- **2026-03-03** 🧠 Cleaner user-message merging, safer multimodal saves, and stronger Cron guards.
|
||||
- **2026-03-02** 🛡️ Safer default access control, sturdier Cron reloads, and cleaner Matrix media handling.
|
||||
- **2026-03-01** 🌐 Web proxy support, smarter Cron reminders, and Feishu rich-text parsing improvements.
|
||||
- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
|
||||
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
|
||||
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
|
||||
|
||||
<details>
|
||||
<summary>Earlier news</summary>
|
||||
|
||||
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
|
||||
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
||||
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
||||
@@ -30,10 +42,6 @@
|
||||
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
||||
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
||||
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
||||
|
||||
<details>
|
||||
<summary>Earlier news</summary>
|
||||
|
||||
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
|
||||
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
|
||||
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
|
||||
@@ -56,7 +64,7 @@
|
||||
|
||||
## Key Features of nanobot:
|
||||
|
||||
🪶 **Ultra-Lightweight**: Just ~4,000 lines of core agent code — 99% smaller than Clawdbot.
|
||||
🪶 **Ultra-Lightweight**: A super lightweight implementation of OpenClaw — 99% smaller, significantly faster.
|
||||
|
||||
🔬 **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research.
|
||||
|
||||
@@ -70,6 +78,25 @@
|
||||
<img src="nanobot_arch.png" alt="nanobot architecture" width="800">
|
||||
</p>
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [News](#-news)
|
||||
- [Key Features](#key-features-of-nanobot)
|
||||
- [Architecture](#️-architecture)
|
||||
- [Features](#-features)
|
||||
- [Install](#-install)
|
||||
- [Quick Start](#-quick-start)
|
||||
- [Chat Apps](#-chat-apps)
|
||||
- [Agent Social Network](#-agent-social-network)
|
||||
- [Configuration](#️-configuration)
|
||||
- [Multiple Instances](#-multiple-instances)
|
||||
- [CLI Reference](#-cli-reference)
|
||||
- [Docker](#-docker)
|
||||
- [Linux Service](#-linux-service)
|
||||
- [Project Structure](#-project-structure)
|
||||
- [Contribute & Roadmap](#-contribute--roadmap)
|
||||
- [Star History](#-star-history)
|
||||
|
||||
## ✨ Features
|
||||
|
||||
<table align="center">
|
||||
@@ -115,6 +142,29 @@ uv tool install nanobot-ai
|
||||
pip install nanobot-ai
|
||||
```
|
||||
|
||||
### Update to latest version
|
||||
|
||||
**PyPI / pip**
|
||||
|
||||
```bash
|
||||
pip install -U nanobot-ai
|
||||
nanobot --version
|
||||
```
|
||||
|
||||
**uv**
|
||||
|
||||
```bash
|
||||
uv tool upgrade nanobot-ai
|
||||
nanobot --version
|
||||
```
|
||||
|
||||
**Using WhatsApp?** Rebuild the local bridge after upgrading:
|
||||
|
||||
```bash
|
||||
rm -rf ~/.nanobot/bridge
|
||||
nanobot channels login
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
> [!TIP]
|
||||
@@ -177,6 +227,7 @@ Connect nanobot to your favorite chat platform.
|
||||
| **Slack** | Bot token + App-Level token |
|
||||
| **Email** | IMAP/SMTP credentials |
|
||||
| **QQ** | App ID + App Secret |
|
||||
| **Wecom** | Bot ID + Bot Secret |
|
||||
|
||||
<details>
|
||||
<summary><b>Telegram</b> (Recommended)</summary>
|
||||
@@ -367,7 +418,7 @@ pip install nanobot-ai[matrix]
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `allowFrom` | User IDs allowed to interact. Empty = all senders. |
|
||||
| `allowFrom` | User IDs allowed to interact. Empty denies all; use `["*"]` to allow everyone. |
|
||||
| `groupPolicy` | `open` (default), `mention`, or `allowlist`. |
|
||||
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
|
||||
| `allowRoomMentions` | Accept `@room` mentions in mention mode. |
|
||||
@@ -420,6 +471,10 @@ nanobot channels login
|
||||
nanobot gateway
|
||||
```
|
||||
|
||||
> WhatsApp bridge updates are not applied automatically for existing installations.
|
||||
> After upgrading nanobot, rebuild the local bridge with:
|
||||
> `rm -rf ~/.nanobot/bridge && nanobot channels login`
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -447,7 +502,8 @@ Uses **WebSocket** long connection — no public IP required.
|
||||
"appSecret": "xxx",
|
||||
"encryptKey": "",
|
||||
"verificationToken": "",
|
||||
"allowFrom": ["ou_YOUR_OPEN_ID"]
|
||||
"allowFrom": ["ou_YOUR_OPEN_ID"],
|
||||
"groupPolicy": "mention"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -455,6 +511,7 @@ Uses **WebSocket** long connection — no public IP required.
|
||||
|
||||
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
||||
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
||||
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
|
||||
|
||||
**3. Run**
|
||||
|
||||
@@ -642,6 +699,46 @@ nanobot gateway
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Wecom (企业微信)</b></summary>
|
||||
|
||||
> Here we use [wecom-aibot-sdk-python](https://github.com/chengyongru/wecom_aibot_sdk) (community Python version of the official [@wecom/aibot-node-sdk](https://www.npmjs.com/package/@wecom/aibot-node-sdk)).
|
||||
>
|
||||
> Uses **WebSocket** long connection — no public IP required.
|
||||
|
||||
**1. Install the optional dependency**
|
||||
|
||||
```bash
|
||||
pip install nanobot-ai[wecom]
|
||||
```
|
||||
|
||||
**2. Create a WeCom AI Bot**
|
||||
|
||||
Go to the WeCom admin console → Intelligent Robot → Create Robot → select **API mode** with **long connection**. Copy the Bot ID and Secret.
|
||||
|
||||
**3. Configure**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"wecom": {
|
||||
"enabled": true,
|
||||
"botId": "your_bot_id",
|
||||
"secret": "your_bot_secret",
|
||||
"allowFrom": ["your_id"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**4. Run**
|
||||
|
||||
```bash
|
||||
nanobot gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🌐 Agent Social Network
|
||||
|
||||
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
|
||||
@@ -664,12 +761,14 @@ Config file: `~/.nanobot/config.json`
|
||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
||||
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
|
||||
|
||||
| Provider | Purpose | Get API Key |
|
||||
|----------|---------|-------------|
|
||||
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
@@ -681,6 +780,7 @@ Config file: `~/.nanobot/config.json`
|
||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||
| `ollama` | LLM (local, Ollama) | — |
|
||||
| `vllm` | LLM (local, any OpenAI-compatible server) | — |
|
||||
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
|
||||
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
|
||||
@@ -709,6 +809,12 @@ nanobot provider login openai-codex
|
||||
**3. Chat:**
|
||||
```bash
|
||||
nanobot agent -m "Hello!"
|
||||
|
||||
# Target a specific workspace/config locally
|
||||
nanobot agent -c ~/.nanobot-telegram/config.json -m "Hello!"
|
||||
|
||||
# One-off workspace override on top of that config
|
||||
nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -m "Hello!"
|
||||
```
|
||||
|
||||
> Docker users: use `docker run -it` for interactive OAuth login.
|
||||
@@ -740,6 +846,37 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Ollama (local)</b></summary>
|
||||
|
||||
Run a local model with Ollama, then add to config:
|
||||
|
||||
**1. Start Ollama** (example):
|
||||
```bash
|
||||
ollama run llama3.2
|
||||
```
|
||||
|
||||
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"apiBase": "http://localhost:11434"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "ollama",
|
||||
"model": "llama3.2"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> `provider: "auto"` also works when `providers.ollama.apiBase` is configured, but setting `"provider": "ollama"` is the clearest option.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
|
||||
|
||||
@@ -881,48 +1018,124 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
||||
|
||||
> [!TIP]
|
||||
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
|
||||
> **Change in source / post-`v0.1.4.post3`:** In `v0.1.4.post3` and earlier, an empty `allowFrom` means "allow all senders". In newer versions (including building from source), **empty `allowFrom` denies all access by default**. To allow all senders, set `"allowFrom": ["*"]`.
|
||||
> In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all senders. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default. To allow all senders, set `"allowFrom": ["*"]`.
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
|
||||
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
|
||||
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
|
||||
| `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. |
|
||||
|
||||
|
||||
## Multiple Instances
|
||||
## 🧩 Multiple Instances
|
||||
|
||||
Run multiple nanobot instances simultaneously, each with its own workspace and configuration.
|
||||
Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint, and optionally use `--workspace` to override the workspace for a specific run.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Instance A - Telegram bot
|
||||
nanobot gateway -w ~/.nanobot/botA -p 18791
|
||||
nanobot gateway --config ~/.nanobot-telegram/config.json
|
||||
|
||||
# Instance B - Discord bot
|
||||
nanobot gateway -w ~/.nanobot/botB -p 18792
|
||||
# Instance B - Discord bot
|
||||
nanobot gateway --config ~/.nanobot-discord/config.json
|
||||
|
||||
# Instance C - Using custom config file
|
||||
nanobot gateway -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json -p 18793
|
||||
# Instance C - Feishu bot with custom port
|
||||
nanobot gateway --config ~/.nanobot-feishu/config.json --port 18792
|
||||
```
|
||||
|
||||
| Option | Short | Description |
|
||||
|--------|-------|-------------|
|
||||
| `--workspace` | `-w` | Workspace directory (default: `~/.nanobot/workspace`) |
|
||||
| `--config` | `-c` | Config file path (default: `~/.nanobot/config.json`) |
|
||||
| `--port` | `-p` | Gateway port (default: `18790`) |
|
||||
### Path Resolution
|
||||
|
||||
Each instance has its own:
|
||||
- Workspace directory (MEMORY.md, HEARTBEAT.md, session files)
|
||||
- Cron jobs storage (`workspace/cron/jobs.json`)
|
||||
- Configuration (if using `--config`)
|
||||
When using `--config`, nanobot derives its runtime data directory from the config file location. The workspace still comes from `agents.defaults.workspace` unless you override it with `--workspace`.
|
||||
|
||||
To open a CLI session against one of these instances locally:
|
||||
|
||||
## CLI Reference
|
||||
```bash
|
||||
nanobot agent -c ~/.nanobot-telegram/config.json -m "Hello from Telegram instance"
|
||||
nanobot agent -c ~/.nanobot-discord/config.json -m "Hello from Discord instance"
|
||||
|
||||
# Optional one-off workspace override
|
||||
nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test
|
||||
```
|
||||
|
||||
> `nanobot agent` starts a local CLI agent using the selected workspace/config. It does not attach to or proxy through an already running `nanobot gateway` process.
|
||||
|
||||
| Component | Resolved From | Example |
|
||||
|-----------|---------------|---------|
|
||||
| **Config** | `--config` path | `~/.nanobot-A/config.json` |
|
||||
| **Workspace** | `--workspace` or config | `~/.nanobot-A/workspace/` |
|
||||
| **Cron Jobs** | config directory | `~/.nanobot-A/cron/` |
|
||||
| **Media / runtime state** | config directory | `~/.nanobot-A/media/` |
|
||||
|
||||
### How It Works
|
||||
|
||||
- `--config` selects which config file to load
|
||||
- By default, the workspace comes from `agents.defaults.workspace` in that config
|
||||
- If you pass `--workspace`, it overrides the workspace from the config file
|
||||
|
||||
### Minimal Setup
|
||||
|
||||
1. Copy your base config into a new instance directory.
|
||||
2. Set a different `agents.defaults.workspace` for that instance.
|
||||
3. Start the instance with `--config`.
|
||||
|
||||
Example config:
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.nanobot-telegram/workspace",
|
||||
"model": "anthropic/claude-sonnet-4-6"
|
||||
}
|
||||
},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_TELEGRAM_BOT_TOKEN"
|
||||
}
|
||||
},
|
||||
"gateway": {
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Start separate instances:
|
||||
|
||||
```bash
|
||||
nanobot gateway --config ~/.nanobot-telegram/config.json
|
||||
nanobot gateway --config ~/.nanobot-discord/config.json
|
||||
```
|
||||
|
||||
Override workspace for one-off runs when needed:
|
||||
|
||||
```bash
|
||||
nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobot-telegram-test
|
||||
```
|
||||
|
||||
### Common Use Cases
|
||||
|
||||
- Run separate bots for Telegram, Discord, Feishu, and other platforms
|
||||
- Keep testing and production instances isolated
|
||||
- Use different models or providers for different teams
|
||||
- Serve multiple tenants with separate configs and runtime data
|
||||
|
||||
### Notes
|
||||
|
||||
- Each instance must use a different port if they run at the same time
|
||||
- Use a different workspace per instance if you want isolated memory, sessions, and skills
|
||||
- `--workspace` overrides the workspace defined in the config file
|
||||
- Cron jobs and runtime media/state are derived from the config directory
|
||||
|
||||
## 💻 CLI Reference
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `nanobot onboard` | Initialize config & workspace |
|
||||
| `nanobot agent -m "..."` | Chat with the agent |
|
||||
| `nanobot agent -w <workspace>` | Chat against a specific workspace |
|
||||
| `nanobot agent -w <workspace> -c <config>` | Chat against a specific workspace/config |
|
||||
| `nanobot agent` | Interactive chat mode |
|
||||
| `nanobot agent --no-markdown` | Show plain-text replies |
|
||||
| `nanobot agent --logs` | Show runtime logs during chat |
|
||||
|
||||
@@ -55,7 +55,7 @@ chmod 600 ~/.nanobot/config.json
|
||||
```
|
||||
|
||||
**Security Notes:**
|
||||
- In `v0.1.4.post3` and earlier, an empty `allowFrom` allows all users. In newer versions (including source builds), **empty `allowFrom` denies all access** — set `["*"]` to explicitly allow everyone.
|
||||
- In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all users. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default — set `["*"]` to explicitly allow everyone.
|
||||
- Get your Telegram user ID from `@userinfobot`
|
||||
- Use full phone numbers with country code for WhatsApp
|
||||
- Review access logs regularly for unauthorized access attempts
|
||||
@@ -212,7 +212,7 @@ If you suspect a security breach:
|
||||
- Input length limits on HTTP requests
|
||||
|
||||
✅ **Authentication**
|
||||
- Allow-list based access control — in `v0.1.4.post3` and earlier empty means allow all; in newer versions empty means deny all (`["*"]` to explicitly allow all)
|
||||
- Allow-list based access control — in `v0.1.4.post3` and earlier empty `allowFrom` allowed all; since `v0.1.4.post4` it denies all (`["*"]` explicitly allows all)
|
||||
- Failed authentication attempt logging
|
||||
|
||||
✅ **Resource Protection**
|
||||
|
||||
@@ -9,11 +9,16 @@ import makeWASocket, {
|
||||
useMultiFileAuthState,
|
||||
fetchLatestBaileysVersion,
|
||||
makeCacheableSignalKeyStore,
|
||||
downloadMediaMessage,
|
||||
extractMessageContent as baileysExtractMessageContent,
|
||||
} from '@whiskeysockets/baileys';
|
||||
|
||||
import { Boom } from '@hapi/boom';
|
||||
import qrcode from 'qrcode-terminal';
|
||||
import pino from 'pino';
|
||||
import { writeFile, mkdir } from 'fs/promises';
|
||||
import { join } from 'path';
|
||||
import { randomBytes } from 'crypto';
|
||||
|
||||
const VERSION = '0.1.0';
|
||||
|
||||
@@ -24,6 +29,7 @@ export interface InboundMessage {
|
||||
content: string;
|
||||
timestamp: number;
|
||||
isGroup: boolean;
|
||||
media?: string[];
|
||||
}
|
||||
|
||||
export interface WhatsAppClientOptions {
|
||||
@@ -110,14 +116,33 @@ export class WhatsAppClient {
|
||||
if (type !== 'notify') return;
|
||||
|
||||
for (const msg of messages) {
|
||||
// Skip own messages
|
||||
if (msg.key.fromMe) continue;
|
||||
|
||||
// Skip status updates
|
||||
if (msg.key.remoteJid === 'status@broadcast') continue;
|
||||
|
||||
const content = this.extractMessageContent(msg);
|
||||
if (!content) continue;
|
||||
const unwrapped = baileysExtractMessageContent(msg.message);
|
||||
if (!unwrapped) continue;
|
||||
|
||||
const content = this.getTextContent(unwrapped);
|
||||
let fallbackContent: string | null = null;
|
||||
const mediaPaths: string[] = [];
|
||||
|
||||
if (unwrapped.imageMessage) {
|
||||
fallbackContent = '[Image]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.documentMessage) {
|
||||
fallbackContent = '[Document]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
|
||||
unwrapped.documentMessage.fileName ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.videoMessage) {
|
||||
fallbackContent = '[Video]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
}
|
||||
|
||||
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
|
||||
if (!finalContent && mediaPaths.length === 0) continue;
|
||||
|
||||
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
|
||||
|
||||
@@ -125,18 +150,45 @@ export class WhatsAppClient {
|
||||
id: msg.key.id || '',
|
||||
sender: msg.key.remoteJid || '',
|
||||
pn: msg.key.remoteJidAlt || '',
|
||||
content,
|
||||
content: finalContent,
|
||||
timestamp: msg.messageTimestamp as number,
|
||||
isGroup,
|
||||
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private extractMessageContent(msg: any): string | null {
|
||||
const message = msg.message;
|
||||
if (!message) return null;
|
||||
private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
|
||||
try {
|
||||
const mediaDir = join(this.options.authDir, '..', 'media');
|
||||
await mkdir(mediaDir, { recursive: true });
|
||||
|
||||
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
|
||||
|
||||
let outFilename: string;
|
||||
if (fileName) {
|
||||
// Documents have a filename — use it with a unique prefix to avoid collisions
|
||||
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
|
||||
outFilename = prefix + fileName;
|
||||
} else {
|
||||
const mime = mimetype || 'application/octet-stream';
|
||||
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
|
||||
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
|
||||
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
|
||||
}
|
||||
|
||||
const filepath = join(mediaDir, outFilename);
|
||||
await writeFile(filepath, buffer);
|
||||
|
||||
return filepath;
|
||||
} catch (err) {
|
||||
console.error('Failed to download media:', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private getTextContent(message: any): string | null {
|
||||
// Text message
|
||||
if (message.conversation) {
|
||||
return message.conversation;
|
||||
@@ -147,19 +199,19 @@ export class WhatsAppClient {
|
||||
return message.extendedTextMessage.text;
|
||||
}
|
||||
|
||||
// Image with caption
|
||||
if (message.imageMessage?.caption) {
|
||||
return `[Image] ${message.imageMessage.caption}`;
|
||||
// Image with optional caption
|
||||
if (message.imageMessage) {
|
||||
return message.imageMessage.caption || '';
|
||||
}
|
||||
|
||||
// Video with caption
|
||||
if (message.videoMessage?.caption) {
|
||||
return `[Video] ${message.videoMessage.caption}`;
|
||||
// Video with optional caption
|
||||
if (message.videoMessage) {
|
||||
return message.videoMessage.caption || '';
|
||||
}
|
||||
|
||||
// Document with caption
|
||||
if (message.documentMessage?.caption) {
|
||||
return `[Document] ${message.documentMessage.caption}`;
|
||||
// Document with optional caption
|
||||
if (message.documentMessage) {
|
||||
return message.documentMessage.caption || '';
|
||||
}
|
||||
|
||||
// Voice/Audio message
|
||||
|
||||
@@ -15,7 +15,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||
|
||||
echo ""
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" | xargs cat | wc -l)
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
|
||||
echo " Core total: $total lines"
|
||||
echo ""
|
||||
echo " (excludes: channels/, cli/, providers/)"
|
||||
echo " (excludes: channels/, cli/, providers/, skills/)"
|
||||
|
||||
@@ -2,5 +2,5 @@
|
||||
nanobot - A lightweight AI agent framework
|
||||
"""
|
||||
|
||||
__version__ = "0.1.4.post3"
|
||||
__version__ = "0.1.4.post4"
|
||||
__logo__ = "🐈"
|
||||
|
||||
@@ -10,12 +10,13 @@ from typing import Any
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
@@ -58,6 +59,19 @@ Skills with available="false" need dependencies installed first - you can try in
|
||||
system = platform.system()
|
||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||
|
||||
platform_policy = ""
|
||||
if system == "Windows":
|
||||
platform_policy = """## Platform Policy (Windows)
|
||||
- You are running on Windows. Do not assume GNU tools like `grep`, `sed`, or `awk` exist.
|
||||
- Prefer Windows-native commands or file tools when they are more reliable.
|
||||
- If terminal output is garbled, retry with UTF-8 output enabled.
|
||||
"""
|
||||
else:
|
||||
platform_policy = """## Platform Policy (POSIX)
|
||||
- You are running on a POSIX system. Prefer UTF-8 and standard shell tools.
|
||||
- Use file tools when they are simpler or more reliable than shell commands.
|
||||
"""
|
||||
|
||||
return f"""# nanobot 🐈
|
||||
|
||||
You are nanobot, a helpful AI assistant.
|
||||
@@ -71,6 +85,8 @@ Your workspace is at: {workspace_path}
|
||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
|
||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||
|
||||
{platform_policy}
|
||||
|
||||
## nanobot Guidelines
|
||||
- State intent before tool calls, but NEVER predict or claim results before receiving them.
|
||||
- Before modifying a file, read it first. Do not assume files or directories exist.
|
||||
@@ -136,10 +152,14 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
||||
images = []
|
||||
for path in media:
|
||||
p = Path(path)
|
||||
mime, _ = mimetypes.guess_type(path)
|
||||
if not p.is_file() or not mime or not mime.startswith("image/"):
|
||||
if not p.is_file():
|
||||
continue
|
||||
b64 = base64.b64encode(p.read_bytes()).decode()
|
||||
raw = p.read_bytes()
|
||||
# Detect real MIME type from magic bytes; fallback to filename guess
|
||||
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||
if not mime or not mime.startswith("image/"):
|
||||
continue
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||||
|
||||
if not images:
|
||||
@@ -162,12 +182,10 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
||||
thinking_blocks: list[dict] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add an assistant message to the message list."""
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
if thinking_blocks:
|
||||
msg["thinking_blocks"] = thinking_blocks
|
||||
messages.append(msg)
|
||||
messages.append(build_assistant_message(
|
||||
content,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
))
|
||||
return messages
|
||||
|
||||
@@ -4,8 +4,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import weakref
|
||||
import sys
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
@@ -13,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
@@ -44,7 +45,7 @@ class AgentLoop:
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 500
|
||||
_TOOL_RESULT_MAX_CHARS = 16_000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -53,10 +54,7 @@ class AgentLoop:
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
memory_window: int = 100,
|
||||
reasoning_effort: str | None = None,
|
||||
context_window_tokens: int = 65_536,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
@@ -73,10 +71,7 @@ class AgentLoop:
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.memory_window = memory_window
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@@ -91,9 +86,6 @@ class AgentLoop:
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=reasoning_effort,
|
||||
brave_api_key=brave_api_key,
|
||||
web_proxy=web_proxy,
|
||||
exec_config=self.exec_config,
|
||||
@@ -105,11 +97,17 @@ class AgentLoop:
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self.memory_consolidator = MemoryConsolidator(
|
||||
workspace=workspace,
|
||||
provider=provider,
|
||||
model=self.model,
|
||||
sessions=self.sessions,
|
||||
context_window_tokens=context_window_tokens,
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
)
|
||||
self._register_default_tools()
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
@@ -182,7 +180,7 @@ class AgentLoop:
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
||||
"""Run the agent iteration loop."""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
@@ -191,40 +189,23 @@ class AgentLoop:
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
tool_defs = self.tools.get_definitions()
|
||||
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if on_progress:
|
||||
thoughts = [
|
||||
self._strip_think(response.content),
|
||||
response.reasoning_content,
|
||||
*(
|
||||
f"Thinking [{b.get('signature', '...')}]:\n{b.get('thought', '...')}"
|
||||
for b in (response.thinking_blocks or [])
|
||||
if isinstance(b, dict) and "signature" in b
|
||||
),
|
||||
]
|
||||
combined_thoughts = "\n\n".join(filter(None, thoughts))
|
||||
if combined_thoughts:
|
||||
await on_progress(combined_thoughts)
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
||||
}
|
||||
}
|
||||
tc.to_openai_tool_call()
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
@@ -277,8 +258,11 @@ class AgentLoop:
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if msg.content.strip().lower() == "/stop":
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/stop":
|
||||
await self._handle_stop(msg)
|
||||
elif cmd == "/restart":
|
||||
await self._handle_restart(msg)
|
||||
else:
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||
@@ -295,11 +279,23 @@ class AgentLoop:
|
||||
pass
|
||||
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||
total = cancelled + sub_cancelled
|
||||
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
||||
content = f"Stopped {total} task(s)." if total else "No active task to stop."
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||
))
|
||||
|
||||
async def _handle_restart(self, msg: InboundMessage) -> None:
|
||||
"""Restart the process in-place via os.execv."""
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
|
||||
))
|
||||
|
||||
async def _do_restart():
|
||||
await asyncio.sleep(1)
|
||||
os.execv(sys.executable, [sys.executable] + sys.argv)
|
||||
|
||||
asyncio.create_task(_do_restart())
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message under the global lock."""
|
||||
async with self._processing_lock:
|
||||
@@ -350,8 +346,9 @@ class AgentLoop:
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
history = session.get_history(max_messages=0)
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
@@ -359,6 +356,7 @@ class AgentLoop:
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
@@ -371,27 +369,20 @@ class AgentLoop:
|
||||
# Slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
self._consolidating.add(session.key)
|
||||
try:
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if snapshot:
|
||||
temp = Session(key=session.key)
|
||||
temp.messages = list(snapshot)
|
||||
if not await self._consolidate_memory(temp, archive_all=True):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
if not await self.memory_consolidator.archive_unconsolidated(session):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("/new archival failed for {}", session.key)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
@@ -399,33 +390,24 @@ class AgentLoop:
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="New session started.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||
|
||||
unconsolidated = len(session.messages) - session.last_consolidated
|
||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||
self._consolidating.add(session.key)
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
|
||||
async def _consolidate_and_unlock():
|
||||
try:
|
||||
async with lock:
|
||||
await self._consolidate_memory(session)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
_task = asyncio.current_task()
|
||||
if _task is not None:
|
||||
self._consolidation_tasks.discard(_task)
|
||||
|
||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
||||
self._consolidation_tasks.add(_task)
|
||||
lines = [
|
||||
"🐈 nanobot commands:",
|
||||
"/new — Start a new conversation",
|
||||
"/stop — Stop the current task",
|
||||
"/restart — Restart the bot",
|
||||
"/help — Show available commands",
|
||||
]
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
||||
)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
history = session.get_history(max_messages=0)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
@@ -450,6 +432,7 @@ class AgentLoop:
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
@@ -496,13 +479,6 @@ class AgentLoop:
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
||||
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
||||
return await MemoryStore(self.workspace).consolidate(
|
||||
session, self.provider, self.model,
|
||||
archive_all=archive_all, memory_window=self.memory_window,
|
||||
)
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
|
||||
@@ -2,17 +2,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
_SAVE_MEMORY_TOOL = [
|
||||
@@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [
|
||||
"properties": {
|
||||
"history_entry": {
|
||||
"type": "string",
|
||||
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
|
||||
"description": "A paragraph summarizing key events/decisions/topics. "
|
||||
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
||||
},
|
||||
"memory_update": {
|
||||
@@ -42,6 +44,19 @@ _SAVE_MEMORY_TOOL = [
|
||||
]
|
||||
|
||||
|
||||
def _ensure_text(value: Any) -> str:
|
||||
"""Normalize tool-call payload values to text for file storage."""
|
||||
return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
|
||||
"""Normalize provider tool-call arguments to the expected dict shape."""
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
if isinstance(args, list):
|
||||
return args[0] if args and isinstance(args[0], dict) else None
|
||||
return args if isinstance(args, dict) else None
|
||||
|
||||
class MemoryStore:
|
||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||
|
||||
@@ -66,40 +81,27 @@ class MemoryStore:
|
||||
long_term = self.read_long_term()
|
||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||
|
||||
@staticmethod
|
||||
def _format_messages(messages: list[dict]) -> str:
|
||||
lines = []
|
||||
for message in messages:
|
||||
if not message.get("content"):
|
||||
continue
|
||||
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
|
||||
lines.append(
|
||||
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
async def consolidate(
|
||||
self,
|
||||
session: Session,
|
||||
messages: list[dict],
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
*,
|
||||
archive_all: bool = False,
|
||||
memory_window: int = 50,
|
||||
) -> bool:
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||
|
||||
Returns True on success (including no-op), False on failure.
|
||||
"""
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
keep_count = 0
|
||||
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
||||
else:
|
||||
keep_count = memory_window // 2
|
||||
if len(session.messages) <= keep_count:
|
||||
return True
|
||||
if len(session.messages) - session.last_consolidated <= 0:
|
||||
return True
|
||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||
if not old_messages:
|
||||
return True
|
||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
||||
|
||||
lines = []
|
||||
for m in old_messages:
|
||||
if not m.get("content"):
|
||||
continue
|
||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
||||
"""Consolidate the provided message chunk into MEMORY.md + HISTORY.md."""
|
||||
if not messages:
|
||||
return True
|
||||
|
||||
current_memory = self.read_long_term()
|
||||
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||
@@ -108,50 +110,175 @@ class MemoryStore:
|
||||
{current_memory or "(empty)"}
|
||||
|
||||
## Conversation to Process
|
||||
{chr(10).join(lines)}"""
|
||||
{self._format_messages(messages)}"""
|
||||
|
||||
try:
|
||||
response = await provider.chat(
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
tools=_SAVE_MEMORY_TOOL,
|
||||
model=model,
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||
return False
|
||||
|
||||
args = response.tool_calls[0].arguments
|
||||
# Some providers return arguments as a JSON string instead of dict
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
# Some providers return arguments as a list (handle edge case)
|
||||
if isinstance(args, list):
|
||||
if args and isinstance(args[0], dict):
|
||||
args = args[0]
|
||||
else:
|
||||
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
||||
return False
|
||||
if not isinstance(args, dict):
|
||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
||||
if args is None:
|
||||
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
||||
return False
|
||||
|
||||
if entry := args.get("history_entry"):
|
||||
if not isinstance(entry, str):
|
||||
entry = json.dumps(entry, ensure_ascii=False)
|
||||
self.append_history(entry)
|
||||
self.append_history(_ensure_text(entry))
|
||||
if update := args.get("memory_update"):
|
||||
if not isinstance(update, str):
|
||||
update = json.dumps(update, ensure_ascii=False)
|
||||
update = _ensure_text(update)
|
||||
if update != current_memory:
|
||||
self.write_long_term(update)
|
||||
|
||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Memory consolidation failed")
|
||||
return False
|
||||
|
||||
|
||||
class MemoryConsolidator:
|
||||
"""Owns consolidation policy, locking, and session offset updates."""
|
||||
|
||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
sessions: SessionManager,
|
||||
context_window_tokens: int,
|
||||
build_messages: Callable[..., list[dict[str, Any]]],
|
||||
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||
):
|
||||
self.store = MemoryStore(workspace)
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.sessions = sessions
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self._build_messages = build_messages
|
||||
self._get_tool_definitions = get_tool_definitions
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
|
||||
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||
"""Return the shared consolidation lock for one session."""
|
||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||
"""Archive a selected message chunk into persistent memory."""
|
||||
return await self.store.consolidate(messages, self.provider, self.model)
|
||||
|
||||
def pick_consolidation_boundary(
|
||||
self,
|
||||
session: Session,
|
||||
tokens_to_remove: int,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Pick a user-turn boundary that removes enough old prompt tokens."""
|
||||
start = session.last_consolidated
|
||||
if start >= len(session.messages) or tokens_to_remove <= 0:
|
||||
return None
|
||||
|
||||
removed_tokens = 0
|
||||
last_boundary: tuple[int, int] | None = None
|
||||
for idx in range(start, len(session.messages)):
|
||||
message = session.messages[idx]
|
||||
if idx > start and message.get("role") == "user":
|
||||
last_boundary = (idx, removed_tokens)
|
||||
if removed_tokens >= tokens_to_remove:
|
||||
return last_boundary
|
||||
removed_tokens += estimate_message_tokens(message)
|
||||
|
||||
return last_boundary
|
||||
|
||||
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||
"""Estimate current prompt size for the normal session history view."""
|
||||
history = session.get_history(max_messages=0)
|
||||
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||
probe_messages = self._build_messages(
|
||||
history=history,
|
||||
current_message="[token-probe]",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return estimate_prompt_tokens_chain(
|
||||
self.provider,
|
||||
self.model,
|
||||
probe_messages,
|
||||
self._get_tool_definitions(),
|
||||
)
|
||||
|
||||
async def archive_unconsolidated(self, session: Session) -> bool:
|
||||
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if not snapshot:
|
||||
return True
|
||||
return await self.consolidate_messages(snapshot)
|
||||
|
||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||
if not session.messages or self.context_window_tokens <= 0:
|
||||
return
|
||||
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
target = self.context_window_tokens // 2
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
if estimated <= 0:
|
||||
return
|
||||
if estimated < self.context_window_tokens:
|
||||
logger.debug(
|
||||
"Token consolidation idle {}: {}/{} via {}",
|
||||
session.key,
|
||||
estimated,
|
||||
self.context_window_tokens,
|
||||
source,
|
||||
)
|
||||
return
|
||||
|
||||
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
|
||||
if estimated <= target:
|
||||
return
|
||||
|
||||
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
|
||||
if boundary is None:
|
||||
logger.debug(
|
||||
"Token consolidation: no safe boundary for {} (round {})",
|
||||
session.key,
|
||||
round_num,
|
||||
)
|
||||
return
|
||||
|
||||
end_idx = boundary[0]
|
||||
chunk = session.messages[session.last_consolidated:end_idx]
|
||||
if not chunk:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
|
||||
round_num,
|
||||
session.key,
|
||||
estimated,
|
||||
self.context_window_tokens,
|
||||
source,
|
||||
len(chunk),
|
||||
)
|
||||
if not await self.consolidate_messages(chunk):
|
||||
return
|
||||
session.last_consolidated = end_idx
|
||||
self.sessions.save(session)
|
||||
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
if estimated <= 0:
|
||||
return
|
||||
|
||||
@@ -16,6 +16,7 @@ from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.utils.helpers import build_assistant_message
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
@@ -27,9 +28,6 @@ class SubagentManager:
|
||||
workspace: Path,
|
||||
bus: MessageBus,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
reasoning_effort: str | None = None,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
@@ -40,9 +38,6 @@ class SubagentManager:
|
||||
self.workspace = workspace
|
||||
self.bus = bus
|
||||
self.model = model or provider.get_default_model()
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@@ -123,33 +118,23 @@ class SubagentManager:
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tools.get_definitions(),
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
tc.to_openai_tool_call()
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content or "",
|
||||
"tool_calls": tool_call_dicts,
|
||||
})
|
||||
messages.append(build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
@@ -230,7 +215,7 @@ Stay focused on the assigned task. Your final response will be reported back to
|
||||
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def cancel_by_session(self, session_key: str) -> int:
|
||||
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
||||
|
||||
@@ -52,6 +52,75 @@ class Tool(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply safe schema-driven casts before validation."""
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
return params
|
||||
|
||||
return self._cast_object(params, schema)
|
||||
|
||||
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Cast an object (dict) according to schema."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for key, value in obj.items():
|
||||
if key in props:
|
||||
result[key] = self._cast_value(value, props[key])
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
||||
"""Cast a single value according to schema."""
|
||||
target_type = schema.get("type")
|
||||
|
||||
if target_type == "boolean" and isinstance(val, bool):
|
||||
return val
|
||||
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
||||
return val
|
||||
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
|
||||
expected = self._TYPE_MAP[target_type]
|
||||
if isinstance(val, expected):
|
||||
return val
|
||||
|
||||
if target_type == "integer" and isinstance(val, str):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
if target_type == "number" and isinstance(val, str):
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
if target_type == "string":
|
||||
return val if val is None else str(val)
|
||||
|
||||
if target_type == "boolean" and isinstance(val, str):
|
||||
val_lower = val.lower()
|
||||
if val_lower in ("true", "1", "yes"):
|
||||
return True
|
||||
if val_lower in ("false", "0", "no"):
|
||||
return False
|
||||
return val
|
||||
|
||||
if target_type == "array" and isinstance(val, list):
|
||||
item_schema = schema.get("items")
|
||||
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
|
||||
|
||||
if target_type == "object" and isinstance(val, dict):
|
||||
return self._cast_object(val, schema)
|
||||
|
||||
return val
|
||||
|
||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||
if not isinstance(params, dict):
|
||||
@@ -63,7 +132,13 @@ class Tool(ABC):
|
||||
|
||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||
t, label = schema.get("type"), path or "parameter"
|
||||
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
||||
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
||||
return [f"{label} should be integer"]
|
||||
if t == "number" and (
|
||||
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
|
||||
):
|
||||
return [f"{label} should be number"]
|
||||
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
|
||||
return [f"{label} should be {t}"]
|
||||
|
||||
errors = []
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""File system tools: read, write, edit."""
|
||||
"""File system tools: read, write, edit, list."""
|
||||
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
@@ -23,62 +23,108 @@ def _resolve_path(
|
||||
return resolved
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Tool to read file contents."""
|
||||
|
||||
_MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context
|
||||
class _FsTool(Tool):
|
||||
"""Shared base for filesystem tools — common init and path resolution."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
def _resolve(self, path: str) -> Path:
|
||||
return _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ReadFileTool(_FsTool):
|
||||
"""Read file contents with optional line-based pagination."""
|
||||
|
||||
_MAX_CHARS = 128_000
|
||||
_DEFAULT_LIMIT = 2000
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Read the contents of a file at the given path."
|
||||
return (
|
||||
"Read the contents of a file. Returns numbered lines. "
|
||||
"Use offset and limit to paginate through large files."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "The file path to read"}},
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to read"},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, default 1)",
|
||||
"minimum": 1,
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (default 2000)",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not file_path.exists():
|
||||
fp = self._resolve(path)
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not file_path.is_file():
|
||||
if not fp.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
size = file_path.stat().st_size
|
||||
if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes)
|
||||
return (
|
||||
f"Error: File too large ({size:,} bytes). "
|
||||
f"Use exec tool with head/tail/grep to read portions."
|
||||
)
|
||||
all_lines = fp.read_text(encoding="utf-8").splitlines()
|
||||
total = len(all_lines)
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
if len(content) > self._MAX_CHARS:
|
||||
return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})"
|
||||
return content
|
||||
if offset < 1:
|
||||
offset = 1
|
||||
if total == 0:
|
||||
return f"(Empty file: {path})"
|
||||
if offset > total:
|
||||
return f"Error: offset {offset} is beyond end of file ({total} lines)"
|
||||
|
||||
start = offset - 1
|
||||
end = min(start + (limit or self._DEFAULT_LIMIT), total)
|
||||
numbered = [f"{start + i + 1}| {line}" for i, line in enumerate(all_lines[start:end])]
|
||||
result = "\n".join(numbered)
|
||||
|
||||
if len(result) > self._MAX_CHARS:
|
||||
trimmed, chars = [], 0
|
||||
for line in numbered:
|
||||
chars += len(line) + 1
|
||||
if chars > self._MAX_CHARS:
|
||||
break
|
||||
trimmed.append(line)
|
||||
end = start + len(trimmed)
|
||||
result = "\n".join(trimmed)
|
||||
|
||||
if end < total:
|
||||
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
||||
else:
|
||||
result += f"\n\n(End of file — {total} lines total)"
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
"""Tool to write content to a file."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
class WriteFileTool(_FsTool):
|
||||
"""Write content to a file."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -101,22 +147,48 @@ class WriteFileTool(Tool):
|
||||
|
||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return f"Successfully wrote {len(content)} bytes to {file_path}"
|
||||
fp = self._resolve(path)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(content, encoding="utf-8")
|
||||
return f"Successfully wrote {len(content)} bytes to {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error writing file: {str(e)}"
|
||||
return f"Error writing file: {e}"
|
||||
|
||||
|
||||
class EditFileTool(Tool):
|
||||
"""Tool to edit a file by replacing text."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# edit_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
"""Locate old_text in content: exact first, then line-trimmed sliding window.
|
||||
|
||||
Both inputs should use LF line endings (caller normalises CRLF).
|
||||
Returns (matched_fragment, count) or (None, 0).
|
||||
"""
|
||||
if old_text in content:
|
||||
return old_text, content.count(old_text)
|
||||
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
return None, 0
|
||||
stripped_old = [l.strip() for l in old_lines]
|
||||
content_lines = content.splitlines()
|
||||
|
||||
candidates = []
|
||||
for i in range(len(content_lines) - len(stripped_old) + 1):
|
||||
window = content_lines[i : i + len(stripped_old)]
|
||||
if [l.strip() for l in window] == stripped_old:
|
||||
candidates.append("\n".join(window))
|
||||
|
||||
if candidates:
|
||||
return candidates[0], len(candidates)
|
||||
return None, 0
|
||||
|
||||
|
||||
class EditFileTool(_FsTool):
|
||||
"""Edit a file by replacing text with fallback matching."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -124,7 +196,11 @@ class EditFileTool(Tool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Supports minor whitespace/line-ending differences. "
|
||||
"Set replace_all=true to replace every occurrence."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
@@ -132,40 +208,52 @@ class EditFileTool(Tool):
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to edit"},
|
||||
"old_text": {"type": "string", "description": "The exact text to find and replace"},
|
||||
"old_text": {"type": "string", "description": "The text to find and replace"},
|
||||
"new_text": {"type": "string", "description": "The text to replace with"},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default false)",
|
||||
},
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
||||
async def execute(
|
||||
self, path: str, old_text: str, new_text: str,
|
||||
replace_all: bool = False, **kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not file_path.exists():
|
||||
fp = self._resolve(path)
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
raw = fp.read_bytes()
|
||||
uses_crlf = b"\r\n" in raw
|
||||
content = raw.decode("utf-8").replace("\r\n", "\n")
|
||||
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
|
||||
|
||||
if old_text not in content:
|
||||
return self._not_found_message(old_text, content, path)
|
||||
if match is None:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
if count > 1 and not replace_all:
|
||||
return (
|
||||
f"Warning: old_text appears {count} times. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
)
|
||||
|
||||
# Count occurrences
|
||||
count = content.count(old_text)
|
||||
if count > 1:
|
||||
return f"Warning: old_text appears {count} times. Please provide more context to make it unique."
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
new_content = content.replace(old_text, new_text, 1)
|
||||
file_path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
return f"Successfully edited {file_path}"
|
||||
fp.write_bytes(new_content.encode("utf-8"))
|
||||
return f"Successfully edited {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {str(e)}"
|
||||
return f"Error editing file: {e}"
|
||||
|
||||
@staticmethod
|
||||
def _not_found_message(old_text: str, content: str, path: str) -> str:
|
||||
"""Build a helpful error when old_text is not found."""
|
||||
def _not_found_msg(old_text: str, content: str, path: str) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
@@ -177,27 +265,29 @@ class EditFileTool(Tool):
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(
|
||||
difflib.unified_diff(
|
||||
old_lines,
|
||||
lines[best_start : best_start + window],
|
||||
fromfile="old_text (provided)",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
)
|
||||
)
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines, lines[best_start : best_start + window],
|
||||
fromfile="old_text (provided)",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
return (
|
||||
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
)
|
||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
|
||||
|
||||
class ListDirTool(Tool):
|
||||
"""Tool to list directory contents."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_dir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
class ListDirTool(_FsTool):
|
||||
"""List directory contents with optional recursion."""
|
||||
|
||||
_DEFAULT_MAX = 200
|
||||
_IGNORE_DIRS = {
|
||||
".git", "node_modules", "__pycache__", ".venv", "venv",
|
||||
"dist", "build", ".tox", ".mypy_cache", ".pytest_cache",
|
||||
".ruff_cache", ".coverage", "htmlcov",
|
||||
}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -205,34 +295,71 @@ class ListDirTool(Tool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List the contents of a directory."
|
||||
return (
|
||||
"List the contents of a directory. "
|
||||
"Set recursive=true to explore nested structure. "
|
||||
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "The directory path to list"}},
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The directory path to list"},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"description": "Recursively list all files (default false)",
|
||||
},
|
||||
"max_entries": {
|
||||
"type": "integer",
|
||||
"description": "Maximum entries to return (default 200)",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
async def execute(
|
||||
self, path: str, recursive: bool = False,
|
||||
max_entries: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not dir_path.exists():
|
||||
dp = self._resolve(path)
|
||||
if not dp.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
if not dir_path.is_dir():
|
||||
if not dp.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
items = []
|
||||
for item in sorted(dir_path.iterdir()):
|
||||
prefix = "📁 " if item.is_dir() else "📄 "
|
||||
items.append(f"{prefix}{item.name}")
|
||||
cap = max_entries or self._DEFAULT_MAX
|
||||
items: list[str] = []
|
||||
total = 0
|
||||
|
||||
if not items:
|
||||
if recursive:
|
||||
for item in sorted(dp.rglob("*")):
|
||||
if any(p in self._IGNORE_DIRS for p in item.parts):
|
||||
continue
|
||||
total += 1
|
||||
if len(items) < cap:
|
||||
rel = item.relative_to(dp)
|
||||
items.append(f"{rel}/" if item.is_dir() else str(rel))
|
||||
else:
|
||||
for item in sorted(dp.iterdir()):
|
||||
if item.name in self._IGNORE_DIRS:
|
||||
continue
|
||||
total += 1
|
||||
if len(items) < cap:
|
||||
pfx = "📁 " if item.is_dir() else "📄 "
|
||||
items.append(f"{pfx}{item.name}")
|
||||
|
||||
if not items and total == 0:
|
||||
return f"Directory {path} is empty"
|
||||
|
||||
return "\n".join(items)
|
||||
result = "\n".join(items)
|
||||
if total > cap:
|
||||
result += f"\n\n(truncated, showing first {cap} of {total} entries)"
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {str(e)}"
|
||||
return f"Error listing directory: {e}"
|
||||
|
||||
@@ -36,6 +36,7 @@ class MCPToolWrapper(Tool):
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.call_tool(self._original_name, arguments=kwargs),
|
||||
@@ -44,6 +45,23 @@ class MCPToolWrapper(Tool):
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||
return f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||
except asyncio.CancelledError:
|
||||
# MCP SDK's anyio cancel scopes can leak CancelledError on timeout/failure.
|
||||
# Re-raise only if our task was externally cancelled (e.g. /stop).
|
||||
task = asyncio.current_task()
|
||||
if task is not None and task.cancelling() > 0:
|
||||
raise
|
||||
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
|
||||
return "(MCP tool call was cancelled)"
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"MCP tool '{}' failed: {}: {}",
|
||||
self._name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return f"(MCP tool call failed: {type(exc).__name__})"
|
||||
|
||||
parts = []
|
||||
for block in result.content:
|
||||
if isinstance(block, types.TextContent):
|
||||
|
||||
@@ -96,7 +96,7 @@ class MessageTool(Tool):
|
||||
media=media or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -44,6 +44,10 @@ class ToolRegistry:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
|
||||
try:
|
||||
# Attempt to cast parameters to match schema types
|
||||
params = tool.cast_params(params)
|
||||
|
||||
# Validate parameters
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||
|
||||
@@ -42,6 +42,9 @@ class ExecTool(Tool):
|
||||
def name(self) -> str:
|
||||
return "exec"
|
||||
|
||||
_MAX_TIMEOUT = 600
|
||||
_MAX_OUTPUT = 10_000
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
@@ -53,22 +56,36 @@ class ExecTool(Tool):
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "Optional working directory for the command"
|
||||
}
|
||||
"description": "Optional working directory for the command",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Timeout in seconds. Increase for long-running commands "
|
||||
"like compilation or installation (default 60, max 600)."
|
||||
),
|
||||
"minimum": 1,
|
||||
"maximum": 600,
|
||||
},
|
||||
},
|
||||
"required": ["command"]
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str:
|
||||
|
||||
async def execute(
|
||||
self, command: str, working_dir: str | None = None,
|
||||
timeout: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
guard_error = self._guard_command(command, cwd)
|
||||
if guard_error:
|
||||
return guard_error
|
||||
|
||||
|
||||
effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT)
|
||||
|
||||
env = os.environ.copy()
|
||||
if self.path_append:
|
||||
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||
@@ -81,44 +98,46 @@ class ExecTool(Tool):
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=self.timeout
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
# Wait for the process to fully terminate so pipes are
|
||||
# drained and file descriptors are released.
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
return f"Error: Command timed out after {self.timeout} seconds"
|
||||
|
||||
return f"Error: Command timed out after {effective_timeout} seconds"
|
||||
|
||||
output_parts = []
|
||||
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
if process.returncode != 0:
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
# Truncate very long output
|
||||
max_len = 10000
|
||||
|
||||
# Head + tail truncation to preserve both start and end of output
|
||||
max_len = self._MAX_OUTPUT
|
||||
if len(result) > max_len:
|
||||
result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)"
|
||||
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
@@ -143,7 +162,8 @@ class ExecTool(Tool):
|
||||
|
||||
for raw in self._extract_absolute_paths(cmd):
|
||||
try:
|
||||
p = Path(raw.strip()).resolve()
|
||||
expanded = os.path.expandvars(raw.strip())
|
||||
p = Path(expanded).expanduser().resolve()
|
||||
except Exception:
|
||||
continue
|
||||
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
|
||||
@@ -154,5 +174,6 @@ class ExecTool(Tool):
|
||||
@staticmethod
|
||||
def _extract_absolute_paths(command: str) -> list[str]:
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
|
||||
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only
|
||||
return win_paths + posix_paths
|
||||
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
|
||||
home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
|
||||
return win_paths + posix_paths + home_paths
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
"""Base channel interface for chat platforms."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@@ -18,6 +21,8 @@ class BaseChannel(ABC):
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
display_name: str = "Base"
|
||||
transcription_api_key: str = ""
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
@@ -31,6 +36,19 @@ class BaseChannel(ABC):
|
||||
self.bus = bus
|
||||
self._running = False
|
||||
|
||||
async def transcribe_audio(self, file_path: str | Path) -> str:
|
||||
"""Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
|
||||
if not self.transcription_api_key:
|
||||
return ""
|
||||
try:
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
|
||||
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
|
||||
return await provider.transcribe(file_path)
|
||||
except Exception as e:
|
||||
logger.warning("{}: audio transcription failed: {}", self.name, e)
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
@@ -66,10 +84,7 @@ class BaseChannel(ABC):
|
||||
return False
|
||||
if "*" in allow_list:
|
||||
return True
|
||||
sender_str = str(sender_id)
|
||||
return sender_str in allow_list or any(
|
||||
p in allow_list for p in sender_str.split("|") if p
|
||||
)
|
||||
return str(sender_id) in allow_list
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
|
||||
@@ -57,6 +57,8 @@ class NanobotDingTalkHandler(CallbackHandler):
|
||||
content = ""
|
||||
if chatbot_msg.text:
|
||||
content = chatbot_msg.text.content.strip()
|
||||
elif chatbot_msg.extensions.get("content", {}).get("recognition"):
|
||||
content = chatbot_msg.extensions["content"]["recognition"].strip()
|
||||
if not content:
|
||||
content = message.data.get("text", {}).get("content", "").strip()
|
||||
|
||||
@@ -70,12 +72,24 @@ class NanobotDingTalkHandler(CallbackHandler):
|
||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||
|
||||
conversation_type = message.data.get("conversationType")
|
||||
conversation_id = (
|
||||
message.data.get("conversationId")
|
||||
or message.data.get("openConversationId")
|
||||
)
|
||||
|
||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||
|
||||
# Forward to Nanobot via _on_message (non-blocking).
|
||||
# Store reference to prevent GC before task completes.
|
||||
task = asyncio.create_task(
|
||||
self.channel._on_message(content, sender_id, sender_name)
|
||||
self.channel._on_message(
|
||||
content,
|
||||
sender_id,
|
||||
sender_name,
|
||||
conversation_type,
|
||||
conversation_id,
|
||||
)
|
||||
)
|
||||
self.channel._background_tasks.add(task)
|
||||
task.add_done_callback(self.channel._background_tasks.discard)
|
||||
@@ -95,11 +109,12 @@ class DingTalkChannel(BaseChannel):
|
||||
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
||||
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
||||
|
||||
Note: Currently only supports private (1:1) chat. Group messages are
|
||||
received but replies are sent back as private messages to the sender.
|
||||
Supports both private (1:1) and group chats.
|
||||
Group chat_id is stored with a "group:" prefix to route replies back.
|
||||
"""
|
||||
|
||||
name = "dingtalk"
|
||||
display_name = "DingTalk"
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||
@@ -301,14 +316,25 @@ class DingTalkChannel(BaseChannel):
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||
return False
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
headers = {"x-acs-dingtalk-access-token": token}
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [chat_id],
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
if chat_id.startswith("group:"):
|
||||
# Group chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"openConversationId": chat_id[6:], # Remove "group:" prefix,
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
else:
|
||||
# Private chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [chat_id],
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=payload, headers=headers)
|
||||
@@ -417,7 +443,14 @@ class DingTalkChannel(BaseChannel):
|
||||
f"[Attachment send failed: {filename}]",
|
||||
)
|
||||
|
||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
||||
async def _on_message(
|
||||
self,
|
||||
content: str,
|
||||
sender_id: str,
|
||||
sender_name: str,
|
||||
conversation_type: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||
|
||||
Delegates to BaseChannel._handle_message() which enforces allow_from
|
||||
@@ -425,13 +458,16 @@ class DingTalkChannel(BaseChannel):
|
||||
"""
|
||||
try:
|
||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||
is_group = conversation_type == "2" and conversation_id
|
||||
chat_id = f"group:{conversation_id}" if is_group else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender_id, # For private chat, chat_id == sender_id
|
||||
chat_id=chat_id,
|
||||
content=str(content),
|
||||
metadata={
|
||||
"sender_name": sender_name,
|
||||
"platform": "dingtalk",
|
||||
"conversation_type": conversation_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,39 +12,20 @@ from loguru import logger
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import DiscordConfig
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
|
||||
|
||||
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
|
||||
"""Split content into chunks within max_len, preferring line breaks."""
|
||||
if not content:
|
||||
return []
|
||||
if len(content) <= max_len:
|
||||
return [content]
|
||||
chunks: list[str] = []
|
||||
while content:
|
||||
if len(content) <= max_len:
|
||||
chunks.append(content)
|
||||
break
|
||||
cut = content[:max_len]
|
||||
pos = cut.rfind('\n')
|
||||
if pos <= 0:
|
||||
pos = cut.rfind(' ')
|
||||
if pos <= 0:
|
||||
pos = max_len
|
||||
chunks.append(content[:pos])
|
||||
content = content[pos:].lstrip()
|
||||
return chunks
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
|
||||
def __init__(self, config: DiscordConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
@@ -96,7 +77,7 @@ class DiscordChannel(BaseChannel):
|
||||
self._http = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API."""
|
||||
"""Send a message through Discord REST API, including file attachments."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
return
|
||||
@@ -105,15 +86,31 @@ class DiscordChannel(BaseChannel):
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
|
||||
try:
|
||||
chunks = _split_message(msg.content or "")
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
# Send file attachments first
|
||||
for media_path in msg.media or []:
|
||||
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
# Send text content
|
||||
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||
if not chunks and failed_media and not sent_media:
|
||||
chunks = split_message(
|
||||
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||
MAX_MESSAGE_LEN,
|
||||
)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Only set reply reference on the first chunk
|
||||
if i == 0 and msg.reply_to:
|
||||
# Let the first successful attachment carry the reply if present.
|
||||
if i == 0 and msg.reply_to and not sent_media:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
@@ -144,6 +141,54 @@ class DiscordChannel(BaseChannel):
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
file_path: str,
|
||||
reply_to: str | None = None,
|
||||
) -> bool:
|
||||
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
payload_json: dict[str, Any] = {}
|
||||
if reply_to:
|
||||
payload_json["message_reference"] = {"message_id": reply_to}
|
||||
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||
data: dict[str, Any] = {}
|
||||
if payload_json:
|
||||
data["payload_json"] = json.dumps(payload_json)
|
||||
response = await self._http.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
if response.status_code == 429:
|
||||
resp_data = response.json()
|
||||
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
@@ -246,7 +291,7 @@ class DiscordChannel(BaseChannel):
|
||||
|
||||
content_parts = [content] if content else []
|
||||
media_paths: list[str] = []
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir = get_media_dir("discord")
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
|
||||
@@ -35,6 +35,7 @@ class EmailChannel(BaseChannel):
|
||||
"""
|
||||
|
||||
name = "email"
|
||||
display_name = "Email"
|
||||
_IMAP_MONTHS = (
|
||||
"Jan",
|
||||
"Feb",
|
||||
|
||||
@@ -14,6 +14,7 @@ from loguru import logger
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import FeishuConfig
|
||||
|
||||
import importlib.util
|
||||
@@ -243,6 +244,7 @@ class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
|
||||
name = "feishu"
|
||||
display_name = "Feishu"
|
||||
|
||||
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
@@ -253,6 +255,12 @@ class FeishuChannel(BaseChannel):
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
@staticmethod
|
||||
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||
"""Register an event handler only when the SDK supports it."""
|
||||
method = getattr(builder, method_name, None)
|
||||
return method(handler) if callable(method) else builder
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Feishu bot with WebSocket long connection."""
|
||||
if not FEISHU_AVAILABLE:
|
||||
@@ -273,14 +281,24 @@ class FeishuChannel(BaseChannel):
|
||||
.app_secret(self.config.app_secret) \
|
||||
.log_level(lark.LogLevel.INFO) \
|
||||
.build()
|
||||
|
||||
# Create event handler (only register message receive, ignore other events)
|
||||
event_handler = lark.EventDispatcherHandler.builder(
|
||||
builder = lark.EventDispatcherHandler.builder(
|
||||
self.config.encrypt_key or "",
|
||||
self.config.verification_token or "",
|
||||
).register_p2_im_message_receive_v1(
|
||||
self._on_message_sync
|
||||
).build()
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_message_read_v1", self._on_message_read
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder,
|
||||
"register_p2_im_chat_access_event_bot_p2p_chat_entered_v1",
|
||||
self._on_bot_p2p_chat_entered,
|
||||
)
|
||||
event_handler = builder.build()
|
||||
|
||||
# Create WebSocket client for long connection
|
||||
self._ws_client = lark.ws.Client(
|
||||
@@ -334,6 +352,27 @@ class FeishuChannel(BaseChannel):
|
||||
self._running = False
|
||||
logger.info("Feishu bot stopped")
|
||||
|
||||
def _is_bot_mentioned(self, message: Any) -> bool:
|
||||
"""Check if the bot is @mentioned in the message."""
|
||||
raw_content = message.content or ""
|
||||
if "@_all" in raw_content:
|
||||
return True
|
||||
|
||||
for mention in getattr(message, "mentions", None) or []:
|
||||
mid = getattr(mention, "id", None)
|
||||
if not mid:
|
||||
continue
|
||||
# Bot mentions have no user_id (None or "") but a valid open_id
|
||||
if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_group_message_for_bot(self, message: Any) -> bool:
|
||||
"""Allow group messages when policy is open or bot is @mentioned."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
return self._is_bot_mentioned(message)
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||
@@ -715,8 +754,7 @@ class FeishuChannel(BaseChannel):
|
||||
(file_path, content_text) - file_path is None if download failed
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
media_dir = get_media_dir("feishu")
|
||||
|
||||
data, filename = None, None
|
||||
|
||||
@@ -736,8 +774,9 @@ class FeishuChannel(BaseChannel):
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
if not filename:
|
||||
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
|
||||
filename = f"{file_key[:16]}{ext}"
|
||||
filename = file_key[:16]
|
||||
if msg_type == "audio" and not filename.endswith(".opus"):
|
||||
filename = f"{filename}.opus"
|
||||
|
||||
if data and filename:
|
||||
file_path = media_dir / filename
|
||||
@@ -841,7 +880,7 @@ class FeishuChannel(BaseChannel):
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu message: {}", e)
|
||||
|
||||
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
def _on_message_sync(self, data: Any) -> None:
|
||||
"""
|
||||
Sync handler for incoming messages (called from WebSocket thread).
|
||||
Schedules async handling in the main event loop.
|
||||
@@ -849,7 +888,7 @@ class FeishuChannel(BaseChannel):
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||
|
||||
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
async def _on_message(self, data: Any) -> None:
|
||||
"""Handle incoming message from Feishu."""
|
||||
try:
|
||||
event = data.event
|
||||
@@ -875,6 +914,10 @@ class FeishuChannel(BaseChannel):
|
||||
chat_type = message.chat_type
|
||||
msg_type = message.message_type
|
||||
|
||||
if chat_type == "group" and not self._is_group_message_for_bot(message):
|
||||
logger.debug("Feishu: skipping group message (not mentioned)")
|
||||
return
|
||||
|
||||
# Add reaction
|
||||
await self._add_reaction(message_id, self.config.react_emoji)
|
||||
|
||||
@@ -909,6 +952,12 @@ class FeishuChannel(BaseChannel):
|
||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
|
||||
if msg_type == "audio" and file_path:
|
||||
transcription = await self.transcribe_audio(file_path)
|
||||
if transcription:
|
||||
content_text = f"[transcription: {transcription}]"
|
||||
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||
@@ -941,3 +990,16 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing Feishu message: {}", e)
|
||||
|
||||
def _on_reaction_created(self, data: Any) -> None:
|
||||
"""Ignore reaction events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_message_read(self, data: Any) -> None:
|
||||
"""Ignore read events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
|
||||
"""Ignore p2p-enter events when a user opens a bot chat."""
|
||||
logger.debug("Bot entered p2p chat (user opened chat window)")
|
||||
pass
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
@@ -32,122 +31,23 @@ class ChannelManager:
|
||||
self._init_channels()
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels based on config."""
|
||||
"""Initialize channels discovered via pkgutil scan."""
|
||||
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||
|
||||
# Telegram channel
|
||||
if self.config.channels.telegram.enabled:
|
||||
groq_key = self.config.providers.groq.api_key
|
||||
|
||||
for modname in discover_channel_names():
|
||||
section = getattr(self.config.channels, modname, None)
|
||||
if not section or not getattr(section, "enabled", False):
|
||||
continue
|
||||
try:
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
self.channels["telegram"] = TelegramChannel(
|
||||
self.config.channels.telegram,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Telegram channel enabled")
|
||||
cls = load_channel_class(modname)
|
||||
channel = cls(section, self.bus)
|
||||
channel.transcription_api_key = groq_key
|
||||
self.channels[modname] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except ImportError as e:
|
||||
logger.warning("Telegram channel not available: {}", e)
|
||||
|
||||
# WhatsApp channel
|
||||
if self.config.channels.whatsapp.enabled:
|
||||
try:
|
||||
from nanobot.channels.whatsapp import WhatsAppChannel
|
||||
self.channels["whatsapp"] = WhatsAppChannel(
|
||||
self.config.channels.whatsapp, self.bus
|
||||
)
|
||||
logger.info("WhatsApp channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("WhatsApp channel not available: {}", e)
|
||||
|
||||
# Discord channel
|
||||
if self.config.channels.discord.enabled:
|
||||
try:
|
||||
from nanobot.channels.discord import DiscordChannel
|
||||
self.channels["discord"] = DiscordChannel(
|
||||
self.config.channels.discord, self.bus
|
||||
)
|
||||
logger.info("Discord channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Discord channel not available: {}", e)
|
||||
|
||||
# Feishu channel
|
||||
if self.config.channels.feishu.enabled:
|
||||
try:
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
self.channels["feishu"] = FeishuChannel(
|
||||
self.config.channels.feishu, self.bus
|
||||
)
|
||||
logger.info("Feishu channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Feishu channel not available: {}", e)
|
||||
|
||||
# Mochat channel
|
||||
if self.config.channels.mochat.enabled:
|
||||
try:
|
||||
from nanobot.channels.mochat import MochatChannel
|
||||
|
||||
self.channels["mochat"] = MochatChannel(
|
||||
self.config.channels.mochat, self.bus
|
||||
)
|
||||
logger.info("Mochat channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Mochat channel not available: {}", e)
|
||||
|
||||
# DingTalk channel
|
||||
if self.config.channels.dingtalk.enabled:
|
||||
try:
|
||||
from nanobot.channels.dingtalk import DingTalkChannel
|
||||
self.channels["dingtalk"] = DingTalkChannel(
|
||||
self.config.channels.dingtalk, self.bus
|
||||
)
|
||||
logger.info("DingTalk channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("DingTalk channel not available: {}", e)
|
||||
|
||||
# Email channel
|
||||
if self.config.channels.email.enabled:
|
||||
try:
|
||||
from nanobot.channels.email import EmailChannel
|
||||
self.channels["email"] = EmailChannel(
|
||||
self.config.channels.email, self.bus
|
||||
)
|
||||
logger.info("Email channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Email channel not available: {}", e)
|
||||
|
||||
# Slack channel
|
||||
if self.config.channels.slack.enabled:
|
||||
try:
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
self.channels["slack"] = SlackChannel(
|
||||
self.config.channels.slack, self.bus
|
||||
)
|
||||
logger.info("Slack channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Slack channel not available: {}", e)
|
||||
|
||||
# QQ channel
|
||||
if self.config.channels.qq.enabled:
|
||||
try:
|
||||
from nanobot.channels.qq import QQChannel
|
||||
self.channels["qq"] = QQChannel(
|
||||
self.config.channels.qq,
|
||||
self.bus,
|
||||
)
|
||||
logger.info("QQ channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("QQ channel not available: {}", e)
|
||||
|
||||
# Matrix channel
|
||||
if self.config.channels.matrix.enabled:
|
||||
try:
|
||||
from nanobot.channels.matrix import MatrixChannel
|
||||
self.channels["matrix"] = MatrixChannel(
|
||||
self.config.channels.matrix,
|
||||
self.bus,
|
||||
)
|
||||
logger.info("Matrix channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Matrix channel not available: {}", e)
|
||||
logger.warning("{} channel not available: {}", modname, e)
|
||||
|
||||
self._validate_allow_from()
|
||||
|
||||
|
||||
@@ -37,8 +37,9 @@ except ImportError as e:
|
||||
) from e
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.loader import get_data_dir
|
||||
from nanobot.config.paths import get_data_dir, get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||
@@ -146,15 +147,15 @@ class MatrixChannel(BaseChannel):
|
||||
"""Matrix (Element) channel using long-polling sync."""
|
||||
|
||||
name = "matrix"
|
||||
display_name = "Matrix"
|
||||
|
||||
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
|
||||
workspace: Path | None = None):
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.client: AsyncClient | None = None
|
||||
self._sync_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._restrict_to_workspace = restrict_to_workspace
|
||||
self._workspace = workspace.expanduser().resolve() if workspace else None
|
||||
self._restrict_to_workspace = False
|
||||
self._workspace: Path | None = None
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
|
||||
@@ -490,9 +491,7 @@ class MatrixChannel(BaseChannel):
|
||||
return False
|
||||
|
||||
def _media_dir(self) -> Path:
|
||||
d = get_data_dir() / "media" / "matrix"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
return get_media_dir("matrix")
|
||||
|
||||
@staticmethod
|
||||
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||
@@ -679,7 +678,14 @@ class MatrixChannel(BaseChannel):
|
||||
parts: list[str] = []
|
||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||
parts.append(body.strip())
|
||||
if marker:
|
||||
|
||||
if attachment and attachment.get("type") == "audio":
|
||||
transcription = await self.transcribe_audio(attachment["path"])
|
||||
if transcription:
|
||||
parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
parts.append(marker)
|
||||
elif marker:
|
||||
parts.append(marker)
|
||||
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
|
||||
@@ -15,8 +15,8 @@ from loguru import logger
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_runtime_subdir
|
||||
from nanobot.config.schema import MochatConfig
|
||||
from nanobot.utils.helpers import get_data_path
|
||||
|
||||
try:
|
||||
import socketio
|
||||
@@ -216,6 +216,7 @@ class MochatChannel(BaseChannel):
|
||||
"""Mochat channel using socket.io with fallback polling workers."""
|
||||
|
||||
name = "mochat"
|
||||
display_name = "Mochat"
|
||||
|
||||
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
@@ -224,7 +225,7 @@ class MochatChannel(BaseChannel):
|
||||
self._socket: Any = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
self._state_dir = get_data_path() / "mochat"
|
||||
self._state_dir = get_runtime_subdir("mochat")
|
||||
self._cursor_path = self._state_dir / "session_cursors.json"
|
||||
self._session_cursor: dict[str, int] = {}
|
||||
self._cursor_save_task: asyncio.Task | None = None
|
||||
|
||||
@@ -13,16 +13,17 @@ from nanobot.config.schema import QQConfig
|
||||
|
||||
try:
|
||||
import botpy
|
||||
from botpy.message import C2CMessage
|
||||
from botpy.message import C2CMessage, GroupMessage
|
||||
|
||||
QQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
botpy = None
|
||||
C2CMessage = None
|
||||
GroupMessage = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botpy.message import C2CMessage
|
||||
from botpy.message import C2CMessage, GroupMessage
|
||||
|
||||
|
||||
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||
@@ -38,10 +39,13 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||
logger.info("QQ bot ready: {}", self.robot.name)
|
||||
|
||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||
await channel._on_message(message)
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
async def on_group_at_message_create(self, message: "GroupMessage"):
|
||||
await channel._on_message(message, is_group=True)
|
||||
|
||||
async def on_direct_message_create(self, message):
|
||||
await channel._on_message(message)
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
return _Bot
|
||||
|
||||
@@ -50,6 +54,7 @@ class QQChannel(BaseChannel):
|
||||
"""QQ channel using botpy SDK with WebSocket connection."""
|
||||
|
||||
name = "qq"
|
||||
display_name = "QQ"
|
||||
|
||||
def __init__(self, config: QQConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
@@ -57,6 +62,7 @@ class QQChannel(BaseChannel):
|
||||
self._client: "botpy.Client | None" = None
|
||||
self._processed_ids: deque = deque(maxlen=1000)
|
||||
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
|
||||
self._chat_type_cache: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the QQ bot."""
|
||||
@@ -71,8 +77,7 @@ class QQChannel(BaseChannel):
|
||||
self._running = True
|
||||
BotClass = _make_bot_class(self)
|
||||
self._client = BotClass()
|
||||
|
||||
logger.info("QQ bot started (C2C private message)")
|
||||
logger.info("QQ bot started (C2C & Group supported)")
|
||||
await self._run_bot()
|
||||
|
||||
async def _run_bot(self) -> None:
|
||||
@@ -101,20 +106,31 @@ class QQChannel(BaseChannel):
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
self._msg_seq += 1 # 递增序列号
|
||||
await self._client.api.post_c2c_message(
|
||||
openid=msg.chat_id,
|
||||
msg_type=0,
|
||||
content=msg.content,
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq, # 添加序列号避免去重
|
||||
)
|
||||
self._msg_seq += 1
|
||||
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
if msg_type == "group":
|
||||
await self._client.api.post_group_message(
|
||||
group_openid=msg.chat_id,
|
||||
msg_type=2,
|
||||
markdown={"content": msg.content},
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
else:
|
||||
await self._client.api.post_c2c_message(
|
||||
openid=msg.chat_id,
|
||||
msg_type=2,
|
||||
markdown={"content": msg.content},
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error sending QQ message: {}", e)
|
||||
|
||||
async def _on_message(self, data: "C2CMessage") -> None:
|
||||
async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
|
||||
"""Handle incoming message from QQ."""
|
||||
try:
|
||||
# Dedup by message ID
|
||||
@@ -122,18 +138,24 @@ class QQChannel(BaseChannel):
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
author = data.author
|
||||
user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
|
||||
content = (data.content or "").strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
metadata={"message_id": data.id},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ message")
|
||||
|
||||
|
||||
35
nanobot/channels/registry.py
Normal file
35
nanobot/channels/registry.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Auto-discovery for channel modules — no hardcoded registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
_INTERNAL = frozenset({"base", "manager", "registry"})
|
||||
|
||||
|
||||
def discover_channel_names() -> list[str]:
|
||||
"""Return all channel module names by scanning the package (zero imports)."""
|
||||
import nanobot.channels as pkg
|
||||
|
||||
return [
|
||||
name
|
||||
for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
|
||||
if name not in _INTERNAL and not ispkg
|
||||
]
|
||||
|
||||
|
||||
def load_channel_class(module_name: str) -> type[BaseChannel]:
|
||||
"""Import *module_name* and return the first BaseChannel subclass found."""
|
||||
from nanobot.channels.base import BaseChannel as _Base
|
||||
|
||||
mod = importlib.import_module(f"nanobot.channels.{module_name}")
|
||||
for attr in dir(mod):
|
||||
obj = getattr(mod, attr)
|
||||
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
|
||||
return obj
|
||||
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
|
||||
@@ -21,6 +21,7 @@ class SlackChannel(BaseChannel):
|
||||
"""Slack channel using Socket Mode."""
|
||||
|
||||
name = "slack"
|
||||
display_name = "Slack"
|
||||
|
||||
def __init__(self, config: SlackConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
@@ -81,14 +82,15 @@ class SlackChannel(BaseChannel):
|
||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||
thread_ts = slack_meta.get("thread_ts")
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
||||
use_thread = thread_ts and channel_type != "im"
|
||||
thread_ts_param = thread_ts if use_thread else None
|
||||
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
|
||||
thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
|
||||
|
||||
if msg.content:
|
||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||
# but send a single blank message when the bot has no text or files to send.
|
||||
if msg.content or not (msg.media or []):
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
text=self._to_mrkdwn(msg.content),
|
||||
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
|
||||
@@ -277,4 +279,3 @@ class SlackChannel(BaseChannel):
|
||||
if parts:
|
||||
rows.append(" · ".join(parts))
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, ReplyParameters, Update
|
||||
@@ -13,7 +15,53 @@ from telegram.request import HTTPXRequest
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
||||
TELEGRAM_REPLY_CONTEXT_MAX_LEN = TELEGRAM_MAX_MESSAGE_LEN # Max length for reply context in user message
|
||||
|
||||
|
||||
def _strip_md(s: str) -> str:
|
||||
"""Strip markdown inline formatting from text."""
|
||||
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
|
||||
s = re.sub(r'__(.+?)__', r'\1', s)
|
||||
s = re.sub(r'~~(.+?)~~', r'\1', s)
|
||||
s = re.sub(r'`([^`]+)`', r'\1', s)
|
||||
return s.strip()
|
||||
|
||||
|
||||
def _render_table_box(table_lines: list[str]) -> str:
|
||||
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
|
||||
|
||||
def dw(s: str) -> int:
|
||||
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
|
||||
|
||||
rows: list[list[str]] = []
|
||||
has_sep = False
|
||||
for line in table_lines:
|
||||
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
|
||||
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
|
||||
has_sep = True
|
||||
continue
|
||||
rows.append(cells)
|
||||
if not rows or not has_sep:
|
||||
return '\n'.join(table_lines)
|
||||
|
||||
ncols = max(len(r) for r in rows)
|
||||
for r in rows:
|
||||
r.extend([''] * (ncols - len(r)))
|
||||
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
||||
|
||||
def dr(cells: list[str]) -> str:
|
||||
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
|
||||
|
||||
out = [dr(rows[0])]
|
||||
out.append(' '.join('─' * w for w in widths))
|
||||
for row in rows[1:]:
|
||||
out.append(dr(row))
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
def _markdown_to_telegram_html(text: str) -> str:
|
||||
@@ -31,6 +79,27 @@ def _markdown_to_telegram_html(text: str) -> str:
|
||||
|
||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||
|
||||
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
|
||||
lines = text.split('\n')
|
||||
rebuilt: list[str] = []
|
||||
li = 0
|
||||
while li < len(lines):
|
||||
if re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl: list[str] = []
|
||||
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl.append(lines[li])
|
||||
li += 1
|
||||
box = _render_table_box(tbl)
|
||||
if box != '\n'.join(tbl):
|
||||
code_blocks.append(box)
|
||||
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
|
||||
else:
|
||||
rebuilt.extend(tbl)
|
||||
else:
|
||||
rebuilt.append(lines[li])
|
||||
li += 1
|
||||
text = '\n'.join(rebuilt)
|
||||
|
||||
# 2. Extract and protect inline code
|
||||
inline_codes: list[str] = []
|
||||
def save_inline_code(m: re.Match) -> str:
|
||||
@@ -79,26 +148,6 @@ def _markdown_to_telegram_html(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _split_message(content: str, max_len: int = 4000) -> list[str]:
|
||||
"""Split content into chunks within max_len, preferring line breaks."""
|
||||
if len(content) <= max_len:
|
||||
return [content]
|
||||
chunks: list[str] = []
|
||||
while content:
|
||||
if len(content) <= max_len:
|
||||
chunks.append(content)
|
||||
break
|
||||
cut = content[:max_len]
|
||||
pos = cut.rfind('\n')
|
||||
if pos == -1:
|
||||
pos = cut.rfind(' ')
|
||||
if pos == -1:
|
||||
pos = max_len
|
||||
chunks.append(content[:pos])
|
||||
content = content[pos:].lstrip()
|
||||
return chunks
|
||||
|
||||
|
||||
class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
Telegram channel using long polling.
|
||||
@@ -107,6 +156,7 @@ class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
display_name = "Telegram"
|
||||
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
@@ -114,22 +164,39 @@ class TelegramChannel(BaseChannel):
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("stop", "Stop the current task"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
BotCommand("restart", "Restart the bot"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TelegramConfig,
|
||||
bus: MessageBus,
|
||||
groq_api_key: str = "",
|
||||
):
|
||||
def __init__(self, config: TelegramConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig = config
|
||||
self.groq_api_key = groq_api_key
|
||||
self._app: Application | None = None
|
||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
self._media_group_buffers: dict[str, dict] = {}
|
||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||
self._message_threads: dict[tuple[str, int], int] = {}
|
||||
self._bot_user_id: int | None = None
|
||||
self._bot_username: str | None = None
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Preserve Telegram's legacy id|username allowlist matching."""
|
||||
if super().is_allowed(sender_id):
|
||||
return True
|
||||
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if not allow_list or "*" in allow_list:
|
||||
return False
|
||||
|
||||
sender_str = str(sender_id)
|
||||
if sender_str.count("|") != 1:
|
||||
return False
|
||||
|
||||
sid, username = sender_str.split("|", 1)
|
||||
if not sid.isdigit() or not username:
|
||||
return False
|
||||
|
||||
return sid in allow_list or username in allow_list
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Telegram bot with long polling."""
|
||||
@@ -140,16 +207,22 @@ class TelegramChannel(BaseChannel):
|
||||
self._running = True
|
||||
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||
req = HTTPXRequest(
|
||||
connection_pool_size=16,
|
||||
pool_timeout=5.0,
|
||||
connect_timeout=30.0,
|
||||
read_timeout=30.0,
|
||||
proxy=self.config.proxy if self.config.proxy else None,
|
||||
)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
if self.config.proxy:
|
||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("restart", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
@@ -169,6 +242,8 @@ class TelegramChannel(BaseChannel):
|
||||
|
||||
# Get bot info and register command menu
|
||||
bot_info = await self._app.bot.get_me()
|
||||
self._bot_user_id = getattr(bot_info, "id", None)
|
||||
self._bot_username = getattr(bot_info, "username", None)
|
||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||
|
||||
try:
|
||||
@@ -234,10 +309,16 @@ class TelegramChannel(BaseChannel):
|
||||
except ValueError:
|
||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||
return
|
||||
reply_to_message_id = msg.metadata.get("message_id")
|
||||
message_thread_id = msg.metadata.get("message_thread_id")
|
||||
if message_thread_id is None and reply_to_message_id is not None:
|
||||
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
|
||||
thread_kwargs = {}
|
||||
if message_thread_id is not None:
|
||||
thread_kwargs["message_thread_id"] = message_thread_id
|
||||
|
||||
reply_params = None
|
||||
if self.config.reply_to_message:
|
||||
reply_to_message_id = msg.metadata.get("message_id")
|
||||
if reply_to_message_id:
|
||||
reply_params = ReplyParameters(
|
||||
message_id=reply_to_message_id,
|
||||
@@ -258,7 +339,8 @@ class TelegramChannel(BaseChannel):
|
||||
await sender(
|
||||
chat_id=chat_id,
|
||||
**{param: f},
|
||||
reply_parameters=reply_params
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
filename = media_path.rsplit("/", 1)[-1]
|
||||
@@ -266,48 +348,71 @@ class TelegramChannel(BaseChannel):
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"[Failed to send: {filename}]",
|
||||
reply_parameters=reply_params
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
|
||||
# Send text content
|
||||
if msg.content and msg.content != "[empty message]":
|
||||
is_progress = msg.metadata.get("_progress", False)
|
||||
draft_id = msg.metadata.get("message_id")
|
||||
|
||||
for chunk in _split_message(msg.content):
|
||||
try:
|
||||
html = _markdown_to_telegram_html(chunk)
|
||||
if is_progress and draft_id:
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id,
|
||||
draft_id=draft_id,
|
||||
text=html,
|
||||
parse_mode="HTML"
|
||||
)
|
||||
else:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=html,
|
||||
parse_mode="HTML",
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
if is_progress and draft_id:
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id,
|
||||
draft_id=draft_id,
|
||||
text=chunk
|
||||
)
|
||||
else:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=chunk,
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||
# Final response: simulate streaming via draft, then persist
|
||||
if not is_progress:
|
||||
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
|
||||
else:
|
||||
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||
|
||||
async def _send_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Send a plain text message with HTML fallback."""
|
||||
try:
|
||||
html = _markdown_to_telegram_html(text)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id, text=html, parse_mode="HTML",
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
async def _send_with_streaming(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
||||
draft_id = int(time.time() * 1000) % (2**31)
|
||||
try:
|
||||
step = max(len(text) // 8, 40)
|
||||
for i in range(step, len(text), step):
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text[:i],
|
||||
)
|
||||
await asyncio.sleep(0.04)
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text,
|
||||
)
|
||||
await asyncio.sleep(0.15)
|
||||
except Exception:
|
||||
pass
|
||||
await self._send_text(chat_id, text, reply_params, thread_kwargs)
|
||||
|
||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /start command."""
|
||||
@@ -338,14 +443,180 @@ class TelegramChannel(BaseChannel):
|
||||
sid = str(user.id)
|
||||
return f"{sid}|{user.username}" if user.username else sid
|
||||
|
||||
@staticmethod
|
||||
def _derive_topic_session_key(message) -> str | None:
|
||||
"""Derive topic-scoped session key for non-private Telegram chats."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message.chat.type == "private" or message_thread_id is None:
|
||||
return None
|
||||
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
|
||||
|
||||
@staticmethod
|
||||
def _build_message_metadata(message, user) -> dict:
|
||||
"""Build common Telegram inbound metadata payload."""
|
||||
reply_to = getattr(message, "reply_to_message", None)
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private",
|
||||
"message_thread_id": getattr(message, "message_thread_id", None),
|
||||
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||
"reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_reply_context(message) -> str | None:
|
||||
"""Extract text from the message being replied to, if any."""
|
||||
reply = getattr(message, "reply_to_message", None)
|
||||
if not reply:
|
||||
return None
|
||||
text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
|
||||
if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
|
||||
text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
|
||||
return f"[Reply to: {text}]" if text else None
|
||||
|
||||
async def _download_message_media(
|
||||
self, msg, *, add_failure_content: bool = False
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Download media from a message (current or reply). Returns (media_paths, content_parts)."""
|
||||
media_file = None
|
||||
media_type = None
|
||||
if getattr(msg, "photo", None):
|
||||
media_file = msg.photo[-1]
|
||||
media_type = "image"
|
||||
elif getattr(msg, "voice", None):
|
||||
media_file = msg.voice
|
||||
media_type = "voice"
|
||||
elif getattr(msg, "audio", None):
|
||||
media_file = msg.audio
|
||||
media_type = "audio"
|
||||
elif getattr(msg, "document", None):
|
||||
media_file = msg.document
|
||||
media_type = "file"
|
||||
elif getattr(msg, "video", None):
|
||||
media_file = msg.video
|
||||
media_type = "video"
|
||||
elif getattr(msg, "video_note", None):
|
||||
media_file = msg.video_note
|
||||
media_type = "video"
|
||||
elif getattr(msg, "animation", None):
|
||||
media_file = msg.animation
|
||||
media_type = "animation"
|
||||
if not media_file or not self._app:
|
||||
return [], []
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(
|
||||
media_type,
|
||||
getattr(media_file, "mime_type", None),
|
||||
getattr(media_file, "file_name", None),
|
||||
)
|
||||
media_dir = get_media_dir("telegram")
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
path_str = str(file_path)
|
||||
if media_type in ("voice", "audio"):
|
||||
transcription = await self.transcribe_audio(file_path)
|
||||
if transcription:
|
||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||
return [path_str], [f"[transcription: {transcription}]"]
|
||||
return [path_str], [f"[{media_type}: {path_str}]"]
|
||||
return [path_str], [f"[{media_type}: {path_str}]"]
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download message media: {}", e)
|
||||
if add_failure_content:
|
||||
return [], [f"[{media_type}: download failed]"]
|
||||
return [], []
|
||||
|
||||
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
|
||||
"""Load bot identity once and reuse it for mention/reply checks."""
|
||||
if self._bot_user_id is not None or self._bot_username is not None:
|
||||
return self._bot_user_id, self._bot_username
|
||||
if not self._app:
|
||||
return None, None
|
||||
bot_info = await self._app.bot.get_me()
|
||||
self._bot_user_id = getattr(bot_info, "id", None)
|
||||
self._bot_username = getattr(bot_info, "username", None)
|
||||
return self._bot_user_id, self._bot_username
|
||||
|
||||
@staticmethod
|
||||
def _has_mention_entity(
|
||||
text: str,
|
||||
entities,
|
||||
bot_username: str,
|
||||
bot_id: int | None,
|
||||
) -> bool:
|
||||
"""Check Telegram mention entities against the bot username."""
|
||||
handle = f"@{bot_username}".lower()
|
||||
for entity in entities or []:
|
||||
entity_type = getattr(entity, "type", None)
|
||||
if entity_type == "text_mention":
|
||||
user = getattr(entity, "user", None)
|
||||
if user is not None and bot_id is not None and getattr(user, "id", None) == bot_id:
|
||||
return True
|
||||
continue
|
||||
if entity_type != "mention":
|
||||
continue
|
||||
offset = getattr(entity, "offset", None)
|
||||
length = getattr(entity, "length", None)
|
||||
if offset is None or length is None:
|
||||
continue
|
||||
if text[offset : offset + length].lower() == handle:
|
||||
return True
|
||||
return handle in text.lower()
|
||||
|
||||
async def _is_group_message_for_bot(self, message) -> bool:
|
||||
"""Allow group messages when policy is open, @mentioned, or replying to the bot."""
|
||||
if message.chat.type == "private" or self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
bot_id, bot_username = await self._ensure_bot_identity()
|
||||
if bot_username:
|
||||
text = message.text or ""
|
||||
caption = message.caption or ""
|
||||
if self._has_mention_entity(
|
||||
text,
|
||||
getattr(message, "entities", None),
|
||||
bot_username,
|
||||
bot_id,
|
||||
):
|
||||
return True
|
||||
if self._has_mention_entity(
|
||||
caption,
|
||||
getattr(message, "caption_entities", None),
|
||||
bot_username,
|
||||
bot_id,
|
||||
):
|
||||
return True
|
||||
|
||||
reply_user = getattr(getattr(message, "reply_to_message", None), "from_user", None)
|
||||
return bool(bot_id and reply_user and reply_user.id == bot_id)
|
||||
|
||||
def _remember_thread_context(self, message) -> None:
|
||||
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message_thread_id is None:
|
||||
return
|
||||
key = (str(message.chat_id), message.message_id)
|
||||
self._message_threads[key] = message_thread_id
|
||||
if len(self._message_threads) > 1000:
|
||||
self._message_threads.pop(next(iter(self._message_threads)))
|
||||
|
||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
self._remember_thread_context(message)
|
||||
await self._handle_message(
|
||||
sender_id=self._sender_id(update.effective_user),
|
||||
chat_id=str(update.message.chat_id),
|
||||
content=update.message.text,
|
||||
sender_id=self._sender_id(user),
|
||||
chat_id=str(message.chat_id),
|
||||
content=message.text or "",
|
||||
metadata=self._build_message_metadata(message, user),
|
||||
session_key=self._derive_topic_session_key(message),
|
||||
)
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
@@ -357,10 +628,14 @@ class TelegramChannel(BaseChannel):
|
||||
user = update.effective_user
|
||||
chat_id = message.chat_id
|
||||
sender_id = self._sender_id(user)
|
||||
self._remember_thread_context(message)
|
||||
|
||||
# Store chat_id for replies
|
||||
self._chat_ids[sender_id] = chat_id
|
||||
|
||||
if not await self._is_group_message_for_bot(message):
|
||||
return
|
||||
|
||||
# Build content from text and/or media
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
@@ -371,62 +646,33 @@ class TelegramChannel(BaseChannel):
|
||||
if message.caption:
|
||||
content_parts.append(message.caption)
|
||||
|
||||
# Handle media files
|
||||
media_file = None
|
||||
media_type = None
|
||||
|
||||
if message.photo:
|
||||
media_file = message.photo[-1] # Largest photo
|
||||
media_type = "image"
|
||||
elif message.voice:
|
||||
media_file = message.voice
|
||||
media_type = "voice"
|
||||
elif message.audio:
|
||||
media_file = message.audio
|
||||
media_type = "audio"
|
||||
elif message.document:
|
||||
media_file = message.document
|
||||
media_type = "file"
|
||||
|
||||
# Download media if present
|
||||
if media_file and self._app:
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
||||
|
||||
# Save to workspace/media/
|
||||
from pathlib import Path
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
|
||||
media_paths.append(str(file_path))
|
||||
|
||||
# Handle voice transcription
|
||||
if media_type == "voice" or media_type == "audio":
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||
transcription = await transcriber.transcribe(file_path)
|
||||
if transcription:
|
||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||
content_parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download media: {}", e)
|
||||
content_parts.append(f"[{media_type}: download failed]")
|
||||
# Download current message media
|
||||
current_media_paths, current_media_parts = await self._download_message_media(
|
||||
message, add_failure_content=True
|
||||
)
|
||||
media_paths.extend(current_media_paths)
|
||||
content_parts.extend(current_media_parts)
|
||||
if current_media_paths:
|
||||
logger.debug("Downloaded message media to {}", current_media_paths[0])
|
||||
|
||||
# Reply context: text and/or media from the replied-to message
|
||||
reply = getattr(message, "reply_to_message", None)
|
||||
if reply is not None:
|
||||
reply_ctx = self._extract_reply_context(message)
|
||||
reply_media, reply_media_parts = await self._download_message_media(reply)
|
||||
if reply_media:
|
||||
media_paths = reply_media + media_paths
|
||||
logger.debug("Attached replied-to media: {}", reply_media[0])
|
||||
tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
|
||||
if tag:
|
||||
content_parts.insert(0, tag)
|
||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||
|
||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||
|
||||
str_chat_id = str(chat_id)
|
||||
metadata = self._build_message_metadata(message, user)
|
||||
session_key = self._derive_topic_session_key(message)
|
||||
|
||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||
if media_group_id := getattr(message, "media_group_id", None):
|
||||
@@ -435,11 +681,8 @@ class TelegramChannel(BaseChannel):
|
||||
self._media_group_buffers[key] = {
|
||||
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||
"contents": [], "media": [],
|
||||
"metadata": {
|
||||
"message_id": message.message_id, "user_id": user.id,
|
||||
"username": user.username, "first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private",
|
||||
},
|
||||
"metadata": metadata,
|
||||
"session_key": session_key,
|
||||
}
|
||||
self._start_typing(str_chat_id)
|
||||
buf = self._media_group_buffers[key]
|
||||
@@ -459,13 +702,8 @@ class TelegramChannel(BaseChannel):
|
||||
chat_id=str_chat_id,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private"
|
||||
}
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
async def _flush_media_group(self, key: str) -> None:
|
||||
@@ -479,6 +717,7 @@ class TelegramChannel(BaseChannel):
|
||||
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||
metadata=buf["metadata"],
|
||||
session_key=buf.get("session_key"),
|
||||
)
|
||||
finally:
|
||||
self._media_group_tasks.pop(key, None)
|
||||
@@ -510,8 +749,13 @@ class TelegramChannel(BaseChannel):
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error("Telegram error: {}", context.error)
|
||||
|
||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
||||
"""Get file extension based on media type."""
|
||||
def _get_extension(
|
||||
self,
|
||||
media_type: str,
|
||||
mime_type: str | None,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get file extension based on media type or original filename."""
|
||||
if mime_type:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
||||
@@ -521,4 +765,12 @@ class TelegramChannel(BaseChannel):
|
||||
return ext_map[mime_type]
|
||||
|
||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||
return type_map.get(media_type, "")
|
||||
if ext := type_map.get(media_type, ""):
|
||||
return ext
|
||||
|
||||
if filename:
|
||||
from pathlib import Path
|
||||
|
||||
return "".join(Path(filename).suffixes)
|
||||
|
||||
return ""
|
||||
|
||||
353
nanobot/channels/wecom.py
Normal file
353
nanobot/channels/wecom.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import WecomConfig
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
|
||||
# Message type display mapping
|
||||
MSG_TYPE_MAP = {
|
||||
"image": "[image]",
|
||||
"voice": "[voice]",
|
||||
"file": "[file]",
|
||||
"mixed": "[mixed content]",
|
||||
}
|
||||
|
||||
|
||||
class WecomChannel(BaseChannel):
|
||||
"""
|
||||
WeCom (Enterprise WeChat) channel using WebSocket long connection.
|
||||
|
||||
Uses WebSocket to receive events - no public IP or webhook required.
|
||||
|
||||
Requires:
|
||||
- Bot ID and Secret from WeCom AI Bot platform
|
||||
"""
|
||||
|
||||
name = "wecom"
|
||||
display_name = "WeCom"
|
||||
|
||||
def __init__(self, config: WecomConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WecomConfig = config
|
||||
self._client: Any = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._generate_req_id = None
|
||||
# Store frame headers for each chat to enable replies
|
||||
self._chat_frames: dict[str, Any] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WeCom bot with WebSocket long connection."""
|
||||
if not WECOM_AVAILABLE:
|
||||
logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
|
||||
return
|
||||
|
||||
if not self.config.bot_id or not self.config.secret:
|
||||
logger.error("WeCom bot_id and secret not configured")
|
||||
return
|
||||
|
||||
from wecom_aibot_sdk import WSClient, generate_req_id
|
||||
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._generate_req_id = generate_req_id
|
||||
|
||||
# Create WebSocket client
|
||||
self._client = WSClient({
|
||||
"bot_id": self.config.bot_id,
|
||||
"secret": self.config.secret,
|
||||
"reconnect_interval": 1000,
|
||||
"max_reconnect_attempts": -1, # Infinite reconnect
|
||||
"heartbeat_interval": 30000,
|
||||
})
|
||||
|
||||
# Register event handlers
|
||||
self._client.on("connected", self._on_connected)
|
||||
self._client.on("authenticated", self._on_authenticated)
|
||||
self._client.on("disconnected", self._on_disconnected)
|
||||
self._client.on("error", self._on_error)
|
||||
self._client.on("message.text", self._on_text_message)
|
||||
self._client.on("message.image", self._on_image_message)
|
||||
self._client.on("message.voice", self._on_voice_message)
|
||||
self._client.on("message.file", self._on_file_message)
|
||||
self._client.on("message.mixed", self._on_mixed_message)
|
||||
self._client.on("event.enter_chat", self._on_enter_chat)
|
||||
|
||||
logger.info("WeCom bot starting with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
# Connect
|
||||
await self._client.connect_async()
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WeCom bot."""
|
||||
self._running = False
|
||||
if self._client:
|
||||
await self._client.disconnect()
|
||||
logger.info("WeCom bot stopped")
|
||||
|
||||
async def _on_connected(self, frame: Any) -> None:
|
||||
"""Handle WebSocket connected event."""
|
||||
logger.info("WeCom WebSocket connected")
|
||||
|
||||
async def _on_authenticated(self, frame: Any) -> None:
|
||||
"""Handle authentication success event."""
|
||||
logger.info("WeCom authenticated successfully")
|
||||
|
||||
async def _on_disconnected(self, frame: Any) -> None:
|
||||
"""Handle WebSocket disconnected event."""
|
||||
reason = frame.body if hasattr(frame, 'body') else str(frame)
|
||||
logger.warning("WeCom WebSocket disconnected: {}", reason)
|
||||
|
||||
async def _on_error(self, frame: Any) -> None:
|
||||
"""Handle error event."""
|
||||
logger.error("WeCom error: {}", frame)
|
||||
|
||||
async def _on_text_message(self, frame: Any) -> None:
|
||||
"""Handle text message."""
|
||||
await self._process_message(frame, "text")
|
||||
|
||||
async def _on_image_message(self, frame: Any) -> None:
|
||||
"""Handle image message."""
|
||||
await self._process_message(frame, "image")
|
||||
|
||||
async def _on_voice_message(self, frame: Any) -> None:
|
||||
"""Handle voice message."""
|
||||
await self._process_message(frame, "voice")
|
||||
|
||||
async def _on_file_message(self, frame: Any) -> None:
|
||||
"""Handle file message."""
|
||||
await self._process_message(frame, "file")
|
||||
|
||||
async def _on_mixed_message(self, frame: Any) -> None:
|
||||
"""Handle mixed content message."""
|
||||
await self._process_message(frame, "mixed")
|
||||
|
||||
async def _on_enter_chat(self, frame: Any) -> None:
|
||||
"""Handle enter_chat event (user opens chat with bot)."""
|
||||
try:
|
||||
# Extract body from WsFrame dataclass or dict
|
||||
if hasattr(frame, 'body'):
|
||||
body = frame.body or {}
|
||||
elif isinstance(frame, dict):
|
||||
body = frame.get("body", frame)
|
||||
else:
|
||||
body = {}
|
||||
|
||||
chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
|
||||
|
||||
if chat_id and self.config.welcome_message:
|
||||
await self._client.reply_welcome(frame, {
|
||||
"msgtype": "text",
|
||||
"text": {"content": self.config.welcome_message},
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Error handling enter_chat: {}", e)
|
||||
|
||||
async def _process_message(self, frame: Any, msg_type: str) -> None:
|
||||
"""Process incoming message and forward to bus."""
|
||||
try:
|
||||
# Extract body from WsFrame dataclass or dict
|
||||
if hasattr(frame, 'body'):
|
||||
body = frame.body or {}
|
||||
elif isinstance(frame, dict):
|
||||
body = frame.get("body", frame)
|
||||
else:
|
||||
body = {}
|
||||
|
||||
# Ensure body is a dict
|
||||
if not isinstance(body, dict):
|
||||
logger.warning("Invalid body type: {}", type(body))
|
||||
return
|
||||
|
||||
# Extract message info
|
||||
msg_id = body.get("msgid", "")
|
||||
if not msg_id:
|
||||
msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
|
||||
|
||||
# Deduplication check
|
||||
if msg_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[msg_id] = None
|
||||
|
||||
# Trim cache
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Extract sender info from "from" field (SDK format)
|
||||
from_info = body.get("from", {})
|
||||
sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
|
||||
|
||||
# For single chat, chatid is the sender's userid
|
||||
# For group chat, chatid is provided in body
|
||||
chat_type = body.get("chattype", "single")
|
||||
chat_id = body.get("chatid", sender_id)
|
||||
|
||||
content_parts = []
|
||||
|
||||
if msg_type == "text":
|
||||
text = body.get("text", {}).get("content", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type == "image":
|
||||
image_info = body.get("image", {})
|
||||
file_url = image_info.get("url", "")
|
||||
aes_key = image_info.get("aeskey", "")
|
||||
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||
if file_path:
|
||||
filename = os.path.basename(file_path)
|
||||
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
|
||||
else:
|
||||
content_parts.append("[image: download failed]")
|
||||
else:
|
||||
content_parts.append("[image: download failed]")
|
||||
|
||||
elif msg_type == "voice":
|
||||
voice_info = body.get("voice", {})
|
||||
# Voice message already contains transcribed content from WeCom
|
||||
voice_content = voice_info.get("content", "")
|
||||
if voice_content:
|
||||
content_parts.append(f"[voice] {voice_content}")
|
||||
else:
|
||||
content_parts.append("[voice]")
|
||||
|
||||
elif msg_type == "file":
|
||||
file_info = body.get("file", {})
|
||||
file_url = file_info.get("url", "")
|
||||
aes_key = file_info.get("aeskey", "")
|
||||
file_name = file_info.get("name", "unknown")
|
||||
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
||||
if file_path:
|
||||
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[file: {file_name}: download failed]")
|
||||
else:
|
||||
content_parts.append(f"[file: {file_name}: download failed]")
|
||||
|
||||
elif msg_type == "mixed":
|
||||
# Mixed content contains multiple message items
|
||||
msg_items = body.get("mixed", {}).get("item", [])
|
||||
for item in msg_items:
|
||||
item_type = item.get("type", "")
|
||||
if item_type == "text":
|
||||
text = item.get("text", {}).get("content", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
|
||||
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Store frame for this chat to enable replies
|
||||
self._chat_frames[chat_id] = frame
|
||||
|
||||
# Forward to message bus
|
||||
# Note: media paths are included in content for broader model compatibility
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=None,
|
||||
metadata={
|
||||
"message_id": msg_id,
|
||||
"msg_type": msg_type,
|
||||
"chat_type": chat_type,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing WeCom message: {}", e)
|
||||
|
||||
async def _download_and_save_media(
|
||||
self,
|
||||
file_url: str,
|
||||
aes_key: str,
|
||||
media_type: str,
|
||||
filename: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Download and decrypt media from WeCom.
|
||||
|
||||
Returns:
|
||||
file_path or None if download failed
|
||||
"""
|
||||
try:
|
||||
data, fname = await self._client.download_file(file_url, aes_key)
|
||||
|
||||
if not data:
|
||||
logger.warning("Failed to download media from WeCom")
|
||||
return None
|
||||
|
||||
media_dir = get_media_dir("wecom")
|
||||
if not filename:
|
||||
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
return str(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error downloading media: {}", e)
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WeCom."""
|
||||
if not self._client:
|
||||
logger.warning("WeCom client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
content = msg.content.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Get the stored frame for this chat
|
||||
frame = self._chat_frames.get(msg.chat_id)
|
||||
if not frame:
|
||||
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
|
||||
return
|
||||
|
||||
# Use streaming reply for better UX
|
||||
stream_id = self._generate_req_id("stream")
|
||||
|
||||
# Send as streaming message with finish=True
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=True,
|
||||
)
|
||||
|
||||
logger.debug("WeCom message sent to {}", msg.chat_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending WeCom message: {}", e)
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
from collections import OrderedDict
|
||||
|
||||
from loguru import logger
|
||||
@@ -21,6 +22,7 @@ class WhatsAppChannel(BaseChannel):
|
||||
"""
|
||||
|
||||
name = "whatsapp"
|
||||
display_name = "WhatsApp"
|
||||
|
||||
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
@@ -128,10 +130,22 @@ class WhatsAppChannel(BaseChannel):
|
||||
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||
|
||||
# Extract media paths (images/documents/videos downloaded by the bridge)
|
||||
media_paths = data.get("media") or []
|
||||
|
||||
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
|
||||
if media_paths:
|
||||
for p in media_paths:
|
||||
mime, _ = mimetypes.guess_type(p)
|
||||
media_type = "image" if mime and mime.startswith("image/") else "file"
|
||||
media_tag = f"[{media_type}: {p}]"
|
||||
content = f"{content}\n{media_tag}" if content else media_tag
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender, # Use full LID for replies
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
|
||||
@@ -9,7 +9,6 @@ from pathlib import Path
|
||||
|
||||
# Force UTF-8 encoding for Windows console
|
||||
if sys.platform == "win32":
|
||||
import locale
|
||||
if sys.stdout.encoding != "utf-8":
|
||||
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||
# Re-open stdout/stderr with UTF-8 encoding
|
||||
@@ -30,6 +29,7 @@ from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from nanobot import __logo__, __version__
|
||||
from nanobot.config.paths import get_workspace_path
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
|
||||
@@ -99,7 +99,9 @@ def _init_prompt_session() -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
history_file = Path.home() / ".nanobot" / "history" / "cli_history"
|
||||
from nanobot.config.paths import get_cli_history_path
|
||||
|
||||
history_file = get_cli_history_path()
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_PROMPT_SESSION = PromptSession(
|
||||
@@ -170,7 +172,6 @@ def onboard():
|
||||
"""Initialize nanobot configuration and workspace."""
|
||||
from nanobot.config.loader import get_config_path, load_config, save_config
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import get_workspace_path
|
||||
|
||||
config_path = get_config_path()
|
||||
|
||||
@@ -190,6 +191,8 @@ def onboard():
|
||||
save_config(Config())
|
||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||
|
||||
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
|
||||
|
||||
# Create workspace
|
||||
workspace = get_workspace_path()
|
||||
|
||||
@@ -212,9 +215,9 @@ def onboard():
|
||||
|
||||
def _make_provider(config: Config):
|
||||
"""Create the appropriate LLM provider from config."""
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
@@ -222,30 +225,79 @@ def _make_provider(config: Config):
|
||||
|
||||
# OpenAI Codex (OAuth)
|
||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||
return OpenAICodexProvider(default_model=model)
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||
if provider_name == "custom":
|
||||
return CustomProvider(
|
||||
elif provider_name == "custom":
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
provider = CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||
default_model=model,
|
||||
)
|
||||
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||
elif provider_name == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||
console.print("Use the model field to specify the deployment name.")
|
||||
raise typer.Exit(1)
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
spec = find_by_name(provider_name)
|
||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
provider = LiteLLMProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
from nanobot.providers.registry import find_by_name
|
||||
spec = find_by_name(provider_name)
|
||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return LiteLLMProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=provider_name,
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
"""Load config and optionally override the active workspace."""
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
|
||||
config_path = None
|
||||
if config:
|
||||
config_path = Path(config).expanduser().resolve()
|
||||
if not config_path.exists():
|
||||
console.print(f"[red]Error: Config file not found: {config_path}[/red]")
|
||||
raise typer.Exit(1)
|
||||
set_config_path(config_path)
|
||||
console.print(f"[dim]Using config: {config_path}[/dim]")
|
||||
|
||||
loaded = load_config(config_path)
|
||||
if workspace:
|
||||
loaded.agents.defaults.workspace = workspace
|
||||
return loaded
|
||||
|
||||
|
||||
def _print_deprecated_memory_window_notice(config: Config) -> None:
|
||||
"""Warn when running with old memoryWindow-only config."""
|
||||
if config.agents.defaults.should_warn_deprecated_memory_window:
|
||||
console.print(
|
||||
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
|
||||
"`contextWindowTokens`. `memoryWindow` is ignored; run "
|
||||
"[cyan]nanobot onboard[/cyan] to refresh your config template."
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -255,16 +307,16 @@ def _make_provider(config: Config):
|
||||
|
||||
@app.command()
|
||||
def gateway(
|
||||
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
|
||||
port: int | None = typer.Option(None, "--port", "-p", help="Gateway port"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Start the nanobot gateway."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.paths import get_cron_dir
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
@@ -274,10 +326,9 @@ def gateway(
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
config_path = Path(config) if config else None
|
||||
config = load_config(config_path)
|
||||
if workspace:
|
||||
config.agents.defaults.workspace = workspace
|
||||
config = _load_runtime_config(config, workspace)
|
||||
_print_deprecated_memory_window_notice(config)
|
||||
port = port if port is not None else config.gateway.port
|
||||
|
||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
@@ -286,8 +337,7 @@ def gateway(
|
||||
session_manager = SessionManager(config.workspace_path)
|
||||
|
||||
# Create cron service first (callback set after agent creation)
|
||||
# Use workspace path for per-instance cron store
|
||||
cron_store_path = config.workspace_path / "cron" / "jobs.json"
|
||||
cron_store_path = get_cron_dir() / "jobs.json"
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
# Create agent with cron service
|
||||
@@ -296,11 +346,8 @@ def gateway(
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
memory_window=config.agents.defaults.memory_window,
|
||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
@@ -448,6 +495,8 @@ def gateway(
|
||||
def agent(
|
||||
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
|
||||
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
|
||||
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
|
||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||
):
|
||||
@@ -456,17 +505,18 @@ def agent(
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.loader import get_data_dir, load_config
|
||||
from nanobot.config.paths import get_cron_dir
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
config = load_config()
|
||||
config = _load_runtime_config(config, workspace)
|
||||
_print_deprecated_memory_window_notice(config)
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
|
||||
# Create cron service for tool usage (no callback needed for CLI unless running)
|
||||
cron_store_path = get_data_dir() / "cron" / "jobs.json"
|
||||
cron_store_path = get_cron_dir() / "jobs.json"
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
if logs:
|
||||
@@ -479,11 +529,8 @@ def agent(
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
memory_window=config.agents.defaults.memory_window,
|
||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
@@ -636,6 +683,7 @@ app.add_typer(channels_app, name="channels")
|
||||
@channels_app.command("status")
|
||||
def channels_status():
|
||||
"""Show channel status."""
|
||||
from nanobot.channels.registry import discover_channel_names, load_channel_class
|
||||
from nanobot.config.loader import load_config
|
||||
|
||||
config = load_config()
|
||||
@@ -643,85 +691,19 @@ def channels_status():
|
||||
table = Table(title="Channel Status")
|
||||
table.add_column("Channel", style="cyan")
|
||||
table.add_column("Enabled", style="green")
|
||||
table.add_column("Configuration", style="yellow")
|
||||
|
||||
# WhatsApp
|
||||
wa = config.channels.whatsapp
|
||||
table.add_row(
|
||||
"WhatsApp",
|
||||
"✓" if wa.enabled else "✗",
|
||||
wa.bridge_url
|
||||
)
|
||||
|
||||
dc = config.channels.discord
|
||||
table.add_row(
|
||||
"Discord",
|
||||
"✓" if dc.enabled else "✗",
|
||||
dc.gateway_url
|
||||
)
|
||||
|
||||
# Feishu
|
||||
fs = config.channels.feishu
|
||||
fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Feishu",
|
||||
"✓" if fs.enabled else "✗",
|
||||
fs_config
|
||||
)
|
||||
|
||||
# Mochat
|
||||
mc = config.channels.mochat
|
||||
mc_base = mc.base_url or "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Mochat",
|
||||
"✓" if mc.enabled else "✗",
|
||||
mc_base
|
||||
)
|
||||
|
||||
# Telegram
|
||||
tg = config.channels.telegram
|
||||
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Telegram",
|
||||
"✓" if tg.enabled else "✗",
|
||||
tg_config
|
||||
)
|
||||
|
||||
# Slack
|
||||
slack = config.channels.slack
|
||||
slack_config = "socket" if slack.app_token and slack.bot_token else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Slack",
|
||||
"✓" if slack.enabled else "✗",
|
||||
slack_config
|
||||
)
|
||||
|
||||
# DingTalk
|
||||
dt = config.channels.dingtalk
|
||||
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"DingTalk",
|
||||
"✓" if dt.enabled else "✗",
|
||||
dt_config
|
||||
)
|
||||
|
||||
# QQ
|
||||
qq = config.channels.qq
|
||||
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"QQ",
|
||||
"✓" if qq.enabled else "✗",
|
||||
qq_config
|
||||
)
|
||||
|
||||
# Email
|
||||
em = config.channels.email
|
||||
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Email",
|
||||
"✓" if em.enabled else "✗",
|
||||
em_config
|
||||
)
|
||||
for modname in sorted(discover_channel_names()):
|
||||
section = getattr(config.channels, modname, None)
|
||||
enabled = section and getattr(section, "enabled", False)
|
||||
try:
|
||||
cls = load_channel_class(modname)
|
||||
display = cls.display_name
|
||||
except ImportError:
|
||||
display = modname.title()
|
||||
table.add_row(
|
||||
display,
|
||||
"[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
@@ -732,7 +714,9 @@ def _get_bridge_dir() -> Path:
|
||||
import subprocess
|
||||
|
||||
# User's bridge location
|
||||
user_bridge = Path.home() / ".nanobot" / "bridge"
|
||||
from nanobot.config.paths import get_bridge_install_dir
|
||||
|
||||
user_bridge = get_bridge_install_dir()
|
||||
|
||||
# Check if already built
|
||||
if (user_bridge / "dist" / "index.js").exists():
|
||||
@@ -790,6 +774,7 @@ def channels_login():
|
||||
import subprocess
|
||||
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.paths import get_runtime_subdir
|
||||
|
||||
config = load_config()
|
||||
bridge_dir = _get_bridge_dir()
|
||||
@@ -800,6 +785,7 @@ def channels_login():
|
||||
env = {**os.environ}
|
||||
if config.channels.whatsapp.bridge_token:
|
||||
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
|
||||
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
|
||||
|
||||
try:
|
||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
||||
|
||||
@@ -1,6 +1,30 @@
|
||||
"""Configuration module for nanobot."""
|
||||
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
from nanobot.config.paths import (
|
||||
get_bridge_install_dir,
|
||||
get_cli_history_path,
|
||||
get_cron_dir,
|
||||
get_data_dir,
|
||||
get_legacy_sessions_dir,
|
||||
get_logs_dir,
|
||||
get_media_dir,
|
||||
get_runtime_subdir,
|
||||
get_workspace_path,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
__all__ = ["Config", "load_config", "get_config_path"]
|
||||
__all__ = [
|
||||
"Config",
|
||||
"load_config",
|
||||
"get_config_path",
|
||||
"get_data_dir",
|
||||
"get_runtime_subdir",
|
||||
"get_media_dir",
|
||||
"get_cron_dir",
|
||||
"get_logs_dir",
|
||||
"get_workspace_path",
|
||||
"get_cli_history_path",
|
||||
"get_bridge_install_dir",
|
||||
"get_legacy_sessions_dir",
|
||||
]
|
||||
|
||||
@@ -6,17 +6,23 @@ from pathlib import Path
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
# Global variable to store current config path (for multi-instance support)
|
||||
_current_config_path: Path | None = None
|
||||
|
||||
|
||||
def set_config_path(path: Path) -> None:
|
||||
"""Set the current config path (used to derive data directory)."""
|
||||
global _current_config_path
|
||||
_current_config_path = path
|
||||
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Get the default configuration file path."""
|
||||
"""Get the configuration file path."""
|
||||
if _current_config_path:
|
||||
return _current_config_path
|
||||
return Path.home() / ".nanobot" / "config.json"
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Get the nanobot data directory."""
|
||||
from nanobot.utils.helpers import get_data_path
|
||||
return get_data_path()
|
||||
|
||||
|
||||
def load_config(config_path: Path | None = None) -> Config:
|
||||
"""
|
||||
Load configuration from file or create default.
|
||||
|
||||
55
nanobot/config/paths.py
Normal file
55
nanobot/config/paths.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Runtime path helpers derived from the active config context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.loader import get_config_path
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Return the instance-level runtime data directory."""
|
||||
return ensure_dir(get_config_path().parent)
|
||||
|
||||
|
||||
def get_runtime_subdir(name: str) -> Path:
|
||||
"""Return a named runtime subdirectory under the instance data dir."""
|
||||
return ensure_dir(get_data_dir() / name)
|
||||
|
||||
|
||||
def get_media_dir(channel: str | None = None) -> Path:
|
||||
"""Return the media directory, optionally namespaced per channel."""
|
||||
base = get_runtime_subdir("media")
|
||||
return ensure_dir(base / channel) if channel else base
|
||||
|
||||
|
||||
def get_cron_dir() -> Path:
|
||||
"""Return the cron storage directory."""
|
||||
return get_runtime_subdir("cron")
|
||||
|
||||
|
||||
def get_logs_dir() -> Path:
|
||||
"""Return the logs directory."""
|
||||
return get_runtime_subdir("logs")
|
||||
|
||||
|
||||
def get_workspace_path(workspace: str | None = None) -> Path:
|
||||
"""Resolve and ensure the agent workspace path."""
|
||||
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
|
||||
return ensure_dir(path)
|
||||
|
||||
|
||||
def get_cli_history_path() -> Path:
|
||||
"""Return the shared CLI history file path."""
|
||||
return Path.home() / ".nanobot" / "history" / "cli_history"
|
||||
|
||||
|
||||
def get_bridge_install_dir() -> Path:
|
||||
"""Return the shared WhatsApp bridge installation directory."""
|
||||
return Path.home() / ".nanobot" / "bridge"
|
||||
|
||||
|
||||
def get_legacy_sessions_dir() -> Path:
|
||||
"""Return the legacy global session directory used for migration fallback."""
|
||||
return Path.home() / ".nanobot" / "sessions"
|
||||
@@ -33,6 +33,7 @@ class TelegramConfig(Base):
|
||||
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||
)
|
||||
reply_to_message: bool = False # If true, bot replies quote the original message
|
||||
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned or replied to, "open" responds to all
|
||||
|
||||
|
||||
class FeishuConfig(Base):
|
||||
@@ -47,6 +48,7 @@ class FeishuConfig(Base):
|
||||
react_emoji: str = (
|
||||
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
||||
)
|
||||
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned, "open" responds to all
|
||||
|
||||
|
||||
class DingTalkConfig(Base):
|
||||
@@ -199,21 +201,14 @@ class QQConfig(Base):
|
||||
) # Allowed user openids (empty = public access)
|
||||
|
||||
|
||||
class MatrixConfig(Base):
|
||||
"""Matrix (Element) channel configuration."""
|
||||
class WecomConfig(Base):
|
||||
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
homeserver: str = "https://matrix.org"
|
||||
access_token: str = ""
|
||||
user_id: str = "" # e.g. @bot:matrix.org
|
||||
device_id: str = ""
|
||||
e2ee_enabled: bool = True # end-to-end encryption support
|
||||
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
|
||||
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
allow_room_mentions: bool = False
|
||||
bot_id: str = "" # Bot ID from WeCom AI Bot platform
|
||||
secret: str = "" # Bot Secret from WeCom AI Bot platform
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
|
||||
welcome_message: str = "" # Welcome message for enter_chat event
|
||||
|
||||
|
||||
class ChannelsConfig(Base):
|
||||
@@ -231,6 +226,7 @@ class ChannelsConfig(Base):
|
||||
slack: SlackConfig = Field(default_factory=SlackConfig)
|
||||
qq: QQConfig = Field(default_factory=QQConfig)
|
||||
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
|
||||
wecom: WecomConfig = Field(default_factory=WecomConfig)
|
||||
|
||||
|
||||
class AgentDefaults(Base):
|
||||
@@ -242,11 +238,18 @@ class AgentDefaults(Base):
|
||||
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||
)
|
||||
max_tokens: int = 8192
|
||||
context_window_tokens: int = 65_536
|
||||
temperature: float = 0.1
|
||||
max_tool_iterations: int = 40
|
||||
memory_window: int = 100
|
||||
# Deprecated compatibility field: accepted from old configs but ignored at runtime.
|
||||
memory_window: int | None = Field(default=None, exclude=True)
|
||||
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||
|
||||
@property
|
||||
def should_warn_deprecated_memory_window(self) -> bool:
|
||||
"""Return True when old memoryWindow is present without contextWindowTokens."""
|
||||
return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
|
||||
|
||||
|
||||
class AgentsConfig(Base):
|
||||
"""Agent configuration."""
|
||||
@@ -266,6 +269,7 @@ class ProvidersConfig(Base):
|
||||
"""Configuration for LLM providers."""
|
||||
|
||||
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
|
||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
@@ -283,10 +287,16 @@ class ProvidersConfig(Base):
|
||||
) # SiliconFlow (硅基流动) API gateway
|
||||
volcengine: ProviderConfig = Field(
|
||||
default_factory=ProviderConfig
|
||||
) # VolcEngine (火山引擎) pay-per-use
|
||||
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
|
||||
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (火山引擎海外版) pay-per-use
|
||||
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
|
||||
) # VolcEngine (火山引擎) API gateway
|
||||
volcengine_coding_plan: ProviderConfig = Field(
|
||||
default_factory=ProviderConfig
|
||||
) # VolcEngine Coding Plan (火山引擎 Coding Plan)
|
||||
byteplus: ProviderConfig = Field(
|
||||
default_factory=ProviderConfig
|
||||
) # BytePlus (VolcEngine international)
|
||||
byteplus_coding_plan: ProviderConfig = Field(
|
||||
default_factory=ProviderConfig
|
||||
) # BytePlus Coding Plan
|
||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||
|
||||
@@ -388,16 +398,25 @@ class Config(BaseSettings):
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and model_prefix and normalized_prefix == spec.name:
|
||||
if spec.is_oauth or p.api_key:
|
||||
if spec.is_oauth or spec.is_local or p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# Match by keyword (order follows PROVIDERS registry)
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and any(_kw_matches(kw) for kw in spec.keywords):
|
||||
if spec.is_oauth or p.api_key:
|
||||
if spec.is_oauth or spec.is_local or p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# Fallback: configured local providers can route models without
|
||||
# provider-specific keywords (for example plain "llama3.2" on Ollama).
|
||||
for spec in PROVIDERS:
|
||||
if not spec.is_local:
|
||||
continue
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and p.api_base:
|
||||
return p, spec.name
|
||||
|
||||
# Fallback: gateways first, then others (follows registry order)
|
||||
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
||||
for spec in PROVIDERS:
|
||||
@@ -424,7 +443,7 @@ class Config(BaseSettings):
|
||||
return p.api_key if p else None
|
||||
|
||||
def get_api_base(self, model: str | None = None) -> str | None:
|
||||
"""Get API base URL for the given model. Applies default URLs for known gateways."""
|
||||
"""Get API base URL for the given model. Applies default URLs for gateway/local providers."""
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
p, name = self._match_provider(model)
|
||||
@@ -435,7 +454,7 @@ class Config(BaseSettings):
|
||||
# to avoid polluting the global litellm.api_base.
|
||||
if name:
|
||||
spec = find_by_name(name)
|
||||
if spec and spec.is_gateway and spec.default_api_base:
|
||||
if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
|
||||
return spec.default_api_base
|
||||
return None
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ class HeartbeatService:
|
||||
|
||||
Returns (action, tasks) where action is 'skip' or 'run'.
|
||||
"""
|
||||
response = await self.provider.chat(
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
||||
{"role": "user", "content": (
|
||||
|
||||
@@ -3,5 +3,6 @@
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
|
||||
|
||||
213
nanobot/providers/azure_openai_provider.py
Normal file
213
nanobot/providers/azure_openai_provider.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class AzureOpenAIProvider(LLMProvider):
|
||||
"""
|
||||
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||
|
||||
Features:
|
||||
- Hardcoded API version 2024-10-21
|
||||
- Uses model field as Azure deployment name in URL path
|
||||
- Uses api-key header instead of Authorization Bearer
|
||||
- Uses max_completion_tokens instead of max_tokens
|
||||
- Direct HTTP calls, bypasses LiteLLM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "",
|
||||
api_base: str = "",
|
||||
default_model: str = "gpt-5.2-chat",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.api_version = "2024-10-21"
|
||||
|
||||
# Validate required parameters
|
||||
if not api_key:
|
||||
raise ValueError("Azure OpenAI api_key is required")
|
||||
if not api_base:
|
||||
raise ValueError("Azure OpenAI api_base is required")
|
||||
|
||||
# Ensure api_base ends with /
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
self.api_base = api_base
|
||||
|
||||
def _build_chat_url(self, deployment_name: str) -> str:
|
||||
"""Build the Azure OpenAI chat completions URL."""
|
||||
# Azure OpenAI URL format:
|
||||
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||
base_url = self.api_base
|
||||
if not base_url.endswith('/'):
|
||||
base_url += '/'
|
||||
|
||||
url = urljoin(
|
||||
base_url,
|
||||
f"openai/deployments/{deployment_name}/chat/completions"
|
||||
)
|
||||
return f"{url}?api-version={self.api_version}"
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Build headers for Azure OpenAI API with api-key header."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _supports_temperature(
|
||||
deployment_name: str,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> bool:
|
||||
"""Return True when temperature is likely supported for this deployment."""
|
||||
if reasoning_effort:
|
||||
return False
|
||||
name = deployment_name.lower()
|
||||
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
def _prepare_request_payload(
|
||||
self,
|
||||
deployment_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||
payload: dict[str, Any] = {
|
||||
"messages": self._sanitize_request_messages(
|
||||
self._sanitize_empty_content(messages),
|
||||
_AZURE_MSG_KEYS,
|
||||
),
|
||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||
}
|
||||
|
||||
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||
payload["temperature"] = temperature
|
||||
|
||||
if reasoning_effort:
|
||||
payload["reasoning_effort"] = reasoning_effort
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return payload
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request to Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (used as deployment name).
|
||||
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||
temperature: Sampling temperature.
|
||||
reasoning_effort: Optional reasoning effort parameter.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return self._parse_response(response_data)
|
||||
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse Azure OpenAI response into our standard format."""
|
||||
try:
|
||||
choice = response["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
tool_calls = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
usage = {}
|
||||
if response.get("usage"):
|
||||
usage_data = response["usage"]
|
||||
usage = {
|
||||
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||
"total_tokens": usage_data.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
reasoning_content = message.get("reasoning_content") or None
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content"),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
return LLMResponse(
|
||||
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model (also used as default deployment name)."""
|
||||
return self.default_model
|
||||
@@ -1,9 +1,13 @@
|
||||
"""Base LLM provider interface."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
@@ -11,6 +15,24 @@ class ToolCallRequest:
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
provider_specific_fields: dict[str, Any] | None = None
|
||||
function_provider_specific_fields: dict[str, Any] | None = None
|
||||
|
||||
def to_openai_tool_call(self) -> dict[str, Any]:
|
||||
"""Serialize to an OpenAI-style tool_call payload."""
|
||||
tool_call = {
|
||||
"id": self.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
if self.provider_specific_fields:
|
||||
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
||||
if self.function_provider_specific_fields:
|
||||
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
|
||||
return tool_call
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,6 +51,21 @@ class LLMResponse:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls.
|
||||
|
||||
Stored on the provider so every call site inherits the same defaults
|
||||
without having to pass temperature / max_tokens / reasoning_effort
|
||||
through every layer. Individual call sites can still override by
|
||||
passing explicit keyword arguments to chat() / chat_with_retry().
|
||||
"""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
reasoning_effort: str | None = None
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
@@ -37,9 +74,28 @@ class LLMProvider(ABC):
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"overloaded",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"connection",
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
)
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.generation: GenerationSettings = GenerationSettings()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
@@ -87,6 +143,20 @@ class LLMProvider(ABC):
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_request_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
allowed_keys: frozenset[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keep only provider-safe message keys and normalize assistant content."""
|
||||
sanitized = []
|
||||
for msg in messages:
|
||||
clean = {k: v for k, v in msg.items() if k in allowed_keys}
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
@@ -96,6 +166,7 @@ class LLMProvider(ABC):
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
@@ -106,12 +177,93 @@ class LLMProvider(ABC):
|
||||
model: Model identifier (provider-specific).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _is_transient_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: object = _SENTINEL,
|
||||
temperature: object = _SENTINEL,
|
||||
reasoning_effort: object = _SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures.
|
||||
|
||||
Parameters default to ``self.generation`` when not explicitly passed,
|
||||
so callers no longer need to thread temperature / max_tokens /
|
||||
reasoning_effort through every layer.
|
||||
"""
|
||||
if max_tokens is self._SENTINEL:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
temperature = self.generation.temperature
|
||||
if reasoning_effort is self._SENTINEL:
|
||||
reasoning_effort = self.generation.reasoning_effort
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
if not self._is_transient_error(response.content):
|
||||
return response
|
||||
|
||||
err = (response.content or "").lower()
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
err[:120],
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
|
||||
@@ -25,7 +25,8 @@ class CustomProvider(LLMProvider):
|
||||
|
||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None) -> LLMResponse:
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self._sanitize_empty_content(messages),
|
||||
@@ -35,7 +36,7 @@ class CustomProvider(LLMProvider):
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice="auto")
|
||||
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
||||
try:
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
@@ -8,6 +9,7 @@ from typing import Any
|
||||
import json_repair
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.registry import find_by_model, find_gateway
|
||||
@@ -165,17 +167,43 @@ class LiteLLMProvider(LLMProvider):
|
||||
return _ANTHROPIC_EXTRA_KEYS
|
||||
return frozenset()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
||||
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
|
||||
if not isinstance(tool_call_id, str):
|
||||
return tool_call_id
|
||||
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
||||
return tool_call_id
|
||||
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
||||
sanitized = []
|
||||
for msg in messages:
|
||||
clean = {k: v for k, v in msg.items() if k in allowed}
|
||||
# Strict providers require "content" even when assistant only has tool_calls
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
|
||||
id_map: dict[str, str] = {}
|
||||
|
||||
def map_id(value: Any) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
||||
|
||||
for clean in sanitized:
|
||||
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
||||
# shortening, otherwise strict providers reject the broken linkage.
|
||||
if isinstance(clean.get("tool_calls"), list):
|
||||
normalized_tool_calls = []
|
||||
for tc in clean["tool_calls"]:
|
||||
if not isinstance(tc, dict):
|
||||
normalized_tool_calls.append(tc)
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
normalized_tool_calls.append(tc_clean)
|
||||
clean["tool_calls"] = normalized_tool_calls
|
||||
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
|
||||
async def chat(
|
||||
@@ -186,6 +214,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request via LiteLLM.
|
||||
@@ -239,7 +268,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
@@ -255,20 +284,44 @@ class LiteLLMProvider(LLMProvider):
|
||||
"""Parse LiteLLM response into our standard format."""
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
content = message.content
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
# Some providers (e.g. GitHub Copilot) split content and tool_calls
|
||||
# across multiple choices. Merge them so tool_calls are not lost.
|
||||
raw_tool_calls = []
|
||||
for ch in response.choices:
|
||||
msg = ch.message
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
raw_tool_calls.extend(msg.tool_calls)
|
||||
if ch.finish_reason in ("tool_calls", "stop"):
|
||||
finish_reason = ch.finish_reason
|
||||
if not content and msg.content:
|
||||
content = msg.content
|
||||
|
||||
if len(response.choices) > 1:
|
||||
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
|
||||
len(response.choices), len(raw_tool_calls))
|
||||
|
||||
tool_calls = []
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
for tc in raw_tool_calls:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
|
||||
function_provider_specific_fields = (
|
||||
getattr(tc.function, "provider_specific_fields", None) or None
|
||||
)
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
function_provider_specific_fields=function_provider_specific_fields,
|
||||
))
|
||||
|
||||
usage = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
@@ -280,11 +333,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||
thinking_blocks = getattr(message, "thinking_blocks", None) or None
|
||||
|
||||
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
finish_reason=finish_reason or "stop",
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
|
||||
@@ -32,6 +32,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
@@ -48,7 +49,7 @@ class OpenAICodexProvider(LLMProvider):
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"prompt_cache_key": _prompt_cache_key(messages),
|
||||
"tool_choice": "auto",
|
||||
"tool_choice": tool_choice or "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
|
||||
|
||||
@@ -26,33 +26,33 @@ class ProviderSpec:
|
||||
"""
|
||||
|
||||
# identity
|
||||
name: str # config field name, e.g. "dashscope"
|
||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
name: str # config field name, e.g. "dashscope"
|
||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
|
||||
# model prefixing
|
||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||
|
||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
|
||||
# gateway / local detection
|
||||
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||
default_api_base: str = "" # fallback base URL
|
||||
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||
default_api_base: str = "" # fallback base URL
|
||||
|
||||
# gateway behavior
|
||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||
|
||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
|
||||
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
||||
|
||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
||||
is_direct: bool = False
|
||||
@@ -70,7 +70,6 @@ class ProviderSpec:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
|
||||
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
@@ -81,16 +80,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
is_direct=True,
|
||||
),
|
||||
|
||||
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
|
||||
ProviderSpec(
|
||||
name="azure_openai",
|
||||
keywords=("azure", "azure-openai"),
|
||||
env_key="",
|
||||
display_name="Azure OpenAI",
|
||||
litellm_prefix="",
|
||||
is_direct=True,
|
||||
),
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
# Gateways can route any model, so they win in fallback.
|
||||
|
||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
@@ -102,16 +109,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
model_overrides=(),
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
|
||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
@@ -119,10 +125,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
||||
ProviderSpec(
|
||||
name="siliconflow",
|
||||
@@ -213,8 +218,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
@@ -233,7 +238,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
model_overrides=(),
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
|
||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
@@ -251,14 +255,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# OpenAI Codex: uses OAuth, not API key.
|
||||
ProviderSpec(
|
||||
name="openai_codex",
|
||||
keywords=("openai-codex",),
|
||||
env_key="", # OAuth-based, no API key
|
||||
env_key="", # OAuth-based, no API key
|
||||
display_name="OpenAI Codex",
|
||||
litellm_prefix="", # Not routed through LiteLLM
|
||||
litellm_prefix="", # Not routed through LiteLLM
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
@@ -268,16 +271,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
default_api_base="https://chatgpt.com/backend-api",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
),
|
||||
|
||||
# Github Copilot: uses OAuth, not API key.
|
||||
ProviderSpec(
|
||||
name="github_copilot",
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="", # OAuth-based, no API key
|
||||
env_key="", # OAuth-based, no API key
|
||||
display_name="Github Copilot",
|
||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
||||
skip_prefixes=("github_copilot/",),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
@@ -287,17 +289,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
),
|
||||
|
||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
@@ -307,15 +308,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
@@ -325,7 +325,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||
@@ -334,11 +333,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(
|
||||
("ZHIPUAI_API_KEY", "{api_key}"),
|
||||
),
|
||||
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
@@ -347,14 +344,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
@@ -365,7 +361,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||
@@ -374,22 +369,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(
|
||||
("MOONSHOT_API_BASE", "{api_base}"),
|
||||
),
|
||||
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(
|
||||
("kimi-k2.5", {"temperature": 1.0}),
|
||||
),
|
||||
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||
),
|
||||
|
||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||
ProviderSpec(
|
||||
@@ -397,7 +387,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
@@ -408,9 +398,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
# Detected when config key is "vllm" (provider_name="vllm").
|
||||
ProviderSpec(
|
||||
@@ -418,20 +406,35 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="", # user must provide in config
|
||||
default_api_base="", # user must provide in config
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# === Ollama (local, OpenAI-compatible) ===================================
|
||||
ProviderSpec(
|
||||
name="ollama",
|
||||
keywords=("ollama", "nemotron"),
|
||||
env_key="OLLAMA_API_KEY",
|
||||
display_name="Ollama",
|
||||
litellm_prefix="ollama_chat", # model → ollama_chat/model
|
||||
skip_prefixes=("ollama/", "ollama_chat/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="11434",
|
||||
default_api_base="http://localhost:11434",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Auxiliary (not a primary LLM provider) ============================
|
||||
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||
ProviderSpec(
|
||||
@@ -439,8 +442,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
@@ -457,6 +460,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
# Lookup helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
||||
@@ -472,7 +476,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
|
||||
return spec
|
||||
|
||||
for spec in std_specs:
|
||||
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
|
||||
if any(
|
||||
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
|
||||
):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_legacy_sessions_dir
|
||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||
|
||||
|
||||
@@ -79,7 +80,7 @@ class SessionManager:
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
||||
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
|
||||
self.legacy_sessions_dir = get_legacy_sessions_dir()
|
||||
self._cache: dict[str, Session] = {}
|
||||
|
||||
def _get_session_path(self, key: str) -> Path:
|
||||
|
||||
@@ -9,15 +9,21 @@ always: true
|
||||
## Structure
|
||||
|
||||
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
|
||||
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM].
|
||||
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep-style tools or in-memory filters. Each entry starts with [YYYY-MM-DD HH:MM].
|
||||
|
||||
## Search Past Events
|
||||
|
||||
```bash
|
||||
grep -i "keyword" memory/HISTORY.md
|
||||
```
|
||||
Choose the search method based on file size:
|
||||
|
||||
Use the `exec` tool to run grep. Combine patterns: `grep -iE "meeting|deadline" memory/HISTORY.md`
|
||||
- Small `memory/HISTORY.md`: use `read_file`, then search in-memory
|
||||
- Large or long-lived `memory/HISTORY.md`: use the `exec` tool for targeted search
|
||||
|
||||
Examples:
|
||||
- **Linux/macOS:** `grep -i "keyword" memory/HISTORY.md`
|
||||
- **Windows:** `findstr /i "keyword" memory\HISTORY.md`
|
||||
- **Cross-platform Python:** `python -c "from pathlib import Path; text = Path('memory/HISTORY.md').read_text(encoding='utf-8'); print('\n'.join([l for l in text.splitlines() if 'keyword' in l.lower()][-20:]))"`
|
||||
|
||||
Prefer targeted command-line search for large history files.
|
||||
|
||||
## When to Update MEMORY.md
|
||||
|
||||
|
||||
@@ -268,6 +268,8 @@ Skip this step only if the skill being developed already exists, and iteration o
|
||||
|
||||
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
|
||||
|
||||
For `nanobot`, custom skills should live under the active workspace `skills/` directory so they can be discovered automatically at runtime (for example, `<workspace>/skills/my-skill/SKILL.md`).
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
@@ -277,9 +279,9 @@ scripts/init_skill.py <skill-name> --path <output-directory> [--resources script
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
scripts/init_skill.py my-skill --path skills/public
|
||||
scripts/init_skill.py my-skill --path skills/public --resources scripts,references
|
||||
scripts/init_skill.py my-skill --path skills/public --resources scripts --examples
|
||||
scripts/init_skill.py my-skill --path ./workspace/skills
|
||||
scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts,references
|
||||
scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts --examples
|
||||
```
|
||||
|
||||
The script:
|
||||
@@ -326,7 +328,7 @@ Write the YAML frontmatter with `name` and `description`:
|
||||
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent.
|
||||
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
|
||||
|
||||
Do not include any other fields in YAML frontmatter.
|
||||
Keep frontmatter minimal. In `nanobot`, `metadata` and `always` are also supported when needed, but avoid adding extra fields unless they are actually required.
|
||||
|
||||
##### Body
|
||||
|
||||
@@ -349,7 +351,6 @@ scripts/package_skill.py <path/to/skill-folder> ./dist
|
||||
The packaging script will:
|
||||
|
||||
1. **Validate** the skill automatically, checking:
|
||||
|
||||
- YAML frontmatter format and required fields
|
||||
- Skill naming conventions and directory structure
|
||||
- Description completeness and quality
|
||||
@@ -357,6 +358,8 @@ The packaging script will:
|
||||
|
||||
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
|
||||
|
||||
Security restriction: symlinks are rejected and packaging fails when any symlink is present.
|
||||
|
||||
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
|
||||
|
||||
### Step 6: Iterate
|
||||
|
||||
378
nanobot/skills/skill-creator/scripts/init_skill.py
Executable file
378
nanobot/skills/skill-creator/scripts/init_skill.py
Executable file
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skill Initializer - Creates a new skill from template
|
||||
|
||||
Usage:
|
||||
init_skill.py <skill-name> --path <path> [--resources scripts,references,assets] [--examples]
|
||||
|
||||
Examples:
|
||||
init_skill.py my-new-skill --path skills/public
|
||||
init_skill.py my-new-skill --path skills/public --resources scripts,references
|
||||
init_skill.py my-api-helper --path skills/private --resources scripts --examples
|
||||
init_skill.py custom-skill --path /custom/location
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
MAX_SKILL_NAME_LENGTH = 64
|
||||
ALLOWED_RESOURCES = {"scripts", "references", "assets"}
|
||||
|
||||
SKILL_TEMPLATE = """---
|
||||
name: {skill_name}
|
||||
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
|
||||
---
|
||||
|
||||
# {skill_title}
|
||||
|
||||
## Overview
|
||||
|
||||
[TODO: 1-2 sentences explaining what this skill enables]
|
||||
|
||||
## Structuring This Skill
|
||||
|
||||
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
|
||||
|
||||
**1. Workflow-Based** (best for sequential processes)
|
||||
- Works well when there are clear step-by-step procedures
|
||||
- Example: DOCX skill with "Workflow Decision Tree" -> "Reading" -> "Creating" -> "Editing"
|
||||
- Structure: ## Overview -> ## Workflow Decision Tree -> ## Step 1 -> ## Step 2...
|
||||
|
||||
**2. Task-Based** (best for tool collections)
|
||||
- Works well when the skill offers different operations/capabilities
|
||||
- Example: PDF skill with "Quick Start" -> "Merge PDFs" -> "Split PDFs" -> "Extract Text"
|
||||
- Structure: ## Overview -> ## Quick Start -> ## Task Category 1 -> ## Task Category 2...
|
||||
|
||||
**3. Reference/Guidelines** (best for standards or specifications)
|
||||
- Works well for brand guidelines, coding standards, or requirements
|
||||
- Example: Brand styling with "Brand Guidelines" -> "Colors" -> "Typography" -> "Features"
|
||||
- Structure: ## Overview -> ## Guidelines -> ## Specifications -> ## Usage...
|
||||
|
||||
**4. Capabilities-Based** (best for integrated systems)
|
||||
- Works well when the skill provides multiple interrelated features
|
||||
- Example: Product Management with "Core Capabilities" -> numbered capability list
|
||||
- Structure: ## Overview -> ## Core Capabilities -> ### 1. Feature -> ### 2. Feature...
|
||||
|
||||
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
|
||||
|
||||
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
|
||||
|
||||
## [TODO: Replace with the first main section based on chosen structure]
|
||||
|
||||
[TODO: Add content here. See examples in existing skills:
|
||||
- Code samples for technical skills
|
||||
- Decision trees for complex workflows
|
||||
- Concrete examples with realistic user requests
|
||||
- References to scripts/templates/references as needed]
|
||||
|
||||
## Resources (optional)
|
||||
|
||||
Create only the resource directories this skill actually needs. Delete this section if no resources are required.
|
||||
|
||||
### scripts/
|
||||
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
|
||||
|
||||
**Examples from other skills:**
|
||||
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
|
||||
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
|
||||
|
||||
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
|
||||
|
||||
**Note:** Scripts may be executed without loading into context, but can still be read by Codex for patching or environment adjustments.
|
||||
|
||||
### references/
|
||||
Documentation and reference material intended to be loaded into context to inform Codex's process and thinking.
|
||||
|
||||
**Examples from other skills:**
|
||||
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
|
||||
- BigQuery: API reference documentation and query examples
|
||||
- Finance: Schema documentation, company policies
|
||||
|
||||
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Codex should reference while working.
|
||||
|
||||
### assets/
|
||||
Files not intended to be loaded into context, but rather used within the output Codex produces.
|
||||
|
||||
**Examples from other skills:**
|
||||
- Brand styling: PowerPoint template files (.pptx), logo files
|
||||
- Frontend builder: HTML/React boilerplate project directories
|
||||
- Typography: Font files (.ttf, .woff2)
|
||||
|
||||
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
|
||||
|
||||
---
|
||||
|
||||
**Not every skill requires all three types of resources.**
|
||||
"""
|
||||
|
||||
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
|
||||
"""
|
||||
Example helper script for {skill_name}
|
||||
|
||||
This is a placeholder script that can be executed directly.
|
||||
Replace with actual implementation or delete if not needed.
|
||||
|
||||
Example real scripts from other skills:
|
||||
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
|
||||
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
|
||||
"""
|
||||
|
||||
def main():
|
||||
print("This is an example script for {skill_name}")
|
||||
# TODO: Add actual script logic here
|
||||
# This could be data processing, file conversion, API calls, etc.
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
|
||||
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
|
||||
|
||||
This is a placeholder for detailed reference documentation.
|
||||
Replace with actual reference content or delete if not needed.
|
||||
|
||||
Example real reference docs from other skills:
|
||||
- product-management/references/communication.md - Comprehensive guide for status updates
|
||||
- product-management/references/context_building.md - Deep-dive on gathering context
|
||||
- bigquery/references/ - API references and query examples
|
||||
|
||||
## When Reference Docs Are Useful
|
||||
|
||||
Reference docs are ideal for:
|
||||
- Comprehensive API documentation
|
||||
- Detailed workflow guides
|
||||
- Complex multi-step processes
|
||||
- Information too lengthy for main SKILL.md
|
||||
- Content that's only needed for specific use cases
|
||||
|
||||
## Structure Suggestions
|
||||
|
||||
### API Reference Example
|
||||
- Overview
|
||||
- Authentication
|
||||
- Endpoints with examples
|
||||
- Error codes
|
||||
- Rate limits
|
||||
|
||||
### Workflow Guide Example
|
||||
- Prerequisites
|
||||
- Step-by-step instructions
|
||||
- Common patterns
|
||||
- Troubleshooting
|
||||
- Best practices
|
||||
"""
|
||||
|
||||
EXAMPLE_ASSET = """# Example Asset File
|
||||
|
||||
This placeholder represents where asset files would be stored.
|
||||
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
|
||||
|
||||
Asset files are NOT intended to be loaded into context, but rather used within
|
||||
the output Codex produces.
|
||||
|
||||
Example asset files from other skills:
|
||||
- Brand guidelines: logo.png, slides_template.pptx
|
||||
- Frontend builder: hello-world/ directory with HTML/React boilerplate
|
||||
- Typography: custom-font.ttf, font-family.woff2
|
||||
- Data: sample_data.csv, test_dataset.json
|
||||
|
||||
## Common Asset Types
|
||||
|
||||
- Templates: .pptx, .docx, boilerplate directories
|
||||
- Images: .png, .jpg, .svg, .gif
|
||||
- Fonts: .ttf, .otf, .woff, .woff2
|
||||
- Boilerplate code: Project directories, starter files
|
||||
- Icons: .ico, .svg
|
||||
- Data files: .csv, .json, .xml, .yaml
|
||||
|
||||
Note: This is a text placeholder. Actual assets can be any file type.
|
||||
"""
|
||||
|
||||
|
||||
def normalize_skill_name(skill_name):
|
||||
"""Normalize a skill name to lowercase hyphen-case."""
|
||||
normalized = skill_name.strip().lower()
|
||||
normalized = re.sub(r"[^a-z0-9]+", "-", normalized)
|
||||
normalized = normalized.strip("-")
|
||||
normalized = re.sub(r"-{2,}", "-", normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def title_case_skill_name(skill_name):
|
||||
"""Convert hyphenated skill name to Title Case for display."""
|
||||
return " ".join(word.capitalize() for word in skill_name.split("-"))
|
||||
|
||||
|
||||
def parse_resources(raw_resources):
|
||||
if not raw_resources:
|
||||
return []
|
||||
resources = [item.strip() for item in raw_resources.split(",") if item.strip()]
|
||||
invalid = sorted({item for item in resources if item not in ALLOWED_RESOURCES})
|
||||
if invalid:
|
||||
allowed = ", ".join(sorted(ALLOWED_RESOURCES))
|
||||
print(f"[ERROR] Unknown resource type(s): {', '.join(invalid)}")
|
||||
print(f" Allowed: {allowed}")
|
||||
sys.exit(1)
|
||||
deduped = []
|
||||
seen = set()
|
||||
for resource in resources:
|
||||
if resource not in seen:
|
||||
deduped.append(resource)
|
||||
seen.add(resource)
|
||||
return deduped
|
||||
|
||||
|
||||
def create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples):
|
||||
for resource in resources:
|
||||
resource_dir = skill_dir / resource
|
||||
resource_dir.mkdir(exist_ok=True)
|
||||
if resource == "scripts":
|
||||
if include_examples:
|
||||
example_script = resource_dir / "example.py"
|
||||
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
|
||||
example_script.chmod(0o755)
|
||||
print("[OK] Created scripts/example.py")
|
||||
else:
|
||||
print("[OK] Created scripts/")
|
||||
elif resource == "references":
|
||||
if include_examples:
|
||||
example_reference = resource_dir / "api_reference.md"
|
||||
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
|
||||
print("[OK] Created references/api_reference.md")
|
||||
else:
|
||||
print("[OK] Created references/")
|
||||
elif resource == "assets":
|
||||
if include_examples:
|
||||
example_asset = resource_dir / "example_asset.txt"
|
||||
example_asset.write_text(EXAMPLE_ASSET)
|
||||
print("[OK] Created assets/example_asset.txt")
|
||||
else:
|
||||
print("[OK] Created assets/")
|
||||
|
||||
|
||||
def init_skill(skill_name, path, resources, include_examples):
|
||||
"""
|
||||
Initialize a new skill directory with template SKILL.md.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
path: Path where the skill directory should be created
|
||||
resources: Resource directories to create
|
||||
include_examples: Whether to create example files in resource directories
|
||||
|
||||
Returns:
|
||||
Path to created skill directory, or None if error
|
||||
"""
|
||||
# Determine skill directory path
|
||||
skill_dir = Path(path).resolve() / skill_name
|
||||
|
||||
# Check if directory already exists
|
||||
if skill_dir.exists():
|
||||
print(f"[ERROR] Skill directory already exists: {skill_dir}")
|
||||
return None
|
||||
|
||||
# Create skill directory
|
||||
try:
|
||||
skill_dir.mkdir(parents=True, exist_ok=False)
|
||||
print(f"[OK] Created skill directory: {skill_dir}")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Error creating directory: {e}")
|
||||
return None
|
||||
|
||||
# Create SKILL.md from template
|
||||
skill_title = title_case_skill_name(skill_name)
|
||||
skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title)
|
||||
|
||||
skill_md_path = skill_dir / "SKILL.md"
|
||||
try:
|
||||
skill_md_path.write_text(skill_content)
|
||||
print("[OK] Created SKILL.md")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Error creating SKILL.md: {e}")
|
||||
return None
|
||||
|
||||
# Create resource directories if requested
|
||||
if resources:
|
||||
try:
|
||||
create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Error creating resource directories: {e}")
|
||||
return None
|
||||
|
||||
# Print next steps
|
||||
print(f"\n[OK] Skill '{skill_name}' initialized successfully at {skill_dir}")
|
||||
print("\nNext steps:")
|
||||
print("1. Edit SKILL.md to complete the TODO items and update the description")
|
||||
if resources:
|
||||
if include_examples:
|
||||
print("2. Customize or delete the example files in scripts/, references/, and assets/")
|
||||
else:
|
||||
print("2. Add resources to scripts/, references/, and assets/ as needed")
|
||||
else:
|
||||
print("2. Create resource directories only if needed (scripts/, references/, assets/)")
|
||||
print("3. Run the validator when ready to check the skill structure")
|
||||
|
||||
return skill_dir
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create a new skill directory with a SKILL.md template.",
|
||||
)
|
||||
parser.add_argument("skill_name", help="Skill name (normalized to hyphen-case)")
|
||||
parser.add_argument("--path", required=True, help="Output directory for the skill")
|
||||
parser.add_argument(
|
||||
"--resources",
|
||||
default="",
|
||||
help="Comma-separated list: scripts,references,assets",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--examples",
|
||||
action="store_true",
|
||||
help="Create example files inside the selected resource directories",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
raw_skill_name = args.skill_name
|
||||
skill_name = normalize_skill_name(raw_skill_name)
|
||||
if not skill_name:
|
||||
print("[ERROR] Skill name must include at least one letter or digit.")
|
||||
sys.exit(1)
|
||||
if len(skill_name) > MAX_SKILL_NAME_LENGTH:
|
||||
print(
|
||||
f"[ERROR] Skill name '{skill_name}' is too long ({len(skill_name)} characters). "
|
||||
f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
|
||||
)
|
||||
sys.exit(1)
|
||||
if skill_name != raw_skill_name:
|
||||
print(f"Note: Normalized skill name from '{raw_skill_name}' to '{skill_name}'.")
|
||||
|
||||
resources = parse_resources(args.resources)
|
||||
if args.examples and not resources:
|
||||
print("[ERROR] --examples requires --resources to be set.")
|
||||
sys.exit(1)
|
||||
|
||||
path = args.path
|
||||
|
||||
print(f"Initializing skill: {skill_name}")
|
||||
print(f" Location: {path}")
|
||||
if resources:
|
||||
print(f" Resources: {', '.join(resources)}")
|
||||
if args.examples:
|
||||
print(" Examples: enabled")
|
||||
else:
|
||||
print(" Resources: none (create as needed)")
|
||||
print()
|
||||
|
||||
result = init_skill(skill_name, path, resources, args.examples)
|
||||
|
||||
if result:
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
154
nanobot/skills/skill-creator/scripts/package_skill.py
Executable file
154
nanobot/skills/skill-creator/scripts/package_skill.py
Executable file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Skill Packager - Creates a distributable .skill file of a skill folder
|
||||
|
||||
Usage:
|
||||
python package_skill.py <path/to/skill-folder> [output-directory]
|
||||
|
||||
Example:
|
||||
python package_skill.py skills/public/my-skill
|
||||
python package_skill.py skills/public/my-skill ./dist
|
||||
"""
|
||||
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
from quick_validate import validate_skill
|
||||
|
||||
|
||||
def _is_within(path: Path, root: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(root)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _cleanup_partial_archive(skill_filename: Path) -> None:
|
||||
try:
|
||||
if skill_filename.exists():
|
||||
skill_filename.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def package_skill(skill_path, output_dir=None):
|
||||
"""
|
||||
Package a skill folder into a .skill file.
|
||||
|
||||
Args:
|
||||
skill_path: Path to the skill folder
|
||||
output_dir: Optional output directory for the .skill file (defaults to current directory)
|
||||
|
||||
Returns:
|
||||
Path to the created .skill file, or None if error
|
||||
"""
|
||||
skill_path = Path(skill_path).resolve()
|
||||
|
||||
# Validate skill folder exists
|
||||
if not skill_path.exists():
|
||||
print(f"[ERROR] Skill folder not found: {skill_path}")
|
||||
return None
|
||||
|
||||
if not skill_path.is_dir():
|
||||
print(f"[ERROR] Path is not a directory: {skill_path}")
|
||||
return None
|
||||
|
||||
# Validate SKILL.md exists
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
print(f"[ERROR] SKILL.md not found in {skill_path}")
|
||||
return None
|
||||
|
||||
# Run validation before packaging
|
||||
print("Validating skill...")
|
||||
valid, message = validate_skill(skill_path)
|
||||
if not valid:
|
||||
print(f"[ERROR] Validation failed: {message}")
|
||||
print(" Please fix the validation errors before packaging.")
|
||||
return None
|
||||
print(f"[OK] {message}\n")
|
||||
|
||||
# Determine output location
|
||||
skill_name = skill_path.name
|
||||
if output_dir:
|
||||
output_path = Path(output_dir).resolve()
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
output_path = Path.cwd()
|
||||
|
||||
skill_filename = output_path / f"{skill_name}.skill"
|
||||
|
||||
EXCLUDED_DIRS = {".git", ".svn", ".hg", "__pycache__", "node_modules"}
|
||||
|
||||
files_to_package = []
|
||||
resolved_archive = skill_filename.resolve()
|
||||
|
||||
for file_path in skill_path.rglob("*"):
|
||||
# Fail closed on symlinks so the packaged contents are explicit and predictable.
|
||||
if file_path.is_symlink():
|
||||
print(f"[ERROR] Symlink not allowed in packaged skill: {file_path}")
|
||||
_cleanup_partial_archive(skill_filename)
|
||||
return None
|
||||
|
||||
rel_parts = file_path.relative_to(skill_path).parts
|
||||
if any(part in EXCLUDED_DIRS for part in rel_parts):
|
||||
continue
|
||||
|
||||
if file_path.is_file():
|
||||
resolved_file = file_path.resolve()
|
||||
if not _is_within(resolved_file, skill_path):
|
||||
print(f"[ERROR] File escapes skill root: {file_path}")
|
||||
_cleanup_partial_archive(skill_filename)
|
||||
return None
|
||||
# If output lives under skill_path, avoid writing archive into itself.
|
||||
if resolved_file == resolved_archive:
|
||||
print(f"[WARN] Skipping output archive: {file_path}")
|
||||
continue
|
||||
files_to_package.append(file_path)
|
||||
|
||||
# Create the .skill file (zip format)
|
||||
try:
|
||||
with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
for file_path in files_to_package:
|
||||
# Calculate the relative path within the zip.
|
||||
arcname = Path(skill_name) / file_path.relative_to(skill_path)
|
||||
zipf.write(file_path, arcname)
|
||||
print(f" Added: {arcname}")
|
||||
|
||||
print(f"\n[OK] Successfully packaged skill to: {skill_filename}")
|
||||
return skill_filename
|
||||
|
||||
except Exception as e:
|
||||
_cleanup_partial_archive(skill_filename)
|
||||
print(f"[ERROR] Error creating .skill file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python package_skill.py <path/to/skill-folder> [output-directory]")
|
||||
print("\nExample:")
|
||||
print(" python package_skill.py skills/public/my-skill")
|
||||
print(" python package_skill.py skills/public/my-skill ./dist")
|
||||
sys.exit(1)
|
||||
|
||||
skill_path = sys.argv[1]
|
||||
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
|
||||
print(f"Packaging skill: {skill_path}")
|
||||
if output_dir:
|
||||
print(f" Output directory: {output_dir}")
|
||||
print()
|
||||
|
||||
result = package_skill(skill_path, output_dir)
|
||||
|
||||
if result:
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
213
nanobot/skills/skill-creator/scripts/quick_validate.py
Normal file
213
nanobot/skills/skill-creator/scripts/quick_validate.py
Normal file
@@ -0,0 +1,213 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Minimal validator for nanobot skill folders.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ModuleNotFoundError:
|
||||
yaml = None
|
||||
|
||||
MAX_SKILL_NAME_LENGTH = 64
|
||||
ALLOWED_FRONTMATTER_KEYS = {
|
||||
"name",
|
||||
"description",
|
||||
"metadata",
|
||||
"always",
|
||||
"license",
|
||||
"allowed-tools",
|
||||
}
|
||||
ALLOWED_RESOURCE_DIRS = {"scripts", "references", "assets"}
|
||||
PLACEHOLDER_MARKERS = ("[todo", "todo:")
|
||||
|
||||
|
||||
def _extract_frontmatter(content: str) -> Optional[str]:
|
||||
lines = content.splitlines()
|
||||
if not lines or lines[0].strip() != "---":
|
||||
return None
|
||||
for i in range(1, len(lines)):
|
||||
if lines[i].strip() == "---":
|
||||
return "\n".join(lines[1:i])
|
||||
return None
|
||||
|
||||
|
||||
def _parse_simple_frontmatter(frontmatter_text: str) -> Optional[dict[str, str]]:
|
||||
"""Fallback parser for simple frontmatter when PyYAML is unavailable."""
|
||||
parsed: dict[str, str] = {}
|
||||
current_key: Optional[str] = None
|
||||
multiline_key: Optional[str] = None
|
||||
|
||||
for raw_line in frontmatter_text.splitlines():
|
||||
stripped = raw_line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
|
||||
is_indented = raw_line[:1].isspace()
|
||||
if is_indented:
|
||||
if current_key is None:
|
||||
return None
|
||||
current_value = parsed[current_key]
|
||||
parsed[current_key] = f"{current_value}\n{stripped}" if current_value else stripped
|
||||
continue
|
||||
|
||||
if ":" not in stripped:
|
||||
return None
|
||||
|
||||
key, value = stripped.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if not key:
|
||||
return None
|
||||
|
||||
if value in {"|", ">"}:
|
||||
parsed[key] = ""
|
||||
current_key = key
|
||||
multiline_key = key
|
||||
continue
|
||||
|
||||
if (value.startswith('"') and value.endswith('"')) or (
|
||||
value.startswith("'") and value.endswith("'")
|
||||
):
|
||||
value = value[1:-1]
|
||||
parsed[key] = value
|
||||
current_key = key
|
||||
multiline_key = None
|
||||
|
||||
if multiline_key is not None and multiline_key not in parsed:
|
||||
return None
|
||||
return parsed
|
||||
|
||||
|
||||
def _load_frontmatter(frontmatter_text: str) -> tuple[Optional[dict], Optional[str]]:
|
||||
if yaml is not None:
|
||||
try:
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
except yaml.YAMLError as exc:
|
||||
return None, f"Invalid YAML in frontmatter: {exc}"
|
||||
if not isinstance(frontmatter, dict):
|
||||
return None, "Frontmatter must be a YAML dictionary"
|
||||
return frontmatter, None
|
||||
|
||||
frontmatter = _parse_simple_frontmatter(frontmatter_text)
|
||||
if frontmatter is None:
|
||||
return None, "Invalid YAML in frontmatter: unsupported syntax without PyYAML installed"
|
||||
return frontmatter, None
|
||||
|
||||
|
||||
def _validate_skill_name(name: str, folder_name: str) -> Optional[str]:
|
||||
if not re.fullmatch(r"[a-z0-9]+(?:-[a-z0-9]+)*", name):
|
||||
return (
|
||||
f"Name '{name}' should be hyphen-case "
|
||||
"(lowercase letters, digits, and single hyphens only)"
|
||||
)
|
||||
if len(name) > MAX_SKILL_NAME_LENGTH:
|
||||
return (
|
||||
f"Name is too long ({len(name)} characters). "
|
||||
f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
|
||||
)
|
||||
if name != folder_name:
|
||||
return f"Skill name '{name}' must match directory name '{folder_name}'"
|
||||
return None
|
||||
|
||||
|
||||
def _validate_description(description: str) -> Optional[str]:
|
||||
trimmed = description.strip()
|
||||
if not trimmed:
|
||||
return "Description cannot be empty"
|
||||
lowered = trimmed.lower()
|
||||
if any(marker in lowered for marker in PLACEHOLDER_MARKERS):
|
||||
return "Description still contains TODO placeholder text"
|
||||
if "<" in trimmed or ">" in trimmed:
|
||||
return "Description cannot contain angle brackets (< or >)"
|
||||
if len(trimmed) > 1024:
|
||||
return f"Description is too long ({len(trimmed)} characters). Maximum is 1024 characters."
|
||||
return None
|
||||
|
||||
|
||||
def validate_skill(skill_path):
|
||||
"""Validate a skill folder structure and required frontmatter."""
|
||||
skill_path = Path(skill_path).resolve()
|
||||
|
||||
if not skill_path.exists():
|
||||
return False, f"Skill folder not found: {skill_path}"
|
||||
if not skill_path.is_dir():
|
||||
return False, f"Path is not a directory: {skill_path}"
|
||||
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
return False, "SKILL.md not found"
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
return False, f"Could not read SKILL.md: {exc}"
|
||||
|
||||
frontmatter_text = _extract_frontmatter(content)
|
||||
if frontmatter_text is None:
|
||||
return False, "Invalid frontmatter format"
|
||||
|
||||
frontmatter, error = _load_frontmatter(frontmatter_text)
|
||||
if error:
|
||||
return False, error
|
||||
|
||||
unexpected_keys = sorted(set(frontmatter.keys()) - ALLOWED_FRONTMATTER_KEYS)
|
||||
if unexpected_keys:
|
||||
allowed = ", ".join(sorted(ALLOWED_FRONTMATTER_KEYS))
|
||||
unexpected = ", ".join(unexpected_keys)
|
||||
return (
|
||||
False,
|
||||
f"Unexpected key(s) in SKILL.md frontmatter: {unexpected}. Allowed properties are: {allowed}",
|
||||
)
|
||||
|
||||
if "name" not in frontmatter:
|
||||
return False, "Missing 'name' in frontmatter"
|
||||
if "description" not in frontmatter:
|
||||
return False, "Missing 'description' in frontmatter"
|
||||
|
||||
name = frontmatter["name"]
|
||||
if not isinstance(name, str):
|
||||
return False, f"Name must be a string, got {type(name).__name__}"
|
||||
name_error = _validate_skill_name(name.strip(), skill_path.name)
|
||||
if name_error:
|
||||
return False, name_error
|
||||
|
||||
description = frontmatter["description"]
|
||||
if not isinstance(description, str):
|
||||
return False, f"Description must be a string, got {type(description).__name__}"
|
||||
description_error = _validate_description(description)
|
||||
if description_error:
|
||||
return False, description_error
|
||||
|
||||
always = frontmatter.get("always")
|
||||
if always is not None and not isinstance(always, bool):
|
||||
return False, f"'always' must be a boolean, got {type(always).__name__}"
|
||||
|
||||
for child in skill_path.iterdir():
|
||||
if child.name == "SKILL.md":
|
||||
continue
|
||||
if child.is_dir() and child.name in ALLOWED_RESOURCE_DIRS:
|
||||
continue
|
||||
if child.is_symlink():
|
||||
continue
|
||||
return (
|
||||
False,
|
||||
f"Unexpected file or directory in skill root: {child.name}. "
|
||||
"Only SKILL.md, scripts/, references/, and assets/ are allowed.",
|
||||
)
|
||||
|
||||
return True, "Skill is valid!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python quick_validate.py <skill_directory>")
|
||||
sys.exit(1)
|
||||
|
||||
valid, message = validate_skill(sys.argv[1])
|
||||
print(message)
|
||||
sys.exit(0 if valid else 1)
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Utility functions for nanobot."""
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
|
||||
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]
|
||||
__all__ = ["ensure_dir"]
|
||||
|
||||
@@ -1,8 +1,25 @@
|
||||
"""Utility functions for nanobot."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
|
||||
def detect_image_mime(data: bytes) -> str | None:
|
||||
"""Detect image MIME type from magic bytes, ignoring file extension."""
|
||||
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
return "image/png"
|
||||
if data[:3] == b"\xff\xd8\xff":
|
||||
return "image/jpeg"
|
||||
if data[:6] in (b"GIF87a", b"GIF89a"):
|
||||
return "image/gif"
|
||||
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
return None
|
||||
|
||||
|
||||
def ensure_dir(path: Path) -> Path:
|
||||
@@ -11,17 +28,6 @@ def ensure_dir(path: Path) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
def get_data_path() -> Path:
|
||||
"""~/.nanobot data directory."""
|
||||
return ensure_dir(Path.home() / ".nanobot")
|
||||
|
||||
|
||||
def get_workspace_path(workspace: str | None = None) -> Path:
|
||||
"""Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace."""
|
||||
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
|
||||
return ensure_dir(path)
|
||||
|
||||
|
||||
def timestamp() -> str:
|
||||
"""Current ISO timestamp."""
|
||||
return datetime.now().isoformat()
|
||||
@@ -34,6 +40,136 @@ def safe_filename(name: str) -> str:
|
||||
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||
|
||||
|
||||
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
||||
"""
|
||||
Split content into chunks within max_len, preferring line breaks.
|
||||
|
||||
Args:
|
||||
content: The text content to split.
|
||||
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
|
||||
|
||||
Returns:
|
||||
List of message chunks, each within max_len.
|
||||
"""
|
||||
if not content:
|
||||
return []
|
||||
if len(content) <= max_len:
|
||||
return [content]
|
||||
chunks: list[str] = []
|
||||
while content:
|
||||
if len(content) <= max_len:
|
||||
chunks.append(content)
|
||||
break
|
||||
cut = content[:max_len]
|
||||
# Try to break at newline first, then space, then hard break
|
||||
pos = cut.rfind('\n')
|
||||
if pos <= 0:
|
||||
pos = cut.rfind(' ')
|
||||
if pos <= 0:
|
||||
pos = max_len
|
||||
chunks.append(content[:pos])
|
||||
content = content[pos:].lstrip()
|
||||
return chunks
|
||||
|
||||
|
||||
def build_assistant_message(
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
thinking_blocks: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a provider-safe assistant message with optional reasoning fields."""
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
if thinking_blocks:
|
||||
msg["thinking_blocks"] = thinking_blocks
|
||||
return msg
|
||||
|
||||
|
||||
def estimate_prompt_tokens(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
"""Estimate prompt tokens with tiktoken."""
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
parts: list[str] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
txt = part.get("text", "")
|
||||
if txt:
|
||||
parts.append(txt)
|
||||
if tools:
|
||||
parts.append(json.dumps(tools, ensure_ascii=False))
|
||||
return len(enc.encode("\n".join(parts)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||
"""Estimate prompt tokens contributed by one persisted message."""
|
||||
content = message.get("content")
|
||||
parts: list[str] = []
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
parts.append(text)
|
||||
else:
|
||||
parts.append(json.dumps(part, ensure_ascii=False))
|
||||
elif content is not None:
|
||||
parts.append(json.dumps(content, ensure_ascii=False))
|
||||
|
||||
for key in ("name", "tool_call_id"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
parts.append(value)
|
||||
if message.get("tool_calls"):
|
||||
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||
|
||||
payload = "\n".join(parts)
|
||||
if not payload:
|
||||
return 1
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
return max(1, len(enc.encode(payload)))
|
||||
except Exception:
|
||||
return max(1, len(payload) // 4)
|
||||
|
||||
|
||||
def estimate_prompt_tokens_chain(
|
||||
provider: Any,
|
||||
model: str | None,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
|
||||
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
|
||||
if callable(provider_counter):
|
||||
try:
|
||||
tokens, source = provider_counter(messages, tools, model)
|
||||
if isinstance(tokens, (int, float)) and tokens > 0:
|
||||
return int(tokens), str(source or "provider_counter")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
estimated = estimate_prompt_tokens(messages, tools)
|
||||
if estimated > 0:
|
||||
return int(estimated), "tiktoken"
|
||||
return 0, "none"
|
||||
|
||||
|
||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||
from importlib.resources import files as pkg_files
|
||||
@@ -54,7 +190,7 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
|
||||
added.append(str(dest.relative_to(workspace)))
|
||||
|
||||
for item in tpl.iterdir():
|
||||
if item.name.endswith(".md"):
|
||||
if item.name.endswith(".md") and not item.name.startswith("."):
|
||||
_write(item, workspace / item.name)
|
||||
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
|
||||
_write(None, workspace / "memory" / "HISTORY.md")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "nanobot-ai"
|
||||
version = "0.1.4.post3"
|
||||
version = "0.1.4.post4"
|
||||
description = "A lightweight personal AI assistant framework"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "MIT"}
|
||||
@@ -18,7 +18,7 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
"typer>=0.20.0,<1.0.0",
|
||||
"litellm>=1.81.5,<2.0.0",
|
||||
"litellm>=1.82.1,<2.0.0",
|
||||
"pydantic>=2.12.0,<3.0.0",
|
||||
"pydantic-settings>=2.12.0,<3.0.0",
|
||||
"websockets>=16.0,<17.0",
|
||||
@@ -44,9 +44,13 @@ dependencies = [
|
||||
"json-repair>=0.57.0,<1.0.0",
|
||||
"chardet>=3.0.2,<6.0.0",
|
||||
"openai>=2.8.0",
|
||||
"tiktoken>=0.12.0,<1.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
wecom = [
|
||||
"wecom-aibot-sdk-python>=0.1.2",
|
||||
]
|
||||
matrix = [
|
||||
"matrix-nio[e2e]>=0.25.2",
|
||||
"mistune>=3.0.0,<4.0.0",
|
||||
@@ -68,6 +72,9 @@ nanobot = "nanobot.cli.commands:app"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["nanobot"]
|
||||
|
||||
|
||||
399
tests/test_azure_openai_provider.py
Normal file
399
tests/test_azure_openai_provider.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
25
tests/test_base_channel.py
Normal file
25
tests/test_base_channel.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
|
||||
class _DummyChannel(BaseChannel):
|
||||
name = "dummy"
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_is_allowed_requires_exact_match() -> None:
|
||||
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("allow@email.com") is True
|
||||
assert channel.is_allowed("attacker|allow@email.com") is False
|
||||
@@ -1,6 +1,6 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
@@ -14,13 +14,17 @@ from nanobot.providers.registry import find_by_model
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class _StopGateway(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paths():
|
||||
"""Mock config/workspace paths for test isolation."""
|
||||
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
|
||||
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||
patch("nanobot.utils.helpers.get_workspace_path") as mock_ws:
|
||||
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
|
||||
|
||||
base_dir = Path("./test_onboard_data")
|
||||
if base_dir.exists():
|
||||
@@ -110,6 +114,35 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
|
||||
assert config.get_provider_name() == "openai_codex"
|
||||
|
||||
|
||||
def test_config_matches_explicit_ollama_prefix_without_api_key():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "ollama/llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||
config = Config()
|
||||
config.agents.defaults.provider = "ollama"
|
||||
config.agents.defaults.model = "llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_config_auto_detects_ollama_from_local_api_base():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
|
||||
|
||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||
|
||||
@@ -128,3 +161,303 @@ def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
|
||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_runtime(tmp_path):
|
||||
"""Mock agent command dependencies for focused CLI tests."""
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
|
||||
cron_dir = tmp_path / "data" / "cron"
|
||||
|
||||
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
||||
patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \
|
||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
||||
patch("nanobot.cli.commands._make_provider", return_value=object()), \
|
||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
||||
patch("nanobot.bus.queue.MessageBus"), \
|
||||
patch("nanobot.cron.service.CronService"), \
|
||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
|
||||
|
||||
agent_loop = MagicMock()
|
||||
agent_loop.channels_config = None
|
||||
agent_loop.process_direct = AsyncMock(return_value="mock-response")
|
||||
agent_loop.close_mcp = AsyncMock(return_value=None)
|
||||
mock_agent_loop_cls.return_value = agent_loop
|
||||
|
||||
yield {
|
||||
"config": config,
|
||||
"load_config": mock_load_config,
|
||||
"sync_templates": mock_sync_templates,
|
||||
"agent_loop_cls": mock_agent_loop_cls,
|
||||
"agent_loop": agent_loop,
|
||||
"print_response": mock_print_response,
|
||||
}
|
||||
|
||||
|
||||
def test_agent_help_shows_workspace_and_config_options():
|
||||
result = runner.invoke(app, ["agent", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "--workspace" in result.stdout
|
||||
assert "-w" in result.stdout
|
||||
assert "--config" in result.stdout
|
||||
assert "-c" in result.stdout
|
||||
|
||||
|
||||
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (None,)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (
|
||||
mock_agent_runtime["config"].workspace_path,
|
||||
)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
|
||||
mock_agent_runtime["config"].workspace_path
|
||||
)
|
||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||
mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True)
|
||||
|
||||
|
||||
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
|
||||
config_path = tmp_path / "agent-config.json"
|
||||
config_path.write_text("{}")
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||
|
||||
|
||||
def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs) -> str:
|
||||
return "ok"
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["config_path"] == config_file.resolve()
|
||||
|
||||
|
||||
def test_agent_overrides_workspace_path(mock_agent_runtime):
|
||||
workspace_path = Path("/tmp/agent-workspace")
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-w", str(workspace_path)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
|
||||
config_path = tmp_path / "agent-config.json"
|
||||
config_path.write_text("{}")
|
||||
workspace_path = Path("/tmp/agent-workspace")
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["agent", "-m", "hello", "-c", str(config_path), "-w", str(workspace_path)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
|
||||
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
|
||||
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
|
||||
mock_agent_runtime["config"].agents.defaults.memory_window = 100
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert seen["config_path"] == config_file.resolve()
|
||||
assert seen["workspace"] == Path(config.agents.defaults.workspace)
|
||||
|
||||
|
||||
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
override = tmp_path / "override-workspace"
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["gateway", "--config", str(config_file), "--workspace", str(override)],
|
||||
)
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert seen["workspace"] == override
|
||||
assert config.workspace_path == override
|
||||
|
||||
|
||||
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.memory_window = 100
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGateway("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "port 18791" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
88
tests/test_config_migration.py
Normal file
88
tests/test_config_migration.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 1234,
|
||||
"memoryWindow": 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.agents.defaults.max_tokens == 1234
|
||||
assert config.agents.defaults.context_window_tokens == 65_536
|
||||
assert config.agents.defaults.should_warn_deprecated_memory_window is True
|
||||
|
||||
|
||||
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 2222,
|
||||
"memoryWindow": 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
|
||||
assert defaults["maxTokens"] == 2222
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
|
||||
|
||||
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 3333,
|
||||
"memoryWindow": 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
assert defaults["maxTokens"] == 3333
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
42
tests/test_config_paths.py
Normal file
42
tests/test_config_paths.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.paths import (
|
||||
get_bridge_install_dir,
|
||||
get_cli_history_path,
|
||||
get_cron_dir,
|
||||
get_data_dir,
|
||||
get_legacy_sessions_dir,
|
||||
get_logs_dir,
|
||||
get_media_dir,
|
||||
get_runtime_subdir,
|
||||
get_workspace_path,
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance-a" / "config.json"
|
||||
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
|
||||
|
||||
assert get_data_dir() == config_file.parent
|
||||
assert get_runtime_subdir("cron") == config_file.parent / "cron"
|
||||
assert get_cron_dir() == config_file.parent / "cron"
|
||||
assert get_logs_dir() == config_file.parent / "logs"
|
||||
|
||||
|
||||
def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance-b" / "config.json"
|
||||
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
|
||||
|
||||
assert get_media_dir() == config_file.parent / "media"
|
||||
assert get_media_dir("telegram") == config_file.parent / "media" / "telegram"
|
||||
|
||||
|
||||
def test_shared_and_legacy_paths_remain_global() -> None:
|
||||
assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history"
|
||||
assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge"
|
||||
assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions"
|
||||
|
||||
|
||||
def test_workspace_path_is_explicitly_resolved() -> None:
|
||||
assert get_workspace_path() == Path.home() / ".nanobot" / "workspace"
|
||||
assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace"
|
||||
@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
|
||||
|
||||
class TestConsolidationDeduplicationGuard:
|
||||
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||
class TestNewCommandArchival:
|
||||
"""Test /new archival behavior with the simplified consolidation flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||
@staticmethod
|
||||
def _make_loop(tmp_path: Path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=1,
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls
|
||||
consolidation_calls += 1
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await loop._process_message(msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 1, (
|
||||
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
active = 0
|
||||
max_active = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls, active, max_active
|
||||
consolidation_calls += 1
|
||||
active += 1
|
||||
max_active = max(max_active, active)
|
||||
await asyncio.sleep(0.05)
|
||||
active -= 1
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
await loop._process_message(new_msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 2, (
|
||||
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||
)
|
||||
assert max_active == 1, (
|
||||
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||
started.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
await started.wait()
|
||||
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
assert len(loop._consolidation_tasks) == 0, (
|
||||
"Task reference must be removed after completion"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new waits for in-flight consolidation and archives before clear."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = 0
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
started.set()
|
||||
await release.wait()
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||
return loop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||
"""/new must keep session data if archive step reports failure."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
if archive_all:
|
||||
return False
|
||||
return True
|
||||
async def _failing_consolidate(_messages) -> bool:
|
||||
return False
|
||||
|
||||
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "failed" in response.content.lower()
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == before_count, (
|
||||
"Session must remain intact when /new archival fails"
|
||||
)
|
||||
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new should archive only messages not yet consolidated by prior task."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
session.last_consolidated = len(session.messages) - 3
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _fake_consolidate(messages) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
|
||||
started.set()
|
||||
await release.wait()
|
||||
sess.last_consolidated = len(sess.messages) - 3
|
||||
archived_count = len(messages)
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done()
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count == 3, (
|
||||
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||
)
|
||||
assert archived_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||
"""/new clears session and returns confirmation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _ok_consolidate(_messages) -> bool:
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime as real_datetime
|
||||
from importlib.resources import files as pkg_files
|
||||
from pathlib import Path
|
||||
import datetime as datetime_module
|
||||
|
||||
@@ -23,6 +24,13 @@ def _make_workspace(tmp_path: Path) -> Path:
|
||||
return workspace
|
||||
|
||||
|
||||
def test_bootstrap_files_are_backed_by_templates() -> None:
|
||||
template_dir = pkg_files("nanobot") / "templates"
|
||||
|
||||
for filename in ContextBuilder.BOOTSTRAP_FILES:
|
||||
assert (template_dir / filename).is_file(), f"missing bootstrap template: {filename}"
|
||||
|
||||
|
||||
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
|
||||
"""System prompt should not change just because wall clock minute changes."""
|
||||
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
|
||||
|
||||
111
tests/test_dingtalk_channel.py
Normal file
111
tests/test_dingtalk_channel.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
import nanobot.channels.dingtalk as dingtalk_module
|
||||
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
||||
from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
|
||||
self.status_code = status_code
|
||||
self._json_body = json_body or {}
|
||||
self.text = "{}"
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._json_body
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url: str, json=None, headers=None):
|
||||
self.calls.append({"url": url, "json": json, "headers": headers})
|
||||
return _FakeResponse()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(config, bus)
|
||||
|
||||
await channel._on_message(
|
||||
"hello",
|
||||
sender_id="user1",
|
||||
sender_name="Alice",
|
||||
conversation_type="2",
|
||||
conversation_id="conv123",
|
||||
)
|
||||
|
||||
msg = await bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
assert msg.metadata["conversation_type"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_send_uses_group_messages_api() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _FakeHttp()
|
||||
|
||||
ok = await channel._send_batch_message(
|
||||
"token",
|
||||
"group:conv123",
|
||||
"sampleMarkdown",
|
||||
{"text": "hello", "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
assert ok is True
|
||||
call = channel._http.calls[0]
|
||||
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
assert call["json"]["openConversationId"] == "conv123"
|
||||
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeChatbotMessage:
|
||||
text = None
|
||||
extensions = {"content": {"recognition": "voice transcript"}}
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "audio"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeChatbotMessage()
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "2",
|
||||
"conversationId": "conv123",
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert msg.content == "voice transcript"
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
@@ -1,4 +1,4 @@
|
||||
from nanobot.channels.feishu import _extract_post_content
|
||||
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
|
||||
|
||||
|
||||
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||
@@ -38,3 +38,28 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
||||
|
||||
assert text == "Daily report"
|
||||
assert image_keys == ["img_a", "img_b"]
|
||||
|
||||
|
||||
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
|
||||
class Builder:
|
||||
pass
|
||||
|
||||
builder = Builder()
|
||||
same = FeishuChannel._register_optional_event(builder, "missing", object())
|
||||
assert same is builder
|
||||
|
||||
|
||||
def test_register_optional_event_calls_supported_method() -> None:
|
||||
called = []
|
||||
|
||||
class Builder:
|
||||
def register_event(self, handler):
|
||||
called.append(handler)
|
||||
return self
|
||||
|
||||
builder = Builder()
|
||||
handler = object()
|
||||
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
|
||||
|
||||
assert same is builder
|
||||
assert called == [handler]
|
||||
|
||||
251
tests/test_filesystem_tools.py
Normal file
251
tests/test_filesystem_tools.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import (
|
||||
EditFileTool,
|
||||
ListDirTool,
|
||||
ReadFileTool,
|
||||
_find_match,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def sample_file(self, tmp_path):
|
||||
f = tmp_path / "sample.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
return f
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_read_has_line_numbers(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file))
|
||||
assert "1| line 1" in result
|
||||
assert "20| line 20" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_and_limit(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
|
||||
assert "5| line 5" in result
|
||||
assert "7| line 7" in result
|
||||
assert "8| line 8" not in result
|
||||
assert "Use offset=8 to continue" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_beyond_end(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=999)
|
||||
assert "Error" in result
|
||||
assert "beyond end" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_of_file_marker(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
|
||||
assert "End of file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert "Empty file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.txt"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_char_budget_trims(self, tool, tmp_path):
|
||||
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
|
||||
f = tmp_path / "big.txt"
|
||||
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
|
||||
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
|
||||
assert "Use offset=" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_match (unit tests for the helper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindMatch:
|
||||
|
||||
def test_exact_match(self):
|
||||
match, count = _find_match("hello world", "world")
|
||||
assert match == "world"
|
||||
assert count == 1
|
||||
|
||||
def test_exact_no_match(self):
|
||||
match, count = _find_match("hello world", "xyz")
|
||||
assert match is None
|
||||
assert count == 0
|
||||
|
||||
def test_crlf_normalisation(self):
|
||||
# Caller normalises CRLF before calling _find_match, so test with
|
||||
# pre-normalised content to verify exact match still works.
|
||||
content = "line1\nline2\nline3"
|
||||
old_text = "line1\nline2\nline3"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_line_trim_fallback(self):
|
||||
content = " def foo():\n pass\n"
|
||||
old_text = "def foo():\n pass"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# The returned match should be the *original* indented text
|
||||
assert " def foo():" in match
|
||||
|
||||
def test_line_trim_multiple_candidates(self):
|
||||
content = " a\n b\n a\n b\n"
|
||||
old_text = "a\nb"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert count == 2
|
||||
|
||||
def test_empty_old_text(self):
|
||||
match, count = _find_match("hello", "")
|
||||
# Empty string is always "in" any string via exact match
|
||||
assert match == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EditFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crlf_normalisation(self, tool, tmp_path):
|
||||
f = tmp_path / "crlf.py"
|
||||
f.write_bytes(b"line1\r\nline2\r\nline3")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
raw = f.read_bytes()
|
||||
assert b"LINE1" in raw
|
||||
# CRLF line endings should be preserved throughout the file
|
||||
assert b"\r\n" in raw
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(" def foo():\n pass\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert "bar" in f.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears" in result.lower() or "Warning" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all(self, tool, tmp_path):
|
||||
f = tmp_path / "multi.py"
|
||||
f.write_text("foo bar foo bar foo", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="foo", new_text="baz", replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "baz bar baz bar baz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
f = tmp_path / "nf.py"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ListDirTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListDirTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ListDirTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def populated_dir(self, tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "main.py").write_text("pass")
|
||||
(tmp_path / "src" / "utils.py").write_text("pass")
|
||||
(tmp_path / "README.md").write_text("hi")
|
||||
(tmp_path / ".git").mkdir()
|
||||
(tmp_path / ".git" / "config").write_text("x")
|
||||
(tmp_path / "node_modules").mkdir()
|
||||
(tmp_path / "node_modules" / "pkg").mkdir()
|
||||
return tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_list(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir))
|
||||
assert "README.md" in result
|
||||
assert "src" in result
|
||||
# .git and node_modules should be ignored
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir), recursive=True)
|
||||
assert "src/main.py" in result
|
||||
assert "src/utils.py" in result
|
||||
assert "README.md" in result
|
||||
# Ignored dirs should not appear
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_entries_truncation(self, tool, tmp_path):
|
||||
for i in range(10):
|
||||
(tmp_path / f"file_{i}.txt").write_text("x")
|
||||
result = await tool.execute(path=str(tmp_path), max_entries=3)
|
||||
assert "truncated" in result
|
||||
assert "3 of 10" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_dir(self, tool, tmp_path):
|
||||
d = tmp_path / "empty"
|
||||
d.mkdir()
|
||||
result = await tool.execute(path=str(d))
|
||||
assert "empty" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
53
tests/test_gemini_thought_signature.py
Normal file
53
tests/test_gemini_thought_signature.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
|
||||
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
|
||||
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
|
||||
|
||||
response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="tool_calls",
|
||||
message=SimpleNamespace(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="call_123",
|
||||
function=SimpleNamespace(
|
||||
name="read_file",
|
||||
arguments='{"path":"todo.md"}',
|
||||
provider_specific_fields={"inner": "value"},
|
||||
),
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
parsed = provider._parse_response(response)
|
||||
|
||||
assert len(parsed.tool_calls) == 1
|
||||
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
|
||||
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
|
||||
|
||||
|
||||
def test_tool_call_request_serializes_provider_fields() -> None:
|
||||
tool_call = ToolCallRequest(
|
||||
id="abc123xyz",
|
||||
name="read_file",
|
||||
arguments={"path": "todo.md"},
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
function_provider_specific_fields={"inner": "value"},
|
||||
)
|
||||
|
||||
message = tool_call.to_openai_tool_call()
|
||||
|
||||
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
|
||||
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||
assert message["function"]["arguments"] == '{"path": "todo.md"}'
|
||||
@@ -3,18 +3,24 @@ import asyncio
|
||||
import pytest
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
class DummyProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(tmp_path) -> None:
|
||||
@@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
||||
)
|
||||
|
||||
assert await service.trigger_now() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||
provider = DummyProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check open tasks"},
|
||||
)
|
||||
],
|
||||
),
|
||||
])
|
||||
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
action, tasks = await service._decide("heartbeat content")
|
||||
|
||||
assert action == "run"
|
||||
assert tasks == "check open tasks"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
190
tests/test_loop_consolidation_tokens.py
Normal file
190
tests/test_loop_consolidation_tokens.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
import nanobot.agent.memory as memory_module
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=context_window_tokens,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
|
||||
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
|
||||
assert session.last_consolidated == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
|
||||
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (300, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
|
||||
"""Once triggered, consolidation should continue until it drops below half threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (150, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
|
||||
"""Verify preflight consolidation runs before the LLM call in process_direct."""
|
||||
order: list[str] = []
|
||||
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
|
||||
async def track_consolidate(messages):
|
||||
order.append("consolidate")
|
||||
return True
|
||||
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||
|
||||
async def track_llm(*args, **kwargs):
|
||||
order.append("llm")
|
||||
return LLMResponse(content="ok", tool_calls=[])
|
||||
loop.provider.chat_with_retry = track_llm
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
return (1000 if call_count[0] <= 1 else 80, "test")
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert "consolidate" in order
|
||||
assert "llm" in order
|
||||
assert order.index("consolidate") < order.index("llm")
|
||||
@@ -5,7 +5,7 @@ from nanobot.session.manager import Session
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._TOOL_RESULT_MAX_CHARS = 500
|
||||
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
||||
return loop
|
||||
|
||||
|
||||
@@ -39,3 +39,17 @@ def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
||||
|
||||
|
||||
def test_save_turn_keeps_tool_results_under_16k() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:tool-result")
|
||||
content = "x" * 12_000
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
|
||||
skip=0,
|
||||
)
|
||||
|
||||
assert session.messages[0]["content"] == content
|
||||
|
||||
99
tests/test_mcp_tool.py
Normal file
99
tests/test_mcp_tool.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.mcp import MCPToolWrapper
|
||||
|
||||
|
||||
class _FakeTextContent:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
mod = ModuleType("mcp")
|
||||
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
|
||||
monkeypatch.setitem(sys.modules, "mcp", mod)
|
||||
|
||||
|
||||
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_text_blocks() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
assert arguments == {"value": 1}
|
||||
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute(value=1)
|
||||
|
||||
assert result == "hello\n42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_timeout_message() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call timed out after 0.01s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_server_cancelled_error() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call was cancelled)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_re_raises_external_cancellation() -> None:
|
||||
started = asyncio.Event()
|
||||
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||
task = asyncio.create_task(wrapper.execute())
|
||||
await started.wait()
|
||||
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_generic_exception() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call failed: RuntimeError)"
|
||||
@@ -7,23 +7,20 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_session(message_count: int = 30, memory_window: int = 50):
|
||||
"""Create a mock session with messages."""
|
||||
session = MagicMock()
|
||||
session.messages = [
|
||||
def _make_messages(message_count: int = 30):
|
||||
"""Create a list of mock messages."""
|
||||
return [
|
||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||
for i in range(message_count)
|
||||
]
|
||||
session.last_consolidated = 0
|
||||
return session
|
||||
|
||||
|
||||
def _make_tool_response(history_entry, memory_update):
|
||||
@@ -43,6 +40,22 @@ def _make_tool_response(history_entry, memory_update):
|
||||
)
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
class TestMemoryConsolidationTypeHandling:
|
||||
"""Test that consolidation handles various argument types correctly."""
|
||||
|
||||
@@ -57,9 +70,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
@@ -77,9 +91,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
@@ -112,9 +127,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
@@ -127,21 +143,23 @@ class TestMemoryConsolidationTypeHandling:
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when messages < keep_count."""
|
||||
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when the selected chunk is empty."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
session = _make_session(message_count=10)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages: list[dict] = []
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
@@ -167,9 +185,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
@@ -192,9 +211,10 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -215,8 +235,56 @@ class TestMemoryConsolidationTypeHandling:
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
provider.chat_with_retry = provider.chat
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="503 server error", finish_reason="error"),
|
||||
_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
),
|
||||
])
|
||||
messages = _make_messages(message_count=60)
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
|
||||
"""Consolidation no longer passes generation params — the provider owns them."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat_with_retry.assert_awaited_once()
|
||||
_, kwargs = provider.chat_with_retry.await_args
|
||||
assert kwargs["model"] == "test-model"
|
||||
assert "temperature" not in kwargs
|
||||
assert "max_tokens" not in kwargs
|
||||
assert "reasoning_effort" not in kwargs
|
||||
|
||||
@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
class TestMessageToolSuppressLogic:
|
||||
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||
@@ -86,6 +86,35 @@ class TestMessageToolSuppressLogic:
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
calls = iter([
|
||||
LLMResponse(
|
||||
content="Visible<think>hidden</think>",
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content="secret reasoning",
|
||||
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
progress: list[tuple[str, bool]] = []
|
||||
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False),
|
||||
('read_file("foo.txt")', True),
|
||||
]
|
||||
|
||||
|
||||
class TestMessageToolTurnTracking:
|
||||
|
||||
|
||||
125
tests/test_provider_retry.py
Normal file
125
tests/test_provider_retry.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||
|
||||
|
||||
class ScriptedProvider(LLMProvider):
|
||||
def __init__(self, responses):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
self.last_kwargs: dict = {}
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
self.last_kwargs = kwargs
|
||||
response = self._responses.pop(0)
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
return response
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.finish_reason == "stop"
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "401 unauthorized"
|
||||
assert provider.calls == 1
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit a", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit b", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit c", finish_reason="error"),
|
||||
LLMResponse(content="503 final server error", finish_reason="error"),
|
||||
])
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "503 final server error"
|
||||
assert provider.calls == 4
|
||||
assert delays == [1, 2, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
||||
provider = ScriptedProvider([asyncio.CancelledError()])
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||
"""When callers omit generation params, provider.generation defaults are used."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.2
|
||||
assert provider.last_kwargs["max_tokens"] == 321
|
||||
assert provider.last_kwargs["reasoning_effort"] == "high"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
||||
"""Explicit kwargs should override provider.generation defaults."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||
|
||||
await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.9,
|
||||
max_tokens=9999,
|
||||
reasoning_effort="low",
|
||||
)
|
||||
|
||||
assert provider.last_kwargs["temperature"] == 0.9
|
||||
assert provider.last_kwargs["max_tokens"] == 9999
|
||||
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||
66
tests/test_qq_channel.py
Normal file
66
tests/test_qq_channel.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import QQChannel
|
||||
from nanobot.config.schema import QQConfig
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.api = _FakeApi()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_group_message_routes_to_group_chat_id() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
|
||||
|
||||
data = SimpleNamespace(
|
||||
id="msg1",
|
||||
content="hello",
|
||||
group_openid="group123",
|
||||
author=SimpleNamespace(member_openid="user1"),
|
||||
)
|
||||
|
||||
await channel._on_message(data, is_group=True)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["group123"] = "group"
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="group123",
|
||||
content="hello",
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(channel._client.api.group_calls) == 1
|
||||
call = channel._client.api.group_calls[0]
|
||||
assert call["group_openid"] == "group123"
|
||||
assert call["msg_id"] == "msg1"
|
||||
assert call["msg_seq"] == 2
|
||||
assert not channel._client.api.c2c_calls
|
||||
76
tests/test_restart_command.py
Normal file
76
tests/test_restart_command.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for /restart slash command."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
|
||||
def _make_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager"):
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
return loop, bus
|
||||
|
||||
|
||||
class TestRestartCommand:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_sends_message_and_calls_execv(self):
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
|
||||
|
||||
with patch("nanobot.agent.loop.os.execv") as mock_execv:
|
||||
await loop._handle_restart(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "Restarting" in out.content
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
mock_execv.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restart_intercepted_in_run_loop(self):
|
||||
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
|
||||
|
||||
with patch.object(loop, "_handle_restart") as mock_handle:
|
||||
mock_handle.return_value = None
|
||||
await bus.publish_inbound(msg)
|
||||
|
||||
loop._running = True
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.1)
|
||||
loop._running = False
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_includes_restart(self):
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
|
||||
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
assert response is not None
|
||||
assert "/restart" in response.content
|
||||
127
tests/test_skill_creator_scripts.py
Normal file
127
tests/test_skill_creator_scripts.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import importlib
|
||||
import shutil
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
|
||||
if str(SCRIPT_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(SCRIPT_DIR))
|
||||
|
||||
init_skill = importlib.import_module("init_skill")
|
||||
package_skill = importlib.import_module("package_skill")
|
||||
quick_validate = importlib.import_module("quick_validate")
|
||||
|
||||
|
||||
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
|
||||
skill_dir = init_skill.init_skill(
|
||||
"demo-skill",
|
||||
tmp_path,
|
||||
["scripts", "references", "assets"],
|
||||
include_examples=True,
|
||||
)
|
||||
|
||||
assert skill_dir == tmp_path / "demo-skill"
|
||||
assert (skill_dir / "SKILL.md").exists()
|
||||
assert (skill_dir / "scripts" / "example.py").exists()
|
||||
assert (skill_dir / "references" / "api_reference.md").exists()
|
||||
assert (skill_dir / "assets" / "example_asset.txt").exists()
|
||||
|
||||
|
||||
def test_validate_skill_accepts_existing_skill_creator() -> None:
|
||||
valid, message = quick_validate.validate_skill(
|
||||
Path("nanobot/skills/skill-creator").resolve()
|
||||
)
|
||||
|
||||
assert valid, message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "placeholder-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: placeholder-skill\n"
|
||||
'description: "[TODO: fill me in]"\n'
|
||||
"---\n"
|
||||
"# Placeholder\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "TODO placeholder" in message
|
||||
|
||||
|
||||
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "bad-root-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: bad-root-skill\n"
|
||||
"description: Valid description\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
|
||||
|
||||
valid, message = quick_validate.validate_skill(skill_dir)
|
||||
|
||||
assert not valid
|
||||
assert "Unexpected file or directory in skill root" in message
|
||||
|
||||
|
||||
def test_package_skill_creates_archive(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "package-me"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: package-me\n"
|
||||
"description: Package this skill.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path == (tmp_path / "dist" / "package-me.skill")
|
||||
assert archive_path.exists()
|
||||
with zipfile.ZipFile(archive_path, "r") as archive:
|
||||
names = set(archive.namelist())
|
||||
assert "package-me/SKILL.md" in names
|
||||
assert "package-me/scripts/helper.py" in names
|
||||
|
||||
|
||||
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
|
||||
skill_dir = tmp_path / "symlink-skill"
|
||||
skill_dir.mkdir()
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: symlink-skill\n"
|
||||
"description: Reject symlinks during packaging.\n"
|
||||
"---\n"
|
||||
"# Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
scripts_dir.mkdir()
|
||||
target = tmp_path / "outside.txt"
|
||||
target.write_text("secret\n", encoding="utf-8")
|
||||
link = scripts_dir / "outside.txt"
|
||||
|
||||
try:
|
||||
link.symlink_to(target)
|
||||
except (OSError, NotImplementedError):
|
||||
return
|
||||
|
||||
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||
|
||||
assert archive_path is None
|
||||
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()
|
||||
90
tests/test_slack_channel.py
Normal file
90
tests/test_slack_channel.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
from nanobot.config.schema import SlackConfig
|
||||
|
||||
|
||||
class _FakeAsyncWebClient:
|
||||
def __init__(self) -> None:
|
||||
self.chat_post_calls: list[dict[str, object | None]] = []
|
||||
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||
|
||||
async def chat_postMessage(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
text: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.chat_post_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
async def files_upload_v2(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
file: str,
|
||||
thread_ts: str | None = None,
|
||||
) -> None:
|
||||
self.file_upload_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"file": file,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_omits_thread_for_dm_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="D123",
|
||||
content="hello",
|
||||
media=["/tmp/demo.txt"],
|
||||
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||
assert fake_web.chat_post_calls[0]["thread_ts"] is None
|
||||
assert len(fake_web.file_upload_calls) == 1
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] is None
|
||||
@@ -165,3 +165,46 @@ class TestSubagentCancellation:
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
captured_second_call: list[dict] = []
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def scripted_chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
provider.chat_with_retry = scripted_chat_with_retry
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
|
||||
599
tests/test_telegram_channel.py
Normal file
599
tests/test_telegram_channel.py
Normal file
@@ -0,0 +1,599 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
|
||||
class _FakeHTTPXRequest:
|
||||
instances: list["_FakeHTTPXRequest"] = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
|
||||
class _FakeUpdater:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self._on_start_polling = on_start_polling
|
||||
|
||||
async def start_polling(self, **kwargs) -> None:
|
||||
self._on_start_polling()
|
||||
|
||||
|
||||
class _FakeBot:
|
||||
def __init__(self) -> None:
|
||||
self.sent_messages: list[dict] = []
|
||||
self.get_me_calls = 0
|
||||
|
||||
async def get_me(self):
|
||||
self.get_me_calls += 1
|
||||
return SimpleNamespace(id=999, username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
self.commands = commands
|
||||
|
||||
async def send_message(self, **kwargs) -> None:
|
||||
self.sent_messages.append(kwargs)
|
||||
|
||||
async def send_chat_action(self, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def get_file(self, file_id: str):
|
||||
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
|
||||
async def _fake_download(path) -> None:
|
||||
pass
|
||||
return SimpleNamespace(download_to_drive=_fake_download)
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self.bot = _FakeBot()
|
||||
self.updater = _FakeUpdater(on_start_polling)
|
||||
self.handlers = []
|
||||
self.error_handlers = []
|
||||
|
||||
def add_error_handler(self, handler) -> None:
|
||||
self.error_handlers.append(handler)
|
||||
|
||||
def add_handler(self, handler) -> None:
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeBuilder:
|
||||
def __init__(self, app: _FakeApp) -> None:
|
||||
self.app = app
|
||||
self.token_value = None
|
||||
self.request_value = None
|
||||
self.get_updates_request_value = None
|
||||
|
||||
def token(self, token: str):
|
||||
self.token_value = token
|
||||
return self
|
||||
|
||||
def request(self, request):
|
||||
self.request_value = request
|
||||
return self
|
||||
|
||||
def get_updates_request(self, request):
|
||||
self.get_updates_request_value = request
|
||||
return self
|
||||
|
||||
def proxy(self, _proxy):
|
||||
raise AssertionError("builder.proxy should not be called when request is set")
|
||||
|
||||
def get_updates_proxy(self, _proxy):
|
||||
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
|
||||
|
||||
def build(self):
|
||||
return self.app
|
||||
|
||||
|
||||
def _make_telegram_update(
|
||||
*,
|
||||
chat_type: str = "group",
|
||||
text: str | None = None,
|
||||
caption: str | None = None,
|
||||
entities=None,
|
||||
caption_entities=None,
|
||||
reply_to_message=None,
|
||||
):
|
||||
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type=chat_type, is_forum=False),
|
||||
chat_id=-100123,
|
||||
text=text,
|
||||
caption=caption,
|
||||
entities=entities or [],
|
||||
caption_entities=caption_entities or [],
|
||||
reply_to_message=reply_to_message,
|
||||
photo=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
media_group_id=None,
|
||||
message_thread_id=None,
|
||||
message_id=1,
|
||||
)
|
||||
return SimpleNamespace(message=message, effective_user=user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert len(_FakeHTTPXRequest.instances) == 1
|
||||
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
|
||||
assert builder.request_value is _FakeHTTPXRequest.instances[0]
|
||||
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
|
||||
|
||||
|
||||
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="supergroup"),
|
||||
chat_id=-100123,
|
||||
message_thread_id=42,
|
||||
)
|
||||
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||
|
||||
|
||||
def test_get_extension_falls_back_to_original_filename() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(), MessageBus())
|
||||
|
||||
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
|
||||
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||
|
||||
|
||||
def test_telegram_group_policy_defaults_to_mention() -> None:
|
||||
assert TelegramConfig().group_policy == "mention"
|
||||
|
||||
|
||||
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("12345|carol") is True
|
||||
assert channel.is_allowed("99999|alice") is True
|
||||
assert channel.is_allowed("67890|bob") is True
|
||||
|
||||
|
||||
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
|
||||
|
||||
assert channel.is_allowed("attacker|alice|extra") is False
|
||||
assert channel.is_allowed("not-a-number|alice") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"_progress": True, "message_thread_id": 42},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._message_threads[("123", 10)] = 42
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"message_id": 10},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
|
||||
|
||||
assert handled == []
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
|
||||
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
|
||||
|
||||
assert len(handled) == 2
|
||||
assert channel._app.bot.get_me_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_caption_mention() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||
await channel._on_message(
|
||||
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
|
||||
None,
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "@nanobot_test photo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
|
||||
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_policy_open_accepts_plain_group_message() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
handled = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello group"), None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert channel._app.bot.get_me_calls == 0
|
||||
|
||||
|
||||
def test_extract_reply_context_no_reply() -> None:
|
||||
"""When there is no reply_to_message, _extract_reply_context returns None."""
|
||||
message = SimpleNamespace(reply_to_message=None)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
def test_extract_reply_context_with_text() -> None:
|
||||
"""When reply has text, return prefixed string."""
|
||||
reply = SimpleNamespace(text="Hello world", caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
|
||||
|
||||
|
||||
def test_extract_reply_context_with_caption_only() -> None:
|
||||
"""When reply has only caption (no text), caption is used."""
|
||||
reply = SimpleNamespace(text=None, caption="Photo caption")
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
|
||||
|
||||
|
||||
def test_extract_reply_context_truncation() -> None:
|
||||
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
|
||||
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
|
||||
reply = SimpleNamespace(text=long_text, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
result = TelegramChannel._extract_reply_context(message)
|
||||
assert result is not None
|
||||
assert result.startswith("[Reply to: ")
|
||||
assert result.endswith("...]")
|
||||
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
|
||||
|
||||
|
||||
def test_extract_reply_context_no_text_returns_none() -> None:
|
||||
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
|
||||
reply = SimpleNamespace(text=None, caption=None)
|
||||
message = SimpleNamespace(reply_to_message=reply)
|
||||
assert TelegramChannel._extract_reply_context(message) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_includes_reply_context() -> None:
|
||||
"""When user replies to a message, content passed to bus starts with reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="translate this", reply_to_message=reply)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: Hello]")
|
||||
assert "translate this" in handled[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_message_media_returns_path_when_download_succeeds(
|
||||
monkeypatch, tmp_path
|
||||
) -> None:
|
||||
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
|
||||
msg = SimpleNamespace(
|
||||
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
|
||||
voice=None,
|
||||
audio=None,
|
||||
document=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
paths, parts = await channel._download_message_media(msg)
|
||||
assert len(paths) == 1
|
||||
assert len(parts) == 1
|
||||
assert "fid123" in paths[0]
|
||||
assert "[image:" in parts[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
|
||||
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what is the image?",
|
||||
reply_to_message=reply_with_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"].startswith("[Reply to: [image:")
|
||||
assert "what is the image?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "reply_photo_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
|
||||
"""When reply has media but download fails, no media attached and no reply tag."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.get_file = None
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption=None,
|
||||
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "what is this?" in handled[0]["content"]
|
||||
assert handled[0]["media"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
|
||||
"""When replying to a message with caption + photo, both text context and media are included."""
|
||||
media_dir = tmp_path / "media" / "telegram"
|
||||
media_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.get_media_dir",
|
||||
lambda channel=None: media_dir if channel else tmp_path / "media",
|
||||
)
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
app = _FakeApp(lambda: None)
|
||||
app.bot.get_file = AsyncMock(
|
||||
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
|
||||
)
|
||||
channel._app = app
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
channel._start_typing = lambda _chat_id: None
|
||||
|
||||
reply_with_caption_and_photo = SimpleNamespace(
|
||||
text=None,
|
||||
caption="A cute cat",
|
||||
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
|
||||
document=None,
|
||||
voice=None,
|
||||
audio=None,
|
||||
video=None,
|
||||
video_note=None,
|
||||
animation=None,
|
||||
)
|
||||
update = _make_telegram_update(
|
||||
text="what breed is this?",
|
||||
reply_to_message=reply_with_caption_and_photo,
|
||||
)
|
||||
await channel._on_message(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert "[Reply to: A cute cat]" in handled[0]["content"]
|
||||
assert "what breed is this?" in handled[0]["content"]
|
||||
assert len(handled[0]["media"]) == 1
|
||||
assert "cat_fid" in handled[0]["media"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_command_does_not_inject_reply_context() -> None:
|
||||
"""Slash commands forwarded via _forward_command must not include reply context."""
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
handled = []
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
channel._handle_message = capture_handle
|
||||
|
||||
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
|
||||
update = _make_telegram_update(text="/new", reply_to_message=reply)
|
||||
await channel._forward_command(update, None)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
@@ -106,3 +106,301 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
|
||||
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
assert "~/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
|
||||
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
|
||||
|
||||
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
class CastTestTool(Tool):
|
||||
"""Minimal tool for testing cast_params."""
|
||||
|
||||
def __init__(self, schema: dict[str, Any]) -> None:
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cast_test"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "test tool for casting"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_cast_params_string_to_int() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "42"})
|
||||
assert result["count"] == 42
|
||||
assert isinstance(result["count"], int)
|
||||
|
||||
|
||||
def test_cast_params_string_to_number() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "3.14"})
|
||||
assert result["rate"] == 3.14
|
||||
assert isinstance(result["rate"], float)
|
||||
|
||||
|
||||
def test_cast_params_string_to_bool() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"enabled": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||
|
||||
|
||||
def test_cast_params_array_items() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||
assert result["nums"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_cast_params_nested_object() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"port": {"type": "integer"},
|
||||
"debug": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||
assert result["config"]["port"] == 8080
|
||||
assert result["config"]["debug"] is True
|
||||
|
||||
|
||||
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||
"""Booleans should not be silently cast to integers."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": True})
|
||||
assert result["count"] is True
|
||||
errors = tool.validate_params(result)
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_preserves_empty_string() -> None:
|
||||
"""Empty strings should be preserved for string type."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": ""})
|
||||
assert result["name"] == ""
|
||||
|
||||
|
||||
def test_cast_params_bool_string_false() -> None:
|
||||
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||
|
||||
|
||||
def test_cast_params_bool_string_invalid() -> None:
|
||||
"""Invalid boolean strings should not be cast."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
# Invalid strings should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"flag": "random"})
|
||||
assert result["flag"] == "random"
|
||||
result = tool.cast_params({"flag": "maybe"})
|
||||
assert result["flag"] == "maybe"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_int() -> None:
|
||||
"""Invalid strings should not be cast to integer."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "abc"})
|
||||
assert result["count"] == "abc" # Original value preserved
|
||||
result = tool.cast_params({"count": "12.5.7"})
|
||||
assert result["count"] == "12.5.7"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_number() -> None:
|
||||
"""Invalid strings should not be cast to number."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "not_a_number"})
|
||||
assert result["rate"] == "not_a_number"
|
||||
|
||||
|
||||
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||
"""Booleans should not pass number validation."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"rate": False})
|
||||
assert any("rate should be number" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_none_values() -> None:
|
||||
"""Test None handling for different types."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
"items": {"type": "array"},
|
||||
"config": {"type": "object"},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params(
|
||||
{
|
||||
"name": None,
|
||||
"count": None,
|
||||
"items": None,
|
||||
"config": None,
|
||||
}
|
||||
)
|
||||
# None should be preserved for all types
|
||||
assert result["name"] is None
|
||||
assert result["count"] is None
|
||||
assert result["items"] is None
|
||||
assert result["config"] is None
|
||||
|
||||
|
||||
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||
"""Single values should NOT be automatically wrapped into arrays."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
}
|
||||
)
|
||||
# Non-array values should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"items": 5})
|
||||
assert result["items"] == 5 # Not wrapped to [5]
|
||||
result = tool.cast_params({"items": "text"})
|
||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||
|
||||
|
||||
# --- ExecTool enhancement tests ---
|
||||
|
||||
|
||||
async def test_exec_always_returns_exit_code() -> None:
|
||||
"""Exit code should appear in output even on success (exit 0)."""
|
||||
tool = ExecTool()
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "Exit code: 0" in result
|
||||
assert "hello" in result
|
||||
|
||||
|
||||
async def test_exec_head_tail_truncation() -> None:
|
||||
"""Long output should preserve both head and tail."""
|
||||
tool = ExecTool()
|
||||
# Generate output that exceeds _MAX_OUTPUT
|
||||
big = "A" * 6000 + "\n" + "B" * 6000
|
||||
result = await tool.execute(command=f"echo '{big}'")
|
||||
assert "chars truncated" in result
|
||||
# Head portion should start with As
|
||||
assert result.startswith("A")
|
||||
# Tail portion should end with the exit code which comes after Bs
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_parameter() -> None:
|
||||
"""LLM-supplied timeout should override the constructor default."""
|
||||
tool = ExecTool(timeout=60)
|
||||
# A very short timeout should cause the command to be killed
|
||||
result = await tool.execute(command="sleep 10", timeout=1)
|
||||
assert "timed out" in result
|
||||
assert "1 seconds" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_capped_at_max() -> None:
|
||||
"""Timeout values above _MAX_TIMEOUT should be clamped."""
|
||||
tool = ExecTool()
|
||||
# Should not raise — just clamp to 600
|
||||
result = await tool.execute(command="echo ok", timeout=9999)
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
Reference in New Issue
Block a user