Merge branch 'main' into pr-1845

This commit is contained in:
Re-bin
2026-03-11 17:25:22 +00:00
82 changed files with 7181 additions and 1498 deletions

4
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.worktrees/
.assets .assets
.env .env
*.pyc *.pyc
@@ -19,4 +20,5 @@ __pycache__/
poetry.lock poetry.lock
.pytest_cache/ .pytest_cache/
botpy.log botpy.log
tests/ nano.*.save

285
README.md
View File

@@ -12,17 +12,29 @@
</p> </p>
</div> </div>
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw) 🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw).
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines. ⚡️ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw.
📏 Real-time line count: **3,935 lines** (run `bash core_agent_lines.sh` to verify anytime) 📏 Real-time line count: run `bash core_agent_lines.sh` to verify anytime.
## 📢 News ## 📢 News
- **2026-03-08** 🚀 Released **v0.1.4.post4** — a reliability-packed release with safer defaults, better multi-instance support, sturdier MCP, and major channel and provider improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post4) for details.
- **2026-03-07** 🚀 Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish.
- **2026-03-06** 🪄 Lighter providers, smarter media handling, and sturdier memory and CLI compatibility.
- **2026-03-05** ⚡️ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes.
- **2026-03-04** 🛠️ Dependency cleanup, safer file reads, and another round of test and Cron fixes.
- **2026-03-03** 🧠 Cleaner user-message merging, safer multimodal saves, and stronger Cron guards.
- **2026-03-02** 🛡️ Safer default access control, sturdier Cron reloads, and cleaner Matrix media handling.
- **2026-03-01** 🌐 Web proxy support, smarter Cron reminders, and Feishu rich-text parsing improvements.
- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details. - **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes. - **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility. - **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
<details>
<summary>Earlier news</summary>
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync. - **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details. - **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes. - **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
@@ -30,10 +42,6 @@
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details. - **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood. - **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode. - **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
<details>
<summary>Earlier news</summary>
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching. - **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details. - **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills. - **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
@@ -70,6 +78,25 @@
<img src="nanobot_arch.png" alt="nanobot architecture" width="800"> <img src="nanobot_arch.png" alt="nanobot architecture" width="800">
</p> </p>
## Table of Contents
- [News](#-news)
- [Key Features](#key-features-of-nanobot)
- [Architecture](#-architecture)
- [Features](#-features)
- [Install](#-install)
- [Quick Start](#-quick-start)
- [Chat Apps](#-chat-apps)
- [Agent Social Network](#-agent-social-network)
- [Configuration](#-configuration)
- [Multiple Instances](#-multiple-instances)
- [CLI Reference](#-cli-reference)
- [Docker](#-docker)
- [Linux Service](#-linux-service)
- [Project Structure](#-project-structure)
- [Contribute & Roadmap](#-contribute--roadmap)
- [Star History](#-star-history)
## ✨ Features ## ✨ Features
<table align="center"> <table align="center">
@@ -115,6 +142,29 @@ uv tool install nanobot-ai
pip install nanobot-ai pip install nanobot-ai
``` ```
### Update to latest version
**PyPI / pip**
```bash
pip install -U nanobot-ai
nanobot --version
```
**uv**
```bash
uv tool upgrade nanobot-ai
nanobot --version
```
**Using WhatsApp?** Rebuild the local bridge after upgrading:
```bash
rm -rf ~/.nanobot/bridge
nanobot channels login
```
## 🚀 Quick Start ## 🚀 Quick Start
> [!TIP] > [!TIP]
@@ -177,6 +227,7 @@ Connect nanobot to your favorite chat platform.
| **Slack** | Bot token + App-Level token | | **Slack** | Bot token + App-Level token |
| **Email** | IMAP/SMTP credentials | | **Email** | IMAP/SMTP credentials |
| **QQ** | App ID + App Secret | | **QQ** | App ID + App Secret |
| **Wecom** | Bot ID + Bot Secret |
<details> <details>
<summary><b>Telegram</b> (Recommended)</summary> <summary><b>Telegram</b> (Recommended)</summary>
@@ -293,12 +344,18 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
"discord": { "discord": {
"enabled": true, "enabled": true,
"token": "YOUR_BOT_TOKEN", "token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"] "allowFrom": ["YOUR_USER_ID"],
"groupPolicy": "mention"
} }
} }
} }
``` ```
> `groupPolicy` controls how the bot responds in group channels:
> - `"mention"` (default) — Only respond when @mentioned
> - `"open"` — Respond to all messages
> DMs always respond when the sender is in `allowFrom`.
**5. Invite the bot** **5. Invite the bot**
- OAuth2 → URL Generator - OAuth2 → URL Generator
- Scopes: `bot` - Scopes: `bot`
@@ -361,7 +418,7 @@ pip install nanobot-ai[matrix]
| Option | Description | | Option | Description |
|--------|-------------| |--------|-------------|
| `allowFrom` | User IDs allowed to interact. Empty = all senders. | | `allowFrom` | User IDs allowed to interact. Empty denies all; use `["*"]` to allow everyone. |
| `groupPolicy` | `open` (default), `mention`, or `allowlist`. | | `groupPolicy` | `open` (default), `mention`, or `allowlist`. |
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). | | `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
| `allowRoomMentions` | Accept `@room` mentions in mention mode. | | `allowRoomMentions` | Accept `@room` mentions in mention mode. |
@@ -414,6 +471,10 @@ nanobot channels login
nanobot gateway nanobot gateway
``` ```
> WhatsApp bridge updates are not applied automatically for existing installations.
> After upgrading nanobot, rebuild the local bridge with:
> `rm -rf ~/.nanobot/bridge && nanobot channels login`
</details> </details>
<details> <details>
@@ -636,6 +697,46 @@ nanobot gateway
</details> </details>
<details>
<summary><b>Wecom (企业微信)</b></summary>
> Here we use [wecom-aibot-sdk-python](https://github.com/chengyongru/wecom_aibot_sdk) (community Python version of the official [@wecom/aibot-node-sdk](https://www.npmjs.com/package/@wecom/aibot-node-sdk)).
>
> Uses **WebSocket** long connection — no public IP required.
**1. Install the optional dependency**
```bash
pip install nanobot-ai[wecom]
```
**2. Create a WeCom AI Bot**
Go to the WeCom admin console → Intelligent Robot → Create Robot → select **API mode** with **long connection**. Copy the Bot ID and Secret.
**3. Configure**
```json
{
"channels": {
"wecom": {
"enabled": true,
"botId": "your_bot_id",
"secret": "your_bot_secret",
"allowFrom": ["your_id"]
}
}
}
```
**4. Run**
```bash
nanobot gateway
```
</details>
## 🌐 Agent Social Network ## 🌐 Agent Social Network
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!** 🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
@@ -658,12 +759,14 @@ Config file: `~/.nanobot/config.json`
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config. > - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config. > - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
| Provider | Purpose | Get API Key | | Provider | Purpose | Get API Key |
|----------|---------|-------------| |----------|---------|-------------|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — | | `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
@@ -675,6 +778,7 @@ Config file: `~/.nanobot/config.json`
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
| `ollama` | LLM (local, Ollama) | — |
| `vllm` | LLM (local, any OpenAI-compatible server) | — | | `vllm` | LLM (local, any OpenAI-compatible server) | — |
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` | | `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` | | `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
@@ -703,6 +807,12 @@ nanobot provider login openai-codex
**3. Chat:** **3. Chat:**
```bash ```bash
nanobot agent -m "Hello!" nanobot agent -m "Hello!"
# Target a specific workspace/config locally
nanobot agent -c ~/.nanobot-telegram/config.json -m "Hello!"
# One-off workspace override on top of that config
nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -m "Hello!"
``` ```
> Docker users: use `docker run -it` for interactive OAuth login. > Docker users: use `docker run -it` for interactive OAuth login.
@@ -734,6 +844,37 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
</details> </details>
<details>
<summary><b>Ollama (local)</b></summary>
Run a local model with Ollama, then add to config:
**1. Start Ollama** (example):
```bash
ollama run llama3.2
```
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
```json
{
"providers": {
"ollama": {
"apiBase": "http://localhost:11434"
}
},
"agents": {
"defaults": {
"provider": "ollama",
"model": "llama3.2"
}
}
}
```
> `provider: "auto"` also works when `providers.ollama.apiBase` is configured, but setting `"provider": "ollama"` is the clearest option.
</details>
<details> <details>
<summary><b>vLLM (local / OpenAI-compatible)</b></summary> <summary><b>vLLM (local / OpenAI-compatible)</b></summary>
@@ -875,21 +1016,124 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
> [!TIP] > [!TIP]
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent. > For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
> **Change in source / post-`v0.1.4.post3`:** In `v0.1.4.post3` and earlier, an empty `allowFrom` means "allow all senders". In newer versions (including building from source), **empty `allowFrom` denies all access by default**. To allow all senders, set `"allowFrom": ["*"]`. > In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all senders. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default. To allow all senders, set `"allowFrom": ["*"]`.
| Option | Default | Description | | Option | Default | Description |
|--------|---------|-------------| |--------|---------|-------------|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. | | `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). | | `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. | | `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. |
## CLI Reference ## 🧩 Multiple Instances
Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint, and optionally use `--workspace` to override the workspace for a specific run.
### Quick Start
```bash
# Instance A - Telegram bot
nanobot gateway --config ~/.nanobot-telegram/config.json
# Instance B - Discord bot
nanobot gateway --config ~/.nanobot-discord/config.json
# Instance C - Feishu bot with custom port
nanobot gateway --config ~/.nanobot-feishu/config.json --port 18792
```
### Path Resolution
When using `--config`, nanobot derives its runtime data directory from the config file location. The workspace still comes from `agents.defaults.workspace` unless you override it with `--workspace`.
To open a CLI session against one of these instances locally:
```bash
nanobot agent -c ~/.nanobot-telegram/config.json -m "Hello from Telegram instance"
nanobot agent -c ~/.nanobot-discord/config.json -m "Hello from Discord instance"
# Optional one-off workspace override
nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test
```
> `nanobot agent` starts a local CLI agent using the selected workspace/config. It does not attach to or proxy through an already running `nanobot gateway` process.
| Component | Resolved From | Example |
|-----------|---------------|---------|
| **Config** | `--config` path | `~/.nanobot-A/config.json` |
| **Workspace** | `--workspace` or config | `~/.nanobot-A/workspace/` |
| **Cron Jobs** | config directory | `~/.nanobot-A/cron/` |
| **Media / runtime state** | config directory | `~/.nanobot-A/media/` |
### How It Works
- `--config` selects which config file to load
- By default, the workspace comes from `agents.defaults.workspace` in that config
- If you pass `--workspace`, it overrides the workspace from the config file
### Minimal Setup
1. Copy your base config into a new instance directory.
2. Set a different `agents.defaults.workspace` for that instance.
3. Start the instance with `--config`.
Example config:
```json
{
"agents": {
"defaults": {
"workspace": "~/.nanobot-telegram/workspace",
"model": "anthropic/claude-sonnet-4-6"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "YOUR_TELEGRAM_BOT_TOKEN"
}
},
"gateway": {
"port": 18790
}
}
```
Start separate instances:
```bash
nanobot gateway --config ~/.nanobot-telegram/config.json
nanobot gateway --config ~/.nanobot-discord/config.json
```
Override workspace for one-off runs when needed:
```bash
nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobot-telegram-test
```
### Common Use Cases
- Run separate bots for Telegram, Discord, Feishu, and other platforms
- Keep testing and production instances isolated
- Use different models or providers for different teams
- Serve multiple tenants with separate configs and runtime data
### Notes
- Each instance must use a different port if they run at the same time
- Use a different workspace per instance if you want isolated memory, sessions, and skills
- `--workspace` overrides the workspace defined in the config file
- Cron jobs and runtime media/state are derived from the config directory
## 💻 CLI Reference
| Command | Description | | Command | Description |
|---------|-------------| |---------|-------------|
| `nanobot onboard` | Initialize config & workspace | | `nanobot onboard` | Initialize config & workspace |
| `nanobot agent -m "..."` | Chat with the agent | | `nanobot agent -m "..."` | Chat with the agent |
| `nanobot agent -w <workspace>` | Chat against a specific workspace |
| `nanobot agent -w <workspace> -c <config>` | Chat against a specific workspace/config |
| `nanobot agent` | Interactive chat mode | | `nanobot agent` | Interactive chat mode |
| `nanobot agent --no-markdown` | Show plain-text replies | | `nanobot agent --no-markdown` | Show plain-text replies |
| `nanobot agent --logs` | Show runtime logs during chat | | `nanobot agent --logs` | Show runtime logs during chat |
@@ -901,23 +1145,6 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`. Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`.
<details>
<summary><b>Scheduled Tasks (Cron)</b></summary>
```bash
# Add a job
nanobot cron add --name "daily" --message "Good morning!" --cron "0 9 * * *"
nanobot cron add --name "hourly" --message "Check status" --every 3600
# List jobs
nanobot cron list
# Remove a job
nanobot cron remove <job_id>
```
</details>
<details> <details>
<summary><b>Heartbeat (Periodic Tasks)</b></summary> <summary><b>Heartbeat (Periodic Tasks)</b></summary>

View File

@@ -55,7 +55,7 @@ chmod 600 ~/.nanobot/config.json
``` ```
**Security Notes:** **Security Notes:**
- In `v0.1.4.post3` and earlier, an empty `allowFrom` allows all users. In newer versions (including source builds), **empty `allowFrom` denies all access** — set `["*"]` to explicitly allow everyone. - In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all users. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default — set `["*"]` to explicitly allow everyone.
- Get your Telegram user ID from `@userinfobot` - Get your Telegram user ID from `@userinfobot`
- Use full phone numbers with country code for WhatsApp - Use full phone numbers with country code for WhatsApp
- Review access logs regularly for unauthorized access attempts - Review access logs regularly for unauthorized access attempts
@@ -212,7 +212,7 @@ If you suspect a security breach:
- Input length limits on HTTP requests - Input length limits on HTTP requests
✅ **Authentication** ✅ **Authentication**
- Allow-list based access control — in `v0.1.4.post3` and earlier empty means allow all; in newer versions empty means deny all (`["*"]` to explicitly allow all) - Allow-list based access control — in `v0.1.4.post3` and earlier empty `allowFrom` allowed all; since `v0.1.4.post4` it denies all (`["*"]` explicitly allows all)
- Failed authentication attempt logging - Failed authentication attempt logging
✅ **Resource Protection** ✅ **Resource Protection**

View File

@@ -9,11 +9,16 @@ import makeWASocket, {
useMultiFileAuthState, useMultiFileAuthState,
fetchLatestBaileysVersion, fetchLatestBaileysVersion,
makeCacheableSignalKeyStore, makeCacheableSignalKeyStore,
downloadMediaMessage,
extractMessageContent as baileysExtractMessageContent,
} from '@whiskeysockets/baileys'; } from '@whiskeysockets/baileys';
import { Boom } from '@hapi/boom'; import { Boom } from '@hapi/boom';
import qrcode from 'qrcode-terminal'; import qrcode from 'qrcode-terminal';
import pino from 'pino'; import pino from 'pino';
import { writeFile, mkdir } from 'fs/promises';
import { join } from 'path';
import { randomBytes } from 'crypto';
const VERSION = '0.1.0'; const VERSION = '0.1.0';
@@ -24,6 +29,7 @@ export interface InboundMessage {
content: string; content: string;
timestamp: number; timestamp: number;
isGroup: boolean; isGroup: boolean;
media?: string[];
} }
export interface WhatsAppClientOptions { export interface WhatsAppClientOptions {
@@ -110,14 +116,33 @@ export class WhatsAppClient {
if (type !== 'notify') return; if (type !== 'notify') return;
for (const msg of messages) { for (const msg of messages) {
// Skip own messages
if (msg.key.fromMe) continue; if (msg.key.fromMe) continue;
// Skip status updates
if (msg.key.remoteJid === 'status@broadcast') continue; if (msg.key.remoteJid === 'status@broadcast') continue;
const content = this.extractMessageContent(msg); const unwrapped = baileysExtractMessageContent(msg.message);
if (!content) continue; if (!unwrapped) continue;
const content = this.getTextContent(unwrapped);
let fallbackContent: string | null = null;
const mediaPaths: string[] = [];
if (unwrapped.imageMessage) {
fallbackContent = '[Image]';
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.documentMessage) {
fallbackContent = '[Document]';
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
unwrapped.documentMessage.fileName ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.videoMessage) {
fallbackContent = '[Video]';
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
}
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
if (!finalContent && mediaPaths.length === 0) continue;
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false; const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
@@ -125,18 +150,45 @@ export class WhatsAppClient {
id: msg.key.id || '', id: msg.key.id || '',
sender: msg.key.remoteJid || '', sender: msg.key.remoteJid || '',
pn: msg.key.remoteJidAlt || '', pn: msg.key.remoteJidAlt || '',
content, content: finalContent,
timestamp: msg.messageTimestamp as number, timestamp: msg.messageTimestamp as number,
isGroup, isGroup,
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
}); });
} }
}); });
} }
private extractMessageContent(msg: any): string | null { private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
const message = msg.message; try {
if (!message) return null; const mediaDir = join(this.options.authDir, '..', 'media');
await mkdir(mediaDir, { recursive: true });
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
let outFilename: string;
if (fileName) {
// Documents have a filename — use it with a unique prefix to avoid collisions
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
outFilename = prefix + fileName;
} else {
const mime = mimetype || 'application/octet-stream';
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
}
const filepath = join(mediaDir, outFilename);
await writeFile(filepath, buffer);
return filepath;
} catch (err) {
console.error('Failed to download media:', err);
return null;
}
}
private getTextContent(message: any): string | null {
// Text message // Text message
if (message.conversation) { if (message.conversation) {
return message.conversation; return message.conversation;
@@ -147,19 +199,19 @@ export class WhatsAppClient {
return message.extendedTextMessage.text; return message.extendedTextMessage.text;
} }
// Image with caption // Image with optional caption
if (message.imageMessage?.caption) { if (message.imageMessage) {
return `[Image] ${message.imageMessage.caption}`; return message.imageMessage.caption || '';
} }
// Video with caption // Video with optional caption
if (message.videoMessage?.caption) { if (message.videoMessage) {
return `[Video] ${message.videoMessage.caption}`; return message.videoMessage.caption || '';
} }
// Document with caption // Document with optional caption
if (message.documentMessage?.caption) { if (message.documentMessage) {
return `[Document] ${message.documentMessage.caption}`; return message.documentMessage.caption || '';
} }
// Voice/Audio message // Voice/Audio message

View File

@@ -15,7 +15,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
printf " %-16s %5s lines\n" "(root)" "$root" printf " %-16s %5s lines\n" "(root)" "$root"
echo "" echo ""
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" | xargs cat | wc -l) total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
echo " Core total: $total lines" echo " Core total: $total lines"
echo "" echo ""
echo " (excludes: channels/, cli/, providers/)" echo " (excludes: channels/, cli/, providers/, skills/)"

View File

@@ -2,5 +2,5 @@
nanobot - A lightweight AI agent framework nanobot - A lightweight AI agent framework
""" """
__version__ = "0.1.4.post3" __version__ = "0.1.4.post4"
__logo__ = "🐈" __logo__ = "🐈"

View File

@@ -10,12 +10,13 @@ from typing import Any
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
class ContextBuilder: class ContextBuilder:
"""Builds the context (system prompt + messages) for the agent.""" """Builds the context (system prompt + messages) for the agent."""
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"] BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
@@ -58,6 +59,19 @@ Skills with available="false" need dependencies installed first - you can try in
system = platform.system() system = platform.system()
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
platform_policy = ""
if system == "Windows":
platform_policy = """## Platform Policy (Windows)
- You are running on Windows. Do not assume GNU tools like `grep`, `sed`, or `awk` exist.
- Prefer Windows-native commands or file tools when they are more reliable.
- If terminal output is garbled, retry with UTF-8 output enabled.
"""
else:
platform_policy = """## Platform Policy (POSIX)
- You are running on a POSIX system. Prefer UTF-8 and standard shell tools.
- Use file tools when they are simpler or more reliable than shell commands.
"""
return f"""# nanobot 🐈 return f"""# nanobot 🐈
You are nanobot, a helpful AI assistant. You are nanobot, a helpful AI assistant.
@@ -71,6 +85,8 @@ Your workspace is at: {workspace_path}
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. - History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
{platform_policy}
## nanobot Guidelines ## nanobot Guidelines
- State intent before tool calls, but NEVER predict or claim results before receiving them. - State intent before tool calls, but NEVER predict or claim results before receiving them.
- Before modifying a file, read it first. Do not assume files or directories exist. - Before modifying a file, read it first. Do not assume files or directories exist.
@@ -112,11 +128,20 @@ Reply directly with text for conversations. Only use the 'message' tool to send
chat_id: str | None = None, chat_id: str | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Build the complete message list for an LLM call.""" """Build the complete message list for an LLM call."""
runtime_ctx = self._build_runtime_context(channel, chat_id)
user_content = self._build_user_content(current_message, media)
# Merge runtime context and user content into a single user message
# to avoid consecutive same-role messages that some providers reject.
if isinstance(user_content, str):
merged = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
return [ return [
{"role": "system", "content": self.build_system_prompt(skill_names)}, {"role": "system", "content": self.build_system_prompt(skill_names)},
*history, *history,
{"role": "user", "content": self._build_runtime_context(channel, chat_id)}, {"role": "user", "content": merged},
{"role": "user", "content": self._build_user_content(current_message, media)},
] ]
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
@@ -127,10 +152,14 @@ Reply directly with text for conversations. Only use the 'message' tool to send
images = [] images = []
for path in media: for path in media:
p = Path(path) p = Path(path)
mime, _ = mimetypes.guess_type(path) if not p.is_file():
if not p.is_file() or not mime or not mime.startswith("image/"):
continue continue
b64 = base64.b64encode(p.read_bytes()).decode() raw = p.read_bytes()
# Detect real MIME type from magic bytes; fallback to filename guess
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
if not mime or not mime.startswith("image/"):
continue
b64 = base64.b64encode(raw).decode()
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
if not images: if not images:
@@ -153,12 +182,10 @@ Reply directly with text for conversations. Only use the 'message' tool to send
thinking_blocks: list[dict] | None = None, thinking_blocks: list[dict] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Add an assistant message to the message list.""" """Add an assistant message to the message list."""
msg: dict[str, Any] = {"role": "assistant", "content": content} messages.append(build_assistant_message(
if tool_calls: content,
msg["tool_calls"] = tool_calls tool_calls=tool_calls,
if reasoning_content is not None: reasoning_content=reasoning_content,
msg["reasoning_content"] = reasoning_content thinking_blocks=thinking_blocks,
if thinking_blocks: ))
msg["thinking_blocks"] = thinking_blocks
messages.append(msg)
return messages return messages

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import re import re
import weakref
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -13,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger from loguru import logger
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
@@ -53,10 +52,7 @@ class AgentLoop:
workspace: Path, workspace: Path,
model: str | None = None, model: str | None = None,
max_iterations: int = 40, max_iterations: int = 40,
temperature: float = 0.1, context_window_tokens: int = 65_536,
max_tokens: int = 4096,
memory_window: int = 100,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None, exec_config: ExecToolConfig | None = None,
@@ -73,10 +69,7 @@ class AgentLoop:
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.temperature = temperature self.context_window_tokens = context_window_tokens
self.max_tokens = max_tokens
self.memory_window = memory_window
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@@ -91,9 +84,6 @@ class AgentLoop:
workspace=workspace, workspace=workspace,
bus=bus, bus=bus,
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=reasoning_effort,
brave_api_key=brave_api_key, brave_api_key=brave_api_key,
web_proxy=web_proxy, web_proxy=web_proxy,
exec_config=self.exec_config, exec_config=self.exec_config,
@@ -105,11 +95,17 @@ class AgentLoop:
self._mcp_stack: AsyncExitStack | None = None self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False self._mcp_connected = False
self._mcp_connecting = False self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._processing_lock = asyncio.Lock() self._processing_lock = asyncio.Lock()
self.memory_consolidator = MemoryConsolidator(
workspace=workspace,
provider=provider,
model=self.model,
sessions=self.sessions,
context_window_tokens=context_window_tokens,
build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions,
)
self._register_default_tools() self._register_default_tools()
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
@@ -182,7 +178,7 @@ class AgentLoop:
initial_messages: list[dict], initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None, on_progress: Callable[..., Awaitable[None]] | None = None,
) -> tuple[str | None, list[str], list[dict]]: ) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop. Returns (final_content, tools_used, messages).""" """Run the agent iteration loop."""
messages = initial_messages messages = initial_messages
iteration = 0 iteration = 0
final_content = None final_content = None
@@ -191,31 +187,23 @@ class AgentLoop:
while iteration < self.max_iterations: while iteration < self.max_iterations:
iteration += 1 iteration += 1
response = await self.provider.chat( tool_defs = self.tools.get_definitions()
response = await self.provider.chat_with_retry(
messages=messages, messages=messages,
tools=self.tools.get_definitions(), tools=tool_defs,
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
if on_progress: if on_progress:
clean = self._strip_think(response.content) thought = self._strip_think(response.content)
if clean: if thought:
await on_progress(clean) await on_progress(thought)
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True) await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [ tool_call_dicts = [
{ tc.to_openai_tool_call()
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
}
}
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
@@ -341,8 +329,9 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}" key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=self.memory_window) history = session.get_history(max_messages=0)
messages = self.context.build_messages( messages = self.context.build_messages(
history=history, history=history,
current_message=msg.content, channel=channel, chat_id=chat_id, current_message=msg.content, channel=channel, chat_id=chat_id,
@@ -350,6 +339,7 @@ class AgentLoop:
final_content, _, all_msgs = await self._run_agent_loop(messages) final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
return OutboundMessage(channel=channel, chat_id=chat_id, return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.") content=final_content or "Background task completed.")
@@ -362,27 +352,20 @@ class AgentLoop:
# Slash commands # Slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
self._consolidating.add(session.key)
try: try:
async with lock: if not await self.memory_consolidator.archive_unconsolidated(session):
snapshot = session.messages[session.last_consolidated:] return OutboundMessage(
if snapshot: channel=msg.channel,
temp = Session(key=session.key) chat_id=msg.chat_id,
temp.messages = list(snapshot) content="Memory archival failed, session not cleared. Please try again.",
if not await self._consolidate_memory(temp, archive_all=True): )
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
except Exception: except Exception:
logger.exception("/new archival failed for {}", session.key) logger.exception("/new archival failed for {}", session.key)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, channel=msg.channel,
chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.", content="Memory archival failed, session not cleared. Please try again.",
) )
finally:
self._consolidating.discard(session.key)
session.clear() session.clear()
self.sessions.save(session) self.sessions.save(session)
@@ -393,30 +376,14 @@ class AgentLoop:
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands") content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
unconsolidated = len(session.messages) - session.last_consolidated await self.memory_consolidator.maybe_consolidate_by_tokens(session)
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
self._consolidating.add(session.key)
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
async def _consolidate_and_unlock():
try:
async with lock:
await self._consolidate_memory(session)
finally:
self._consolidating.discard(session.key)
_task = asyncio.current_task()
if _task is not None:
self._consolidation_tasks.discard(_task)
_task = asyncio.create_task(_consolidate_and_unlock())
self._consolidation_tasks.add(_task)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"): if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool): if isinstance(message_tool, MessageTool):
message_tool.start_turn() message_tool.start_turn()
history = session.get_history(max_messages=self.memory_window) history = session.get_history(max_messages=0)
initial_messages = self.context.build_messages( initial_messages = self.context.build_messages(
history=history, history=history,
current_message=msg.content, current_message=msg.content,
@@ -441,6 +408,7 @@ class AgentLoop:
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None return None
@@ -464,25 +432,29 @@ class AgentLoop:
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
elif role == "user": elif role == "user":
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
continue # Strip the runtime-context prefix, keep only the user text.
parts = content.split("\n\n", 1)
if len(parts) > 1 and parts[1].strip():
entry["content"] = parts[1]
else:
continue
if isinstance(content, list): if isinstance(content, list):
entry["content"] = [ filtered = []
{"type": "text", "text": "[image]"} if ( for c in content:
c.get("type") == "image_url" if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
and c.get("image_url", {}).get("url", "").startswith("data:image/") continue # Strip runtime context from multimodal messages
) else c for c in content if (c.get("type") == "image_url"
] and c.get("image_url", {}).get("url", "").startswith("data:image/")):
filtered.append({"type": "text", "text": "[image]"})
else:
filtered.append(c)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat()) entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry) session.messages.append(entry)
session.updated_at = datetime.now() session.updated_at = datetime.now()
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
return await MemoryStore(self.workspace).consolidate(
session, self.provider, self.model,
archive_all=archive_all, memory_window=self.memory_window,
)
async def process_direct( async def process_direct(
self, self,
content: str, content: str,

View File

@@ -2,17 +2,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import weakref
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Callable
from loguru import logger from loguru import logger
from nanobot.utils.helpers import ensure_dir from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain
if TYPE_CHECKING: if TYPE_CHECKING:
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session from nanobot.session.manager import Session, SessionManager
_SAVE_MEMORY_TOOL = [ _SAVE_MEMORY_TOOL = [
@@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [
"properties": { "properties": {
"history_entry": { "history_entry": {
"type": "string", "type": "string",
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. " "description": "A paragraph summarizing key events/decisions/topics. "
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.", "Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
}, },
"memory_update": { "memory_update": {
@@ -42,6 +44,19 @@ _SAVE_MEMORY_TOOL = [
] ]
def _ensure_text(value: Any) -> str:
"""Normalize tool-call payload values to text for file storage."""
return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
"""Normalize provider tool-call arguments to the expected dict shape."""
if isinstance(args, str):
args = json.loads(args)
if isinstance(args, list):
return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
@@ -66,40 +81,27 @@ class MemoryStore:
long_term = self.read_long_term() long_term = self.read_long_term()
return f"## Long-term Memory\n{long_term}" if long_term else "" return f"## Long-term Memory\n{long_term}" if long_term else ""
@staticmethod
def _format_messages(messages: list[dict]) -> str:
lines = []
for message in messages:
if not message.get("content"):
continue
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
lines.append(
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
)
return "\n".join(lines)
async def consolidate( async def consolidate(
self, self,
session: Session, messages: list[dict],
provider: LLMProvider, provider: LLMProvider,
model: str, model: str,
*,
archive_all: bool = False,
memory_window: int = 50,
) -> bool: ) -> bool:
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call. """Consolidate the provided message chunk into MEMORY.md + HISTORY.md."""
if not messages:
Returns True on success (including no-op), False on failure. return True
"""
if archive_all:
old_messages = session.messages
keep_count = 0
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
else:
keep_count = memory_window // 2
if len(session.messages) <= keep_count:
return True
if len(session.messages) - session.last_consolidated <= 0:
return True
old_messages = session.messages[session.last_consolidated:-keep_count]
if not old_messages:
return True
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
lines = []
for m in old_messages:
if not m.get("content"):
continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
current_memory = self.read_long_term() current_memory = self.read_long_term()
prompt = f"""Process this conversation and call the save_memory tool with your consolidation. prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
@@ -108,10 +110,10 @@ class MemoryStore:
{current_memory or "(empty)"} {current_memory or "(empty)"}
## Conversation to Process ## Conversation to Process
{chr(10).join(lines)}""" {self._format_messages(messages)}"""
try: try:
response = await provider.chat( response = await provider.chat_with_retry(
messages=[ messages=[
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
@@ -124,27 +126,158 @@ class MemoryStore:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping") logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
return False return False
args = response.tool_calls[0].arguments args = _normalize_save_memory_args(response.tool_calls[0].arguments)
# Some providers return arguments as a JSON string instead of dict if args is None:
if isinstance(args, str): logger.warning("Memory consolidation: unexpected save_memory arguments")
args = json.loads(args)
if not isinstance(args, dict):
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False return False
if entry := args.get("history_entry"): if entry := args.get("history_entry"):
if not isinstance(entry, str): self.append_history(_ensure_text(entry))
entry = json.dumps(entry, ensure_ascii=False)
self.append_history(entry)
if update := args.get("memory_update"): if update := args.get("memory_update"):
if not isinstance(update, str): update = _ensure_text(update)
update = json.dumps(update, ensure_ascii=False)
if update != current_memory: if update != current_memory:
self.write_long_term(update) self.write_long_term(update)
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count logger.info("Memory consolidation done for {} messages", len(messages))
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
return True return True
except Exception: except Exception:
logger.exception("Memory consolidation failed") logger.exception("Memory consolidation failed")
return False return False
class MemoryConsolidator:
"""Owns consolidation policy, locking, and session offset updates."""
_MAX_CONSOLIDATION_ROUNDS = 5
def __init__(
self,
workspace: Path,
provider: LLMProvider,
model: str,
sessions: SessionManager,
context_window_tokens: int,
build_messages: Callable[..., list[dict[str, Any]]],
get_tool_definitions: Callable[[], list[dict[str, Any]]],
):
self.store = MemoryStore(workspace)
self.provider = provider
self.model = model
self.sessions = sessions
self.context_window_tokens = context_window_tokens
self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
def get_lock(self, session_key: str) -> asyncio.Lock:
"""Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock())
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
"""Archive a selected message chunk into persistent memory."""
return await self.store.consolidate(messages, self.provider, self.model)
def pick_consolidation_boundary(
self,
session: Session,
tokens_to_remove: int,
) -> tuple[int, int] | None:
"""Pick a user-turn boundary that removes enough old prompt tokens."""
start = session.last_consolidated
if start >= len(session.messages) or tokens_to_remove <= 0:
return None
removed_tokens = 0
last_boundary: tuple[int, int] | None = None
for idx in range(start, len(session.messages)):
message = session.messages[idx]
if idx > start and message.get("role") == "user":
last_boundary = (idx, removed_tokens)
if removed_tokens >= tokens_to_remove:
return last_boundary
removed_tokens += estimate_message_tokens(message)
return last_boundary
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
"""Estimate current prompt size for the normal session history view."""
history = session.get_history(max_messages=0)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
probe_messages = self._build_messages(
history=history,
current_message="[token-probe]",
channel=channel,
chat_id=chat_id,
)
return estimate_prompt_tokens_chain(
self.provider,
self.model,
probe_messages,
self._get_tool_definitions(),
)
async def archive_unconsolidated(self, session: Session) -> bool:
"""Archive the full unconsolidated tail for /new-style session rollover."""
lock = self.get_lock(session.key)
async with lock:
snapshot = session.messages[session.last_consolidated:]
if not snapshot:
return True
return await self.consolidate_messages(snapshot)
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Loop: archive old messages until prompt fits within half the context window."""
if not session.messages or self.context_window_tokens <= 0:
return
lock = self.get_lock(session.key)
async with lock:
target = self.context_window_tokens // 2
estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0:
return
if estimated < self.context_window_tokens:
logger.debug(
"Token consolidation idle {}: {}/{} via {}",
session.key,
estimated,
self.context_window_tokens,
source,
)
return
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
if estimated <= target:
return
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
if boundary is None:
logger.debug(
"Token consolidation: no safe boundary for {} (round {})",
session.key,
round_num,
)
return
end_idx = boundary[0]
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
return
logger.info(
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
round_num,
session.key,
estimated,
self.context_window_tokens,
source,
len(chunk),
)
if not await self.consolidate_messages(chunk):
return
session.last_consolidated = end_idx
self.sessions.save(session)
estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0:
return

View File

@@ -16,6 +16,7 @@ from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ExecToolConfig from nanobot.config.schema import ExecToolConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.utils.helpers import build_assistant_message
class SubagentManager: class SubagentManager:
@@ -27,9 +28,6 @@ class SubagentManager:
workspace: Path, workspace: Path,
bus: MessageBus, bus: MessageBus,
model: str | None = None, model: str | None = None,
temperature: float = 0.7,
max_tokens: int = 4096,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
exec_config: "ExecToolConfig | None" = None, exec_config: "ExecToolConfig | None" = None,
@@ -40,9 +38,6 @@ class SubagentManager:
self.workspace = workspace self.workspace = workspace
self.bus = bus self.bus = bus
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.temperature = temperature
self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@@ -123,33 +118,23 @@ class SubagentManager:
while iteration < max_iterations: while iteration < max_iterations:
iteration += 1 iteration += 1
response = await self.provider.chat( response = await self.provider.chat_with_retry(
messages=messages, messages=messages,
tools=tools.get_definitions(), tools=tools.get_definitions(),
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
# Add assistant message with tool calls
tool_call_dicts = [ tool_call_dicts = [
{ tc.to_openai_tool_call()
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
},
}
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages.append({ messages.append(build_assistant_message(
"role": "assistant", response.content or "",
"content": response.content or "", tool_calls=tool_call_dicts,
"tool_calls": tool_call_dicts, reasoning_content=response.reasoning_content,
}) thinking_blocks=response.thinking_blocks,
))
# Execute tools # Execute tools
for tool_call in response.tool_calls: for tool_call in response.tool_calls:

View File

@@ -52,8 +52,79 @@ class Tool(ABC):
""" """
pass pass
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Apply safe schema-driven casts before validation."""
schema = self.parameters or {}
if schema.get("type", "object") != "object":
return params
return self._cast_object(params, schema)
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
"""Cast an object (dict) according to schema."""
if not isinstance(obj, dict):
return obj
props = schema.get("properties", {})
result = {}
for key, value in obj.items():
if key in props:
result[key] = self._cast_value(value, props[key])
else:
result[key] = value
return result
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
"""Cast a single value according to schema."""
target_type = schema.get("type")
if target_type == "boolean" and isinstance(val, bool):
return val
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
return val
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
expected = self._TYPE_MAP[target_type]
if isinstance(val, expected):
return val
if target_type == "integer" and isinstance(val, str):
try:
return int(val)
except ValueError:
return val
if target_type == "number" and isinstance(val, str):
try:
return float(val)
except ValueError:
return val
if target_type == "string":
return val if val is None else str(val)
if target_type == "boolean" and isinstance(val, str):
val_lower = val.lower()
if val_lower in ("true", "1", "yes"):
return True
if val_lower in ("false", "0", "no"):
return False
return val
if target_type == "array" and isinstance(val, list):
item_schema = schema.get("items")
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
if target_type == "object" and isinstance(val, dict):
return self._cast_object(val, schema)
return val
def validate_params(self, params: dict[str, Any]) -> list[str]: def validate_params(self, params: dict[str, Any]) -> list[str]:
"""Validate tool parameters against JSON schema. Returns error list (empty if valid).""" """Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
if not isinstance(params, dict):
return [f"parameters must be an object, got {type(params).__name__}"]
schema = self.parameters or {} schema = self.parameters or {}
if schema.get("type", "object") != "object": if schema.get("type", "object") != "object":
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}") raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
@@ -61,7 +132,13 @@ class Tool(ABC):
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]: def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
t, label = schema.get("type"), path or "parameter" t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]): if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
return [f"{label} should be integer"]
if t == "number" and (
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
):
return [f"{label} should be number"]
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"] return [f"{label} should be {t}"]
errors = [] errors = []

View File

@@ -1,5 +1,6 @@
"""Cron tool for scheduling reminders and tasks.""" """Cron tool for scheduling reminders and tasks."""
from contextvars import ContextVar
from typing import Any from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -14,12 +15,21 @@ class CronTool(Tool):
self._cron = cron_service self._cron = cron_service
self._channel = "" self._channel = ""
self._chat_id = "" self._chat_id = ""
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the current session context for delivery.""" """Set the current session context for delivery."""
self._channel = channel self._channel = channel
self._chat_id = chat_id self._chat_id = chat_id
def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback."""
return self._in_cron_context.set(active)
def reset_cron_context(self, token) -> None:
"""Restore previous cron context."""
self._in_cron_context.reset(token)
@property @property
def name(self) -> str: def name(self) -> str:
return "cron" return "cron"
@@ -72,6 +82,8 @@ class CronTool(Tool):
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if action == "add": if action == "add":
if self._in_cron_context.get():
return "Error: cannot schedule new jobs from within a cron job execution"
return self._add_job(message, every_seconds, cron_expr, tz, at) return self._add_job(message, every_seconds, cron_expr, tz, at)
elif action == "list": elif action == "list":
return self._list_jobs() return self._list_jobs()
@@ -110,7 +122,10 @@ class CronTool(Tool):
elif at: elif at:
from datetime import datetime from datetime import datetime
dt = datetime.fromisoformat(at) try:
dt = datetime.fromisoformat(at)
except ValueError:
return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS"
at_ms = int(dt.timestamp() * 1000) at_ms = int(dt.timestamp() * 1000)
schedule = CronSchedule(kind="at", at_ms=at_ms) schedule = CronSchedule(kind="at", at_ms=at_ms)
delete_after = True delete_after = True

View File

@@ -1,4 +1,4 @@
"""File system tools: read, write, edit.""" """File system tools: read, write, edit, list."""
import difflib import difflib
from pathlib import Path from pathlib import Path
@@ -23,51 +23,108 @@ def _resolve_path(
return resolved return resolved
class ReadFileTool(Tool): class _FsTool(Tool):
"""Tool to read file contents.""" """Shared base for filesystem tools — common init and path resolution."""
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
def _resolve(self, path: str) -> Path:
return _resolve_path(path, self._workspace, self._allowed_dir)
# ---------------------------------------------------------------------------
# read_file
# ---------------------------------------------------------------------------
class ReadFileTool(_FsTool):
"""Read file contents with optional line-based pagination."""
_MAX_CHARS = 128_000
_DEFAULT_LIMIT = 2000
@property @property
def name(self) -> str: def name(self) -> str:
return "read_file" return "read_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Read the contents of a file at the given path." return (
"Read the contents of a file. Returns numbered lines. "
"Use offset and limit to paginate through large files."
)
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": {"path": {"type": "string", "description": "The file path to read"}}, "properties": {
"path": {"type": "string", "description": "The file path to read"},
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, default 1)",
"minimum": 1,
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (default 2000)",
"minimum": 1,
},
},
"required": ["path"], "required": ["path"],
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) fp = self._resolve(path)
if not file_path.exists(): if not fp.exists():
return f"Error: File not found: {path}" return f"Error: File not found: {path}"
if not file_path.is_file(): if not fp.is_file():
return f"Error: Not a file: {path}" return f"Error: Not a file: {path}"
content = file_path.read_text(encoding="utf-8") all_lines = fp.read_text(encoding="utf-8").splitlines()
return content total = len(all_lines)
if offset < 1:
offset = 1
if total == 0:
return f"(Empty file: {path})"
if offset > total:
return f"Error: offset {offset} is beyond end of file ({total} lines)"
start = offset - 1
end = min(start + (limit or self._DEFAULT_LIMIT), total)
numbered = [f"{start + i + 1}| {line}" for i, line in enumerate(all_lines[start:end])]
result = "\n".join(numbered)
if len(result) > self._MAX_CHARS:
trimmed, chars = [], 0
for line in numbered:
chars += len(line) + 1
if chars > self._MAX_CHARS:
break
trimmed.append(line)
end = start + len(trimmed)
result = "\n".join(trimmed)
if end < total:
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
else:
result += f"\n\n(End of file — {total} lines total)"
return result
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error reading file: {str(e)}" return f"Error reading file: {e}"
class WriteFileTool(Tool): # ---------------------------------------------------------------------------
"""Tool to write content to a file.""" # write_file
# ---------------------------------------------------------------------------
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): class WriteFileTool(_FsTool):
self._workspace = workspace """Write content to a file."""
self._allowed_dir = allowed_dir
@property @property
def name(self) -> str: def name(self) -> str:
@@ -90,22 +147,48 @@ class WriteFileTool(Tool):
async def execute(self, path: str, content: str, **kwargs: Any) -> str: async def execute(self, path: str, content: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) fp = self._resolve(path)
file_path.parent.mkdir(parents=True, exist_ok=True) fp.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content, encoding="utf-8") fp.write_text(content, encoding="utf-8")
return f"Successfully wrote {len(content)} bytes to {file_path}" return f"Successfully wrote {len(content)} bytes to {fp}"
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error writing file: {str(e)}" return f"Error writing file: {e}"
class EditFileTool(Tool): # ---------------------------------------------------------------------------
"""Tool to edit a file by replacing text.""" # edit_file
# ---------------------------------------------------------------------------
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
self._workspace = workspace """Locate old_text in content: exact first, then line-trimmed sliding window.
self._allowed_dir = allowed_dir
Both inputs should use LF line endings (caller normalises CRLF).
Returns (matched_fragment, count) or (None, 0).
"""
if old_text in content:
return old_text, content.count(old_text)
old_lines = old_text.splitlines()
if not old_lines:
return None, 0
stripped_old = [l.strip() for l in old_lines]
content_lines = content.splitlines()
candidates = []
for i in range(len(content_lines) - len(stripped_old) + 1):
window = content_lines[i : i + len(stripped_old)]
if [l.strip() for l in window] == stripped_old:
candidates.append("\n".join(window))
if candidates:
return candidates[0], len(candidates)
return None, 0
class EditFileTool(_FsTool):
"""Edit a file by replacing text with fallback matching."""
@property @property
def name(self) -> str: def name(self) -> str:
@@ -113,7 +196,11 @@ class EditFileTool(Tool):
@property @property
def description(self) -> str: def description(self) -> str:
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file." return (
"Edit a file by replacing old_text with new_text. "
"Supports minor whitespace/line-ending differences. "
"Set replace_all=true to replace every occurrence."
)
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
@@ -121,40 +208,52 @@ class EditFileTool(Tool):
"type": "object", "type": "object",
"properties": { "properties": {
"path": {"type": "string", "description": "The file path to edit"}, "path": {"type": "string", "description": "The file path to edit"},
"old_text": {"type": "string", "description": "The exact text to find and replace"}, "old_text": {"type": "string", "description": "The text to find and replace"},
"new_text": {"type": "string", "description": "The text to replace with"}, "new_text": {"type": "string", "description": "The text to replace with"},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default false)",
},
}, },
"required": ["path", "old_text", "new_text"], "required": ["path", "old_text", "new_text"],
} }
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str: async def execute(
self, path: str, old_text: str, new_text: str,
replace_all: bool = False, **kwargs: Any,
) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) fp = self._resolve(path)
if not file_path.exists(): if not fp.exists():
return f"Error: File not found: {path}" return f"Error: File not found: {path}"
content = file_path.read_text(encoding="utf-8") raw = fp.read_bytes()
uses_crlf = b"\r\n" in raw
content = raw.decode("utf-8").replace("\r\n", "\n")
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
if old_text not in content: if match is None:
return self._not_found_message(old_text, content, path) return self._not_found_msg(old_text, content, path)
if count > 1 and not replace_all:
return (
f"Warning: old_text appears {count} times. "
"Provide more context to make it unique, or set replace_all=true."
)
# Count occurrences norm_new = new_text.replace("\r\n", "\n")
count = content.count(old_text) new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
if count > 1: if uses_crlf:
return f"Warning: old_text appears {count} times. Please provide more context to make it unique." new_content = new_content.replace("\n", "\r\n")
new_content = content.replace(old_text, new_text, 1) fp.write_bytes(new_content.encode("utf-8"))
file_path.write_text(new_content, encoding="utf-8") return f"Successfully edited {fp}"
return f"Successfully edited {file_path}"
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error editing file: {str(e)}" return f"Error editing file: {e}"
@staticmethod @staticmethod
def _not_found_message(old_text: str, content: str, path: str) -> str: def _not_found_msg(old_text: str, content: str, path: str) -> str:
"""Build a helpful error when old_text is not found."""
lines = content.splitlines(keepends=True) lines = content.splitlines(keepends=True)
old_lines = old_text.splitlines(keepends=True) old_lines = old_text.splitlines(keepends=True)
window = len(old_lines) window = len(old_lines)
@@ -166,27 +265,29 @@ class EditFileTool(Tool):
best_ratio, best_start = ratio, i best_ratio, best_start = ratio, i
if best_ratio > 0.5: if best_ratio > 0.5:
diff = "\n".join( diff = "\n".join(difflib.unified_diff(
difflib.unified_diff( old_lines, lines[best_start : best_start + window],
old_lines, fromfile="old_text (provided)",
lines[best_start : best_start + window], tofile=f"{path} (actual, line {best_start + 1})",
fromfile="old_text (provided)", lineterm="",
tofile=f"{path} (actual, line {best_start + 1})", ))
lineterm="",
)
)
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
return ( return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
)
class ListDirTool(Tool): # ---------------------------------------------------------------------------
"""Tool to list directory contents.""" # list_dir
# ---------------------------------------------------------------------------
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): class ListDirTool(_FsTool):
self._workspace = workspace """List directory contents with optional recursion."""
self._allowed_dir = allowed_dir
_DEFAULT_MAX = 200
_IGNORE_DIRS = {
".git", "node_modules", "__pycache__", ".venv", "venv",
"dist", "build", ".tox", ".mypy_cache", ".pytest_cache",
".ruff_cache", ".coverage", "htmlcov",
}
@property @property
def name(self) -> str: def name(self) -> str:
@@ -194,34 +295,71 @@ class ListDirTool(Tool):
@property @property
def description(self) -> str: def description(self) -> str:
return "List the contents of a directory." return (
"List the contents of a directory. "
"Set recursive=true to explore nested structure. "
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
)
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": {"path": {"type": "string", "description": "The directory path to list"}}, "properties": {
"path": {"type": "string", "description": "The directory path to list"},
"recursive": {
"type": "boolean",
"description": "Recursively list all files (default false)",
},
"max_entries": {
"type": "integer",
"description": "Maximum entries to return (default 200)",
"minimum": 1,
},
},
"required": ["path"], "required": ["path"],
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(
self, path: str, recursive: bool = False,
max_entries: int | None = None, **kwargs: Any,
) -> str:
try: try:
dir_path = _resolve_path(path, self._workspace, self._allowed_dir) dp = self._resolve(path)
if not dir_path.exists(): if not dp.exists():
return f"Error: Directory not found: {path}" return f"Error: Directory not found: {path}"
if not dir_path.is_dir(): if not dp.is_dir():
return f"Error: Not a directory: {path}" return f"Error: Not a directory: {path}"
items = [] cap = max_entries or self._DEFAULT_MAX
for item in sorted(dir_path.iterdir()): items: list[str] = []
prefix = "📁 " if item.is_dir() else "📄 " total = 0
items.append(f"{prefix}{item.name}")
if not items: if recursive:
for item in sorted(dp.rglob("*")):
if any(p in self._IGNORE_DIRS for p in item.parts):
continue
total += 1
if len(items) < cap:
rel = item.relative_to(dp)
items.append(f"{rel}/" if item.is_dir() else str(rel))
else:
for item in sorted(dp.iterdir()):
if item.name in self._IGNORE_DIRS:
continue
total += 1
if len(items) < cap:
pfx = "📁 " if item.is_dir() else "📄 "
items.append(f"{pfx}{item.name}")
if not items and total == 0:
return f"Directory {path} is empty" return f"Directory {path} is empty"
return "\n".join(items) result = "\n".join(items)
if total > cap:
result += f"\n\n(truncated, showing first {cap} of {total} entries)"
return result
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error listing directory: {str(e)}" return f"Error listing directory: {e}"

View File

@@ -36,6 +36,7 @@ class MCPToolWrapper(Tool):
async def execute(self, **kwargs: Any) -> str: async def execute(self, **kwargs: Any) -> str:
from mcp import types from mcp import types
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self._session.call_tool(self._original_name, arguments=kwargs), self._session.call_tool(self._original_name, arguments=kwargs),
@@ -44,6 +45,23 @@ class MCPToolWrapper(Tool):
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout) logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
return f"(MCP tool call timed out after {self._tool_timeout}s)" return f"(MCP tool call timed out after {self._tool_timeout}s)"
except asyncio.CancelledError:
# MCP SDK's anyio cancel scopes can leak CancelledError on timeout/failure.
# Re-raise only if our task was externally cancelled (e.g. /stop).
task = asyncio.current_task()
if task is not None and task.cancelling() > 0:
raise
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
return "(MCP tool call was cancelled)"
except Exception as exc:
logger.exception(
"MCP tool '{}' failed: {}: {}",
self._name,
type(exc).__name__,
exc,
)
return f"(MCP tool call failed: {type(exc).__name__})"
parts = [] parts = []
for block in result.content: for block in result.content:
if isinstance(block, types.TextContent): if isinstance(block, types.TextContent):
@@ -58,17 +76,48 @@ async def connect_mcp_servers(
) -> None: ) -> None:
"""Connect to configured MCP servers and register their tools.""" """Connect to configured MCP servers and register their tools."""
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
for name, cfg in mcp_servers.items(): for name, cfg in mcp_servers.items():
try: try:
if cfg.command: transport_type = cfg.type
if not transport_type:
if cfg.command:
transport_type = "stdio"
elif cfg.url:
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue
if transport_type == "stdio":
params = StdioServerParameters( params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None command=cfg.command, args=cfg.args, env=cfg.env or None
) )
read, write = await stack.enter_async_context(stdio_client(params)) read, write = await stack.enter_async_context(stdio_client(params))
elif cfg.url: elif transport_type == "sse":
from mcp.client.streamable_http import streamable_http_client def httpx_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
merged_headers = {**(cfg.headers or {}), **(headers or {})}
return httpx.AsyncClient(
headers=merged_headers or None,
follow_redirects=True,
timeout=timeout,
auth=auth,
)
read, write = await stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
)
elif transport_type == "streamableHttp":
# Always provide an explicit httpx client so MCP HTTP transport does not # Always provide an explicit httpx client so MCP HTTP transport does not
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout. # inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
http_client = await stack.enter_async_context( http_client = await stack.enter_async_context(
@@ -82,7 +131,7 @@ async def connect_mcp_servers(
streamable_http_client(cfg.url, http_client=http_client) streamable_http_client(cfg.url, http_client=http_client)
) )
else: else:
logger.warning("MCP server '{}': no command or url configured, skipping", name) logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
continue continue
session = await stack.enter_async_context(ClientSession(read, write)) session = await stack.enter_async_context(ClientSession(read, write))

View File

@@ -96,7 +96,7 @@ class MessageTool(Tool):
media=media or [], media=media or [],
metadata={ metadata={
"message_id": message_id, "message_id": message_id,
} },
) )
try: try:

View File

@@ -44,6 +44,10 @@ class ToolRegistry:
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
try: try:
# Attempt to cast parameters to match schema types
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params) errors = tool.validate_params(params)
if errors: if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT

View File

@@ -42,6 +42,9 @@ class ExecTool(Tool):
def name(self) -> str: def name(self) -> str:
return "exec" return "exec"
_MAX_TIMEOUT = 600
_MAX_OUTPUT = 10_000
@property @property
def description(self) -> str: def description(self) -> str:
return "Execute a shell command and return its output. Use with caution." return "Execute a shell command and return its output. Use with caution."
@@ -53,22 +56,36 @@ class ExecTool(Tool):
"properties": { "properties": {
"command": { "command": {
"type": "string", "type": "string",
"description": "The shell command to execute" "description": "The shell command to execute",
}, },
"working_dir": { "working_dir": {
"type": "string", "type": "string",
"description": "Optional working directory for the command" "description": "Optional working directory for the command",
} },
"timeout": {
"type": "integer",
"description": (
"Timeout in seconds. Increase for long-running commands "
"like compilation or installation (default 60, max 600)."
),
"minimum": 1,
"maximum": 600,
},
}, },
"required": ["command"] "required": ["command"],
} }
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str: async def execute(
self, command: str, working_dir: str | None = None,
timeout: int | None = None, **kwargs: Any,
) -> str:
cwd = working_dir or self.working_dir or os.getcwd() cwd = working_dir or self.working_dir or os.getcwd()
guard_error = self._guard_command(command, cwd) guard_error = self._guard_command(command, cwd)
if guard_error: if guard_error:
return guard_error return guard_error
effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT)
env = os.environ.copy() env = os.environ.copy()
if self.path_append: if self.path_append:
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
@@ -85,17 +102,15 @@ class ExecTool(Tool):
try: try:
stdout, stderr = await asyncio.wait_for( stdout, stderr = await asyncio.wait_for(
process.communicate(), process.communicate(),
timeout=self.timeout timeout=effective_timeout,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
process.kill() process.kill()
# Wait for the process to fully terminate so pipes are
# drained and file descriptors are released.
try: try:
await asyncio.wait_for(process.wait(), timeout=5.0) await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
return f"Error: Command timed out after {self.timeout} seconds" return f"Error: Command timed out after {effective_timeout} seconds"
output_parts = [] output_parts = []
@@ -107,15 +122,19 @@ class ExecTool(Tool):
if stderr_text.strip(): if stderr_text.strip():
output_parts.append(f"STDERR:\n{stderr_text}") output_parts.append(f"STDERR:\n{stderr_text}")
if process.returncode != 0: output_parts.append(f"\nExit code: {process.returncode}")
output_parts.append(f"\nExit code: {process.returncode}")
result = "\n".join(output_parts) if output_parts else "(no output)" result = "\n".join(output_parts) if output_parts else "(no output)"
# Truncate very long output # Head + tail truncation to preserve both start and end of output
max_len = 10000 max_len = self._MAX_OUTPUT
if len(result) > max_len: if len(result) > max_len:
result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)" half = max_len // 2
result = (
result[:half]
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
+ result[-half:]
)
return result return result
@@ -143,10 +162,11 @@ class ExecTool(Tool):
for raw in self._extract_absolute_paths(cmd): for raw in self._extract_absolute_paths(cmd):
try: try:
p = Path(raw.strip()).expanduser().resolve() expanded = os.path.expandvars(raw.strip())
p = Path(expanded).expanduser().resolve()
except Exception: except Exception:
continue continue
if not p.is_relative_to(cwd_path): if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
return "Error: Command blocked by safety guard (path outside working dir)" return "Error: Command blocked by safety guard (path outside working dir)"
return None return None
@@ -154,6 +174,6 @@ class ExecTool(Tool):
@staticmethod @staticmethod
def _extract_absolute_paths(command: str) -> list[str]: def _extract_absolute_paths(command: str) -> list[str]:
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\... win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
tilde_paths = re.findall(r"(?:^|[\s|>])(~[^\s\"'>]*)", command) # Tilde: ~/... home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
return win_paths + posix_paths + tilde_paths return win_paths + posix_paths + home_paths

View File

@@ -1,6 +1,9 @@
"""Base channel interface for chat platforms.""" """Base channel interface for chat platforms."""
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@@ -18,6 +21,8 @@ class BaseChannel(ABC):
""" """
name: str = "base" name: str = "base"
display_name: str = "Base"
transcription_api_key: str = ""
def __init__(self, config: Any, bus: MessageBus): def __init__(self, config: Any, bus: MessageBus):
""" """
@@ -31,6 +36,19 @@ class BaseChannel(ABC):
self.bus = bus self.bus = bus
self._running = False self._running = False
async def transcribe_audio(self, file_path: str | Path) -> str:
"""Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
if not self.transcription_api_key:
return ""
try:
from nanobot.providers.transcription import GroqTranscriptionProvider
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
return await provider.transcribe(file_path)
except Exception as e:
logger.warning("{}: audio transcription failed: {}", self.name, e)
return ""
@abstractmethod @abstractmethod
async def start(self) -> None: async def start(self) -> None:
""" """
@@ -66,10 +84,7 @@ class BaseChannel(ABC):
return False return False
if "*" in allow_list: if "*" in allow_list:
return True return True
sender_str = str(sender_id) return str(sender_id) in allow_list
return sender_str in allow_list or any(
p in allow_list for p in sender_str.split("|") if p
)
async def _handle_message( async def _handle_message(
self, self,

View File

@@ -57,6 +57,8 @@ class NanobotDingTalkHandler(CallbackHandler):
content = "" content = ""
if chatbot_msg.text: if chatbot_msg.text:
content = chatbot_msg.text.content.strip() content = chatbot_msg.text.content.strip()
elif chatbot_msg.extensions.get("content", {}).get("recognition"):
content = chatbot_msg.extensions["content"]["recognition"].strip()
if not content: if not content:
content = message.data.get("text", {}).get("content", "").strip() content = message.data.get("text", {}).get("content", "").strip()
@@ -70,12 +72,24 @@ class NanobotDingTalkHandler(CallbackHandler):
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
sender_name = chatbot_msg.sender_nick or "Unknown" sender_name = chatbot_msg.sender_nick or "Unknown"
conversation_type = message.data.get("conversationType")
conversation_id = (
message.data.get("conversationId")
or message.data.get("openConversationId")
)
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content) logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
# Forward to Nanobot via _on_message (non-blocking). # Forward to Nanobot via _on_message (non-blocking).
# Store reference to prevent GC before task completes. # Store reference to prevent GC before task completes.
task = asyncio.create_task( task = asyncio.create_task(
self.channel._on_message(content, sender_id, sender_name) self.channel._on_message(
content,
sender_id,
sender_name,
conversation_type,
conversation_id,
)
) )
self.channel._background_tasks.add(task) self.channel._background_tasks.add(task)
task.add_done_callback(self.channel._background_tasks.discard) task.add_done_callback(self.channel._background_tasks.discard)
@@ -95,11 +109,12 @@ class DingTalkChannel(BaseChannel):
Uses WebSocket to receive events via `dingtalk-stream` SDK. Uses WebSocket to receive events via `dingtalk-stream` SDK.
Uses direct HTTP API to send messages (SDK is mainly for receiving). Uses direct HTTP API to send messages (SDK is mainly for receiving).
Note: Currently only supports private (1:1) chat. Group messages are Supports both private (1:1) and group chats.
received but replies are sent back as private messages to the sender. Group chat_id is stored with a "group:" prefix to route replies back.
""" """
name = "dingtalk" name = "dingtalk"
display_name = "DingTalk"
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"} _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"} _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
@@ -301,14 +316,25 @@ class DingTalkChannel(BaseChannel):
logger.warning("DingTalk HTTP client not initialized, cannot send") logger.warning("DingTalk HTTP client not initialized, cannot send")
return False return False
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
headers = {"x-acs-dingtalk-access-token": token} headers = {"x-acs-dingtalk-access-token": token}
payload = { if chat_id.startswith("group:"):
"robotCode": self.config.client_id, # Group chat
"userIds": [chat_id], url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
"msgKey": msg_key, payload = {
"msgParam": json.dumps(msg_param, ensure_ascii=False), "robotCode": self.config.client_id,
} "openConversationId": chat_id[6:], # Remove "group:" prefix,
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
else:
# Private chat
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
payload = {
"robotCode": self.config.client_id,
"userIds": [chat_id],
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
try: try:
resp = await self._http.post(url, json=payload, headers=headers) resp = await self._http.post(url, json=payload, headers=headers)
@@ -417,7 +443,14 @@ class DingTalkChannel(BaseChannel):
f"[Attachment send failed: {filename}]", f"[Attachment send failed: {filename}]",
) )
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None: async def _on_message(
self,
content: str,
sender_id: str,
sender_name: str,
conversation_type: str | None = None,
conversation_id: str | None = None,
) -> None:
"""Handle incoming message (called by NanobotDingTalkHandler). """Handle incoming message (called by NanobotDingTalkHandler).
Delegates to BaseChannel._handle_message() which enforces allow_from Delegates to BaseChannel._handle_message() which enforces allow_from
@@ -425,13 +458,16 @@ class DingTalkChannel(BaseChannel):
""" """
try: try:
logger.info("DingTalk inbound: {} from {}", content, sender_name) logger.info("DingTalk inbound: {} from {}", content, sender_name)
is_group = conversation_type == "2" and conversation_id
chat_id = f"group:{conversation_id}" if is_group else sender_id
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=sender_id, # For private chat, chat_id == sender_id chat_id=chat_id,
content=str(content), content=str(content),
metadata={ metadata={
"sender_name": sender_name, "sender_name": sender_name,
"platform": "dingtalk", "platform": "dingtalk",
"conversation_type": conversation_type,
}, },
) )
except Exception as e: except Exception as e:

View File

@@ -12,39 +12,20 @@ from loguru import logger
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import DiscordConfig from nanobot.config.schema import DiscordConfig
from nanobot.utils.helpers import split_message
DISCORD_API_BASE = "https://discord.com/api/v10" DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit MAX_MESSAGE_LEN = 2000 # Discord message character limit
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
"""Split content into chunks within max_len, preferring line breaks."""
if not content:
return []
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
pos = cut.rfind('\n')
if pos <= 0:
pos = cut.rfind(' ')
if pos <= 0:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
class DiscordChannel(BaseChannel): class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket.""" """Discord channel using Gateway websocket."""
name = "discord" name = "discord"
display_name = "Discord"
def __init__(self, config: DiscordConfig, bus: MessageBus): def __init__(self, config: DiscordConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
@@ -54,6 +35,7 @@ class DiscordChannel(BaseChannel):
self._heartbeat_task: asyncio.Task | None = None self._heartbeat_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tasks: dict[str, asyncio.Task] = {}
self._http: httpx.AsyncClient | None = None self._http: httpx.AsyncClient | None = None
self._bot_user_id: str | None = None
async def start(self) -> None: async def start(self) -> None:
"""Start the Discord gateway connection.""" """Start the Discord gateway connection."""
@@ -95,7 +77,7 @@ class DiscordChannel(BaseChannel):
self._http = None self._http = None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Discord REST API.""" """Send a message through Discord REST API, including file attachments."""
if not self._http: if not self._http:
logger.warning("Discord HTTP client not initialized") logger.warning("Discord HTTP client not initialized")
return return
@@ -104,15 +86,31 @@ class DiscordChannel(BaseChannel):
headers = {"Authorization": f"Bot {self.config.token}"} headers = {"Authorization": f"Bot {self.config.token}"}
try: try:
chunks = _split_message(msg.content or "") sent_media = False
failed_media: list[str] = []
# Send file attachments first
for media_path in msg.media or []:
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
sent_media = True
else:
failed_media.append(Path(media_path).name)
# Send text content
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks and failed_media and not sent_media:
chunks = split_message(
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
MAX_MESSAGE_LEN,
)
if not chunks: if not chunks:
return return
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
payload: dict[str, Any] = {"content": chunk} payload: dict[str, Any] = {"content": chunk}
# Only set reply reference on the first chunk # Let the first successful attachment carry the reply if present.
if i == 0 and msg.reply_to: if i == 0 and msg.reply_to and not sent_media:
payload["message_reference"] = {"message_id": msg.reply_to} payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False} payload["allowed_mentions"] = {"replied_user": False}
@@ -143,6 +141,54 @@ class DiscordChannel(BaseChannel):
await asyncio.sleep(1) await asyncio.sleep(1)
return False return False
async def _send_file(
self,
url: str,
headers: dict[str, str],
file_path: str,
reply_to: str | None = None,
) -> bool:
"""Send a file attachment via Discord REST API using multipart/form-data."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
payload_json: dict[str, Any] = {}
if reply_to:
payload_json["message_reference"] = {"message_id": reply_to}
payload_json["allowed_mentions"] = {"replied_user": False}
for attempt in range(3):
try:
with open(path, "rb") as f:
files = {"files[0]": (path.name, f, "application/octet-stream")}
data: dict[str, Any] = {}
if payload_json:
data["payload_json"] = json.dumps(payload_json)
response = await self._http.post(
url, headers=headers, files=files, data=data
)
if response.status_code == 429:
resp_data = response.json()
retry_after = float(resp_data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord file {}: {}", path.name, e)
else:
await asyncio.sleep(1)
return False
async def _gateway_loop(self) -> None: async def _gateway_loop(self) -> None:
"""Main gateway loop: identify, heartbeat, dispatch events.""" """Main gateway loop: identify, heartbeat, dispatch events."""
if not self._ws: if not self._ws:
@@ -170,6 +216,10 @@ class DiscordChannel(BaseChannel):
await self._identify() await self._identify()
elif op == 0 and event_type == "READY": elif op == 0 and event_type == "READY":
logger.info("Discord gateway READY") logger.info("Discord gateway READY")
# Capture bot user ID for mention detection
user_data = payload.get("user") or {}
self._bot_user_id = user_data.get("id")
logger.info("Discord bot connected as user {}", self._bot_user_id)
elif op == 0 and event_type == "MESSAGE_CREATE": elif op == 0 and event_type == "MESSAGE_CREATE":
await self._handle_message_create(payload) await self._handle_message_create(payload)
elif op == 7: elif op == 7:
@@ -226,6 +276,7 @@ class DiscordChannel(BaseChannel):
sender_id = str(author.get("id", "")) sender_id = str(author.get("id", ""))
channel_id = str(payload.get("channel_id", "")) channel_id = str(payload.get("channel_id", ""))
content = payload.get("content") or "" content = payload.get("content") or ""
guild_id = payload.get("guild_id")
if not sender_id or not channel_id: if not sender_id or not channel_id:
return return
@@ -233,9 +284,14 @@ class DiscordChannel(BaseChannel):
if not self.is_allowed(sender_id): if not self.is_allowed(sender_id):
return return
# Check group channel policy (DMs always respond if is_allowed passes)
if guild_id is not None:
if not self._should_respond_in_group(payload, content):
return
content_parts = [content] if content else [] content_parts = [content] if content else []
media_paths: list[str] = [] media_paths: list[str] = []
media_dir = Path.home() / ".nanobot" / "media" media_dir = get_media_dir("discord")
for attachment in payload.get("attachments") or []: for attachment in payload.get("attachments") or []:
url = attachment.get("url") url = attachment.get("url")
@@ -269,11 +325,32 @@ class DiscordChannel(BaseChannel):
media=media_paths, media=media_paths,
metadata={ metadata={
"message_id": str(payload.get("id", "")), "message_id": str(payload.get("id", "")),
"guild_id": payload.get("guild_id"), "guild_id": guild_id,
"reply_to": reply_to, "reply_to": reply_to,
}, },
) )
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
"""Check if bot should respond in a group channel based on policy."""
if self.config.group_policy == "open":
return True
if self.config.group_policy == "mention":
# Check if bot was mentioned in the message
if self._bot_user_id:
# Check mentions array
mentions = payload.get("mentions") or []
for mention in mentions:
if str(mention.get("id")) == self._bot_user_id:
return True
# Also check content for mention format <@USER_ID>
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
return False
return True
async def _start_typing(self, channel_id: str) -> None: async def _start_typing(self, channel_id: str) -> None:
"""Start periodic typing indicator for a channel.""" """Start periodic typing indicator for a channel."""
await self._stop_typing(channel_id) await self._stop_typing(channel_id)

View File

@@ -35,6 +35,7 @@ class EmailChannel(BaseChannel):
""" """
name = "email" name = "email"
display_name = "Email"
_IMAP_MONTHS = ( _IMAP_MONTHS = (
"Jan", "Jan",
"Feb", "Feb",

View File

@@ -14,28 +14,12 @@ from loguru import logger
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import FeishuConfig from nanobot.config.schema import FeishuConfig
try: import importlib.util
import lark_oapi as lark
from lark_oapi.api.im.v1 import ( FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
CreateFileRequest,
CreateFileRequestBody,
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
CreateMessageRequest,
CreateMessageRequestBody,
Emoji,
GetMessageResourceRequest,
P2ImMessageReceiveV1,
)
FEISHU_AVAILABLE = True
except ImportError:
FEISHU_AVAILABLE = False
lark = None
Emoji = None
# Message type display mapping # Message type display mapping
MSG_TYPE_MAP = { MSG_TYPE_MAP = {
@@ -260,6 +244,7 @@ class FeishuChannel(BaseChannel):
""" """
name = "feishu" name = "feishu"
display_name = "Feishu"
def __init__(self, config: FeishuConfig, bus: MessageBus): def __init__(self, config: FeishuConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
@@ -270,6 +255,12 @@ class FeishuChannel(BaseChannel):
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
self._loop: asyncio.AbstractEventLoop | None = None self._loop: asyncio.AbstractEventLoop | None = None
@staticmethod
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
"""Register an event handler only when the SDK supports it."""
method = getattr(builder, method_name, None)
return method(handler) if callable(method) else builder
async def start(self) -> None: async def start(self) -> None:
"""Start the Feishu bot with WebSocket long connection.""" """Start the Feishu bot with WebSocket long connection."""
if not FEISHU_AVAILABLE: if not FEISHU_AVAILABLE:
@@ -280,6 +271,7 @@ class FeishuChannel(BaseChannel):
logger.error("Feishu app_id and app_secret not configured") logger.error("Feishu app_id and app_secret not configured")
return return
import lark_oapi as lark
self._running = True self._running = True
self._loop = asyncio.get_running_loop() self._loop = asyncio.get_running_loop()
@@ -289,14 +281,24 @@ class FeishuChannel(BaseChannel):
.app_secret(self.config.app_secret) \ .app_secret(self.config.app_secret) \
.log_level(lark.LogLevel.INFO) \ .log_level(lark.LogLevel.INFO) \
.build() .build()
builder = lark.EventDispatcherHandler.builder(
# Create event handler (only register message receive, ignore other events)
event_handler = lark.EventDispatcherHandler.builder(
self.config.encrypt_key or "", self.config.encrypt_key or "",
self.config.verification_token or "", self.config.verification_token or "",
).register_p2_im_message_receive_v1( ).register_p2_im_message_receive_v1(
self._on_message_sync self._on_message_sync
).build() )
builder = self._register_optional_event(
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
)
builder = self._register_optional_event(
builder, "register_p2_im_message_message_read_v1", self._on_message_read
)
builder = self._register_optional_event(
builder,
"register_p2_im_chat_access_event_bot_p2p_chat_entered_v1",
self._on_bot_p2p_chat_entered,
)
event_handler = builder.build()
# Create WebSocket client for long connection # Create WebSocket client for long connection
self._ws_client = lark.ws.Client( self._ws_client = lark.ws.Client(
@@ -306,16 +308,28 @@ class FeishuChannel(BaseChannel):
log_level=lark.LogLevel.INFO log_level=lark.LogLevel.INFO
) )
# Start WebSocket client in a separate thread with reconnect loop # Start WebSocket client in a separate thread with reconnect loop.
# A dedicated event loop is created for this thread so that lark_oapi's
# module-level `loop = asyncio.get_event_loop()` picks up an idle loop
# instead of the already-running main asyncio loop, which would cause
# "This event loop is already running" errors.
def run_ws(): def run_ws():
while self._running: import time
try: import lark_oapi.ws.client as _lark_ws_client
self._ws_client.start() ws_loop = asyncio.new_event_loop()
except Exception as e: asyncio.set_event_loop(ws_loop)
logger.warning("Feishu WebSocket error: {}", e) # Patch the module-level loop used by lark's ws Client.start()
if self._running: _lark_ws_client.loop = ws_loop
import time try:
time.sleep(5) while self._running:
try:
self._ws_client.start()
except Exception as e:
logger.warning("Feishu WebSocket error: {}", e)
if self._running:
time.sleep(5)
finally:
ws_loop.close()
self._ws_thread = threading.Thread(target=run_ws, daemon=True) self._ws_thread = threading.Thread(target=run_ws, daemon=True)
self._ws_thread.start() self._ws_thread.start()
@@ -340,6 +354,7 @@ class FeishuChannel(BaseChannel):
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool).""" """Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
try: try:
request = CreateMessageReactionRequest.builder() \ request = CreateMessageReactionRequest.builder() \
.message_id(message_id) \ .message_id(message_id) \
@@ -364,7 +379,7 @@ class FeishuChannel(BaseChannel):
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
""" """
if not self._client or not Emoji: if not self._client:
return return
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@@ -413,6 +428,34 @@ class FeishuChannel(BaseChannel):
elements.extend(self._split_headings(remaining)) elements.extend(self._split_headings(remaining))
return elements or [{"tag": "markdown", "content": content}] return elements or [{"tag": "markdown", "content": content}]
@staticmethod
def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]:
"""Split card elements into groups with at most *max_tables* table elements each.
Feishu cards have a hard limit of one table per card (API error 11310).
When the rendered content contains multiple markdown tables each table is
placed in a separate card message so every table reaches the user.
"""
if not elements:
return [[]]
groups: list[list[dict]] = []
current: list[dict] = []
table_count = 0
for el in elements:
if el.get("tag") == "table":
if table_count >= max_tables:
if current:
groups.append(current)
current = []
table_count = 0
current.append(el)
table_count += 1
else:
current.append(el)
if current:
groups.append(current)
return groups or [[]]
def _split_headings(self, content: str) -> list[dict]: def _split_headings(self, content: str) -> list[dict]:
"""Split content by headings, converting headings to div elements.""" """Split content by headings, converting headings to div elements."""
protected = content protected = content
@@ -447,8 +490,124 @@ class FeishuChannel(BaseChannel):
return elements or [{"tag": "markdown", "content": content}] return elements or [{"tag": "markdown", "content": content}]
# ── Smart format detection ──────────────────────────────────────────
# Patterns that indicate "complex" markdown needing card rendering
_COMPLEX_MD_RE = re.compile(
r"```" # fenced code block
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
r"|^#{1,6}\s+" # headings
, re.MULTILINE,
)
# Simple markdown patterns (bold, italic, strikethrough)
_SIMPLE_MD_RE = re.compile(
r"\*\*.+?\*\*" # **bold**
r"|__.+?__" # __bold__
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
r"|~~.+?~~" # ~~strikethrough~~
, re.DOTALL,
)
# Markdown link: [text](url)
_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\((https?://[^\)]+)\)")
# Unordered list items
_LIST_RE = re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE)
# Ordered list items
_OLIST_RE = re.compile(r"^[\s]*\d+\.\s+", re.MULTILINE)
# Max length for plain text format
_TEXT_MAX_LEN = 200
# Max length for post (rich text) format; beyond this, use card
_POST_MAX_LEN = 2000
@classmethod
def _detect_msg_format(cls, content: str) -> str:
"""Determine the optimal Feishu message format for *content*.
Returns one of:
- ``"text"`` plain text, short and no markdown
- ``"post"`` rich text (links only, moderate length)
- ``"interactive"`` card with full markdown rendering
"""
stripped = content.strip()
# Complex markdown (code blocks, tables, headings) → always card
if cls._COMPLEX_MD_RE.search(stripped):
return "interactive"
# Long content → card (better readability with card layout)
if len(stripped) > cls._POST_MAX_LEN:
return "interactive"
# Has bold/italic/strikethrough → card (post format can't render these)
if cls._SIMPLE_MD_RE.search(stripped):
return "interactive"
# Has list items → card (post format can't render list bullets well)
if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped):
return "interactive"
# Has links → post format (supports <a> tags)
if cls._MD_LINK_RE.search(stripped):
return "post"
# Short plain text → text format
if len(stripped) <= cls._TEXT_MAX_LEN:
return "text"
# Medium plain text without any formatting → post format
return "post"
@classmethod
def _markdown_to_post(cls, content: str) -> str:
"""Convert markdown content to Feishu post message JSON.
Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
Each line becomes a paragraph (row) in the post body.
"""
lines = content.strip().split("\n")
paragraphs: list[list[dict]] = []
for line in lines:
elements: list[dict] = []
last_end = 0
for m in cls._MD_LINK_RE.finditer(line):
# Text before this link
before = line[last_end:m.start()]
if before:
elements.append({"tag": "text", "text": before})
elements.append({
"tag": "a",
"text": m.group(1),
"href": m.group(2),
})
last_end = m.end()
# Remaining text after last link
remaining = line[last_end:]
if remaining:
elements.append({"tag": "text", "text": remaining})
# Empty line → empty paragraph for spacing
if not elements:
elements.append({"tag": "text", "text": ""})
paragraphs.append(elements)
post_body = {
"zh_cn": {
"content": paragraphs,
}
}
return json.dumps(post_body, ensure_ascii=False)
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"} _IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
_AUDIO_EXTS = {".opus"} _AUDIO_EXTS = {".opus"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
_FILE_TYPE_MAP = { _FILE_TYPE_MAP = {
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc", ".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt", ".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
@@ -456,6 +615,7 @@ class FeishuChannel(BaseChannel):
def _upload_image_sync(self, file_path: str) -> str | None: def _upload_image_sync(self, file_path: str) -> str | None:
"""Upload an image to Feishu and return the image_key.""" """Upload an image to Feishu and return the image_key."""
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
try: try:
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
request = CreateImageRequest.builder() \ request = CreateImageRequest.builder() \
@@ -479,6 +639,7 @@ class FeishuChannel(BaseChannel):
def _upload_file_sync(self, file_path: str) -> str | None: def _upload_file_sync(self, file_path: str) -> str | None:
"""Upload a file to Feishu and return the file_key.""" """Upload a file to Feishu and return the file_key."""
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
ext = os.path.splitext(file_path)[1].lower() ext = os.path.splitext(file_path)[1].lower()
file_type = self._FILE_TYPE_MAP.get(ext, "stream") file_type = self._FILE_TYPE_MAP.get(ext, "stream")
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
@@ -506,6 +667,7 @@ class FeishuChannel(BaseChannel):
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]: def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
"""Download an image from Feishu message by message_id and image_key.""" """Download an image from Feishu message by message_id and image_key."""
from lark_oapi.api.im.v1 import GetMessageResourceRequest
try: try:
request = GetMessageResourceRequest.builder() \ request = GetMessageResourceRequest.builder() \
.message_id(message_id) \ .message_id(message_id) \
@@ -530,6 +692,13 @@ class FeishuChannel(BaseChannel):
self, message_id: str, file_key: str, resource_type: str = "file" self, message_id: str, file_key: str, resource_type: str = "file"
) -> tuple[bytes | None, str | None]: ) -> tuple[bytes | None, str | None]:
"""Download a file/audio/media from a Feishu message by message_id and file_key.""" """Download a file/audio/media from a Feishu message by message_id and file_key."""
from lark_oapi.api.im.v1 import GetMessageResourceRequest
# Feishu API only accepts 'image' or 'file' as type parameter
# Convert 'audio' to 'file' for API compatibility
if resource_type == "audio":
resource_type = "file"
try: try:
request = ( request = (
GetMessageResourceRequest.builder() GetMessageResourceRequest.builder()
@@ -564,8 +733,7 @@ class FeishuChannel(BaseChannel):
(file_path, content_text) - file_path is None if download failed (file_path, content_text) - file_path is None if download failed
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
media_dir = Path.home() / ".nanobot" / "media" media_dir = get_media_dir("feishu")
media_dir.mkdir(parents=True, exist_ok=True)
data, filename = None, None data, filename = None, None
@@ -585,8 +753,9 @@ class FeishuChannel(BaseChannel):
None, self._download_file_sync, message_id, file_key, msg_type None, self._download_file_sync, message_id, file_key, msg_type
) )
if not filename: if not filename:
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "") filename = file_key[:16]
filename = f"{file_key[:16]}{ext}" if msg_type == "audio" and not filename.endswith(".opus"):
filename = f"{filename}.opus"
if data and filename: if data and filename:
file_path = media_dir / filename file_path = media_dir / filename
@@ -598,6 +767,7 @@ class FeishuChannel(BaseChannel):
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool: def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
"""Send a single message (text/image/file/interactive) synchronously.""" """Send a single message (text/image/file/interactive) synchronously."""
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
try: try:
request = CreateMessageRequest.builder() \ request = CreateMessageRequest.builder() \
.receive_id_type(receive_id_type) \ .receive_id_type(receive_id_type) \
@@ -646,23 +816,50 @@ class FeishuChannel(BaseChannel):
else: else:
key = await loop.run_in_executor(None, self._upload_file_sync, file_path) key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
if key: if key:
media_type = "audio" if ext in self._AUDIO_EXTS else "file" # Use msg_type "media" for audio/video so users can play inline;
# "file" for everything else (documents, archives, etc.)
if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
media_type = "media"
else:
media_type = "file"
await loop.run_in_executor( await loop.run_in_executor(
None, self._send_message_sync, None, self._send_message_sync,
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False), receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
) )
if msg.content and msg.content.strip(): if msg.content and msg.content.strip():
card = {"config": {"wide_screen_mode": True}, "elements": self._build_card_elements(msg.content)} fmt = self._detect_msg_format(msg.content)
await loop.run_in_executor(
None, self._send_message_sync, if fmt == "text":
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False), # Short plain text send as simple text message
) text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "text", text_body,
)
elif fmt == "post":
# Medium content with links send as rich-text post
post_body = self._markdown_to_post(msg.content)
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "post", post_body,
)
else:
# Complex / long content send as interactive card
elements = self._build_card_elements(msg.content)
for chunk in self._split_elements_by_table_limit(elements):
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
)
except Exception as e: except Exception as e:
logger.error("Error sending Feishu message: {}", e) logger.error("Error sending Feishu message: {}", e)
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None: def _on_message_sync(self, data: Any) -> None:
""" """
Sync handler for incoming messages (called from WebSocket thread). Sync handler for incoming messages (called from WebSocket thread).
Schedules async handling in the main event loop. Schedules async handling in the main event loop.
@@ -670,7 +867,7 @@ class FeishuChannel(BaseChannel):
if self._loop and self._loop.is_running(): if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop) asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None: async def _on_message(self, data: Any) -> None:
"""Handle incoming message from Feishu.""" """Handle incoming message from Feishu."""
try: try:
event = data.event event = data.event
@@ -730,6 +927,12 @@ class FeishuChannel(BaseChannel):
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id) file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
if file_path: if file_path:
media_paths.append(file_path) media_paths.append(file_path)
if msg_type == "audio" and file_path:
transcription = await self.transcribe_audio(file_path)
if transcription:
content_text = f"[transcription: {transcription}]"
content_parts.append(content_text) content_parts.append(content_text)
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"): elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
@@ -762,3 +965,16 @@ class FeishuChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.error("Error processing Feishu message: {}", e) logger.error("Error processing Feishu message: {}", e)
def _on_reaction_created(self, data: Any) -> None:
"""Ignore reaction events so they do not generate SDK noise."""
pass
def _on_message_read(self, data: Any) -> None:
"""Ignore read events so they do not generate SDK noise."""
pass
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
"""Ignore p2p-enter events when a user opens a bot chat."""
logger.debug("Bot entered p2p chat (user opened chat window)")
pass

View File

@@ -7,7 +7,6 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Config from nanobot.config.schema import Config
@@ -32,122 +31,23 @@ class ChannelManager:
self._init_channels() self._init_channels()
def _init_channels(self) -> None: def _init_channels(self) -> None:
"""Initialize channels based on config.""" """Initialize channels discovered via pkgutil scan."""
from nanobot.channels.registry import discover_channel_names, load_channel_class
# Telegram channel groq_key = self.config.providers.groq.api_key
if self.config.channels.telegram.enabled:
for modname in discover_channel_names():
section = getattr(self.config.channels, modname, None)
if not section or not getattr(section, "enabled", False):
continue
try: try:
from nanobot.channels.telegram import TelegramChannel cls = load_channel_class(modname)
self.channels["telegram"] = TelegramChannel( channel = cls(section, self.bus)
self.config.channels.telegram, channel.transcription_api_key = groq_key
self.bus, self.channels[modname] = channel
groq_api_key=self.config.providers.groq.api_key, logger.info("{} channel enabled", cls.display_name)
)
logger.info("Telegram channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("Telegram channel not available: {}", e) logger.warning("{} channel not available: {}", modname, e)
# WhatsApp channel
if self.config.channels.whatsapp.enabled:
try:
from nanobot.channels.whatsapp import WhatsAppChannel
self.channels["whatsapp"] = WhatsAppChannel(
self.config.channels.whatsapp, self.bus
)
logger.info("WhatsApp channel enabled")
except ImportError as e:
logger.warning("WhatsApp channel not available: {}", e)
# Discord channel
if self.config.channels.discord.enabled:
try:
from nanobot.channels.discord import DiscordChannel
self.channels["discord"] = DiscordChannel(
self.config.channels.discord, self.bus
)
logger.info("Discord channel enabled")
except ImportError as e:
logger.warning("Discord channel not available: {}", e)
# Feishu channel
if self.config.channels.feishu.enabled:
try:
from nanobot.channels.feishu import FeishuChannel
self.channels["feishu"] = FeishuChannel(
self.config.channels.feishu, self.bus
)
logger.info("Feishu channel enabled")
except ImportError as e:
logger.warning("Feishu channel not available: {}", e)
# Mochat channel
if self.config.channels.mochat.enabled:
try:
from nanobot.channels.mochat import MochatChannel
self.channels["mochat"] = MochatChannel(
self.config.channels.mochat, self.bus
)
logger.info("Mochat channel enabled")
except ImportError as e:
logger.warning("Mochat channel not available: {}", e)
# DingTalk channel
if self.config.channels.dingtalk.enabled:
try:
from nanobot.channels.dingtalk import DingTalkChannel
self.channels["dingtalk"] = DingTalkChannel(
self.config.channels.dingtalk, self.bus
)
logger.info("DingTalk channel enabled")
except ImportError as e:
logger.warning("DingTalk channel not available: {}", e)
# Email channel
if self.config.channels.email.enabled:
try:
from nanobot.channels.email import EmailChannel
self.channels["email"] = EmailChannel(
self.config.channels.email, self.bus
)
logger.info("Email channel enabled")
except ImportError as e:
logger.warning("Email channel not available: {}", e)
# Slack channel
if self.config.channels.slack.enabled:
try:
from nanobot.channels.slack import SlackChannel
self.channels["slack"] = SlackChannel(
self.config.channels.slack, self.bus
)
logger.info("Slack channel enabled")
except ImportError as e:
logger.warning("Slack channel not available: {}", e)
# QQ channel
if self.config.channels.qq.enabled:
try:
from nanobot.channels.qq import QQChannel
self.channels["qq"] = QQChannel(
self.config.channels.qq,
self.bus,
)
logger.info("QQ channel enabled")
except ImportError as e:
logger.warning("QQ channel not available: {}", e)
# Matrix channel
if self.config.channels.matrix.enabled:
try:
from nanobot.channels.matrix import MatrixChannel
self.channels["matrix"] = MatrixChannel(
self.config.channels.matrix,
self.bus,
)
logger.info("Matrix channel enabled")
except ImportError as e:
logger.warning("Matrix channel not available: {}", e)
self._validate_allow_from() self._validate_allow_from()

View File

@@ -37,8 +37,9 @@ except ImportError as e:
) from e ) from e
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.loader import get_data_dir from nanobot.config.paths import get_data_dir, get_media_dir
from nanobot.utils.helpers import safe_filename from nanobot.utils.helpers import safe_filename
TYPING_NOTICE_TIMEOUT_MS = 30_000 TYPING_NOTICE_TIMEOUT_MS = 30_000
@@ -146,15 +147,15 @@ class MatrixChannel(BaseChannel):
"""Matrix (Element) channel using long-polling sync.""" """Matrix (Element) channel using long-polling sync."""
name = "matrix" name = "matrix"
display_name = "Matrix"
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False, def __init__(self, config: Any, bus: MessageBus):
workspace: Path | None = None):
super().__init__(config, bus) super().__init__(config, bus)
self.client: AsyncClient | None = None self.client: AsyncClient | None = None
self._sync_task: asyncio.Task | None = None self._sync_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tasks: dict[str, asyncio.Task] = {}
self._restrict_to_workspace = restrict_to_workspace self._restrict_to_workspace = False
self._workspace = workspace.expanduser().resolve() if workspace else None self._workspace: Path | None = None
self._server_upload_limit_bytes: int | None = None self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False self._server_upload_limit_checked = False
@@ -490,9 +491,7 @@ class MatrixChannel(BaseChannel):
return False return False
def _media_dir(self) -> Path: def _media_dir(self) -> Path:
d = get_data_dir() / "media" / "matrix" return get_media_dir("matrix")
d.mkdir(parents=True, exist_ok=True)
return d
@staticmethod @staticmethod
def _event_source_content(event: RoomMessage) -> dict[str, Any]: def _event_source_content(event: RoomMessage) -> dict[str, Any]:
@@ -679,7 +678,14 @@ class MatrixChannel(BaseChannel):
parts: list[str] = [] parts: list[str] = []
if isinstance(body := getattr(event, "body", None), str) and body.strip(): if isinstance(body := getattr(event, "body", None), str) and body.strip():
parts.append(body.strip()) parts.append(body.strip())
if marker:
if attachment and attachment.get("type") == "audio":
transcription = await self.transcribe_audio(attachment["path"])
if transcription:
parts.append(f"[transcription: {transcription}]")
else:
parts.append(marker)
elif marker:
parts.append(marker) parts.append(marker)
await self._start_typing_keepalive(room.room_id) await self._start_typing_keepalive(room.room_id)

View File

@@ -15,8 +15,8 @@ from loguru import logger
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_runtime_subdir
from nanobot.config.schema import MochatConfig from nanobot.config.schema import MochatConfig
from nanobot.utils.helpers import get_data_path
try: try:
import socketio import socketio
@@ -216,6 +216,7 @@ class MochatChannel(BaseChannel):
"""Mochat channel using socket.io with fallback polling workers.""" """Mochat channel using socket.io with fallback polling workers."""
name = "mochat" name = "mochat"
display_name = "Mochat"
def __init__(self, config: MochatConfig, bus: MessageBus): def __init__(self, config: MochatConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
@@ -224,7 +225,7 @@ class MochatChannel(BaseChannel):
self._socket: Any = None self._socket: Any = None
self._ws_connected = self._ws_ready = False self._ws_connected = self._ws_ready = False
self._state_dir = get_data_path() / "mochat" self._state_dir = get_runtime_subdir("mochat")
self._cursor_path = self._state_dir / "session_cursors.json" self._cursor_path = self._state_dir / "session_cursors.json"
self._session_cursor: dict[str, int] = {} self._session_cursor: dict[str, int] = {}
self._cursor_save_task: asyncio.Task | None = None self._cursor_save_task: asyncio.Task | None = None

View File

@@ -13,16 +13,17 @@ from nanobot.config.schema import QQConfig
try: try:
import botpy import botpy
from botpy.message import C2CMessage from botpy.message import C2CMessage, GroupMessage
QQ_AVAILABLE = True QQ_AVAILABLE = True
except ImportError: except ImportError:
QQ_AVAILABLE = False QQ_AVAILABLE = False
botpy = None botpy = None
C2CMessage = None C2CMessage = None
GroupMessage = None
if TYPE_CHECKING: if TYPE_CHECKING:
from botpy.message import C2CMessage from botpy.message import C2CMessage, GroupMessage
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
@@ -38,10 +39,13 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
logger.info("QQ bot ready: {}", self.robot.name) logger.info("QQ bot ready: {}", self.robot.name)
async def on_c2c_message_create(self, message: "C2CMessage"): async def on_c2c_message_create(self, message: "C2CMessage"):
await channel._on_message(message) await channel._on_message(message, is_group=False)
async def on_group_at_message_create(self, message: "GroupMessage"):
await channel._on_message(message, is_group=True)
async def on_direct_message_create(self, message): async def on_direct_message_create(self, message):
await channel._on_message(message) await channel._on_message(message, is_group=False)
return _Bot return _Bot
@@ -50,12 +54,15 @@ class QQChannel(BaseChannel):
"""QQ channel using botpy SDK with WebSocket connection.""" """QQ channel using botpy SDK with WebSocket connection."""
name = "qq" name = "qq"
display_name = "QQ"
def __init__(self, config: QQConfig, bus: MessageBus): def __init__(self, config: QQConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: QQConfig = config self.config: QQConfig = config
self._client: "botpy.Client | None" = None self._client: "botpy.Client | None" = None
self._processed_ids: deque = deque(maxlen=1000) self._processed_ids: deque = deque(maxlen=1000)
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
self._chat_type_cache: dict[str, str] = {}
async def start(self) -> None: async def start(self) -> None:
"""Start the QQ bot.""" """Start the QQ bot."""
@@ -70,8 +77,7 @@ class QQChannel(BaseChannel):
self._running = True self._running = True
BotClass = _make_bot_class(self) BotClass = _make_bot_class(self)
self._client = BotClass() self._client = BotClass()
logger.info("QQ bot started (C2C & Group supported)")
logger.info("QQ bot started (C2C private message)")
await self._run_bot() await self._run_bot()
async def _run_bot(self) -> None: async def _run_bot(self) -> None:
@@ -100,18 +106,31 @@ class QQChannel(BaseChannel):
if not self._client: if not self._client:
logger.warning("QQ client not initialized") logger.warning("QQ client not initialized")
return return
try: try:
msg_id = msg.metadata.get("message_id") msg_id = msg.metadata.get("message_id")
await self._client.api.post_c2c_message( self._msg_seq += 1
openid=msg.chat_id, msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
msg_type=0, if msg_type == "group":
content=msg.content, await self._client.api.post_group_message(
msg_id=msg_id, group_openid=msg.chat_id,
) msg_type=2,
markdown={"content": msg.content},
msg_id=msg_id,
msg_seq=self._msg_seq,
)
else:
await self._client.api.post_c2c_message(
openid=msg.chat_id,
msg_type=2,
markdown={"content": msg.content},
msg_id=msg_id,
msg_seq=self._msg_seq,
)
except Exception as e: except Exception as e:
logger.error("Error sending QQ message: {}", e) logger.error("Error sending QQ message: {}", e)
async def _on_message(self, data: "C2CMessage") -> None: async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
"""Handle incoming message from QQ.""" """Handle incoming message from QQ."""
try: try:
# Dedup by message ID # Dedup by message ID
@@ -119,15 +138,22 @@ class QQChannel(BaseChannel):
return return
self._processed_ids.append(data.id) self._processed_ids.append(data.id)
author = data.author
user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
content = (data.content or "").strip() content = (data.content or "").strip()
if not content: if not content:
return return
if is_group:
chat_id = data.group_openid
user_id = data.author.member_openid
self._chat_type_cache[chat_id] = "group"
else:
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
user_id = chat_id
self._chat_type_cache[chat_id] = "c2c"
await self._handle_message( await self._handle_message(
sender_id=user_id, sender_id=user_id,
chat_id=user_id, chat_id=chat_id,
content=content, content=content,
metadata={"message_id": data.id}, metadata={"message_id": data.id},
) )

View File

@@ -0,0 +1,35 @@
"""Auto-discovery for channel modules — no hardcoded registry."""
from __future__ import annotations
import importlib
import pkgutil
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from nanobot.channels.base import BaseChannel
_INTERNAL = frozenset({"base", "manager", "registry"})
def discover_channel_names() -> list[str]:
"""Return all channel module names by scanning the package (zero imports)."""
import nanobot.channels as pkg
return [
name
for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
if name not in _INTERNAL and not ispkg
]
def load_channel_class(module_name: str) -> type[BaseChannel]:
"""Import *module_name* and return the first BaseChannel subclass found."""
from nanobot.channels.base import BaseChannel as _Base
mod = importlib.import_module(f"nanobot.channels.{module_name}")
for attr in dir(mod):
obj = getattr(mod, attr)
if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
return obj
raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")

View File

@@ -21,6 +21,7 @@ class SlackChannel(BaseChannel):
"""Slack channel using Socket Mode.""" """Slack channel using Socket Mode."""
name = "slack" name = "slack"
display_name = "Slack"
def __init__(self, config: SlackConfig, bus: MessageBus): def __init__(self, config: SlackConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
@@ -81,14 +82,15 @@ class SlackChannel(BaseChannel):
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {} slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
thread_ts = slack_meta.get("thread_ts") thread_ts = slack_meta.get("thread_ts")
channel_type = slack_meta.get("channel_type") channel_type = slack_meta.get("channel_type")
# Only reply in thread for channel/group messages; DMs don't use threads # Slack DMs don't use threads; channel/group replies may keep thread_ts.
use_thread = thread_ts and channel_type != "im" thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
thread_ts_param = thread_ts if use_thread else None
if msg.content: # Slack rejects empty text payloads. Keep media-only messages media-only,
# but send a single blank message when the bot has no text or files to send.
if msg.content or not (msg.media or []):
await self._web_client.chat_postMessage( await self._web_client.chat_postMessage(
channel=msg.chat_id, channel=msg.chat_id,
text=self._to_mrkdwn(msg.content), text=self._to_mrkdwn(msg.content) if msg.content else " ",
thread_ts=thread_ts_param, thread_ts=thread_ts_param,
) )
@@ -277,4 +279,3 @@ class SlackChannel(BaseChannel):
if parts: if parts:
rows.append(" · ".join(parts)) rows.append(" · ".join(parts))
return "\n".join(rows) return "\n".join(rows)

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio import asyncio
import re import re
import time
import unicodedata
from loguru import logger from loguru import logger
from telegram import BotCommand, ReplyParameters, Update from telegram import BotCommand, ReplyParameters, Update
@@ -13,7 +15,52 @@ from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import TelegramConfig from nanobot.config.schema import TelegramConfig
from nanobot.utils.helpers import split_message
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
def _strip_md(s: str) -> str:
"""Strip markdown inline formatting from text."""
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
s = re.sub(r'__(.+?)__', r'\1', s)
s = re.sub(r'~~(.+?)~~', r'\1', s)
s = re.sub(r'`([^`]+)`', r'\1', s)
return s.strip()
def _render_table_box(table_lines: list[str]) -> str:
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
def dw(s: str) -> int:
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
rows: list[list[str]] = []
has_sep = False
for line in table_lines:
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
has_sep = True
continue
rows.append(cells)
if not rows or not has_sep:
return '\n'.join(table_lines)
ncols = max(len(r) for r in rows)
for r in rows:
r.extend([''] * (ncols - len(r)))
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
def dr(cells: list[str]) -> str:
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
out = [dr(rows[0])]
out.append(' '.join('' * w for w in widths))
for row in rows[1:]:
out.append(dr(row))
return '\n'.join(out)
def _markdown_to_telegram_html(text: str) -> str: def _markdown_to_telegram_html(text: str) -> str:
@@ -31,6 +78,27 @@ def _markdown_to_telegram_html(text: str) -> str:
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text) text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
lines = text.split('\n')
rebuilt: list[str] = []
li = 0
while li < len(lines):
if re.match(r'^\s*\|.+\|', lines[li]):
tbl: list[str] = []
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
tbl.append(lines[li])
li += 1
box = _render_table_box(tbl)
if box != '\n'.join(tbl):
code_blocks.append(box)
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
else:
rebuilt.extend(tbl)
else:
rebuilt.append(lines[li])
li += 1
text = '\n'.join(rebuilt)
# 2. Extract and protect inline code # 2. Extract and protect inline code
inline_codes: list[str] = [] inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str: def save_inline_code(m: re.Match) -> str:
@@ -79,26 +147,6 @@ def _markdown_to_telegram_html(text: str) -> str:
return text return text
def _split_message(content: str, max_len: int = 4000) -> list[str]:
"""Split content into chunks within max_len, preferring line breaks."""
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
pos = cut.rfind('\n')
if pos == -1:
pos = cut.rfind(' ')
if pos == -1:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
class TelegramChannel(BaseChannel): class TelegramChannel(BaseChannel):
""" """
Telegram channel using long polling. Telegram channel using long polling.
@@ -107,6 +155,7 @@ class TelegramChannel(BaseChannel):
""" """
name = "telegram" name = "telegram"
display_name = "Telegram"
# Commands registered with Telegram's command menu # Commands registered with Telegram's command menu
BOT_COMMANDS = [ BOT_COMMANDS = [
@@ -116,20 +165,36 @@ class TelegramChannel(BaseChannel):
BotCommand("help", "Show available commands"), BotCommand("help", "Show available commands"),
] ]
def __init__( def __init__(self, config: TelegramConfig, bus: MessageBus):
self,
config: TelegramConfig,
bus: MessageBus,
groq_api_key: str = "",
):
super().__init__(config, bus) super().__init__(config, bus)
self.config: TelegramConfig = config self.config: TelegramConfig = config
self.groq_api_key = groq_api_key
self._app: Application | None = None self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict] = {} self._media_group_buffers: dict[str, dict] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {} self._media_group_tasks: dict[str, asyncio.Task] = {}
self._message_threads: dict[tuple[str, int], int] = {}
self._bot_user_id: int | None = None
self._bot_username: str | None = None
def is_allowed(self, sender_id: str) -> bool:
"""Preserve Telegram's legacy id|username allowlist matching."""
if super().is_allowed(sender_id):
return True
allow_list = getattr(self.config, "allow_from", [])
if not allow_list or "*" in allow_list:
return False
sender_str = str(sender_id)
if sender_str.count("|") != 1:
return False
sid, username = sender_str.split("|", 1)
if not sid.isdigit() or not username:
return False
return sid in allow_list or username in allow_list
async def start(self) -> None: async def start(self) -> None:
"""Start the Telegram bot with long polling.""" """Start the Telegram bot with long polling."""
@@ -140,16 +205,21 @@ class TelegramChannel(BaseChannel):
self._running = True self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs # Build the application with larger connection pool to avoid pool-timeout on long runs
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) req = HTTPXRequest(
connection_pool_size=16,
pool_timeout=5.0,
connect_timeout=30.0,
read_timeout=30.0,
proxy=self.config.proxy if self.config.proxy else None,
)
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
if self.config.proxy:
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build() self._app = builder.build()
self._app.add_error_handler(self._on_error) self._app.add_error_handler(self._on_error)
# Add command handlers # Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command)) self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("stop", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help)) self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents # Add message handler for text, photos, voice, documents
@@ -169,6 +239,8 @@ class TelegramChannel(BaseChannel):
# Get bot info and register command menu # Get bot info and register command menu
bot_info = await self._app.bot.get_me() bot_info = await self._app.bot.get_me()
self._bot_user_id = getattr(bot_info, "id", None)
self._bot_username = getattr(bot_info, "username", None)
logger.info("Telegram bot @{} connected", bot_info.username) logger.info("Telegram bot @{} connected", bot_info.username)
try: try:
@@ -225,17 +297,25 @@ class TelegramChannel(BaseChannel):
logger.warning("Telegram bot not running") logger.warning("Telegram bot not running")
return return
self._stop_typing(msg.chat_id) # Only stop typing indicator for final responses
if not msg.metadata.get("_progress", False):
self._stop_typing(msg.chat_id)
try: try:
chat_id = int(msg.chat_id) chat_id = int(msg.chat_id)
except ValueError: except ValueError:
logger.error("Invalid chat_id: {}", msg.chat_id) logger.error("Invalid chat_id: {}", msg.chat_id)
return return
reply_to_message_id = msg.metadata.get("message_id")
message_thread_id = msg.metadata.get("message_thread_id")
if message_thread_id is None and reply_to_message_id is not None:
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
thread_kwargs = {}
if message_thread_id is not None:
thread_kwargs["message_thread_id"] = message_thread_id
reply_params = None reply_params = None
if self.config.reply_to_message: if self.config.reply_to_message:
reply_to_message_id = msg.metadata.get("message_id")
if reply_to_message_id: if reply_to_message_id:
reply_params = ReplyParameters( reply_params = ReplyParameters(
message_id=reply_to_message_id, message_id=reply_to_message_id,
@@ -256,7 +336,8 @@ class TelegramChannel(BaseChannel):
await sender( await sender(
chat_id=chat_id, chat_id=chat_id,
**{param: f}, **{param: f},
reply_parameters=reply_params reply_parameters=reply_params,
**thread_kwargs,
) )
except Exception as e: except Exception as e:
filename = media_path.rsplit("/", 1)[-1] filename = media_path.rsplit("/", 1)[-1]
@@ -264,30 +345,71 @@ class TelegramChannel(BaseChannel):
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=f"[Failed to send: {filename}]", text=f"[Failed to send: {filename}]",
reply_parameters=reply_params reply_parameters=reply_params,
**thread_kwargs,
) )
# Send text content # Send text content
if msg.content and msg.content != "[empty message]": if msg.content and msg.content != "[empty message]":
for chunk in _split_message(msg.content): is_progress = msg.metadata.get("_progress", False)
try:
html = _markdown_to_telegram_html(chunk) for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
await self._app.bot.send_message( # Final response: simulate streaming via draft, then persist
chat_id=chat_id, if not is_progress:
text=html, await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
parse_mode="HTML", else:
reply_parameters=reply_params await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
)
except Exception as e: async def _send_text(
logger.warning("HTML parse failed, falling back to plain text: {}", e) self,
try: chat_id: int,
await self._app.bot.send_message( text: str,
chat_id=chat_id, reply_params=None,
text=chunk, thread_kwargs: dict | None = None,
reply_parameters=reply_params ) -> None:
) """Send a plain text message with HTML fallback."""
except Exception as e2: try:
logger.error("Error sending Telegram message: {}", e2) html = _markdown_to_telegram_html(text)
await self._app.bot.send_message(
chat_id=chat_id, text=html, parse_mode="HTML",
reply_parameters=reply_params,
**(thread_kwargs or {}),
)
except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e)
try:
await self._app.bot.send_message(
chat_id=chat_id,
text=text,
reply_parameters=reply_params,
**(thread_kwargs or {}),
)
except Exception as e2:
logger.error("Error sending Telegram message: {}", e2)
async def _send_with_streaming(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Simulate streaming via send_message_draft, then persist with send_message."""
draft_id = int(time.time() * 1000) % (2**31)
try:
step = max(len(text) // 8, 40)
for i in range(step, len(text), step):
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text[:i],
)
await asyncio.sleep(0.04)
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text,
)
await asyncio.sleep(0.15)
except Exception:
pass
await self._send_text(chat_id, text, reply_params, thread_kwargs)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command.""" """Handle /start command."""
@@ -318,14 +440,114 @@ class TelegramChannel(BaseChannel):
sid = str(user.id) sid = str(user.id)
return f"{sid}|{user.username}" if user.username else sid return f"{sid}|{user.username}" if user.username else sid
@staticmethod
def _derive_topic_session_key(message) -> str | None:
"""Derive topic-scoped session key for non-private Telegram chats."""
message_thread_id = getattr(message, "message_thread_id", None)
if message.chat.type == "private" or message_thread_id is None:
return None
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
@staticmethod
def _build_message_metadata(message, user) -> dict:
"""Build common Telegram inbound metadata payload."""
return {
"message_id": message.message_id,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private",
"message_thread_id": getattr(message, "message_thread_id", None),
"is_forum": bool(getattr(message.chat, "is_forum", False)),
}
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
"""Load bot identity once and reuse it for mention/reply checks."""
if self._bot_user_id is not None or self._bot_username is not None:
return self._bot_user_id, self._bot_username
if not self._app:
return None, None
bot_info = await self._app.bot.get_me()
self._bot_user_id = getattr(bot_info, "id", None)
self._bot_username = getattr(bot_info, "username", None)
return self._bot_user_id, self._bot_username
@staticmethod
def _has_mention_entity(
text: str,
entities,
bot_username: str,
bot_id: int | None,
) -> bool:
"""Check Telegram mention entities against the bot username."""
handle = f"@{bot_username}".lower()
for entity in entities or []:
entity_type = getattr(entity, "type", None)
if entity_type == "text_mention":
user = getattr(entity, "user", None)
if user is not None and bot_id is not None and getattr(user, "id", None) == bot_id:
return True
continue
if entity_type != "mention":
continue
offset = getattr(entity, "offset", None)
length = getattr(entity, "length", None)
if offset is None or length is None:
continue
if text[offset : offset + length].lower() == handle:
return True
return handle in text.lower()
async def _is_group_message_for_bot(self, message) -> bool:
"""Allow group messages when policy is open, @mentioned, or replying to the bot."""
if message.chat.type == "private" or self.config.group_policy == "open":
return True
bot_id, bot_username = await self._ensure_bot_identity()
if bot_username:
text = message.text or ""
caption = message.caption or ""
if self._has_mention_entity(
text,
getattr(message, "entities", None),
bot_username,
bot_id,
):
return True
if self._has_mention_entity(
caption,
getattr(message, "caption_entities", None),
bot_username,
bot_id,
):
return True
reply_user = getattr(getattr(message, "reply_to_message", None), "from_user", None)
return bool(bot_id and reply_user and reply_user.id == bot_id)
def _remember_thread_context(self, message) -> None:
"""Cache topic thread id by chat/message id for follow-up replies."""
message_thread_id = getattr(message, "message_thread_id", None)
if message_thread_id is None:
return
key = (str(message.chat_id), message.message_id)
self._message_threads[key] = message_thread_id
if len(self._message_threads) > 1000:
self._message_threads.pop(next(iter(self._message_threads)))
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Forward slash commands to the bus for unified handling in AgentLoop.""" """Forward slash commands to the bus for unified handling in AgentLoop."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
return return
message = update.message
user = update.effective_user
self._remember_thread_context(message)
await self._handle_message( await self._handle_message(
sender_id=self._sender_id(update.effective_user), sender_id=self._sender_id(user),
chat_id=str(update.message.chat_id), chat_id=str(message.chat_id),
content=update.message.text, content=message.text,
metadata=self._build_message_metadata(message, user),
session_key=self._derive_topic_session_key(message),
) )
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
@@ -337,10 +559,14 @@ class TelegramChannel(BaseChannel):
user = update.effective_user user = update.effective_user
chat_id = message.chat_id chat_id = message.chat_id
sender_id = self._sender_id(user) sender_id = self._sender_id(user)
self._remember_thread_context(message)
# Store chat_id for replies # Store chat_id for replies
self._chat_ids[sender_id] = chat_id self._chat_ids[sender_id] = chat_id
if not await self._is_group_message_for_bot(message):
return
# Build content from text and/or media # Build content from text and/or media
content_parts = [] content_parts = []
media_paths = [] media_paths = []
@@ -372,23 +598,20 @@ class TelegramChannel(BaseChannel):
if media_file and self._app: if media_file and self._app:
try: try:
file = await self._app.bot.get_file(media_file.file_id) file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None)) ext = self._get_extension(
media_type,
# Save to workspace/media/ getattr(media_file, 'mime_type', None),
from pathlib import Path getattr(media_file, 'file_name', None),
media_dir = Path.home() / ".nanobot" / "media" )
media_dir.mkdir(parents=True, exist_ok=True) media_dir = get_media_dir("telegram")
file_path = media_dir / f"{media_file.file_id[:16]}{ext}" file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
await file.download_to_drive(str(file_path)) await file.download_to_drive(str(file_path))
media_paths.append(str(file_path)) media_paths.append(str(file_path))
# Handle voice transcription if media_type in ("voice", "audio"):
if media_type == "voice" or media_type == "audio": transcription = await self.transcribe_audio(file_path)
from nanobot.providers.transcription import GroqTranscriptionProvider
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
transcription = await transcriber.transcribe(file_path)
if transcription: if transcription:
logger.info("Transcribed {}: {}...", media_type, transcription[:50]) logger.info("Transcribed {}: {}...", media_type, transcription[:50])
content_parts.append(f"[transcription: {transcription}]") content_parts.append(f"[transcription: {transcription}]")
@@ -407,6 +630,8 @@ class TelegramChannel(BaseChannel):
logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id) str_chat_id = str(chat_id)
metadata = self._build_message_metadata(message, user)
session_key = self._derive_topic_session_key(message)
# Telegram media groups: buffer briefly, forward as one aggregated turn. # Telegram media groups: buffer briefly, forward as one aggregated turn.
if media_group_id := getattr(message, "media_group_id", None): if media_group_id := getattr(message, "media_group_id", None):
@@ -415,11 +640,8 @@ class TelegramChannel(BaseChannel):
self._media_group_buffers[key] = { self._media_group_buffers[key] = {
"sender_id": sender_id, "chat_id": str_chat_id, "sender_id": sender_id, "chat_id": str_chat_id,
"contents": [], "media": [], "contents": [], "media": [],
"metadata": { "metadata": metadata,
"message_id": message.message_id, "user_id": user.id, "session_key": session_key,
"username": user.username, "first_name": user.first_name,
"is_group": message.chat.type != "private",
},
} }
self._start_typing(str_chat_id) self._start_typing(str_chat_id)
buf = self._media_group_buffers[key] buf = self._media_group_buffers[key]
@@ -439,13 +661,8 @@ class TelegramChannel(BaseChannel):
chat_id=str_chat_id, chat_id=str_chat_id,
content=content, content=content,
media=media_paths, media=media_paths,
metadata={ metadata=metadata,
"message_id": message.message_id, session_key=session_key,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private"
}
) )
async def _flush_media_group(self, key: str) -> None: async def _flush_media_group(self, key: str) -> None:
@@ -459,6 +676,7 @@ class TelegramChannel(BaseChannel):
sender_id=buf["sender_id"], chat_id=buf["chat_id"], sender_id=buf["sender_id"], chat_id=buf["chat_id"],
content=content, media=list(dict.fromkeys(buf["media"])), content=content, media=list(dict.fromkeys(buf["media"])),
metadata=buf["metadata"], metadata=buf["metadata"],
session_key=buf.get("session_key"),
) )
finally: finally:
self._media_group_tasks.pop(key, None) self._media_group_tasks.pop(key, None)
@@ -490,8 +708,13 @@ class TelegramChannel(BaseChannel):
"""Log polling / handler errors instead of silently swallowing them.""" """Log polling / handler errors instead of silently swallowing them."""
logger.error("Telegram error: {}", context.error) logger.error("Telegram error: {}", context.error)
def _get_extension(self, media_type: str, mime_type: str | None) -> str: def _get_extension(
"""Get file extension based on media type.""" self,
media_type: str,
mime_type: str | None,
filename: str | None = None,
) -> str:
"""Get file extension based on media type or original filename."""
if mime_type: if mime_type:
ext_map = { ext_map = {
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", "image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
@@ -501,4 +724,12 @@ class TelegramChannel(BaseChannel):
return ext_map[mime_type] return ext_map[mime_type]
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""} type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
return type_map.get(media_type, "") if ext := type_map.get(media_type, ""):
return ext
if filename:
from pathlib import Path
return "".join(Path(filename).suffixes)
return ""

353
nanobot/channels/wecom.py Normal file
View File

@@ -0,0 +1,353 @@
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
import asyncio
import importlib.util
import os
from collections import OrderedDict
from typing import Any
from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import WecomConfig
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
# Message type display mapping
MSG_TYPE_MAP = {
"image": "[image]",
"voice": "[voice]",
"file": "[file]",
"mixed": "[mixed content]",
}
class WecomChannel(BaseChannel):
"""
WeCom (Enterprise WeChat) channel using WebSocket long connection.
Uses WebSocket to receive events - no public IP or webhook required.
Requires:
- Bot ID and Secret from WeCom AI Bot platform
"""
name = "wecom"
display_name = "WeCom"
def __init__(self, config: WecomConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: WecomConfig = config
self._client: Any = None
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
self._loop: asyncio.AbstractEventLoop | None = None
self._generate_req_id = None
# Store frame headers for each chat to enable replies
self._chat_frames: dict[str, Any] = {}
async def start(self) -> None:
"""Start the WeCom bot with WebSocket long connection."""
if not WECOM_AVAILABLE:
logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
return
if not self.config.bot_id or not self.config.secret:
logger.error("WeCom bot_id and secret not configured")
return
from wecom_aibot_sdk import WSClient, generate_req_id
self._running = True
self._loop = asyncio.get_running_loop()
self._generate_req_id = generate_req_id
# Create WebSocket client
self._client = WSClient({
"bot_id": self.config.bot_id,
"secret": self.config.secret,
"reconnect_interval": 1000,
"max_reconnect_attempts": -1, # Infinite reconnect
"heartbeat_interval": 30000,
})
# Register event handlers
self._client.on("connected", self._on_connected)
self._client.on("authenticated", self._on_authenticated)
self._client.on("disconnected", self._on_disconnected)
self._client.on("error", self._on_error)
self._client.on("message.text", self._on_text_message)
self._client.on("message.image", self._on_image_message)
self._client.on("message.voice", self._on_voice_message)
self._client.on("message.file", self._on_file_message)
self._client.on("message.mixed", self._on_mixed_message)
self._client.on("event.enter_chat", self._on_enter_chat)
logger.info("WeCom bot starting with WebSocket long connection")
logger.info("No public IP required - using WebSocket to receive events")
# Connect
await self._client.connect_async()
# Keep running until stopped
while self._running:
await asyncio.sleep(1)
async def stop(self) -> None:
"""Stop the WeCom bot."""
self._running = False
if self._client:
await self._client.disconnect()
logger.info("WeCom bot stopped")
async def _on_connected(self, frame: Any) -> None:
"""Handle WebSocket connected event."""
logger.info("WeCom WebSocket connected")
async def _on_authenticated(self, frame: Any) -> None:
"""Handle authentication success event."""
logger.info("WeCom authenticated successfully")
async def _on_disconnected(self, frame: Any) -> None:
"""Handle WebSocket disconnected event."""
reason = frame.body if hasattr(frame, 'body') else str(frame)
logger.warning("WeCom WebSocket disconnected: {}", reason)
async def _on_error(self, frame: Any) -> None:
"""Handle error event."""
logger.error("WeCom error: {}", frame)
async def _on_text_message(self, frame: Any) -> None:
"""Handle text message."""
await self._process_message(frame, "text")
async def _on_image_message(self, frame: Any) -> None:
"""Handle image message."""
await self._process_message(frame, "image")
async def _on_voice_message(self, frame: Any) -> None:
"""Handle voice message."""
await self._process_message(frame, "voice")
async def _on_file_message(self, frame: Any) -> None:
"""Handle file message."""
await self._process_message(frame, "file")
async def _on_mixed_message(self, frame: Any) -> None:
"""Handle mixed content message."""
await self._process_message(frame, "mixed")
async def _on_enter_chat(self, frame: Any) -> None:
"""Handle enter_chat event (user opens chat with bot)."""
try:
# Extract body from WsFrame dataclass or dict
if hasattr(frame, 'body'):
body = frame.body or {}
elif isinstance(frame, dict):
body = frame.get("body", frame)
else:
body = {}
chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
if chat_id and self.config.welcome_message:
await self._client.reply_welcome(frame, {
"msgtype": "text",
"text": {"content": self.config.welcome_message},
})
except Exception as e:
logger.error("Error handling enter_chat: {}", e)
async def _process_message(self, frame: Any, msg_type: str) -> None:
"""Process incoming message and forward to bus."""
try:
# Extract body from WsFrame dataclass or dict
if hasattr(frame, 'body'):
body = frame.body or {}
elif isinstance(frame, dict):
body = frame.get("body", frame)
else:
body = {}
# Ensure body is a dict
if not isinstance(body, dict):
logger.warning("Invalid body type: {}", type(body))
return
# Extract message info
msg_id = body.get("msgid", "")
if not msg_id:
msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
# Deduplication check
if msg_id in self._processed_message_ids:
return
self._processed_message_ids[msg_id] = None
# Trim cache
while len(self._processed_message_ids) > 1000:
self._processed_message_ids.popitem(last=False)
# Extract sender info from "from" field (SDK format)
from_info = body.get("from", {})
sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
# For single chat, chatid is the sender's userid
# For group chat, chatid is provided in body
chat_type = body.get("chattype", "single")
chat_id = body.get("chatid", sender_id)
content_parts = []
if msg_type == "text":
text = body.get("text", {}).get("content", "")
if text:
content_parts.append(text)
elif msg_type == "image":
image_info = body.get("image", {})
file_url = image_info.get("url", "")
aes_key = image_info.get("aeskey", "")
if file_url and aes_key:
file_path = await self._download_and_save_media(file_url, aes_key, "image")
if file_path:
filename = os.path.basename(file_path)
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
else:
content_parts.append("[image: download failed]")
else:
content_parts.append("[image: download failed]")
elif msg_type == "voice":
voice_info = body.get("voice", {})
# Voice message already contains transcribed content from WeCom
voice_content = voice_info.get("content", "")
if voice_content:
content_parts.append(f"[voice] {voice_content}")
else:
content_parts.append("[voice]")
elif msg_type == "file":
file_info = body.get("file", {})
file_url = file_info.get("url", "")
aes_key = file_info.get("aeskey", "")
file_name = file_info.get("name", "unknown")
if file_url and aes_key:
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
if file_path:
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
else:
content_parts.append(f"[file: {file_name}: download failed]")
else:
content_parts.append(f"[file: {file_name}: download failed]")
elif msg_type == "mixed":
# Mixed content contains multiple message items
msg_items = body.get("mixed", {}).get("item", [])
for item in msg_items:
item_type = item.get("type", "")
if item_type == "text":
text = item.get("text", {}).get("content", "")
if text:
content_parts.append(text)
else:
content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
else:
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
content = "\n".join(content_parts) if content_parts else ""
if not content:
return
# Store frame for this chat to enable replies
self._chat_frames[chat_id] = frame
# Forward to message bus
# Note: media paths are included in content for broader model compatibility
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=content,
media=None,
metadata={
"message_id": msg_id,
"msg_type": msg_type,
"chat_type": chat_type,
}
)
except Exception as e:
logger.error("Error processing WeCom message: {}", e)
async def _download_and_save_media(
self,
file_url: str,
aes_key: str,
media_type: str,
filename: str | None = None,
) -> str | None:
"""
Download and decrypt media from WeCom.
Returns:
file_path or None if download failed
"""
try:
data, fname = await self._client.download_file(file_url, aes_key)
if not data:
logger.warning("Failed to download media from WeCom")
return None
media_dir = get_media_dir("wecom")
if not filename:
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
filename = os.path.basename(filename)
file_path = media_dir / filename
file_path.write_bytes(data)
logger.debug("Downloaded {} to {}", media_type, file_path)
return str(file_path)
except Exception as e:
logger.error("Error downloading media: {}", e)
return None
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WeCom."""
if not self._client:
logger.warning("WeCom client not initialized")
return
try:
content = msg.content.strip()
if not content:
return
# Get the stored frame for this chat
frame = self._chat_frames.get(msg.chat_id)
if not frame:
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
return
# Use streaming reply for better UX
stream_id = self._generate_req_id("stream")
# Send as streaming message with finish=True
await self._client.reply_stream(
frame,
stream_id,
content,
finish=True,
)
logger.debug("WeCom message sent to {}", msg.chat_id)
except Exception as e:
logger.error("Error sending WeCom message: {}", e)

View File

@@ -2,6 +2,7 @@
import asyncio import asyncio
import json import json
import mimetypes
from collections import OrderedDict from collections import OrderedDict
from loguru import logger from loguru import logger
@@ -21,6 +22,7 @@ class WhatsAppChannel(BaseChannel):
""" """
name = "whatsapp" name = "whatsapp"
display_name = "WhatsApp"
def __init__(self, config: WhatsAppConfig, bus: MessageBus): def __init__(self, config: WhatsAppConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
@@ -128,10 +130,22 @@ class WhatsAppChannel(BaseChannel):
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id) logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]" content = "[Voice Message: Transcription not available for WhatsApp yet]"
# Extract media paths (images/documents/videos downloaded by the bridge)
media_paths = data.get("media") or []
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
if media_paths:
for p in media_paths:
mime, _ = mimetypes.guess_type(p)
media_type = "image" if mime and mime.startswith("image/") else "file"
media_tag = f"[{media_type}: {p}]"
content = f"{content}\n{media_tag}" if content else media_tag
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=sender, # Use full LID for replies chat_id=sender, # Use full LID for replies
content=content, content=content,
media=media_paths,
metadata={ metadata={
"message_id": message_id, "message_id": message_id,
"timestamp": data.get("timestamp"), "timestamp": data.get("timestamp"),

View File

@@ -7,6 +7,17 @@ import signal
import sys import sys
from pathlib import Path from pathlib import Path
# Force UTF-8 encoding for Windows console
if sys.platform == "win32":
if sys.stdout.encoding != "utf-8":
os.environ["PYTHONIOENCODING"] = "utf-8"
# Re-open stdout/stderr with UTF-8 encoding
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
import typer import typer
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML from prompt_toolkit.formatted_text import HTML
@@ -18,6 +29,7 @@ from rich.table import Table
from rich.text import Text from rich.text import Text
from nanobot import __logo__, __version__ from nanobot import __logo__, __version__
from nanobot.config.paths import get_workspace_path
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.utils.helpers import sync_workspace_templates from nanobot.utils.helpers import sync_workspace_templates
@@ -87,7 +99,9 @@ def _init_prompt_session() -> None:
except Exception: except Exception:
pass pass
history_file = Path.home() / ".nanobot" / "history" / "cli_history" from nanobot.config.paths import get_cli_history_path
history_file = get_cli_history_path()
history_file.parent.mkdir(parents=True, exist_ok=True) history_file.parent.mkdir(parents=True, exist_ok=True)
_PROMPT_SESSION = PromptSession( _PROMPT_SESSION = PromptSession(
@@ -158,7 +172,6 @@ def onboard():
"""Initialize nanobot configuration and workspace.""" """Initialize nanobot configuration and workspace."""
from nanobot.config.loader import get_config_path, load_config, save_config from nanobot.config.loader import get_config_path, load_config, save_config
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.utils.helpers import get_workspace_path
config_path = get_config_path() config_path = get_config_path()
@@ -178,6 +191,8 @@ def onboard():
save_config(Config()) save_config(Config())
console.print(f"[green]✓[/green] Created config at {config_path}") console.print(f"[green]✓[/green] Created config at {config_path}")
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
# Create workspace # Create workspace
workspace = get_workspace_path() workspace = get_workspace_path()
@@ -200,9 +215,9 @@ def onboard():
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.custom_provider import CustomProvider from nanobot.providers.base import GenerationSettings
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model model = config.agents.defaults.model
provider_name = config.get_provider_name(model) provider_name = config.get_provider_name(model)
@@ -210,30 +225,79 @@ def _make_provider(config: Config):
# OpenAI Codex (OAuth) # OpenAI Codex (OAuth)
if provider_name == "openai_codex" or model.startswith("openai-codex/"): if provider_name == "openai_codex" or model.startswith("openai-codex/"):
return OpenAICodexProvider(default_model=model) provider = OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
if provider_name == "custom": elif provider_name == "custom":
return CustomProvider( from nanobot.providers.custom_provider import CustomProvider
provider = CustomProvider(
api_key=p.api_key if p else "no-key", api_key=p.api_key if p else "no-key",
api_base=config.get_api_base(model) or "http://localhost:8000/v1", api_base=config.get_api_base(model) or "http://localhost:8000/v1",
default_model=model, default_model=model,
) )
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
elif provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1)
provider = AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
else:
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name)
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
console.print("[red]Error: No API key configured.[/red]")
console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1)
provider = LiteLLMProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
provider_name=provider_name,
)
from nanobot.providers.registry import find_by_name defaults = config.agents.defaults
spec = find_by_name(provider_name) provider.generation = GenerationSettings(
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth): temperature=defaults.temperature,
console.print("[red]Error: No API key configured.[/red]") max_tokens=defaults.max_tokens,
console.print("Set one in ~/.nanobot/config.json under providers section") reasoning_effort=defaults.reasoning_effort,
raise typer.Exit(1)
return LiteLLMProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
provider_name=provider_name,
) )
return provider
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
"""Load config and optionally override the active workspace."""
from nanobot.config.loader import load_config, set_config_path
config_path = None
if config:
config_path = Path(config).expanduser().resolve()
if not config_path.exists():
console.print(f"[red]Error: Config file not found: {config_path}[/red]")
raise typer.Exit(1)
set_config_path(config_path)
console.print(f"[dim]Using config: {config_path}[/dim]")
loaded = load_config(config_path)
if workspace:
loaded.agents.defaults.workspace = workspace
return loaded
def _print_deprecated_memory_window_notice(config: Config) -> None:
"""Warn when running with old memoryWindow-only config."""
if config.agents.defaults.should_warn_deprecated_memory_window:
console.print(
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
"`contextWindowTokens`. `memoryWindow` is ignored; run "
"[cyan]nanobot onboard[/cyan] to refresh your config template."
)
# ============================================================================ # ============================================================================
@@ -243,14 +307,16 @@ def _make_provider(config: Config):
@app.command() @app.command()
def gateway( def gateway(
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"), port: int | None = typer.Option(None, "--port", "-p", help="Gateway port"),
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
): ):
"""Start the nanobot gateway.""" """Start the nanobot gateway."""
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.config.loader import get_data_dir, load_config from nanobot.config.paths import get_cron_dir
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
@@ -260,16 +326,18 @@ def gateway(
import logging import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
console.print(f"{__logo__} Starting nanobot gateway on port {port}...") config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
port = port if port is not None else config.gateway.port
config = load_config() console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
sync_workspace_templates(config.workspace_path) sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path) session_manager = SessionManager(config.workspace_path)
# Create cron service first (callback set after agent creation) # Create cron service first (callback set after agent creation)
cron_store_path = get_data_dir() / "cron" / "jobs.json" cron_store_path = get_cron_dir() / "jobs.json"
cron = CronService(cron_store_path) cron = CronService(cron_store_path)
# Create agent with cron service # Create agent with cron service
@@ -278,11 +346,8 @@ def gateway(
provider=provider, provider=provider,
workspace=config.workspace_path, workspace=config.workspace_path,
model=config.agents.defaults.model, model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, context_window_tokens=config.agents.defaults.context_window_tokens,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
@@ -296,6 +361,7 @@ def gateway(
# Set cron callback (needs agent) # Set cron callback (needs agent)
async def on_cron_job(job: CronJob) -> str | None: async def on_cron_job(job: CronJob) -> str | None:
"""Execute a cron job through the agent.""" """Execute a cron job through the agent."""
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.message import MessageTool
reminder_note = ( reminder_note = (
"[Scheduled Task] Timer finished.\n\n" "[Scheduled Task] Timer finished.\n\n"
@@ -303,12 +369,21 @@ def gateway(
f"Scheduled instruction: {job.payload.message}" f"Scheduled instruction: {job.payload.message}"
) )
response = await agent.process_direct( # Prevent the agent from scheduling new cron jobs during execution
reminder_note, cron_tool = agent.tools.get("cron")
session_key=f"cron:{job.id}", cron_token = None
channel=job.payload.channel or "cli", if isinstance(cron_tool, CronTool):
chat_id=job.payload.to or "direct", cron_token = cron_tool.set_cron_context(True)
) try:
response = await agent.process_direct(
reminder_note,
session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli",
chat_id=job.payload.to or "direct",
)
finally:
if isinstance(cron_tool, CronTool) and cron_token is not None:
cron_tool.reset_cron_context(cron_token)
message_tool = agent.tools.get("message") message_tool = agent.tools.get("message")
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
@@ -420,6 +495,8 @@ def gateway(
def agent( def agent(
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"), message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"), session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"), markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"), logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
): ):
@@ -428,17 +505,18 @@ def agent(
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.loader import get_data_dir, load_config from nanobot.config.paths import get_cron_dir
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
config = load_config() config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
sync_workspace_templates(config.workspace_path) sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
# Create cron service for tool usage (no callback needed for CLI unless running) # Create cron service for tool usage (no callback needed for CLI unless running)
cron_store_path = get_data_dir() / "cron" / "jobs.json" cron_store_path = get_cron_dir() / "jobs.json"
cron = CronService(cron_store_path) cron = CronService(cron_store_path)
if logs: if logs:
@@ -451,11 +529,8 @@ def agent(
provider=provider, provider=provider,
workspace=config.workspace_path, workspace=config.workspace_path,
model=config.agents.defaults.model, model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, context_window_tokens=config.agents.defaults.context_window_tokens,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
@@ -501,12 +576,21 @@ def agent(
else: else:
cli_channel, cli_chat_id = "cli", session_id cli_channel, cli_chat_id = "cli", session_id
def _exit_on_sigint(signum, frame): def _handle_signal(signum, frame):
sig_name = signal.Signals(signum).name
_restore_terminal() _restore_terminal()
console.print("\nGoodbye!") console.print(f"\nReceived {sig_name}, goodbye!")
os._exit(0) sys.exit(0)
signal.signal(signal.SIGINT, _exit_on_sigint) signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
# SIGHUP is not available on Windows
if hasattr(signal, 'SIGHUP'):
signal.signal(signal.SIGHUP, _handle_signal)
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
# SIGPIPE is not available on Windows
if hasattr(signal, 'SIGPIPE'):
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
async def run_interactive(): async def run_interactive():
bus_task = asyncio.create_task(agent_loop.run()) bus_task = asyncio.create_task(agent_loop.run())
@@ -599,6 +683,7 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status") @channels_app.command("status")
def channels_status(): def channels_status():
"""Show channel status.""" """Show channel status."""
from nanobot.channels.registry import discover_channel_names, load_channel_class
from nanobot.config.loader import load_config from nanobot.config.loader import load_config
config = load_config() config = load_config()
@@ -606,85 +691,19 @@ def channels_status():
table = Table(title="Channel Status") table = Table(title="Channel Status")
table.add_column("Channel", style="cyan") table.add_column("Channel", style="cyan")
table.add_column("Enabled", style="green") table.add_column("Enabled", style="green")
table.add_column("Configuration", style="yellow")
# WhatsApp for modname in sorted(discover_channel_names()):
wa = config.channels.whatsapp section = getattr(config.channels, modname, None)
table.add_row( enabled = section and getattr(section, "enabled", False)
"WhatsApp", try:
"" if wa.enabled else "", cls = load_channel_class(modname)
wa.bridge_url display = cls.display_name
) except ImportError:
display = modname.title()
dc = config.channels.discord table.add_row(
table.add_row( display,
"Discord", "[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
"" if dc.enabled else "", )
dc.gateway_url
)
# Feishu
fs = config.channels.feishu
fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
table.add_row(
"Feishu",
"" if fs.enabled else "",
fs_config
)
# Mochat
mc = config.channels.mochat
mc_base = mc.base_url or "[dim]not configured[/dim]"
table.add_row(
"Mochat",
"" if mc.enabled else "",
mc_base
)
# Telegram
tg = config.channels.telegram
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
table.add_row(
"Telegram",
"" if tg.enabled else "",
tg_config
)
# Slack
slack = config.channels.slack
slack_config = "socket" if slack.app_token and slack.bot_token else "[dim]not configured[/dim]"
table.add_row(
"Slack",
"" if slack.enabled else "",
slack_config
)
# DingTalk
dt = config.channels.dingtalk
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
table.add_row(
"DingTalk",
"" if dt.enabled else "",
dt_config
)
# QQ
qq = config.channels.qq
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
table.add_row(
"QQ",
"" if qq.enabled else "",
qq_config
)
# Email
em = config.channels.email
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
table.add_row(
"Email",
"" if em.enabled else "",
em_config
)
console.print(table) console.print(table)
@@ -695,7 +714,9 @@ def _get_bridge_dir() -> Path:
import subprocess import subprocess
# User's bridge location # User's bridge location
user_bridge = Path.home() / ".nanobot" / "bridge" from nanobot.config.paths import get_bridge_install_dir
user_bridge = get_bridge_install_dir()
# Check if already built # Check if already built
if (user_bridge / "dist" / "index.js").exists(): if (user_bridge / "dist" / "index.js").exists():
@@ -753,6 +774,7 @@ def channels_login():
import subprocess import subprocess
from nanobot.config.loader import load_config from nanobot.config.loader import load_config
from nanobot.config.paths import get_runtime_subdir
config = load_config() config = load_config()
bridge_dir = _get_bridge_dir() bridge_dir = _get_bridge_dir()
@@ -763,6 +785,7 @@ def channels_login():
env = {**os.environ} env = {**os.environ}
if config.channels.whatsapp.bridge_token: if config.channels.whatsapp.bridge_token:
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
try: try:
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env) subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
@@ -772,221 +795,6 @@ def channels_login():
console.print("[red]npm not found. Please install Node.js.[/red]") console.print("[red]npm not found. Please install Node.js.[/red]")
# ============================================================================
# Cron Commands
# ============================================================================
cron_app = typer.Typer(help="Manage scheduled tasks")
app.add_typer(cron_app, name="cron")
@cron_app.command("list")
def cron_list(
all: bool = typer.Option(False, "--all", "-a", help="Include disabled jobs"),
):
"""List scheduled jobs."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
jobs = service.list_jobs(include_disabled=all)
if not jobs:
console.print("No scheduled jobs.")
return
table = Table(title="Scheduled Jobs")
table.add_column("ID", style="cyan")
table.add_column("Name")
table.add_column("Schedule")
table.add_column("Status")
table.add_column("Next Run")
import time
from datetime import datetime as _dt
from zoneinfo import ZoneInfo
for job in jobs:
# Format schedule
if job.schedule.kind == "every":
sched = f"every {(job.schedule.every_ms or 0) // 1000}s"
elif job.schedule.kind == "cron":
sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
else:
sched = "one-time"
# Format next run
next_run = ""
if job.state.next_run_at_ms:
ts = job.state.next_run_at_ms / 1000
try:
tz = ZoneInfo(job.schedule.tz) if job.schedule.tz else None
next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M")
except Exception:
next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))
status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
table.add_row(job.id, job.name, sched, status, next_run)
console.print(table)
@cron_app.command("add")
def cron_add(
name: str = typer.Option(..., "--name", "-n", help="Job name"),
message: str = typer.Option(..., "--message", "-m", help="Message for agent"),
every: int = typer.Option(None, "--every", "-e", help="Run every N seconds"),
cron_expr: str = typer.Option(None, "--cron", "-c", help="Cron expression (e.g. '0 9 * * *')"),
tz: str | None = typer.Option(None, "--tz", help="IANA timezone for cron (e.g. 'America/Vancouver')"),
at: str = typer.Option(None, "--at", help="Run once at time (ISO format)"),
deliver: bool = typer.Option(False, "--deliver", "-d", help="Deliver response to channel"),
to: str = typer.Option(None, "--to", help="Recipient for delivery"),
channel: str = typer.Option(None, "--channel", help="Channel for delivery (e.g. 'telegram', 'whatsapp')"),
):
"""Add a scheduled job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
from nanobot.cron.types import CronSchedule
if tz and not cron_expr:
console.print("[red]Error: --tz can only be used with --cron[/red]")
raise typer.Exit(1)
# Determine schedule type
if every:
schedule = CronSchedule(kind="every", every_ms=every * 1000)
elif cron_expr:
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
elif at:
import datetime
dt = datetime.datetime.fromisoformat(at)
schedule = CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000))
else:
console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
raise typer.Exit(1)
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
try:
job = service.add_job(
name=name,
schedule=schedule,
message=message,
deliver=deliver,
to=to,
channel=channel,
)
except ValueError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
@cron_app.command("remove")
def cron_remove(
job_id: str = typer.Argument(..., help="Job ID to remove"),
):
"""Remove a scheduled job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
if service.remove_job(job_id):
console.print(f"[green]✓[/green] Removed job {job_id}")
else:
console.print(f"[red]Job {job_id} not found[/red]")
@cron_app.command("enable")
def cron_enable(
job_id: str = typer.Argument(..., help="Job ID"),
disable: bool = typer.Option(False, "--disable", help="Disable instead of enable"),
):
"""Enable or disable a job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
job = service.enable_job(job_id, enabled=not disable)
if job:
status = "disabled" if disable else "enabled"
console.print(f"[green]✓[/green] Job '{job.name}' {status}")
else:
console.print(f"[red]Job {job_id} not found[/red]")
@cron_app.command("run")
def cron_run(
job_id: str = typer.Argument(..., help="Job ID to run"),
force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
):
"""Manually run a job."""
from loguru import logger
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import get_data_dir, load_config
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
logger.disable("nanobot")
config = load_config()
provider = _make_provider(config)
bus = MessageBus()
agent_loop = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
)
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
result_holder = []
async def on_job(job: CronJob) -> str | None:
response = await agent_loop.process_direct(
job.payload.message,
session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli",
chat_id=job.payload.to or "direct",
)
result_holder.append(response)
return response
service.on_job = on_job
async def run():
return await service.run_job(job_id, force=force)
if asyncio.run(run()):
console.print("[green]✓[/green] Job executed")
if result_holder:
_print_agent_response(result_holder[0], render_markdown=True)
else:
console.print(f"[red]Failed to run job {job_id}[/red]")
# ============================================================================ # ============================================================================
# Status Commands # Status Commands
# ============================================================================ # ============================================================================

View File

@@ -1,6 +1,30 @@
"""Configuration module for nanobot.""" """Configuration module for nanobot."""
from nanobot.config.loader import get_config_path, load_config from nanobot.config.loader import get_config_path, load_config
from nanobot.config.paths import (
get_bridge_install_dir,
get_cli_history_path,
get_cron_dir,
get_data_dir,
get_legacy_sessions_dir,
get_logs_dir,
get_media_dir,
get_runtime_subdir,
get_workspace_path,
)
from nanobot.config.schema import Config from nanobot.config.schema import Config
__all__ = ["Config", "load_config", "get_config_path"] __all__ = [
"Config",
"load_config",
"get_config_path",
"get_data_dir",
"get_runtime_subdir",
"get_media_dir",
"get_cron_dir",
"get_logs_dir",
"get_workspace_path",
"get_cli_history_path",
"get_bridge_install_dir",
"get_legacy_sessions_dir",
]

View File

@@ -6,17 +6,23 @@ from pathlib import Path
from nanobot.config.schema import Config from nanobot.config.schema import Config
# Global variable to store current config path (for multi-instance support)
_current_config_path: Path | None = None
def set_config_path(path: Path) -> None:
"""Set the current config path (used to derive data directory)."""
global _current_config_path
_current_config_path = path
def get_config_path() -> Path: def get_config_path() -> Path:
"""Get the default configuration file path.""" """Get the configuration file path."""
if _current_config_path:
return _current_config_path
return Path.home() / ".nanobot" / "config.json" return Path.home() / ".nanobot" / "config.json"
def get_data_dir() -> Path:
"""Get the nanobot data directory."""
from nanobot.utils.helpers import get_data_path
return get_data_path()
def load_config(config_path: Path | None = None) -> Config: def load_config(config_path: Path | None = None) -> Config:
""" """
Load configuration from file or create default. Load configuration from file or create default.

55
nanobot/config/paths.py Normal file
View File

@@ -0,0 +1,55 @@
"""Runtime path helpers derived from the active config context."""
from __future__ import annotations
from pathlib import Path
from nanobot.config.loader import get_config_path
from nanobot.utils.helpers import ensure_dir
def get_data_dir() -> Path:
"""Return the instance-level runtime data directory."""
return ensure_dir(get_config_path().parent)
def get_runtime_subdir(name: str) -> Path:
"""Return a named runtime subdirectory under the instance data dir."""
return ensure_dir(get_data_dir() / name)
def get_media_dir(channel: str | None = None) -> Path:
"""Return the media directory, optionally namespaced per channel."""
base = get_runtime_subdir("media")
return ensure_dir(base / channel) if channel else base
def get_cron_dir() -> Path:
"""Return the cron storage directory."""
return get_runtime_subdir("cron")
def get_logs_dir() -> Path:
"""Return the logs directory."""
return get_runtime_subdir("logs")
def get_workspace_path(workspace: str | None = None) -> Path:
"""Resolve and ensure the agent workspace path."""
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
return ensure_dir(path)
def get_cli_history_path() -> Path:
"""Return the shared CLI history file path."""
return Path.home() / ".nanobot" / "history" / "cli_history"
def get_bridge_install_dir() -> Path:
"""Return the shared WhatsApp bridge installation directory."""
return Path.home() / ".nanobot" / "bridge"
def get_legacy_sessions_dir() -> Path:
"""Return the legacy global session directory used for migration fallback."""
return Path.home() / ".nanobot" / "sessions"

View File

@@ -29,8 +29,11 @@ class TelegramConfig(Base):
enabled: bool = False enabled: bool = False
token: str = "" # Bot token from @BotFather token: str = "" # Bot token from @BotFather
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" proxy: str | None = (
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
)
reply_to_message: bool = False # If true, bot replies quote the original message reply_to_message: bool = False # If true, bot replies quote the original message
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned or replied to, "open" responds to all
class FeishuConfig(Base): class FeishuConfig(Base):
@@ -42,7 +45,9 @@ class FeishuConfig(Base):
encrypt_key: str = "" # Encrypt Key for event subscription (optional) encrypt_key: str = "" # Encrypt Key for event subscription (optional)
verification_token: str = "" # Verification Token for event subscription (optional) verification_token: str = "" # Verification Token for event subscription (optional)
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE) react_emoji: str = (
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
)
class DingTalkConfig(Base): class DingTalkConfig(Base):
@@ -62,6 +67,7 @@ class DiscordConfig(Base):
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json" gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
group_policy: Literal["mention", "open"] = "mention"
class MatrixConfig(Base): class MatrixConfig(Base):
@@ -72,9 +78,13 @@ class MatrixConfig(Base):
access_token: str = "" access_token: str = ""
user_id: str = "" # @bot:matrix.org user_id: str = "" # @bot:matrix.org
device_id: str = "" device_id: str = ""
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling). e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
sync_stop_grace_seconds: int = 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback. sync_stop_grace_seconds: int = (
max_media_bytes: int = 20 * 1024 * 1024 # Max attachment size accepted for Matrix media handling (inbound + outbound). 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
)
max_media_bytes: int = (
20 * 1024 * 1024
) # Max attachment size accepted for Matrix media handling (inbound + outbound).
allow_from: list[str] = Field(default_factory=list) allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open" group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list) group_allow_from: list[str] = Field(default_factory=list)
@@ -105,7 +115,9 @@ class EmailConfig(Base):
from_address: str = "" from_address: str = ""
# Behavior # Behavior
auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent auto_reply_enabled: bool = (
True # If false, inbound email is read but no automatic reply is sent
)
poll_interval_seconds: int = 30 poll_interval_seconds: int = 30
mark_seen: bool = True mark_seen: bool = True
max_body_chars: int = 12000 max_body_chars: int = 12000
@@ -183,27 +195,25 @@ class QQConfig(Base):
enabled: bool = False enabled: bool = False
app_id: str = "" # 机器人 ID (AppID) from q.qq.com app_id: str = "" # 机器人 ID (AppID) from q.qq.com
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access) allow_from: list[str] = Field(
default_factory=list
) # Allowed user openids (empty = public access)
class WecomConfig(Base):
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
class MatrixConfig(Base):
"""Matrix (Element) channel configuration."""
enabled: bool = False enabled: bool = False
homeserver: str = "https://matrix.org" bot_id: str = "" # Bot ID from WeCom AI Bot platform
access_token: str = "" secret: str = "" # Bot Secret from WeCom AI Bot platform
user_id: str = "" # e.g. @bot:matrix.org allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
device_id: str = "" welcome_message: str = "" # Welcome message for enter_chat event
e2ee_enabled: bool = True # end-to-end encryption support
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
class ChannelsConfig(Base): class ChannelsConfig(Base):
"""Configuration for chat channels.""" """Configuration for chat channels."""
send_progress: bool = True # stream agent's text progress to the channel send_progress: bool = True # stream agent's text progress to the channel
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig) whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
telegram: TelegramConfig = Field(default_factory=TelegramConfig) telegram: TelegramConfig = Field(default_factory=TelegramConfig)
@@ -215,6 +225,7 @@ class ChannelsConfig(Base):
slack: SlackConfig = Field(default_factory=SlackConfig) slack: SlackConfig = Field(default_factory=SlackConfig)
qq: QQConfig = Field(default_factory=QQConfig) qq: QQConfig = Field(default_factory=QQConfig)
matrix: MatrixConfig = Field(default_factory=MatrixConfig) matrix: MatrixConfig = Field(default_factory=MatrixConfig)
wecom: WecomConfig = Field(default_factory=WecomConfig)
class AgentDefaults(Base): class AgentDefaults(Base):
@@ -222,13 +233,22 @@ class AgentDefaults(Base):
workspace: str = "~/.nanobot/workspace" workspace: str = "~/.nanobot/workspace"
model: str = "anthropic/claude-opus-4-5" model: str = "anthropic/claude-opus-4-5"
provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection provider: str = (
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
)
max_tokens: int = 8192 max_tokens: int = 8192
context_window_tokens: int = 65_536
temperature: float = 0.1 temperature: float = 0.1
max_tool_iterations: int = 40 max_tool_iterations: int = 40
memory_window: int = 100 # Deprecated compatibility field: accepted from old configs but ignored at runtime.
memory_window: int | None = Field(default=None, exclude=True)
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
@property
def should_warn_deprecated_memory_window(self) -> bool:
"""Return True when old memoryWindow is present without contextWindowTokens."""
return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
class AgentsConfig(Base): class AgentsConfig(Base):
"""Agent configuration.""" """Agent configuration."""
@@ -248,6 +268,7 @@ class ProvidersConfig(Base):
"""Configuration for LLM providers.""" """Configuration for LLM providers."""
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig) anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig) openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig) openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
@@ -260,8 +281,9 @@ class ProvidersConfig(Base):
moonshot: ProviderConfig = Field(default_factory=ProviderConfig) moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) API gateway ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) API gateway siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -291,7 +313,9 @@ class WebSearchConfig(Base):
class WebToolsConfig(Base): class WebToolsConfig(Base):
"""Web tools configuration.""" """Web tools configuration."""
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" proxy: str | None = (
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
)
search: WebSearchConfig = Field(default_factory=WebSearchConfig) search: WebSearchConfig = Field(default_factory=WebSearchConfig)
@@ -305,12 +329,13 @@ class ExecToolConfig(Base):
class MCPServerConfig(Base): class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP).""" """MCP server connection configuration (stdio or HTTP)."""
type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
command: str = "" # Stdio: command to run (e.g. "npx") command: str = "" # Stdio: command to run (e.g. "npx")
args: list[str] = Field(default_factory=list) # Stdio: command arguments args: list[str] = Field(default_factory=list) # Stdio: command arguments
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
url: str = "" # HTTP: streamable HTTP endpoint URL url: str = "" # HTTP/SSE: endpoint URL
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
tool_timeout: int = 30 # Seconds before a tool call is cancelled tool_timeout: int = 30 # seconds before a tool call is cancelled
class ToolsConfig(Base): class ToolsConfig(Base):
@@ -336,7 +361,9 @@ class Config(BaseSettings):
"""Get expanded workspace path.""" """Get expanded workspace path."""
return Path(self.agents.defaults.workspace).expanduser() return Path(self.agents.defaults.workspace).expanduser()
def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]: def _match_provider(
self, model: str | None = None
) -> tuple["ProviderConfig | None", str | None]:
"""Match provider config and its registry name. Returns (config, spec_name).""" """Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS from nanobot.providers.registry import PROVIDERS
@@ -358,16 +385,25 @@ class Config(BaseSettings):
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and model_prefix and normalized_prefix == spec.name: if p and model_prefix and normalized_prefix == spec.name:
if spec.is_oauth or p.api_key: if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name return p, spec.name
# Match by keyword (order follows PROVIDERS registry) # Match by keyword (order follows PROVIDERS registry)
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and any(_kw_matches(kw) for kw in spec.keywords): if p and any(_kw_matches(kw) for kw in spec.keywords):
if spec.is_oauth or p.api_key: if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name return p, spec.name
# Fallback: configured local providers can route models without
# provider-specific keywords (for example plain "llama3.2" on Ollama).
for spec in PROVIDERS:
if not spec.is_local:
continue
p = getattr(self.providers, spec.name, None)
if p and p.api_base:
return p, spec.name
# Fallback: gateways first, then others (follows registry order) # Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks — they require explicit model selection # OAuth providers are NOT valid fallbacks — they require explicit model selection
for spec in PROVIDERS: for spec in PROVIDERS:
@@ -394,7 +430,7 @@ class Config(BaseSettings):
return p.api_key if p else None return p.api_key if p else None
def get_api_base(self, model: str | None = None) -> str | None: def get_api_base(self, model: str | None = None) -> str | None:
"""Get API base URL for the given model. Applies default URLs for known gateways.""" """Get API base URL for the given model. Applies default URLs for gateway/local providers."""
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
p, name = self._match_provider(model) p, name = self._match_provider(model)
@@ -405,7 +441,7 @@ class Config(BaseSettings):
# to avoid polluting the global litellm.api_base. # to avoid polluting the global litellm.api_base.
if name: if name:
spec = find_by_name(name) spec = find_by_name(name)
if spec and spec.is_gateway and spec.default_api_base: if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
return spec.default_api_base return spec.default_api_base
return None return None

View File

@@ -87,7 +87,7 @@ class HeartbeatService:
Returns (action, tasks) where action is 'skip' or 'run'. Returns (action, tasks) where action is 'skip' or 'run'.
""" """
response = await self.provider.chat( response = await self.provider.chat_with_retry(
messages=[ messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": ( {"role": "user", "content": (

View File

@@ -3,5 +3,6 @@
from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"] __all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]

View File

@@ -0,0 +1,210 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
from __future__ import annotations
import uuid
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
"""
def __init__(
self,
api_key: str = "",
api_base: str = "",
default_model: str = "gpt-5.2-chat",
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
return payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@@ -1,9 +1,13 @@
"""Base LLM provider interface.""" """Base LLM provider interface."""
import asyncio
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from loguru import logger
@dataclass @dataclass
class ToolCallRequest: class ToolCallRequest:
@@ -11,6 +15,24 @@ class ToolCallRequest:
id: str id: str
name: str name: str
arguments: dict[str, Any] arguments: dict[str, Any]
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
def to_openai_tool_call(self) -> dict[str, Any]:
"""Serialize to an OpenAI-style tool_call payload."""
tool_call = {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
if self.provider_specific_fields:
tool_call["provider_specific_fields"] = self.provider_specific_fields
if self.function_provider_specific_fields:
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
return tool_call
@dataclass @dataclass
@@ -29,6 +51,21 @@ class LLMResponse:
return len(self.tool_calls) > 0 return len(self.tool_calls) > 0
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation parameters for LLM calls.
Stored on the provider so every call site inherits the same defaults
without having to pass temperature / max_tokens / reasoning_effort
through every layer. Individual call sites can still override by
passing explicit keyword arguments to chat() / chat_with_retry().
"""
temperature: float = 0.7
max_tokens: int = 4096
reasoning_effort: str | None = None
class LLMProvider(ABC): class LLMProvider(ABC):
""" """
Abstract base class for LLM providers. Abstract base class for LLM providers.
@@ -37,9 +74,28 @@ class LLMProvider(ABC):
while maintaining a consistent interface. while maintaining a consistent interface.
""" """
_CHAT_RETRY_DELAYS = (1, 2, 4)
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
"500",
"502",
"503",
"504",
"overloaded",
"timeout",
"timed out",
"connection",
"server error",
"temporarily unavailable",
)
_SENTINEL = object()
def __init__(self, api_key: str | None = None, api_base: str | None = None): def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
self.generation: GenerationSettings = GenerationSettings()
@staticmethod @staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -87,6 +143,20 @@ class LLMProvider(ABC):
result.append(msg) result.append(msg)
return result return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
@@ -112,6 +182,83 @@ class LLMProvider(ABC):
""" """
pass pass
@classmethod
def _is_transient_error(cls, content: str | None) -> bool:
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures.
Parameters default to ``self.generation`` when not explicitly passed,
so callers no longer need to thread temperature / max_tokens /
reasoning_effort through every layer.
"""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
try:
response = await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
except asyncio.CancelledError:
raise
except Exception as exc:
response = LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
return response
err = (response.content or "").lower()
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt,
len(self._CHAT_RETRY_DELAYS),
delay,
err[:120],
)
await asyncio.sleep(delay)
try:
return await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
@abstractmethod @abstractmethod
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model for this provider.""" """Get the default model for this provider."""

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import uuid
from typing import Any from typing import Any
import json_repair import json_repair
@@ -15,7 +16,12 @@ class CustomProvider(LLMProvider):
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"): def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base) # Keep affinity stable for this provider instance to improve backend cache locality.
self._client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
default_headers={"x-session-affinity": uuid.uuid4().hex},
)
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,

View File

@@ -1,5 +1,6 @@
"""LiteLLM provider implementation for multi-provider support.""" """LiteLLM provider implementation for multi-provider support."""
import hashlib
import os import os
import secrets import secrets
import string import string
@@ -8,6 +9,7 @@ from typing import Any
import json_repair import json_repair
import litellm import litellm
from litellm import acompletion from litellm import acompletion
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway from nanobot.providers.registry import find_by_model, find_gateway
@@ -165,17 +167,43 @@ class LiteLLMProvider(LLMProvider):
return _ANTHROPIC_EXTRA_KEYS return _ANTHROPIC_EXTRA_KEYS
return frozenset() return frozenset()
@staticmethod
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
if not isinstance(tool_call_id, str):
return tool_call_id
if len(tool_call_id) == 9 and tool_call_id.isalnum():
return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
@staticmethod @staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key.""" """Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = [] sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
for msg in messages: id_map: dict[str, str] = {}
clean = {k: v for k, v in msg.items() if k in allowed}
# Strict providers require "content" even when assistant only has tool_calls def map_id(value: Any) -> Any:
if clean.get("role") == "assistant" and "content" not in clean: if not isinstance(value, str):
clean["content"] = None return value
sanitized.append(clean) return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for clean in sanitized:
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list):
normalized_tool_calls = []
for tc in clean["tool_calls"]:
if not isinstance(tc, dict):
normalized_tool_calls.append(tc)
continue
tc_clean = dict(tc)
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized_tool_calls.append(tc_clean)
clean["tool_calls"] = normalized_tool_calls
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized return sanitized
async def chat( async def chat(
@@ -255,20 +283,44 @@ class LiteLLMProvider(LLMProvider):
"""Parse LiteLLM response into our standard format.""" """Parse LiteLLM response into our standard format."""
choice = response.choices[0] choice = response.choices[0]
message = choice.message message = choice.message
content = message.content
finish_reason = choice.finish_reason
# Some providers (e.g. GitHub Copilot) split content and tool_calls
# across multiple choices. Merge them so tool_calls are not lost.
raw_tool_calls = []
for ch in response.choices:
msg = ch.message
if hasattr(msg, "tool_calls") and msg.tool_calls:
raw_tool_calls.extend(msg.tool_calls)
if ch.finish_reason in ("tool_calls", "stop"):
finish_reason = ch.finish_reason
if not content and msg.content:
content = msg.content
if len(response.choices) > 1:
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
len(response.choices), len(raw_tool_calls))
tool_calls = [] tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls: for tc in raw_tool_calls:
for tc in message.tool_calls: # Parse arguments from JSON string if needed
# Parse arguments from JSON string if needed args = tc.function.arguments
args = tc.function.arguments if isinstance(args, str):
if isinstance(args, str): args = json_repair.loads(args)
args = json_repair.loads(args)
tool_calls.append(ToolCallRequest( provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
id=_short_tool_id(), function_provider_specific_fields = (
name=tc.function.name, getattr(tc.function, "provider_specific_fields", None) or None
arguments=args, )
))
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
provider_specific_fields=provider_specific_fields,
function_provider_specific_fields=function_provider_specific_fields,
))
usage = {} usage = {}
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
@@ -282,9 +334,9 @@ class LiteLLMProvider(LLMProvider):
thinking_blocks = getattr(message, "thinking_blocks", None) or None thinking_blocks = getattr(message, "thinking_blocks", None) or None
return LLMResponse( return LLMResponse(
content=message.content, content=content,
tool_calls=tool_calls, tool_calls=tool_calls,
finish_reason=choice.finish_reason or "stop", finish_reason=finish_reason or "stop",
usage=usage, usage=usage,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks, thinking_blocks=thinking_blocks,

View File

@@ -52,6 +52,9 @@ class OpenAICodexProvider(LLMProvider):
"parallel_tool_calls": True, "parallel_tool_calls": True,
} }
if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort}
if tools: if tools:
body["tools"] = _convert_tools(tools) body["tools"] = _convert_tools(tools)

View File

@@ -26,33 +26,33 @@ class ProviderSpec:
""" """
# identity # identity
name: str # config field name, e.g. "dashscope" name: str # config field name, e.g. "dashscope"
keywords: tuple[str, ...] # model-name keywords for matching (lowercase) keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY" env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
display_name: str = "" # shown in `nanobot status` display_name: str = "" # shown in `nanobot status`
# model prefixing # model prefixing
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}" litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
env_extras: tuple[tuple[str, str], ...] = () env_extras: tuple[tuple[str, str], ...] = ()
# gateway / local detection # gateway / local detection
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix) is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
is_local: bool = False # local deployment (vLLM, Ollama) is_local: bool = False # local deployment (vLLM, Ollama)
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-" detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
detect_by_base_keyword: str = "" # match substring in api_base URL detect_by_base_keyword: str = "" # match substring in api_base URL
default_api_base: str = "" # fallback base URL default_api_base: str = "" # fallback base URL
# gateway behavior # gateway behavior
strip_model_prefix: bool = False # strip "provider/" before re-prefixing strip_model_prefix: bool = False # strip "provider/" before re-prefixing
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),) # per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = () model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys # OAuth-based providers (e.g., OpenAI Codex) don't use API keys
is_oauth: bool = False # if True, uses OAuth flow instead of API key is_oauth: bool = False # if True, uses OAuth flow instead of API key
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider) # Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
is_direct: bool = False is_direct: bool = False
@@ -70,7 +70,6 @@ class ProviderSpec:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
PROVIDERS: tuple[ProviderSpec, ...] = ( PROVIDERS: tuple[ProviderSpec, ...] = (
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ====== # === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
ProviderSpec( ProviderSpec(
name="custom", name="custom",
@@ -81,16 +80,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
is_direct=True, is_direct=True,
), ),
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
ProviderSpec(
name="azure_openai",
keywords=("azure", "azure-openai"),
env_key="",
display_name="Azure OpenAI",
litellm_prefix="",
is_direct=True,
),
# === Gateways (detected by api_key / api_base, not model name) ========= # === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback. # Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-" # OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec( ProviderSpec(
name="openrouter", name="openrouter",
keywords=("openrouter",), keywords=("openrouter",),
env_key="OPENROUTER_API_KEY", env_key="OPENROUTER_API_KEY",
display_name="OpenRouter", display_name="OpenRouter",
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3 litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=True, is_gateway=True,
@@ -102,16 +109,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(), model_overrides=(),
supports_prompt_caching=True, supports_prompt_caching=True,
), ),
# AiHubMix: global gateway, OpenAI-compatible interface. # AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3", # strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3". # so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
ProviderSpec( ProviderSpec(
name="aihubmix", name="aihubmix",
keywords=("aihubmix",), keywords=("aihubmix",),
env_key="OPENAI_API_KEY", # OpenAI-compatible env_key="OPENAI_API_KEY", # OpenAI-compatible
display_name="AiHubMix", display_name="AiHubMix",
litellm_prefix="openai", # → openai/{model} litellm_prefix="openai", # → openai/{model}
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=True, is_gateway=True,
@@ -119,10 +125,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
detect_by_key_prefix="", detect_by_key_prefix="",
detect_by_base_keyword="aihubmix", detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1", default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3 strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
model_overrides=(), model_overrides=(),
), ),
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix # SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
ProviderSpec( ProviderSpec(
name="siliconflow", name="siliconflow",
@@ -140,7 +145,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# VolcEngine (火山引擎): OpenAI-compatible gateway # VolcEngine (火山引擎): OpenAI-compatible gateway
ProviderSpec( ProviderSpec(
name="volcengine", name="volcengine",
@@ -158,9 +162,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Standard providers (matched by model-name keywords) =============== # === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(
name="anthropic", name="anthropic",
@@ -179,7 +181,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(), model_overrides=(),
supports_prompt_caching=True, supports_prompt_caching=True,
), ),
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed. # OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(
name="openai", name="openai",
@@ -197,14 +198,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# OpenAI Codex: uses OAuth, not API key. # OpenAI Codex: uses OAuth, not API key.
ProviderSpec( ProviderSpec(
name="openai_codex", name="openai_codex",
keywords=("openai-codex",), keywords=("openai-codex",),
env_key="", # OAuth-based, no API key env_key="", # OAuth-based, no API key
display_name="OpenAI Codex", display_name="OpenAI Codex",
litellm_prefix="", # Not routed through LiteLLM litellm_prefix="", # Not routed through LiteLLM
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -214,16 +214,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://chatgpt.com/backend-api", default_api_base="https://chatgpt.com/backend-api",
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
is_oauth=True, # OAuth-based authentication is_oauth=True, # OAuth-based authentication
), ),
# Github Copilot: uses OAuth, not API key. # Github Copilot: uses OAuth, not API key.
ProviderSpec( ProviderSpec(
name="github_copilot", name="github_copilot",
keywords=("github_copilot", "copilot"), keywords=("github_copilot", "copilot"),
env_key="", # OAuth-based, no API key env_key="", # OAuth-based, no API key
display_name="Github Copilot", display_name="Github Copilot",
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
skip_prefixes=("github_copilot/",), skip_prefixes=("github_copilot/",),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -233,17 +232,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="", default_api_base="",
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
is_oauth=True, # OAuth-based authentication is_oauth=True, # OAuth-based authentication
), ),
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing. # DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
ProviderSpec( ProviderSpec(
name="deepseek", name="deepseek",
keywords=("deepseek",), keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY", env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek", display_name="DeepSeek",
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
skip_prefixes=("deepseek/",), # avoid double-prefix skip_prefixes=("deepseek/",), # avoid double-prefix
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
@@ -253,15 +251,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Gemini: needs "gemini/" prefix for LiteLLM. # Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec( ProviderSpec(
name="gemini", name="gemini",
keywords=("gemini",), keywords=("gemini",),
env_key="GEMINI_API_KEY", env_key="GEMINI_API_KEY",
display_name="Gemini", display_name="Gemini",
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
skip_prefixes=("gemini/",), # avoid double-prefix skip_prefixes=("gemini/",), # avoid double-prefix
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
@@ -271,7 +268,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Zhipu: LiteLLM uses "zai/" prefix. # Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that). # Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway. # skip_prefixes: don't add "zai/" when already routed via gateway.
@@ -280,11 +276,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("zhipu", "glm", "zai"), keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY", env_key="ZAI_API_KEY",
display_name="Zhipu AI", display_name="Zhipu AI",
litellm_prefix="zai", # glm-4 → zai/glm-4 litellm_prefix="zai", # glm-4 → zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"), skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=( env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
("ZHIPUAI_API_KEY", "{api_key}"),
),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
detect_by_key_prefix="", detect_by_key_prefix="",
@@ -293,14 +287,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# DashScope: Qwen models, needs "dashscope/" prefix. # DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec( ProviderSpec(
name="dashscope", name="dashscope",
keywords=("qwen", "dashscope"), keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY", env_key="DASHSCOPE_API_KEY",
display_name="DashScope", display_name="DashScope",
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
skip_prefixes=("dashscope/", "openrouter/"), skip_prefixes=("dashscope/", "openrouter/"),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -311,7 +304,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Moonshot: Kimi models, needs "moonshot/" prefix. # Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint. # LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0. # Kimi K2.5 API enforces temperature >= 1.0.
@@ -320,22 +312,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("moonshot", "kimi"), keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY", env_key="MOONSHOT_API_KEY",
display_name="Moonshot", display_name="Moonshot",
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5 litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"), skip_prefixes=("moonshot/", "openrouter/"),
env_extras=( env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
("MOONSHOT_API_BASE", "{api_base}"),
),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
detect_by_key_prefix="", detect_by_key_prefix="",
detect_by_base_keyword="", detect_by_base_keyword="",
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=( model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
("kimi-k2.5", {"temperature": 1.0}),
),
), ),
# MiniMax: needs "minimax/" prefix for LiteLLM routing. # MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1. # Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec( ProviderSpec(
@@ -343,7 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("minimax",), keywords=("minimax",),
env_key="MINIMAX_API_KEY", env_key="MINIMAX_API_KEY",
display_name="MiniMax", display_name="MiniMax",
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1 litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
skip_prefixes=("minimax/", "openrouter/"), skip_prefixes=("minimax/", "openrouter/"),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
@@ -354,9 +341,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Local deployment (matched by config key, NOT by api_base) ========= # === Local deployment (matched by config key, NOT by api_base) =========
# vLLM / any OpenAI-compatible local server. # vLLM / any OpenAI-compatible local server.
# Detected when config key is "vllm" (provider_name="vllm"). # Detected when config key is "vllm" (provider_name="vllm").
ProviderSpec( ProviderSpec(
@@ -364,20 +349,35 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("vllm",), keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY", env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local", display_name="vLLM/Local",
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
skip_prefixes=(), skip_prefixes=(),
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=True, is_local=True,
detect_by_key_prefix="", detect_by_key_prefix="",
detect_by_base_keyword="", detect_by_base_keyword="",
default_api_base="", # user must provide in config default_api_base="", # user must provide in config
strip_model_prefix=False,
model_overrides=(),
),
# === Ollama (local, OpenAI-compatible) ===================================
ProviderSpec(
name="ollama",
keywords=("ollama", "nemotron"),
env_key="OLLAMA_API_KEY",
display_name="Ollama",
litellm_prefix="ollama_chat", # model → ollama_chat/model
skip_prefixes=("ollama/", "ollama_chat/"),
env_extras=(),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="11434",
default_api_base="http://localhost:11434",
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Auxiliary (not a primary LLM provider) ============================ # === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM. # Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback. # Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
ProviderSpec( ProviderSpec(
@@ -385,8 +385,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("groq",), keywords=("groq",),
env_key="GROQ_API_KEY", env_key="GROQ_API_KEY",
display_name="Groq", display_name="Groq",
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192 litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
skip_prefixes=("groq/",), # avoid double-prefix skip_prefixes=("groq/",), # avoid double-prefix
env_extras=(), env_extras=(),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
@@ -403,6 +403,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# Lookup helpers # Lookup helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def find_by_model(model: str) -> ProviderSpec | None: def find_by_model(model: str) -> ProviderSpec | None:
"""Match a standard provider by model-name keyword (case-insensitive). """Match a standard provider by model-name keyword (case-insensitive).
Skips gateways/local — those are matched by api_key/api_base instead.""" Skips gateways/local — those are matched by api_key/api_base instead."""
@@ -418,7 +419,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
return spec return spec
for spec in std_specs: for spec in std_specs:
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords): if any(
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
):
return spec return spec
return None return None

View File

@@ -9,6 +9,7 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ensure_dir, safe_filename from nanobot.utils.helpers import ensure_dir, safe_filename
@@ -79,7 +80,7 @@ class SessionManager:
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.workspace = workspace self.workspace = workspace
self.sessions_dir = ensure_dir(self.workspace / "sessions") self.sessions_dir = ensure_dir(self.workspace / "sessions")
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions" self.legacy_sessions_dir = get_legacy_sessions_dir()
self._cache: dict[str, Session] = {} self._cache: dict[str, Session] = {}
def _get_session_path(self, key: str) -> Path: def _get_session_path(self, key: str) -> Path:

View File

@@ -9,15 +9,21 @@ always: true
## Structure ## Structure
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context. - `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM]. - `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep-style tools or in-memory filters. Each entry starts with [YYYY-MM-DD HH:MM].
## Search Past Events ## Search Past Events
```bash Choose the search method based on file size:
grep -i "keyword" memory/HISTORY.md
```
Use the `exec` tool to run grep. Combine patterns: `grep -iE "meeting|deadline" memory/HISTORY.md` - Small `memory/HISTORY.md`: use `read_file`, then search in-memory
- Large or long-lived `memory/HISTORY.md`: use the `exec` tool for targeted search
Examples:
- **Linux/macOS:** `grep -i "keyword" memory/HISTORY.md`
- **Windows:** `findstr /i "keyword" memory\HISTORY.md`
- **Cross-platform Python:** `python -c "from pathlib import Path; text = Path('memory/HISTORY.md').read_text(encoding='utf-8'); print('\n'.join([l for l in text.splitlines() if 'keyword' in l.lower()][-20:]))"`
Prefer targeted command-line search for large history files.
## When to Update MEMORY.md ## When to Update MEMORY.md

View File

@@ -268,6 +268,8 @@ Skip this step only if the skill being developed already exists, and iteration o
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable. When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
For `nanobot`, custom skills should live under the active workspace `skills/` directory so they can be discovered automatically at runtime (for example, `<workspace>/skills/my-skill/SKILL.md`).
Usage: Usage:
```bash ```bash
@@ -277,9 +279,9 @@ scripts/init_skill.py <skill-name> --path <output-directory> [--resources script
Examples: Examples:
```bash ```bash
scripts/init_skill.py my-skill --path skills/public scripts/init_skill.py my-skill --path ./workspace/skills
scripts/init_skill.py my-skill --path skills/public --resources scripts,references scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts,references
scripts/init_skill.py my-skill --path skills/public --resources scripts --examples scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts --examples
``` ```
The script: The script:
@@ -326,7 +328,7 @@ Write the YAML frontmatter with `name` and `description`:
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent. - Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent.
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks" - Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
Do not include any other fields in YAML frontmatter. Keep frontmatter minimal. In `nanobot`, `metadata` and `always` are also supported when needed, but avoid adding extra fields unless they are actually required.
##### Body ##### Body
@@ -349,7 +351,6 @@ scripts/package_skill.py <path/to/skill-folder> ./dist
The packaging script will: The packaging script will:
1. **Validate** the skill automatically, checking: 1. **Validate** the skill automatically, checking:
- YAML frontmatter format and required fields - YAML frontmatter format and required fields
- Skill naming conventions and directory structure - Skill naming conventions and directory structure
- Description completeness and quality - Description completeness and quality
@@ -357,6 +358,8 @@ The packaging script will:
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension. 2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
Security restriction: symlinks are rejected and packaging fails when any symlink is present.
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again. If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
### Step 6: Iterate ### Step 6: Iterate

View File

@@ -0,0 +1,378 @@
#!/usr/bin/env python3
"""
Skill Initializer - Creates a new skill from template
Usage:
init_skill.py <skill-name> --path <path> [--resources scripts,references,assets] [--examples]
Examples:
init_skill.py my-new-skill --path skills/public
init_skill.py my-new-skill --path skills/public --resources scripts,references
init_skill.py my-api-helper --path skills/private --resources scripts --examples
init_skill.py custom-skill --path /custom/location
"""
import argparse
import re
import sys
from pathlib import Path
MAX_SKILL_NAME_LENGTH = 64
ALLOWED_RESOURCES = {"scripts", "references", "assets"}
SKILL_TEMPLATE = """---
name: {skill_name}
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
---
# {skill_title}
## Overview
[TODO: 1-2 sentences explaining what this skill enables]
## Structuring This Skill
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
**1. Workflow-Based** (best for sequential processes)
- Works well when there are clear step-by-step procedures
- Example: DOCX skill with "Workflow Decision Tree" -> "Reading" -> "Creating" -> "Editing"
- Structure: ## Overview -> ## Workflow Decision Tree -> ## Step 1 -> ## Step 2...
**2. Task-Based** (best for tool collections)
- Works well when the skill offers different operations/capabilities
- Example: PDF skill with "Quick Start" -> "Merge PDFs" -> "Split PDFs" -> "Extract Text"
- Structure: ## Overview -> ## Quick Start -> ## Task Category 1 -> ## Task Category 2...
**3. Reference/Guidelines** (best for standards or specifications)
- Works well for brand guidelines, coding standards, or requirements
- Example: Brand styling with "Brand Guidelines" -> "Colors" -> "Typography" -> "Features"
- Structure: ## Overview -> ## Guidelines -> ## Specifications -> ## Usage...
**4. Capabilities-Based** (best for integrated systems)
- Works well when the skill provides multiple interrelated features
- Example: Product Management with "Core Capabilities" -> numbered capability list
- Structure: ## Overview -> ## Core Capabilities -> ### 1. Feature -> ### 2. Feature...
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
## [TODO: Replace with the first main section based on chosen structure]
[TODO: Add content here. See examples in existing skills:
- Code samples for technical skills
- Decision trees for complex workflows
- Concrete examples with realistic user requests
- References to scripts/templates/references as needed]
## Resources (optional)
Create only the resource directories this skill actually needs. Delete this section if no resources are required.
### scripts/
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
**Examples from other skills:**
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
**Note:** Scripts may be executed without loading into context, but can still be read by Codex for patching or environment adjustments.
### references/
Documentation and reference material intended to be loaded into context to inform Codex's process and thinking.
**Examples from other skills:**
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
- BigQuery: API reference documentation and query examples
- Finance: Schema documentation, company policies
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Codex should reference while working.
### assets/
Files not intended to be loaded into context, but rather used within the output Codex produces.
**Examples from other skills:**
- Brand styling: PowerPoint template files (.pptx), logo files
- Frontend builder: HTML/React boilerplate project directories
- Typography: Font files (.ttf, .woff2)
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
---
**Not every skill requires all three types of resources.**
"""
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
"""
Example helper script for {skill_name}
This is a placeholder script that can be executed directly.
Replace with actual implementation or delete if not needed.
Example real scripts from other skills:
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
"""
def main():
print("This is an example script for {skill_name}")
# TODO: Add actual script logic here
# This could be data processing, file conversion, API calls, etc.
if __name__ == "__main__":
main()
'''
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
This is a placeholder for detailed reference documentation.
Replace with actual reference content or delete if not needed.
Example real reference docs from other skills:
- product-management/references/communication.md - Comprehensive guide for status updates
- product-management/references/context_building.md - Deep-dive on gathering context
- bigquery/references/ - API references and query examples
## When Reference Docs Are Useful
Reference docs are ideal for:
- Comprehensive API documentation
- Detailed workflow guides
- Complex multi-step processes
- Information too lengthy for main SKILL.md
- Content that's only needed for specific use cases
## Structure Suggestions
### API Reference Example
- Overview
- Authentication
- Endpoints with examples
- Error codes
- Rate limits
### Workflow Guide Example
- Prerequisites
- Step-by-step instructions
- Common patterns
- Troubleshooting
- Best practices
"""
EXAMPLE_ASSET = """# Example Asset File
This placeholder represents where asset files would be stored.
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
Asset files are NOT intended to be loaded into context, but rather used within
the output Codex produces.
Example asset files from other skills:
- Brand guidelines: logo.png, slides_template.pptx
- Frontend builder: hello-world/ directory with HTML/React boilerplate
- Typography: custom-font.ttf, font-family.woff2
- Data: sample_data.csv, test_dataset.json
## Common Asset Types
- Templates: .pptx, .docx, boilerplate directories
- Images: .png, .jpg, .svg, .gif
- Fonts: .ttf, .otf, .woff, .woff2
- Boilerplate code: Project directories, starter files
- Icons: .ico, .svg
- Data files: .csv, .json, .xml, .yaml
Note: This is a text placeholder. Actual assets can be any file type.
"""
def normalize_skill_name(skill_name):
"""Normalize a skill name to lowercase hyphen-case."""
normalized = skill_name.strip().lower()
normalized = re.sub(r"[^a-z0-9]+", "-", normalized)
normalized = normalized.strip("-")
normalized = re.sub(r"-{2,}", "-", normalized)
return normalized
def title_case_skill_name(skill_name):
"""Convert hyphenated skill name to Title Case for display."""
return " ".join(word.capitalize() for word in skill_name.split("-"))
def parse_resources(raw_resources):
if not raw_resources:
return []
resources = [item.strip() for item in raw_resources.split(",") if item.strip()]
invalid = sorted({item for item in resources if item not in ALLOWED_RESOURCES})
if invalid:
allowed = ", ".join(sorted(ALLOWED_RESOURCES))
print(f"[ERROR] Unknown resource type(s): {', '.join(invalid)}")
print(f" Allowed: {allowed}")
sys.exit(1)
deduped = []
seen = set()
for resource in resources:
if resource not in seen:
deduped.append(resource)
seen.add(resource)
return deduped
def create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples):
for resource in resources:
resource_dir = skill_dir / resource
resource_dir.mkdir(exist_ok=True)
if resource == "scripts":
if include_examples:
example_script = resource_dir / "example.py"
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
example_script.chmod(0o755)
print("[OK] Created scripts/example.py")
else:
print("[OK] Created scripts/")
elif resource == "references":
if include_examples:
example_reference = resource_dir / "api_reference.md"
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
print("[OK] Created references/api_reference.md")
else:
print("[OK] Created references/")
elif resource == "assets":
if include_examples:
example_asset = resource_dir / "example_asset.txt"
example_asset.write_text(EXAMPLE_ASSET)
print("[OK] Created assets/example_asset.txt")
else:
print("[OK] Created assets/")
def init_skill(skill_name, path, resources, include_examples):
"""
Initialize a new skill directory with template SKILL.md.
Args:
skill_name: Name of the skill
path: Path where the skill directory should be created
resources: Resource directories to create
include_examples: Whether to create example files in resource directories
Returns:
Path to created skill directory, or None if error
"""
# Determine skill directory path
skill_dir = Path(path).resolve() / skill_name
# Check if directory already exists
if skill_dir.exists():
print(f"[ERROR] Skill directory already exists: {skill_dir}")
return None
# Create skill directory
try:
skill_dir.mkdir(parents=True, exist_ok=False)
print(f"[OK] Created skill directory: {skill_dir}")
except Exception as e:
print(f"[ERROR] Error creating directory: {e}")
return None
# Create SKILL.md from template
skill_title = title_case_skill_name(skill_name)
skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title)
skill_md_path = skill_dir / "SKILL.md"
try:
skill_md_path.write_text(skill_content)
print("[OK] Created SKILL.md")
except Exception as e:
print(f"[ERROR] Error creating SKILL.md: {e}")
return None
# Create resource directories if requested
if resources:
try:
create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples)
except Exception as e:
print(f"[ERROR] Error creating resource directories: {e}")
return None
# Print next steps
print(f"\n[OK] Skill '{skill_name}' initialized successfully at {skill_dir}")
print("\nNext steps:")
print("1. Edit SKILL.md to complete the TODO items and update the description")
if resources:
if include_examples:
print("2. Customize or delete the example files in scripts/, references/, and assets/")
else:
print("2. Add resources to scripts/, references/, and assets/ as needed")
else:
print("2. Create resource directories only if needed (scripts/, references/, assets/)")
print("3. Run the validator when ready to check the skill structure")
return skill_dir
def main():
parser = argparse.ArgumentParser(
description="Create a new skill directory with a SKILL.md template.",
)
parser.add_argument("skill_name", help="Skill name (normalized to hyphen-case)")
parser.add_argument("--path", required=True, help="Output directory for the skill")
parser.add_argument(
"--resources",
default="",
help="Comma-separated list: scripts,references,assets",
)
parser.add_argument(
"--examples",
action="store_true",
help="Create example files inside the selected resource directories",
)
args = parser.parse_args()
raw_skill_name = args.skill_name
skill_name = normalize_skill_name(raw_skill_name)
if not skill_name:
print("[ERROR] Skill name must include at least one letter or digit.")
sys.exit(1)
if len(skill_name) > MAX_SKILL_NAME_LENGTH:
print(
f"[ERROR] Skill name '{skill_name}' is too long ({len(skill_name)} characters). "
f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
)
sys.exit(1)
if skill_name != raw_skill_name:
print(f"Note: Normalized skill name from '{raw_skill_name}' to '{skill_name}'.")
resources = parse_resources(args.resources)
if args.examples and not resources:
print("[ERROR] --examples requires --resources to be set.")
sys.exit(1)
path = args.path
print(f"Initializing skill: {skill_name}")
print(f" Location: {path}")
if resources:
print(f" Resources: {', '.join(resources)}")
if args.examples:
print(" Examples: enabled")
else:
print(" Resources: none (create as needed)")
print()
result = init_skill(skill_name, path, resources, args.examples)
if result:
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python3
"""
Skill Packager - Creates a distributable .skill file of a skill folder
Usage:
python package_skill.py <path/to/skill-folder> [output-directory]
Example:
python package_skill.py skills/public/my-skill
python package_skill.py skills/public/my-skill ./dist
"""
import sys
import zipfile
from pathlib import Path
from quick_validate import validate_skill
def _is_within(path: Path, root: Path) -> bool:
try:
path.relative_to(root)
return True
except ValueError:
return False
def _cleanup_partial_archive(skill_filename: Path) -> None:
try:
if skill_filename.exists():
skill_filename.unlink()
except OSError:
pass
def package_skill(skill_path, output_dir=None):
"""
Package a skill folder into a .skill file.
Args:
skill_path: Path to the skill folder
output_dir: Optional output directory for the .skill file (defaults to current directory)
Returns:
Path to the created .skill file, or None if error
"""
skill_path = Path(skill_path).resolve()
# Validate skill folder exists
if not skill_path.exists():
print(f"[ERROR] Skill folder not found: {skill_path}")
return None
if not skill_path.is_dir():
print(f"[ERROR] Path is not a directory: {skill_path}")
return None
# Validate SKILL.md exists
skill_md = skill_path / "SKILL.md"
if not skill_md.exists():
print(f"[ERROR] SKILL.md not found in {skill_path}")
return None
# Run validation before packaging
print("Validating skill...")
valid, message = validate_skill(skill_path)
if not valid:
print(f"[ERROR] Validation failed: {message}")
print(" Please fix the validation errors before packaging.")
return None
print(f"[OK] {message}\n")
# Determine output location
skill_name = skill_path.name
if output_dir:
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
else:
output_path = Path.cwd()
skill_filename = output_path / f"{skill_name}.skill"
EXCLUDED_DIRS = {".git", ".svn", ".hg", "__pycache__", "node_modules"}
files_to_package = []
resolved_archive = skill_filename.resolve()
for file_path in skill_path.rglob("*"):
# Fail closed on symlinks so the packaged contents are explicit and predictable.
if file_path.is_symlink():
print(f"[ERROR] Symlink not allowed in packaged skill: {file_path}")
_cleanup_partial_archive(skill_filename)
return None
rel_parts = file_path.relative_to(skill_path).parts
if any(part in EXCLUDED_DIRS for part in rel_parts):
continue
if file_path.is_file():
resolved_file = file_path.resolve()
if not _is_within(resolved_file, skill_path):
print(f"[ERROR] File escapes skill root: {file_path}")
_cleanup_partial_archive(skill_filename)
return None
# If output lives under skill_path, avoid writing archive into itself.
if resolved_file == resolved_archive:
print(f"[WARN] Skipping output archive: {file_path}")
continue
files_to_package.append(file_path)
# Create the .skill file (zip format)
try:
with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
for file_path in files_to_package:
# Calculate the relative path within the zip.
arcname = Path(skill_name) / file_path.relative_to(skill_path)
zipf.write(file_path, arcname)
print(f" Added: {arcname}")
print(f"\n[OK] Successfully packaged skill to: {skill_filename}")
return skill_filename
except Exception as e:
_cleanup_partial_archive(skill_filename)
print(f"[ERROR] Error creating .skill file: {e}")
return None
def main():
if len(sys.argv) < 2:
print("Usage: python package_skill.py <path/to/skill-folder> [output-directory]")
print("\nExample:")
print(" python package_skill.py skills/public/my-skill")
print(" python package_skill.py skills/public/my-skill ./dist")
sys.exit(1)
skill_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
print(f"Packaging skill: {skill_path}")
if output_dir:
print(f" Output directory: {output_dir}")
print()
result = package_skill(skill_path, output_dir)
if result:
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,213 @@
#!/usr/bin/env python3
"""
Minimal validator for nanobot skill folders.
"""
import re
import sys
from pathlib import Path
from typing import Optional
try:
import yaml
except ModuleNotFoundError:
yaml = None
MAX_SKILL_NAME_LENGTH = 64
ALLOWED_FRONTMATTER_KEYS = {
"name",
"description",
"metadata",
"always",
"license",
"allowed-tools",
}
ALLOWED_RESOURCE_DIRS = {"scripts", "references", "assets"}
PLACEHOLDER_MARKERS = ("[todo", "todo:")
def _extract_frontmatter(content: str) -> Optional[str]:
lines = content.splitlines()
if not lines or lines[0].strip() != "---":
return None
for i in range(1, len(lines)):
if lines[i].strip() == "---":
return "\n".join(lines[1:i])
return None
def _parse_simple_frontmatter(frontmatter_text: str) -> Optional[dict[str, str]]:
"""Fallback parser for simple frontmatter when PyYAML is unavailable."""
parsed: dict[str, str] = {}
current_key: Optional[str] = None
multiline_key: Optional[str] = None
for raw_line in frontmatter_text.splitlines():
stripped = raw_line.strip()
if not stripped or stripped.startswith("#"):
continue
is_indented = raw_line[:1].isspace()
if is_indented:
if current_key is None:
return None
current_value = parsed[current_key]
parsed[current_key] = f"{current_value}\n{stripped}" if current_value else stripped
continue
if ":" not in stripped:
return None
key, value = stripped.split(":", 1)
key = key.strip()
value = value.strip()
if not key:
return None
if value in {"|", ">"}:
parsed[key] = ""
current_key = key
multiline_key = key
continue
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
value = value[1:-1]
parsed[key] = value
current_key = key
multiline_key = None
if multiline_key is not None and multiline_key not in parsed:
return None
return parsed
def _load_frontmatter(frontmatter_text: str) -> tuple[Optional[dict], Optional[str]]:
if yaml is not None:
try:
frontmatter = yaml.safe_load(frontmatter_text)
except yaml.YAMLError as exc:
return None, f"Invalid YAML in frontmatter: {exc}"
if not isinstance(frontmatter, dict):
return None, "Frontmatter must be a YAML dictionary"
return frontmatter, None
frontmatter = _parse_simple_frontmatter(frontmatter_text)
if frontmatter is None:
return None, "Invalid YAML in frontmatter: unsupported syntax without PyYAML installed"
return frontmatter, None
def _validate_skill_name(name: str, folder_name: str) -> Optional[str]:
if not re.fullmatch(r"[a-z0-9]+(?:-[a-z0-9]+)*", name):
return (
f"Name '{name}' should be hyphen-case "
"(lowercase letters, digits, and single hyphens only)"
)
if len(name) > MAX_SKILL_NAME_LENGTH:
return (
f"Name is too long ({len(name)} characters). "
f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
)
if name != folder_name:
return f"Skill name '{name}' must match directory name '{folder_name}'"
return None
def _validate_description(description: str) -> Optional[str]:
trimmed = description.strip()
if not trimmed:
return "Description cannot be empty"
lowered = trimmed.lower()
if any(marker in lowered for marker in PLACEHOLDER_MARKERS):
return "Description still contains TODO placeholder text"
if "<" in trimmed or ">" in trimmed:
return "Description cannot contain angle brackets (< or >)"
if len(trimmed) > 1024:
return f"Description is too long ({len(trimmed)} characters). Maximum is 1024 characters."
return None
def validate_skill(skill_path):
"""Validate a skill folder structure and required frontmatter."""
skill_path = Path(skill_path).resolve()
if not skill_path.exists():
return False, f"Skill folder not found: {skill_path}"
if not skill_path.is_dir():
return False, f"Path is not a directory: {skill_path}"
skill_md = skill_path / "SKILL.md"
if not skill_md.exists():
return False, "SKILL.md not found"
try:
content = skill_md.read_text(encoding="utf-8")
except OSError as exc:
return False, f"Could not read SKILL.md: {exc}"
frontmatter_text = _extract_frontmatter(content)
if frontmatter_text is None:
return False, "Invalid frontmatter format"
frontmatter, error = _load_frontmatter(frontmatter_text)
if error:
return False, error
unexpected_keys = sorted(set(frontmatter.keys()) - ALLOWED_FRONTMATTER_KEYS)
if unexpected_keys:
allowed = ", ".join(sorted(ALLOWED_FRONTMATTER_KEYS))
unexpected = ", ".join(unexpected_keys)
return (
False,
f"Unexpected key(s) in SKILL.md frontmatter: {unexpected}. Allowed properties are: {allowed}",
)
if "name" not in frontmatter:
return False, "Missing 'name' in frontmatter"
if "description" not in frontmatter:
return False, "Missing 'description' in frontmatter"
name = frontmatter["name"]
if not isinstance(name, str):
return False, f"Name must be a string, got {type(name).__name__}"
name_error = _validate_skill_name(name.strip(), skill_path.name)
if name_error:
return False, name_error
description = frontmatter["description"]
if not isinstance(description, str):
return False, f"Description must be a string, got {type(description).__name__}"
description_error = _validate_description(description)
if description_error:
return False, description_error
always = frontmatter.get("always")
if always is not None and not isinstance(always, bool):
return False, f"'always' must be a boolean, got {type(always).__name__}"
for child in skill_path.iterdir():
if child.name == "SKILL.md":
continue
if child.is_dir() and child.name in ALLOWED_RESOURCE_DIRS:
continue
if child.is_symlink():
continue
return (
False,
f"Unexpected file or directory in skill root: {child.name}. "
"Only SKILL.md, scripts/, references/, and assets/ are allowed.",
)
return True, "Skill is valid!"
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python quick_validate.py <skill_directory>")
sys.exit(1)
valid, message = validate_skill(sys.argv[1])
print(message)
sys.exit(0 if valid else 1)

View File

@@ -4,17 +4,15 @@ You are a helpful AI assistant. Be concise, accurate, and friendly.
## Scheduled Reminders ## Scheduled Reminders
When user asks for a reminder at a specific time, use `exec` to run: Before scheduling reminders, check available skills and follow skill guidance first.
``` Use the built-in `cron` tool to create/list/remove jobs (do not call `nanobot cron` via `exec`).
nanobot cron add --name "reminder" --message "Your message" --at "YYYY-MM-DDTHH:MM:SS" --deliver --to "USER_ID" --channel "CHANNEL"
```
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`). Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications. **Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
## Heartbeat Tasks ## Heartbeat Tasks
`HEARTBEAT.md` is checked every 30 minutes. Use file tools to manage periodic tasks: `HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
- **Add**: `edit_file` to append new tasks - **Add**: `edit_file` to append new tasks
- **Remove**: `edit_file` to delete completed tasks - **Remove**: `edit_file` to delete completed tasks

View File

@@ -1,5 +1,5 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path from nanobot.utils.helpers import ensure_dir
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"] __all__ = ["ensure_dir"]

View File

@@ -1,8 +1,25 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
import json
import re import re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any
import tiktoken
def detect_image_mime(data: bytes) -> str | None:
"""Detect image MIME type from magic bytes, ignoring file extension."""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:3] == b"\xff\xd8\xff":
return "image/jpeg"
if data[:6] in (b"GIF87a", b"GIF89a"):
return "image/gif"
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
return None
def ensure_dir(path: Path) -> Path: def ensure_dir(path: Path) -> Path:
@@ -11,17 +28,6 @@ def ensure_dir(path: Path) -> Path:
return path return path
def get_data_path() -> Path:
"""~/.nanobot data directory."""
return ensure_dir(Path.home() / ".nanobot")
def get_workspace_path(workspace: str | None = None) -> Path:
"""Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace."""
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
return ensure_dir(path)
def timestamp() -> str: def timestamp() -> str:
"""Current ISO timestamp.""" """Current ISO timestamp."""
return datetime.now().isoformat() return datetime.now().isoformat()
@@ -34,6 +40,136 @@ def safe_filename(name: str) -> str:
return _UNSAFE_CHARS.sub("_", name).strip() return _UNSAFE_CHARS.sub("_", name).strip()
def split_message(content: str, max_len: int = 2000) -> list[str]:
"""
Split content into chunks within max_len, preferring line breaks.
Args:
content: The text content to split.
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
Returns:
List of message chunks, each within max_len.
"""
if not content:
return []
if len(content) <= max_len:
return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
# Try to break at newline first, then space, then hard break
pos = cut.rfind('\n')
if pos <= 0:
pos = cut.rfind(' ')
if pos <= 0:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
def build_assistant_message(
content: str | None,
tool_calls: list[dict[str, Any]] | None = None,
reasoning_content: str | None = None,
thinking_blocks: list[dict] | None = None,
) -> dict[str, Any]:
"""Build a provider-safe assistant message with optional reasoning fields."""
msg: dict[str, Any] = {"role": "assistant", "content": content}
if tool_calls:
msg["tool_calls"] = tool_calls
if reasoning_content is not None:
msg["reasoning_content"] = reasoning_content
if thinking_blocks:
msg["thinking_blocks"] = thinking_blocks
return msg
def estimate_prompt_tokens(
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> int:
"""Estimate prompt tokens with tiktoken."""
try:
enc = tiktoken.get_encoding("cl100k_base")
parts: list[str] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
txt = part.get("text", "")
if txt:
parts.append(txt)
if tools:
parts.append(json.dumps(tools, ensure_ascii=False))
return len(enc.encode("\n".join(parts)))
except Exception:
return 0
def estimate_message_tokens(message: dict[str, Any]) -> int:
"""Estimate prompt tokens contributed by one persisted message."""
content = message.get("content")
parts: list[str] = []
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
text = part.get("text", "")
if text:
parts.append(text)
else:
parts.append(json.dumps(part, ensure_ascii=False))
elif content is not None:
parts.append(json.dumps(content, ensure_ascii=False))
for key in ("name", "tool_call_id"):
value = message.get(key)
if isinstance(value, str) and value:
parts.append(value)
if message.get("tool_calls"):
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
payload = "\n".join(parts)
if not payload:
return 1
try:
enc = tiktoken.get_encoding("cl100k_base")
return max(1, len(enc.encode(payload)))
except Exception:
return max(1, len(payload) // 4)
def estimate_prompt_tokens_chain(
provider: Any,
model: str | None,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> tuple[int, str]:
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
if callable(provider_counter):
try:
tokens, source = provider_counter(messages, tools, model)
if isinstance(tokens, (int, float)) and tokens > 0:
return int(tokens), str(source or "provider_counter")
except Exception:
pass
estimated = estimate_prompt_tokens(messages, tools)
if estimated > 0:
return int(estimated), "tiktoken"
return 0, "none"
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
"""Sync bundled templates to workspace. Only creates missing files.""" """Sync bundled templates to workspace. Only creates missing files."""
from importlib.resources import files as pkg_files from importlib.resources import files as pkg_files
@@ -54,7 +190,7 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
added.append(str(dest.relative_to(workspace))) added.append(str(dest.relative_to(workspace)))
for item in tpl.iterdir(): for item in tpl.iterdir():
if item.name.endswith(".md"): if item.name.endswith(".md") and not item.name.startswith("."):
_write(item, workspace / item.name) _write(item, workspace / item.name)
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md") _write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
_write(None, workspace / "memory" / "HISTORY.md") _write(None, workspace / "memory" / "HISTORY.md")

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "nanobot-ai" name = "nanobot-ai"
version = "0.1.4.post3" version = "0.1.4.post4"
description = "A lightweight personal AI assistant framework" description = "A lightweight personal AI assistant framework"
requires-python = ">=3.11" requires-python = ">=3.11"
license = {text = "MIT"} license = {text = "MIT"}
@@ -18,7 +18,7 @@ classifiers = [
dependencies = [ dependencies = [
"typer>=0.20.0,<1.0.0", "typer>=0.20.0,<1.0.0",
"litellm>=1.81.5,<2.0.0", "litellm>=1.82.1,<2.0.0",
"pydantic>=2.12.0,<3.0.0", "pydantic>=2.12.0,<3.0.0",
"pydantic-settings>=2.12.0,<3.0.0", "pydantic-settings>=2.12.0,<3.0.0",
"websockets>=16.0,<17.0", "websockets>=16.0,<17.0",
@@ -30,7 +30,7 @@ dependencies = [
"rich>=14.0.0,<15.0.0", "rich>=14.0.0,<15.0.0",
"croniter>=6.0.0,<7.0.0", "croniter>=6.0.0,<7.0.0",
"dingtalk-stream>=0.24.0,<1.0.0", "dingtalk-stream>=0.24.0,<1.0.0",
"python-telegram-bot[socks]>=22.0,<23.0", "python-telegram-bot[socks]>=22.6,<23.0",
"lark-oapi>=1.5.0,<2.0.0", "lark-oapi>=1.5.0,<2.0.0",
"socksio>=1.0.0,<2.0.0", "socksio>=1.0.0,<2.0.0",
"python-socketio>=5.16.0,<6.0.0", "python-socketio>=5.16.0,<6.0.0",
@@ -42,9 +42,15 @@ dependencies = [
"prompt-toolkit>=3.0.50,<4.0.0", "prompt-toolkit>=3.0.50,<4.0.0",
"mcp>=1.26.0,<2.0.0", "mcp>=1.26.0,<2.0.0",
"json-repair>=0.57.0,<1.0.0", "json-repair>=0.57.0,<1.0.0",
"chardet>=3.0.2,<6.0.0",
"openai>=2.8.0",
"tiktoken>=0.12.0,<1.0.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
wecom = [
"wecom-aibot-sdk-python @ git+https://github.com/chengyongru/wecom_aibot_sdk.git@v0.1.2",
]
matrix = [ matrix = [
"matrix-nio[e2e]>=0.25.2", "matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0", "mistune>=3.0.0,<4.0.0",
@@ -54,6 +60,9 @@ dev = [
"pytest>=9.0.0,<10.0.0", "pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0", "pytest-asyncio>=1.3.0,<2.0.0",
"ruff>=0.1.0", "ruff>=0.1.0",
"matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
] ]
[project.scripts] [project.scripts]
@@ -63,6 +72,9 @@ nanobot = "nanobot.cli.commands:app"
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["nanobot"] packages = ["nanobot"]

View File

@@ -0,0 +1,399 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init():
"""Test AzureOpenAIProvider initialization without deployment_name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21"
def test_azure_openai_provider_init_validation():
"""Test AzureOpenAIProvider initialization validation."""
# Missing api_key
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url():
"""Test Azure OpenAI URL building with different deployment names."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test various deployment names
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
def test_build_chat_url_api_base_without_slash():
"""Test URL building when api_base doesn't end with slash."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com", # No trailing slash
default_model="gpt-4o",
)
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers():
"""Test Azure OpenAI header building with api-key authentication."""
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
)
assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio
async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
# Mock response data
mock_response_data = {
"choices": [{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided():
"""Test that chat uses default_model when no model is specified."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
)
mock_response_data = {
"choices": [{
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified
# Verify URL was built with default model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Mock response with tool calls
mock_response_data = {
"choices": [{
"message": {
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
@pytest.mark.asyncio
async def test_chat_api_error():
"""Test chat request API error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content
assert result.finish_reason == "error"
def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
)
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")

View File

@@ -0,0 +1,25 @@
from types import SimpleNamespace
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
class _DummyChannel(BaseChannel):
name = "dummy"
async def start(self) -> None:
return None
async def stop(self) -> None:
return None
async def send(self, msg: OutboundMessage) -> None:
return None
def test_is_allowed_requires_exact_match() -> None:
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
assert channel.is_allowed("allow@email.com") is True
assert channel.is_allowed("attacker|allow@email.com") is False

View File

@@ -1,6 +1,6 @@
import shutil import shutil
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from typer.testing import CliRunner from typer.testing import CliRunner
@@ -14,13 +14,17 @@ from nanobot.providers.registry import find_by_model
runner = CliRunner() runner = CliRunner()
class _StopGateway(RuntimeError):
pass
@pytest.fixture @pytest.fixture
def mock_paths(): def mock_paths():
"""Mock config/workspace paths for test isolation.""" """Mock config/workspace paths for test isolation."""
with patch("nanobot.config.loader.get_config_path") as mock_cp, \ with patch("nanobot.config.loader.get_config_path") as mock_cp, \
patch("nanobot.config.loader.save_config") as mock_sc, \ patch("nanobot.config.loader.save_config") as mock_sc, \
patch("nanobot.config.loader.load_config") as mock_lc, \ patch("nanobot.config.loader.load_config") as mock_lc, \
patch("nanobot.utils.helpers.get_workspace_path") as mock_ws: patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
base_dir = Path("./test_onboard_data") base_dir = Path("./test_onboard_data")
if base_dir.exists(): if base_dir.exists():
@@ -110,6 +114,35 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex" assert config.get_provider_name() == "openai_codex"
def test_config_matches_explicit_ollama_prefix_without_api_key():
config = Config()
config.agents.defaults.model = "ollama/llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
config = Config()
config.agents.defaults.provider = "ollama"
config.agents.defaults.model = "llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_auto_detects_ollama_from_local_api_base():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword(): def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex") spec = find_by_model("github-copilot/gpt-5.3-codex")
@@ -128,3 +161,303 @@ def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex" assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex" assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
@pytest.fixture
def mock_agent_runtime(tmp_path):
"""Mock agent command dependencies for focused CLI tests."""
config = Config()
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
cron_dir = tmp_path / "data" / "cron"
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
patch("nanobot.cli.commands._make_provider", return_value=object()), \
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
patch("nanobot.bus.queue.MessageBus"), \
patch("nanobot.cron.service.CronService"), \
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
agent_loop = MagicMock()
agent_loop.channels_config = None
agent_loop.process_direct = AsyncMock(return_value="mock-response")
agent_loop.close_mcp = AsyncMock(return_value=None)
mock_agent_loop_cls.return_value = agent_loop
yield {
"config": config,
"load_config": mock_load_config,
"sync_templates": mock_sync_templates,
"agent_loop_cls": mock_agent_loop_cls,
"agent_loop": agent_loop,
"print_response": mock_print_response,
}
def test_agent_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["agent", "--help"])
assert result.exit_code == 0
assert "--workspace" in result.stdout
assert "-w" in result.stdout
assert "--config" in result.stdout
assert "-c" in result.stdout
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
result = runner.invoke(app, ["agent", "-m", "hello"])
assert result.exit_code == 0
assert mock_agent_runtime["load_config"].call_args.args == (None,)
assert mock_agent_runtime["sync_templates"].call_args.args == (
mock_agent_runtime["config"].workspace_path,
)
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
mock_agent_runtime["config"].workspace_path
)
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True)
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
config_path = tmp_path / "agent-config.json"
config_path.write_text("{}")
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_path)])
assert result.exit_code == 0
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
seen: dict[str, Path] = {}
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
class _FakeAgentLoop:
def __init__(self, *args, **kwargs) -> None:
pass
async def process_direct(self, *_args, **_kwargs) -> str:
return "ok"
async def close_mcp(self) -> None:
return None
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
assert result.exit_code == 0
assert seen["config_path"] == config_file.resolve()
def test_agent_overrides_workspace_path(mock_agent_runtime):
workspace_path = Path("/tmp/agent-workspace")
result = runner.invoke(app, ["agent", "-m", "hello", "-w", str(workspace_path)])
assert result.exit_code == 0
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
config_path = tmp_path / "agent-config.json"
config_path.write_text("{}")
workspace_path = Path("/tmp/agent-workspace")
result = runner.invoke(
app,
["agent", "-m", "hello", "-c", str(config_path), "-w", str(workspace_path)],
)
assert result.exit_code == 0
assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
mock_agent_runtime["config"].agents.defaults.memory_window = 100
result = runner.invoke(app, ["agent", "-m", "hello"])
assert result.exit_code == 0
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert seen["config_path"] == config_file.resolve()
assert seen["workspace"] == Path(config.agents.defaults.workspace)
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
override = tmp_path / "override-workspace"
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(
app,
["gateway", "--config", str(config_file), "--workspace", str(override)],
)
assert isinstance(result.exception, _StopGateway)
assert seen["workspace"] == override
assert config.workspace_path == override
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.memory_window = 100
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGateway("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "port 18791" in result.stdout
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
assert isinstance(result.exception, _StopGateway)
assert "port 18792" in result.stdout

View File

@@ -0,0 +1,88 @@
import json
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.config.loader import load_config, save_config
runner = CliRunner()
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 1234,
"memoryWindow": 42,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536
assert config.agents.defaults.should_warn_deprecated_memory_window is True
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 2222,
"memoryWindow": 30,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 2222
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 3333,
"memoryWindow": 50,
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
assert "contextWindowTokens" in result.stdout
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 3333
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults

View File

@@ -0,0 +1,42 @@
from pathlib import Path
from nanobot.config.paths import (
get_bridge_install_dir,
get_cli_history_path,
get_cron_dir,
get_data_dir,
get_legacy_sessions_dir,
get_logs_dir,
get_media_dir,
get_runtime_subdir,
get_workspace_path,
)
def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance-a" / "config.json"
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
assert get_data_dir() == config_file.parent
assert get_runtime_subdir("cron") == config_file.parent / "cron"
assert get_cron_dir() == config_file.parent / "cron"
assert get_logs_dir() == config_file.parent / "logs"
def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance-b" / "config.json"
monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
assert get_media_dir() == config_file.parent / "media"
assert get_media_dir("telegram") == config_file.parent / "media" / "telegram"
def test_shared_and_legacy_paths_remain_global() -> None:
assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history"
assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge"
assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions"
def test_workspace_path_is_explicitly_resolved() -> None:
assert get_workspace_path() == Path.home() / ".nanobot" / "workspace"
assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace"

View File

@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
assert_messages_content(old_messages, 10, 34) assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard: class TestNewCommandArchival:
"""Test that consolidation tasks are deduplicated and serialized.""" """Test /new archival behavior with the simplified consolidation flow."""
@pytest.mark.asyncio @staticmethod
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None: def _make_loop(tmp_path: Path):
"""Concurrent messages above memory_window spawn only one consolidation task."""
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
bus = MessageBus() bus = MessageBus()
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (10_000, "test")
loop = AgentLoop( loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=1,
) )
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
return loop
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None: async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
for i in range(5): for i in range(5):
session.add_message("user", f"msg{i}") session.add_message("user", f"msg{i}")
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
loop.sessions.save(session) loop.sessions.save(session)
before_count = len(session.messages) before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool: async def _failing_consolidate(_messages) -> bool:
if archive_all: return False
return False
return True
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg) response = await loop._process_message(new_msg)
assert response is not None assert response is not None
assert "failed" in response.content.lower() assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test") assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages_after_inflight_task( async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""/new should archive only messages not yet consolidated by prior task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
for i in range(15): for i in range(15):
session.add_message("user", f"msg{i}") session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}") session.add_message("assistant", f"resp{i}")
session.last_consolidated = len(session.messages) - 3
loop.sessions.save(session) loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = -1 archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool: async def _fake_consolidate(messages) -> bool:
nonlocal archived_count nonlocal archived_count
if archive_all: archived_count = len(messages)
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
sess.last_consolidated = len(sess.messages) - 3
return True return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg)) response = await loop._process_message(new_msg)
await asyncio.sleep(0.02)
assert not pending_new.done()
release.set()
response = await pending_new
assert response is not None assert response is not None
assert "new session started" in response.content.lower() assert "new session started" in response.content.lower()
assert archived_count == 3, ( assert archived_count == 3
f"Expected only unconsolidated tail to archive, got {archived_count}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None: async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
"""/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
for i in range(3): for i in range(3):
session.add_message("user", f"msg{i}") session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}") session.add_message("assistant", f"resp{i}")
loop.sessions.save(session) loop.sessions.save(session)
async def _ok_consolidate(sess, archive_all: bool = False) -> bool: async def _ok_consolidate(_messages) -> bool:
return True return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg) response = await loop._process_message(new_msg)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime as real_datetime from datetime import datetime as real_datetime
from importlib.resources import files as pkg_files
from pathlib import Path from pathlib import Path
import datetime as datetime_module import datetime as datetime_module
@@ -23,6 +24,13 @@ def _make_workspace(tmp_path: Path) -> Path:
return workspace return workspace
def test_bootstrap_files_are_backed_by_templates() -> None:
template_dir = pkg_files("nanobot") / "templates"
for filename in ContextBuilder.BOOTSTRAP_FILES:
assert (template_dir / filename).is_file(), f"missing bootstrap template: {filename}"
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None: def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
"""System prompt should not change just because wall clock minute changes.""" """System prompt should not change just because wall clock minute changes."""
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime) monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
@@ -40,7 +48,7 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None: def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
"""Runtime metadata should be a separate user message before the actual user message.""" """Runtime metadata should be merged with the user message."""
workspace = _make_workspace(tmp_path) workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace) builder = ContextBuilder(workspace)
@@ -54,13 +62,12 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
assert messages[0]["role"] == "system" assert messages[0]["role"] == "system"
assert "## Current Session" not in messages[0]["content"] assert "## Current Session" not in messages[0]["content"]
assert messages[-2]["role"] == "user" # Runtime context is now merged with user message into a single message
runtime_content = messages[-2]["content"]
assert isinstance(runtime_content, str)
assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
assert "Current Time:" in runtime_content
assert "Channel: cli" in runtime_content
assert "Chat ID: direct" in runtime_content
assert messages[-1]["role"] == "user" assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == "Return exactly: OK" user_content = messages[-1]["content"]
assert isinstance(user_content, str)
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
assert "Current Time:" in user_content
assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content
assert "Return exactly: OK" in user_content

View File

@@ -1,29 +0,0 @@
from typer.testing import CliRunner
from nanobot.cli.commands import app
runner = CliRunner()
def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
result = runner.invoke(
app,
[
"cron",
"add",
"--name",
"demo",
"--message",
"hello",
"--cron",
"0 9 * * *",
"--tz",
"America/Vancovuer",
],
)
assert result.exit_code == 1
assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
assert not (tmp_path / "cron" / "jobs.json").exists()

View File

@@ -48,6 +48,8 @@ async def test_running_service_honors_external_disable(tmp_path) -> None:
) )
await service.start() await service.start()
try: try:
# Wait slightly to ensure file mtime is definitively different
await asyncio.sleep(0.05)
external = CronService(store_path) external = CronService(store_path)
updated = external.enable_job(job.id, enabled=False) updated = external.enable_job(job.id, enabled=False)
assert updated is not None assert updated is not None

View File

@@ -0,0 +1,111 @@
import asyncio
from types import SimpleNamespace
import pytest
from nanobot.bus.queue import MessageBus
import nanobot.channels.dingtalk as dingtalk_module
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
from nanobot.config.schema import DingTalkConfig
class _FakeResponse:
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
self.status_code = status_code
self._json_body = json_body or {}
self.text = "{}"
def json(self) -> dict:
return self._json_body
class _FakeHttp:
def __init__(self) -> None:
self.calls: list[dict] = []
async def post(self, url: str, json=None, headers=None):
self.calls.append({"url": url, "json": json, "headers": headers})
return _FakeResponse()
@pytest.mark.asyncio
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
bus = MessageBus()
channel = DingTalkChannel(config, bus)
await channel._on_message(
"hello",
sender_id="user1",
sender_name="Alice",
conversation_type="2",
conversation_id="conv123",
)
msg = await bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"
assert msg.metadata["conversation_type"] == "2"
@pytest.mark.asyncio
async def test_group_send_uses_group_messages_api() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
channel._http = _FakeHttp()
ok = await channel._send_batch_message(
"token",
"group:conv123",
"sampleMarkdown",
{"text": "hello", "title": "Nanobot Reply"},
)
assert ok is True
call = channel._http.calls[0]
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
assert call["json"]["openConversationId"] == "conv123"
assert call["json"]["msgKey"] == "sampleMarkdown"
@pytest.mark.asyncio
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeChatbotMessage:
text = None
extensions = {"content": {"recognition": "voice transcript"}}
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "audio"
@staticmethod
def from_dict(_data):
return _FakeChatbotMessage()
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "2",
"conversationId": "conv123",
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert msg.content == "voice transcript"
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"

View File

@@ -1,4 +1,4 @@
from nanobot.channels.feishu import _extract_post_content from nanobot.channels.feishu import FeishuChannel, _extract_post_content
def test_extract_post_content_supports_post_wrapper_shape() -> None: def test_extract_post_content_supports_post_wrapper_shape() -> None:
@@ -38,3 +38,28 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
assert text == "Daily report" assert text == "Daily report"
assert image_keys == ["img_a", "img_b"] assert image_keys == ["img_a", "img_b"]
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
class Builder:
pass
builder = Builder()
same = FeishuChannel._register_optional_event(builder, "missing", object())
assert same is builder
def test_register_optional_event_calls_supported_method() -> None:
called = []
class Builder:
def register_event(self, handler):
called.append(handler)
return self
builder = Builder()
handler = object()
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
assert same is builder
assert called == [handler]

View File

@@ -0,0 +1,104 @@
"""Tests for FeishuChannel._split_elements_by_table_limit.
Feishu cards reject messages that contain more than one table element
(API error 11310: card table number over limit). The helper splits a flat
list of card elements into groups so that each group contains at most one
table, allowing nanobot to send multiple cards instead of failing.
"""
from nanobot.channels.feishu import FeishuChannel
def _md(text: str) -> dict:
return {"tag": "markdown", "content": text}
def _table() -> dict:
return {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
"rows": [{"c0": "v"}],
"page_size": 2,
}
split = FeishuChannel._split_elements_by_table_limit
def test_empty_list_returns_single_empty_group() -> None:
assert split([]) == [[]]
def test_no_tables_returns_single_group() -> None:
els = [_md("hello"), _md("world")]
result = split(els)
assert result == [els]
def test_single_table_stays_in_one_group() -> None:
els = [_md("intro"), _table(), _md("outro")]
result = split(els)
assert len(result) == 1
assert result[0] == els
def test_two_tables_split_into_two_groups() -> None:
# Use different row values so the two tables are not equal
t1 = {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
"rows": [{"c0": "table-one"}],
"page_size": 2,
}
t2 = {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
"rows": [{"c0": "table-two"}],
"page_size": 2,
}
els = [_md("before"), t1, _md("between"), t2, _md("after")]
result = split(els)
assert len(result) == 2
# First group: text before table-1 + table-1
assert t1 in result[0]
assert t2 not in result[0]
# Second group: text between tables + table-2 + text after
assert t2 in result[1]
assert t1 not in result[1]
def test_three_tables_split_into_three_groups() -> None:
tables = [
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
for i in range(3)
]
els = tables[:]
result = split(els)
assert len(result) == 3
for i, group in enumerate(result):
assert tables[i] in group
def test_leading_markdown_stays_with_first_table() -> None:
intro = _md("intro")
t = _table()
result = split([intro, t])
assert len(result) == 1
assert result[0] == [intro, t]
def test_trailing_markdown_after_second_table() -> None:
t1, t2 = _table(), _table()
tail = _md("end")
result = split([t1, t2, tail])
assert len(result) == 2
assert result[1] == [t2, tail]
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
head = _md("head")
t1, t2 = _table(), _table()
result = split([head, t1, t2])
# head + t1 in group 0; t2 in group 1
assert result[0] == [head, t1]
assert result[1] == [t2]

View File

@@ -0,0 +1,251 @@
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
import pytest
from nanobot.agent.tools.filesystem import (
EditFileTool,
ListDirTool,
ReadFileTool,
_find_match,
)
# ---------------------------------------------------------------------------
# ReadFileTool
# ---------------------------------------------------------------------------
class TestReadFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def sample_file(self, tmp_path):
f = tmp_path / "sample.txt"
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
return f
@pytest.mark.asyncio
async def test_basic_read_has_line_numbers(self, tool, sample_file):
result = await tool.execute(path=str(sample_file))
assert "1| line 1" in result
assert "20| line 20" in result
@pytest.mark.asyncio
async def test_offset_and_limit(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
assert "5| line 5" in result
assert "7| line 7" in result
assert "8| line 8" not in result
assert "Use offset=8 to continue" in result
@pytest.mark.asyncio
async def test_offset_beyond_end(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=999)
assert "Error" in result
assert "beyond end" in result
@pytest.mark.asyncio
async def test_end_of_file_marker(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
assert "End of file" in result
@pytest.mark.asyncio
async def test_empty_file(self, tool, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("", encoding="utf-8")
result = await tool.execute(path=str(f))
assert "Empty file" in result
@pytest.mark.asyncio
async def test_file_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope.txt"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_char_budget_trims(self, tool, tmp_path):
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
f = tmp_path / "big.txt"
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
result = await tool.execute(path=str(f))
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
assert "Use offset=" in result
# ---------------------------------------------------------------------------
# _find_match (unit tests for the helper)
# ---------------------------------------------------------------------------
class TestFindMatch:
def test_exact_match(self):
match, count = _find_match("hello world", "world")
assert match == "world"
assert count == 1
def test_exact_no_match(self):
match, count = _find_match("hello world", "xyz")
assert match is None
assert count == 0
def test_crlf_normalisation(self):
# Caller normalises CRLF before calling _find_match, so test with
# pre-normalised content to verify exact match still works.
content = "line1\nline2\nline3"
old_text = "line1\nline2\nline3"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
def test_line_trim_fallback(self):
content = " def foo():\n pass\n"
old_text = "def foo():\n pass"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
# The returned match should be the *original* indented text
assert " def foo():" in match
def test_line_trim_multiple_candidates(self):
content = " a\n b\n a\n b\n"
old_text = "a\nb"
match, count = _find_match(content, old_text)
assert count == 2
def test_empty_old_text(self):
match, count = _find_match("hello", "")
# Empty string is always "in" any string via exact match
assert match == ""
# ---------------------------------------------------------------------------
# EditFileTool
# ---------------------------------------------------------------------------
class TestEditFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_exact_match(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
assert f.read_text() == "hello earth"
@pytest.mark.asyncio
async def test_crlf_normalisation(self, tool, tmp_path):
f = tmp_path / "crlf.py"
f.write_bytes(b"line1\r\nline2\r\nline3")
result = await tool.execute(
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
)
assert "Successfully" in result
raw = f.read_bytes()
assert b"LINE1" in raw
# CRLF line endings should be preserved throughout the file
assert b"\r\n" in raw
@pytest.mark.asyncio
async def test_trim_fallback(self, tool, tmp_path):
f = tmp_path / "indent.py"
f.write_text(" def foo():\n pass\n", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
)
assert "Successfully" in result
assert "bar" in f.read_text()
@pytest.mark.asyncio
async def test_ambiguous_match(self, tool, tmp_path):
f = tmp_path / "dup.py"
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
assert "appears" in result.lower() or "Warning" in result
@pytest.mark.asyncio
async def test_replace_all(self, tool, tmp_path):
f = tmp_path / "multi.py"
f.write_text("foo bar foo bar foo", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="foo", new_text="baz", replace_all=True,
)
assert "Successfully" in result
assert f.read_text() == "baz bar baz bar baz"
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
f = tmp_path / "nf.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
assert "Error" in result
assert "not found" in result
# ---------------------------------------------------------------------------
# ListDirTool
# ---------------------------------------------------------------------------
class TestListDirTool:
@pytest.fixture()
def tool(self, tmp_path):
return ListDirTool(workspace=tmp_path)
@pytest.fixture()
def populated_dir(self, tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("pass")
(tmp_path / "src" / "utils.py").write_text("pass")
(tmp_path / "README.md").write_text("hi")
(tmp_path / ".git").mkdir()
(tmp_path / ".git" / "config").write_text("x")
(tmp_path / "node_modules").mkdir()
(tmp_path / "node_modules" / "pkg").mkdir()
return tmp_path
@pytest.mark.asyncio
async def test_basic_list(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir))
assert "README.md" in result
assert "src" in result
# .git and node_modules should be ignored
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_recursive(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir), recursive=True)
assert "src/main.py" in result
assert "src/utils.py" in result
assert "README.md" in result
# Ignored dirs should not appear
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_max_entries_truncation(self, tool, tmp_path):
for i in range(10):
(tmp_path / f"file_{i}.txt").write_text("x")
result = await tool.execute(path=str(tmp_path), max_entries=3)
assert "truncated" in result
assert "3 of 10" in result
@pytest.mark.asyncio
async def test_empty_dir(self, tool, tmp_path):
d = tmp_path / "empty"
d.mkdir()
result = await tool.execute(path=str(d))
assert "empty" in result.lower()
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope"))
assert "Error" in result
assert "not found" in result

View File

@@ -0,0 +1,53 @@
from types import SimpleNamespace
from nanobot.providers.base import ToolCallRequest
from nanobot.providers.litellm_provider import LiteLLMProvider
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="tool_calls",
message=SimpleNamespace(
content=None,
tool_calls=[
SimpleNamespace(
id="call_123",
function=SimpleNamespace(
name="read_file",
arguments='{"path":"todo.md"}',
provider_specific_fields={"inner": "value"},
),
provider_specific_fields={"thought_signature": "signed-token"},
)
],
),
)
],
usage=None,
)
parsed = provider._parse_response(response)
assert len(parsed.tool_calls) == 1
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
def test_tool_call_request_serializes_provider_fields() -> None:
tool_call = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"thought_signature": "signed-token"},
function_provider_specific_fields={"inner": "value"},
)
message = tool_call.to_openai_tool_call()
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
assert message["function"]["arguments"] == '{"path": "todo.md"}'

View File

@@ -3,18 +3,24 @@ import asyncio
import pytest import pytest
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider: class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]): def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses) self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse: async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses: if self._responses:
return self._responses.pop(0) return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[]) return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None: async def test_start_is_idempotent(tmp_path) -> None:
@@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
) )
assert await service.trigger_now() is None assert await service.trigger_now() is None
@pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "run"
assert tasks == "check open tasks"
assert provider.calls == 2
assert delays == [1]

View File

@@ -0,0 +1,190 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
import nanobot.agent.memory as memory_module
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=context_window_tokens,
)
loop.tools.get_definitions = MagicMock(return_value=[])
return loop
@pytest.mark.asyncio
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
@pytest.mark.asyncio
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
await loop.process_direct("hello", session_key="cli:test")
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
@pytest.mark.asyncio
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
]
loop.sessions.save(session)
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
assert session.last_consolidated == 4
@pytest.mark.asyncio
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (300, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
"""Once triggered, consolidation should continue until it drops below half threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (150, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
"""Verify preflight consolidation runs before the LLM call in process_direct."""
order: list[str] = []
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages):
order.append("consolidate")
return True
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
async def track_llm(*args, **kwargs):
order.append("llm")
return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
return (1000 if call_count[0] <= 1 else 80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
assert "consolidate" in order
assert "llm" in order
assert order.index("consolidate") < order.index("llm")

View File

@@ -0,0 +1,41 @@
from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = 500
return loop
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
loop = _mk_loop()
session = Session(key="test:runtime-only")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
skip=0,
)
assert session.messages == []
def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
loop = _mk_loop()
session = Session(key="test:image")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{
"role": "user",
"content": [
{"type": "text", "text": runtime},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
}],
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]

View File

@@ -159,6 +159,7 @@ class _FakeAsyncClient:
def _make_config(**kwargs) -> MatrixConfig: def _make_config(**kwargs) -> MatrixConfig:
kwargs.setdefault("allow_from", ["*"])
return MatrixConfig( return MatrixConfig(
enabled=True, enabled=True,
homeserver="https://matrix.org", homeserver="https://matrix.org",
@@ -274,7 +275,7 @@ async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_invite_joins_when_allow_list_is_empty() -> None: async def test_room_invite_ignores_when_allow_list_is_empty() -> None:
channel = MatrixChannel(_make_config(allow_from=[]), MessageBus()) channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
client = _FakeAsyncClient("", "", "", None) client = _FakeAsyncClient("", "", "", None)
channel.client = client channel.client = client
@@ -284,9 +285,22 @@ async def test_room_invite_joins_when_allow_list_is_empty() -> None:
await channel._on_room_invite(room, event) await channel._on_room_invite(room, event)
assert client.join_calls == ["!room:matrix.org"] assert client.join_calls == []
@pytest.mark.asyncio
async def test_room_invite_joins_when_sender_allowed() -> None:
channel = MatrixChannel(_make_config(allow_from=["@alice:matrix.org"]), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
room = SimpleNamespace(room_id="!room:matrix.org")
event = SimpleNamespace(sender="@alice:matrix.org")
await channel._on_room_invite(room, event)
assert client.join_calls == ["!room:matrix.org"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_invite_respects_allow_list_when_configured() -> None: async def test_room_invite_respects_allow_list_when_configured() -> None:
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus()) channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
@@ -1163,6 +1177,8 @@ async def test_send_progress_keeps_typing_keepalive_running() -> None:
assert "!room:matrix.org" in channel._typing_tasks assert "!room:matrix.org" in channel._typing_tasks
assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS) assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
await channel.stop()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_clears_typing_when_send_fails() -> None: async def test_send_clears_typing_when_send_fails() -> None:

99
tests/test_mcp_tool.py Normal file
View File

@@ -0,0 +1,99 @@
from __future__ import annotations
import asyncio
import sys
from types import ModuleType, SimpleNamespace
import pytest
from nanobot.agent.tools.mcp import MCPToolWrapper
class _FakeTextContent:
def __init__(self, text: str) -> None:
self.text = text
@pytest.fixture(autouse=True)
def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None:
mod = ModuleType("mcp")
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
monkeypatch.setitem(sys.modules, "mcp", mod)
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={"type": "object", "properties": {}},
)
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
@pytest.mark.asyncio
async def test_execute_returns_text_blocks() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
assert arguments == {"value": 1}
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute(value=1)
assert result == "hello\n42"
@pytest.mark.asyncio
async def test_execute_returns_timeout_message() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
await asyncio.sleep(1)
return SimpleNamespace(content=[])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
result = await wrapper.execute()
assert result == "(MCP tool call timed out after 0.01s)"
@pytest.mark.asyncio
async def test_execute_handles_server_cancelled_error() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
raise asyncio.CancelledError()
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute()
assert result == "(MCP tool call was cancelled)"
@pytest.mark.asyncio
async def test_execute_re_raises_external_cancellation() -> None:
started = asyncio.Event()
async def call_tool(_name: str, arguments: dict) -> object:
started.set()
await asyncio.sleep(60)
return SimpleNamespace(content=[])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
task = asyncio.create_task(wrapper.execute())
await started.wait()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_execute_handles_generic_exception() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
raise RuntimeError("boom")
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute()
assert result == "(MCP tool call failed: RuntimeError)"

View File

@@ -7,23 +7,20 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
import json import json
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock
import pytest import pytest
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50): def _make_messages(message_count: int = 30):
"""Create a mock session with messages.""" """Create a list of mock messages."""
session = MagicMock() return [
session.messages = [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"} {"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count) for i in range(message_count)
] ]
session.last_consolidated = 0
return session
def _make_tool_response(history_entry, memory_update): def _make_tool_response(history_entry, memory_update):
@@ -43,6 +40,22 @@ def _make_tool_response(history_entry, memory_update):
) )
class ScriptedProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
class TestMemoryConsolidationTypeHandling: class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly.""" """Test that consolidation handles various argument types correctly."""
@@ -57,9 +70,10 @@ class TestMemoryConsolidationTypeHandling:
memory_update="# Memory\nUser likes testing.", memory_update="# Memory\nUser likes testing.",
) )
) )
session = _make_session(message_count=60) provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert store.history_file.exists() assert store.history_file.exists()
@@ -77,9 +91,10 @@ class TestMemoryConsolidationTypeHandling:
memory_update={"facts": ["User likes testing"], "topics": ["testing"]}, memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
) )
) )
session = _make_session(message_count=60) provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert store.history_file.exists() assert store.history_file.exists()
@@ -112,9 +127,10 @@ class TestMemoryConsolidationTypeHandling:
], ],
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60) provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
assert "User discussed testing." in store.history_file.read_text() assert "User discussed testing." in store.history_file.read_text()
@@ -127,21 +143,148 @@ class TestMemoryConsolidationTypeHandling:
provider.chat = AsyncMock( provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[]) return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
) )
session = _make_session(message_count=60) provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is False assert result is False
assert not store.history_file.exists() assert not store.history_file.exists()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_skips_when_few_messages(self, tmp_path: Path) -> None: async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when messages < keep_count.""" """Consolidation should be a no-op when the selected chunk is empty."""
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
session = _make_session(message_count=10) provider.chat_with_retry = provider.chat
messages: list[dict] = []
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(messages, provider, "test-model")
assert result is True assert result is True
provider.chat.assert_not_called() provider.chat.assert_not_called()
@pytest.mark.asyncio
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
"""Some providers return arguments as a list - extract first element if it's a dict."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[{
"history_entry": "[2026-01-01] User discussed testing.",
"memory_update": "# Memory\nUser likes testing.",
}],
)
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
assert "User likes testing." in store.memory_file.read_text()
@pytest.mark.asyncio
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
"""Empty list arguments should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[],
)
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@pytest.mark.asyncio
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
"""List with non-dict content should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=["string", "content"],
)
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path)
provider = ScriptedProvider([
LLMResponse(content="503 server error", finish_reason="error"),
_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
),
])
messages = _make_messages(message_count=60)
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
"""Consolidation no longer passes generation params — the provider owns them."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat_with_retry.assert_awaited_once()
_, kwargs = provider.chat_with_retry.await_args
assert kwargs["model"] == "test-model"
assert "temperature" not in kwargs
assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs

View File

@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus() bus = MessageBus()
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
class TestMessageToolSuppressLogic: class TestMessageToolSuppressLogic:
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]), LLMResponse(content="Done", tool_calls=[]),
]) ])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = [] sent: list[OutboundMessage] = []
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]), LLMResponse(content="I've sent the email.", tool_calls=[]),
]) ])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = [] sent: list[OutboundMessage] = []
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path) loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
@@ -86,6 +86,35 @@ class TestMessageToolSuppressLogic:
assert result is not None assert result is not None
assert "Hello" in result.content assert "Hello" in result.content
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
calls = iter([
LLMResponse(
content="Visible<think>hidden</think>",
tool_calls=[tool_call],
reasoning_content="secret reasoning",
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
progress: list[tuple[str, bool]] = []
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [
("Visible", False),
('read_file("foo.txt")', True),
]
class TestMessageToolTurnTracking: class TestMessageToolTurnTracking:

View File

@@ -0,0 +1,125 @@
import asyncio
import pytest
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider):
def __init__(self, responses):
super().__init__()
self._responses = list(responses)
self.calls = 0
self.last_kwargs: dict = {}
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
self.last_kwargs = kwargs
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "stop"
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "401 unauthorized"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit a", finish_reason="error"),
LLMResponse(content="429 rate limit b", finish_reason="error"),
LLMResponse(content="429 rate limit c", finish_reason="error"),
LLMResponse(content="503 final server error", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "503 final server error"
assert provider.calls == 4
assert delays == [1, 2, 4]
@pytest.mark.asyncio
async def test_chat_with_retry_preserves_cancelled_error() -> None:
provider = ScriptedProvider([asyncio.CancelledError()])
with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
@pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert provider.last_kwargs["temperature"] == 0.2
assert provider.last_kwargs["max_tokens"] == 321
assert provider.last_kwargs["reasoning_effort"] == "high"
@pytest.mark.asyncio
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
"""Explicit kwargs should override provider.generation defaults."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
temperature=0.9,
max_tokens=9999,
reasoning_effort="low",
)
assert provider.last_kwargs["temperature"] == 0.9
assert provider.last_kwargs["max_tokens"] == 9999
assert provider.last_kwargs["reasoning_effort"] == "low"

66
tests/test_qq_channel.py Normal file
View File

@@ -0,0 +1,66 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel
from nanobot.config.schema import QQConfig
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeClient:
def __init__(self) -> None:
self.api = _FakeApi()
@pytest.mark.asyncio
async def test_on_group_message_routes_to_group_chat_id() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
data = SimpleNamespace(
id="msg1",
content="hello",
group_openid="group123",
author=SimpleNamespace(member_openid="user1"),
)
await channel._on_message(data, is_group=True)
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group123"
@pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
await channel.send(
OutboundMessage(
channel="qq",
chat_id="group123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123"
assert call["msg_id"] == "msg1"
assert call["msg_seq"] == 2
assert not channel._client.api.c2c_calls

View File

@@ -0,0 +1,127 @@
import importlib
import shutil
import sys
import zipfile
from pathlib import Path
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
init_skill = importlib.import_module("init_skill")
package_skill = importlib.import_module("package_skill")
quick_validate = importlib.import_module("quick_validate")
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
skill_dir = init_skill.init_skill(
"demo-skill",
tmp_path,
["scripts", "references", "assets"],
include_examples=True,
)
assert skill_dir == tmp_path / "demo-skill"
assert (skill_dir / "SKILL.md").exists()
assert (skill_dir / "scripts" / "example.py").exists()
assert (skill_dir / "references" / "api_reference.md").exists()
assert (skill_dir / "assets" / "example_asset.txt").exists()
def test_validate_skill_accepts_existing_skill_creator() -> None:
valid, message = quick_validate.validate_skill(
Path("nanobot/skills/skill-creator").resolve()
)
assert valid, message
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
skill_dir = tmp_path / "placeholder-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: placeholder-skill\n"
'description: "[TODO: fill me in]"\n'
"---\n"
"# Placeholder\n",
encoding="utf-8",
)
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "TODO placeholder" in message
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
skill_dir = tmp_path / "bad-root-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: bad-root-skill\n"
"description: Valid description\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "Unexpected file or directory in skill root" in message
def test_package_skill_creates_archive(tmp_path: Path) -> None:
skill_dir = tmp_path / "package-me"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: package-me\n"
"description: Package this skill.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path == (tmp_path / "dist" / "package-me.skill")
assert archive_path.exists()
with zipfile.ZipFile(archive_path, "r") as archive:
names = set(archive.namelist())
assert "package-me/SKILL.md" in names
assert "package-me/scripts/helper.py" in names
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
skill_dir = tmp_path / "symlink-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: symlink-skill\n"
"description: Reject symlinks during packaging.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
target = tmp_path / "outside.txt"
target.write_text("secret\n", encoding="utf-8")
link = scripts_dir / "outside.txt"
try:
link.symlink_to(target)
except (OSError, NotImplementedError):
return
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path is None
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.slack import SlackChannel
from nanobot.config.schema import SlackConfig
class _FakeAsyncWebClient:
def __init__(self) -> None:
self.chat_post_calls: list[dict[str, object | None]] = []
self.file_upload_calls: list[dict[str, object | None]] = []
async def chat_postMessage(
self,
*,
channel: str,
text: str,
thread_ts: str | None = None,
) -> None:
self.chat_post_calls.append(
{
"channel": channel,
"text": text,
"thread_ts": thread_ts,
}
)
async def files_upload_v2(
self,
*,
channel: str,
file: str,
thread_ts: str | None = None,
) -> None:
self.file_upload_calls.append(
{
"channel": channel,
"file": file,
"thread_ts": thread_ts,
}
)
@pytest.mark.asyncio
async def test_send_uses_thread_for_channel_messages() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="C123",
content="hello",
media=["/tmp/demo.txt"],
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
)
)
assert len(fake_web.chat_post_calls) == 1
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
assert len(fake_web.file_upload_calls) == 1
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
@pytest.mark.asyncio
async def test_send_omits_thread_for_dm_messages() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="D123",
content="hello",
media=["/tmp/demo.txt"],
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
)
)
assert len(fake_web.chat_post_calls) == 1
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
assert fake_web.chat_post_calls[0]["thread_ts"] is None
assert len(fake_web.file_upload_calls) == 1
assert fake_web.file_upload_calls[0]["thread_ts"] is None

View File

@@ -165,3 +165,46 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
assert await mgr.cancel_by_session("nonexistent") == 0 assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def scripted_chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
async def fake_execute(self, name, arguments):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
assistant_messages = [
msg for msg in captured_second_call
if msg.get("role") == "assistant" and msg.get("tool_calls")
]
assert len(assistant_messages) == 1
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]

View File

@@ -0,0 +1,338 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TelegramChannel
from nanobot.config.schema import TelegramConfig
class _FakeHTTPXRequest:
instances: list["_FakeHTTPXRequest"] = []
def __init__(self, **kwargs) -> None:
self.kwargs = kwargs
self.__class__.instances.append(self)
class _FakeUpdater:
def __init__(self, on_start_polling) -> None:
self._on_start_polling = on_start_polling
async def start_polling(self, **kwargs) -> None:
self._on_start_polling()
class _FakeBot:
def __init__(self) -> None:
self.sent_messages: list[dict] = []
self.get_me_calls = 0
async def get_me(self):
self.get_me_calls += 1
return SimpleNamespace(id=999, username="nanobot_test")
async def set_my_commands(self, commands) -> None:
self.commands = commands
async def send_message(self, **kwargs) -> None:
self.sent_messages.append(kwargs)
async def send_chat_action(self, **kwargs) -> None:
pass
class _FakeApp:
def __init__(self, on_start_polling) -> None:
self.bot = _FakeBot()
self.updater = _FakeUpdater(on_start_polling)
self.handlers = []
self.error_handlers = []
def add_error_handler(self, handler) -> None:
self.error_handlers.append(handler)
def add_handler(self, handler) -> None:
self.handlers.append(handler)
async def initialize(self) -> None:
pass
async def start(self) -> None:
pass
class _FakeBuilder:
def __init__(self, app: _FakeApp) -> None:
self.app = app
self.token_value = None
self.request_value = None
self.get_updates_request_value = None
def token(self, token: str):
self.token_value = token
return self
def request(self, request):
self.request_value = request
return self
def get_updates_request(self, request):
self.get_updates_request_value = request
return self
def proxy(self, _proxy):
raise AssertionError("builder.proxy should not be called when request is set")
def get_updates_proxy(self, _proxy):
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
def build(self):
return self.app
def _make_telegram_update(
*,
chat_type: str = "group",
text: str | None = None,
caption: str | None = None,
entities=None,
caption_entities=None,
reply_to_message=None,
):
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
message = SimpleNamespace(
chat=SimpleNamespace(type=chat_type, is_forum=False),
chat_id=-100123,
text=text,
caption=caption,
entities=entities or [],
caption_entities=caption_entities or [],
reply_to_message=reply_to_message,
photo=None,
voice=None,
audio=None,
document=None,
media_group_id=None,
message_thread_id=None,
message_id=1,
)
return SimpleNamespace(message=message, effective_user=user)
@pytest.mark.asyncio
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
config = TelegramConfig(
enabled=True,
token="123:abc",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
)
bus = MessageBus()
channel = TelegramChannel(config, bus)
app = _FakeApp(lambda: setattr(channel, "_running", False))
builder = _FakeBuilder(app)
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
monkeypatch.setattr(
"nanobot.channels.telegram.Application",
SimpleNamespace(builder=lambda: builder),
)
await channel.start()
assert len(_FakeHTTPXRequest.instances) == 1
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
assert builder.request_value is _FakeHTTPXRequest.instances[0]
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
def test_derive_topic_session_key_uses_thread_id() -> None:
message = SimpleNamespace(
chat=SimpleNamespace(type="supergroup"),
chat_id=-100123,
message_thread_id=42,
)
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
def test_get_extension_falls_back_to_original_filename() -> None:
channel = TelegramChannel(TelegramConfig(), MessageBus())
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
def test_telegram_group_policy_defaults_to_mention() -> None:
assert TelegramConfig().group_policy == "mention"
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
assert channel.is_allowed("12345|carol") is True
assert channel.is_allowed("99999|alice") is True
assert channel.is_allowed("67890|bob") is True
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
assert channel.is_allowed("attacker|alice|extra") is False
assert channel.is_allowed("not-a-number|alice") is False
@pytest.mark.asyncio
async def test_send_progress_keeps_message_in_topic() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"_progress": True, "message_thread_id": 42},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
@pytest.mark.asyncio
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
channel._message_threads[("123", 10)] = 42
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"message_id": 10},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
@pytest.mark.asyncio
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
assert handled == []
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
assert len(handled) == 2
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_caption_mention() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
None,
)
assert len(handled) == 1
assert handled[0]["content"] == "@nanobot_test photo"
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
assert len(handled) == 1
@pytest.mark.asyncio
async def test_group_policy_open_accepts_plain_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello group"), None)
assert len(handled) == 1
assert channel._app.bot.get_me_calls == 0

View File

@@ -106,3 +106,301 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
paths = ExecTool._extract_absolute_paths(cmd) paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths assert "/tmp/data.txt" in paths
assert "/tmp/out.txt" in paths assert "/tmp/out.txt" in paths
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "~/.nanobot/config.json" in paths
assert "~/out.txt" in paths
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "~/.nanobot/config.json" in paths
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests ---
class CastTestTool(Tool):
"""Minimal tool for testing cast_params."""
def __init__(self, schema: dict[str, Any]) -> None:
self._schema = schema
@property
def name(self) -> str:
return "cast_test"
@property
def description(self) -> str:
return "test tool for casting"
@property
def parameters(self) -> dict[str, Any]:
return self._schema
async def execute(self, **kwargs: Any) -> str:
return "ok"
def test_cast_params_string_to_int() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "42"})
assert result["count"] == 42
assert isinstance(result["count"], int)
def test_cast_params_string_to_number() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "3.14"})
assert result["rate"] == 3.14
assert isinstance(result["rate"], float)
def test_cast_params_string_to_bool() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"enabled": {"type": "boolean"}},
}
)
assert tool.cast_params({"enabled": "true"})["enabled"] is True
assert tool.cast_params({"enabled": "false"})["enabled"] is False
assert tool.cast_params({"enabled": "1"})["enabled"] is True
def test_cast_params_array_items() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"nums": {"type": "array", "items": {"type": "integer"}},
},
}
)
result = tool.cast_params({"nums": ["1", "2", "3"]})
assert result["nums"] == [1, 2, 3]
def test_cast_params_nested_object() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"port": {"type": "integer"},
"debug": {"type": "boolean"},
},
},
},
}
)
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
assert result["config"]["port"] == 8080
assert result["config"]["debug"] is True
def test_cast_params_bool_not_cast_to_int() -> None:
"""Booleans should not be silently cast to integers."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": True})
assert result["count"] is True
errors = tool.validate_params(result)
assert any("count should be integer" in e for e in errors)
def test_cast_params_preserves_empty_string() -> None:
"""Empty strings should be preserved for string type."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string"}},
}
)
result = tool.cast_params({"name": ""})
assert result["name"] == ""
def test_cast_params_bool_string_false() -> None:
"""Test that 'false', '0', 'no' strings convert to False."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
assert tool.cast_params({"flag": "false"})["flag"] is False
assert tool.cast_params({"flag": "False"})["flag"] is False
assert tool.cast_params({"flag": "0"})["flag"] is False
assert tool.cast_params({"flag": "no"})["flag"] is False
assert tool.cast_params({"flag": "NO"})["flag"] is False
def test_cast_params_bool_string_invalid() -> None:
"""Invalid boolean strings should not be cast."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
# Invalid strings should be preserved (validation will catch them)
result = tool.cast_params({"flag": "random"})
assert result["flag"] == "random"
result = tool.cast_params({"flag": "maybe"})
assert result["flag"] == "maybe"
def test_cast_params_invalid_string_to_int() -> None:
"""Invalid strings should not be cast to integer."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "abc"})
assert result["count"] == "abc" # Original value preserved
result = tool.cast_params({"count": "12.5.7"})
assert result["count"] == "12.5.7"
def test_cast_params_invalid_string_to_number() -> None:
"""Invalid strings should not be cast to number."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "not_a_number"})
assert result["rate"] == "not_a_number"
def test_validate_params_bool_not_accepted_as_number() -> None:
"""Booleans should not pass number validation."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
errors = tool.validate_params({"rate": False})
assert any("rate should be number" in e for e in errors)
def test_cast_params_none_values() -> None:
"""Test None handling for different types."""
tool = CastTestTool(
{
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "integer"},
"items": {"type": "array"},
"config": {"type": "object"},
},
}
)
result = tool.cast_params(
{
"name": None,
"count": None,
"items": None,
"config": None,
}
)
# None should be preserved for all types
assert result["name"] is None
assert result["count"] is None
assert result["items"] is None
assert result["config"] is None
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
"""Single values should NOT be automatically wrapped into arrays."""
tool = CastTestTool(
{
"type": "object",
"properties": {"items": {"type": "array"}},
}
)
# Non-array values should be preserved (validation will catch them)
result = tool.cast_params({"items": 5})
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]
# --- ExecTool enhancement tests ---
async def test_exec_always_returns_exit_code() -> None:
"""Exit code should appear in output even on success (exit 0)."""
tool = ExecTool()
result = await tool.execute(command="echo hello")
assert "Exit code: 0" in result
assert "hello" in result
async def test_exec_head_tail_truncation() -> None:
"""Long output should preserve both head and tail."""
tool = ExecTool()
# Generate output that exceeds _MAX_OUTPUT
big = "A" * 6000 + "\n" + "B" * 6000
result = await tool.execute(command=f"echo '{big}'")
assert "chars truncated" in result
# Head portion should start with As
assert result.startswith("A")
# Tail portion should end with the exit code which comes after Bs
assert "Exit code:" in result
async def test_exec_timeout_parameter() -> None:
"""LLM-supplied timeout should override the constructor default."""
tool = ExecTool(timeout=60)
# A very short timeout should cause the command to be killed
result = await tool.execute(command="sleep 10", timeout=1)
assert "timed out" in result
assert "1 seconds" in result
async def test_exec_timeout_capped_at_max() -> None:
"""Timeout values above _MAX_TIMEOUT should be clamped."""
tool = ExecTool()
# Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result