merge main into pr-673 and keep slack empty-text fallback without regressing thread/media support

This commit is contained in:
Re-bin
2026-03-07 16:51:48 +00:00
88 changed files with 9997 additions and 2492 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.worktrees/
.assets .assets
.env .env
*.pyc *.pyc
@@ -19,4 +20,4 @@ __pycache__/
poetry.lock poetry.lock
.pytest_cache/ .pytest_cache/
botpy.log botpy.log
tests/

436
README.md
View File

@@ -12,26 +12,48 @@
</p> </p>
</div> </div>
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw) 🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw).
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines. ⚡️ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw.
📏 Real-time line count: **3,663 lines** (run `bash core_agent_lines.sh` to verify anytime) 📏 Real-time line count: run `bash core_agent_lines.sh` to verify anytime.
## 📢 News ## 📢 News
- **2026-02-13** 🎉 Released v0.1.3.post7 — includes security hardening and multiple improvements. All users are recommended to upgrade to the latest version. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details. - **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements.
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
<details>
<summary>Earlier news</summary>
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
- **2026-02-15** 🔑 nanobot now supports OpenAI Codex provider with OAuth login support.
- **2026-02-14** 🔌 nanobot now supports MCP! See [MCP section](#mcp-model-context-protocol) for details.
- **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it! - **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
- **2026-02-10** 🎉 Released v0.1.3.post6 with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431). - **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms! - **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers). - **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
- **2026-02-07** 🚀 Released v0.1.3.post5 with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details. - **2026-02-07** 🚀 Released **v0.1.3.post5** with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details.
- **2026-02-06** ✨ Added Moonshot/Kimi provider, Discord integration, and enhanced security hardening! - **2026-02-06** ✨ Added Moonshot/Kimi provider, Discord integration, and enhanced security hardening!
- **2026-02-05** ✨ Added Feishu channel, DeepSeek provider, and enhanced scheduled tasks support! - **2026-02-05** ✨ Added Feishu channel, DeepSeek provider, and enhanced scheduled tasks support!
- **2026-02-04** 🚀 Released v0.1.3.post4 with multi-provider & Docker support! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post4) for details. - **2026-02-04** 🚀 Released **v0.1.3.post4** with multi-provider & Docker support! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post4) for details.
- **2026-02-03** ⚡ Integrated vLLM for local LLM support and improved natural language task scheduling! - **2026-02-03** ⚡ Integrated vLLM for local LLM support and improved natural language task scheduling!
- **2026-02-02** 🎉 nanobot officially launched! Welcome to try 🐈 nanobot! - **2026-02-02** 🎉 nanobot officially launched! Welcome to try 🐈 nanobot!
</details>
## Key Features of nanobot: ## Key Features of nanobot:
🪶 **Ultra-Lightweight**: Just ~4,000 lines of core agent code — 99% smaller than Clawdbot. 🪶 **Ultra-Lightweight**: Just ~4,000 lines of core agent code — 99% smaller than Clawdbot.
@@ -107,17 +129,26 @@ nanobot onboard
**2. Configure** (`~/.nanobot/config.json`) **2. Configure** (`~/.nanobot/config.json`)
For OpenRouter - recommended for global users: Add or merge these **two parts** into your config (other options have defaults).
*Set your API key* (e.g. OpenRouter, recommended for global users):
```json ```json
{ {
"providers": { "providers": {
"openrouter": { "openrouter": {
"apiKey": "sk-or-v1-xxx" "apiKey": "sk-or-v1-xxx"
} }
}, }
}
```
*Set your model* (optionally pin a provider — defaults to auto-detection):
```json
{
"agents": { "agents": {
"defaults": { "defaults": {
"model": "anthropic/claude-opus-4-5" "model": "anthropic/claude-opus-4-5",
"provider": "openrouter"
} }
} }
} }
@@ -126,63 +157,26 @@ For OpenRouter - recommended for global users:
**3. Chat** **3. Chat**
```bash ```bash
nanobot agent -m "What is 2+2?" nanobot agent
``` ```
That's it! You have a working AI assistant in 2 minutes. That's it! You have a working AI assistant in 2 minutes.
## 🖥️ Local Models (vLLM)
Run nanobot with your own local models using vLLM or any OpenAI-compatible server.
**1. Start your vLLM server**
```bash
vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
```
**2. Configure** (`~/.nanobot/config.json`)
```json
{
"providers": {
"vllm": {
"apiKey": "dummy",
"apiBase": "http://localhost:8000/v1"
}
},
"agents": {
"defaults": {
"model": "meta-llama/Llama-3.1-8B-Instruct"
}
}
}
```
**3. Chat**
```bash
nanobot agent -m "Hello from my local LLM!"
```
> [!TIP]
> The `apiKey` can be any non-empty string for local servers that don't require authentication.
## 💬 Chat Apps ## 💬 Chat Apps
Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, Mochat, DingTalk, Slack, Email, or QQ — anytime, anywhere. Connect nanobot to your favorite chat platform.
| Channel | Setup | | Channel | What you need |
|---------|-------| |---------|---------------|
| **Telegram** | Easy (just a token) | | **Telegram** | Bot token from @BotFather |
| **Discord** | Easy (bot token + intents) | | **Discord** | Bot token + Message Content intent |
| **WhatsApp** | Medium (scan QR) | | **WhatsApp** | QR code scan |
| **Feishu** | Medium (app credentials) | | **Feishu** | App ID + App Secret |
| **Mochat** | Medium (claw token + websocket) | | **Mochat** | Claw token (auto-setup available) |
| **DingTalk** | Medium (app credentials) | | **DingTalk** | App Key + App Secret |
| **Slack** | Medium (bot + app tokens) | | **Slack** | Bot token + App-Level token |
| **Email** | Medium (IMAP/SMTP credentials) | | **Email** | IMAP/SMTP credentials |
| **QQ** | Easy (app credentials) | | **QQ** | App ID + App Secret |
<details> <details>
<summary><b>Telegram</b> (Recommended)</summary> <summary><b>Telegram</b> (Recommended)</summary>
@@ -299,12 +293,18 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
"discord": { "discord": {
"enabled": true, "enabled": true,
"token": "YOUR_BOT_TOKEN", "token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"] "allowFrom": ["YOUR_USER_ID"],
"groupPolicy": "mention"
} }
} }
} }
``` ```
> `groupPolicy` controls how the bot responds in group channels:
> - `"mention"` (default) — Only respond when @mentioned
> - `"open"` — Respond to all messages
> DMs always respond when the sender is in `allowFrom`.
**5. Invite the bot** **5. Invite the bot**
- OAuth2 → URL Generator - OAuth2 → URL Generator
- Scopes: `bot` - Scopes: `bot`
@@ -319,6 +319,72 @@ nanobot gateway
</details> </details>
<details>
<summary><b>Matrix (Element)</b></summary>
Install Matrix dependencies first:
```bash
pip install nanobot-ai[matrix]
```
**1. Create/choose a Matrix account**
- Create or reuse a Matrix account on your homeserver (for example `matrix.org`).
- Confirm you can log in with Element.
**2. Get credentials**
- You need:
- `userId` (example: `@nanobot:matrix.org`)
- `accessToken`
- `deviceId` (recommended so sync tokens can be restored across restarts)
- You can obtain these from your homeserver login API (`/_matrix/client/v3/login`) or from your client's advanced session settings.
**3. Configure**
```json
{
"channels": {
"matrix": {
"enabled": true,
"homeserver": "https://matrix.org",
"userId": "@nanobot:matrix.org",
"accessToken": "syt_xxx",
"deviceId": "NANOBOT01",
"e2eeEnabled": true,
"allowFrom": ["@your_user:matrix.org"],
"groupPolicy": "open",
"groupAllowFrom": [],
"allowRoomMentions": false,
"maxMediaBytes": 20971520
}
}
}
```
> Keep a persistent `matrix-store` and stable `deviceId` — encrypted session state is lost if these change across restarts.
| Option | Description |
|--------|-------------|
| `allowFrom` | User IDs allowed to interact. Empty = all senders. |
| `groupPolicy` | `open` (default), `mention`, or `allowlist`. |
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
| `allowRoomMentions` | Accept `@room` mentions in mention mode. |
| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. |
| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. |
**4. Run**
```bash
nanobot gateway
```
</details>
<details> <details>
<summary><b>WhatsApp</b></summary> <summary><b>WhatsApp</b></summary>
@@ -354,6 +420,10 @@ nanobot channels login
nanobot gateway nanobot gateway
``` ```
> WhatsApp bridge updates are not applied automatically for existing installations.
> If you upgrade nanobot and need the latest WhatsApp bridge, run:
> `rm -rf ~/.nanobot/bridge && nanobot channels login`
</details> </details>
<details> <details>
@@ -364,7 +434,7 @@ Uses **WebSocket** long connection — no public IP required.
**1. Create a Feishu bot** **1. Create a Feishu bot**
- Visit [Feishu Open Platform](https://open.feishu.cn/app) - Visit [Feishu Open Platform](https://open.feishu.cn/app)
- Create a new app → Enable **Bot** capability - Create a new app → Enable **Bot** capability
- **Permissions**: Add `im:message` (send messages) - **Permissions**: Add `im:message` (send messages) and `im:message.p2p_msg:readonly` (receive messages)
- **Events**: Add `im.message.receive_v1` (receive messages) - **Events**: Add `im.message.receive_v1` (receive messages)
- Select **Long Connection** mode (requires running nanobot first to establish connection) - Select **Long Connection** mode (requires running nanobot first to establish connection)
- Get **App ID** and **App Secret** from "Credentials & Basic Info" - Get **App ID** and **App Secret** from "Credentials & Basic Info"
@@ -381,14 +451,14 @@ Uses **WebSocket** long connection — no public IP required.
"appSecret": "xxx", "appSecret": "xxx",
"encryptKey": "", "encryptKey": "",
"verificationToken": "", "verificationToken": "",
"allowFrom": [] "allowFrom": ["ou_YOUR_OPEN_ID"]
} }
} }
} }
``` ```
> `encryptKey` and `verificationToken` are optional for Long Connection mode. > `encryptKey` and `verificationToken` are optional for Long Connection mode.
> `allowFrom`: Leave empty to allow all users, or add `["ou_xxx"]` to restrict access. > `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
**3. Run** **3. Run**
@@ -418,7 +488,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports
**3. Configure** **3. Configure**
> - `allowFrom`: Leave empty for public access, or add user openids to restrict. You can find openids in the nanobot logs when a user messages the bot. > - `allowFrom`: Add your openid (find it in nanobot logs when you message the bot). Use `["*"]` for public access.
> - For production: submit a review in the bot console and publish. See [QQ Bot Docs](https://bot.q.qq.com/wiki/) for the full publishing flow. > - For production: submit a review in the bot console and publish. See [QQ Bot Docs](https://bot.q.qq.com/wiki/) for the full publishing flow.
```json ```json
@@ -428,7 +498,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports
"enabled": true, "enabled": true,
"appId": "YOUR_APP_ID", "appId": "YOUR_APP_ID",
"secret": "YOUR_APP_SECRET", "secret": "YOUR_APP_SECRET",
"allowFrom": [] "allowFrom": ["YOUR_OPENID"]
} }
} }
} }
@@ -467,13 +537,13 @@ Uses **Stream Mode** — no public IP required.
"enabled": true, "enabled": true,
"clientId": "YOUR_APP_KEY", "clientId": "YOUR_APP_KEY",
"clientSecret": "YOUR_APP_SECRET", "clientSecret": "YOUR_APP_SECRET",
"allowFrom": [] "allowFrom": ["YOUR_STAFF_ID"]
} }
} }
} }
``` ```
> `allowFrom`: Leave empty to allow all users, or add `["staffId"]` to restrict access. > `allowFrom`: Add your staff ID. Use `["*"]` to allow all users.
**3. Run** **3. Run**
@@ -508,6 +578,7 @@ Uses **Socket Mode** — no public URL required.
"enabled": true, "enabled": true,
"botToken": "xoxb-...", "botToken": "xoxb-...",
"appToken": "xapp-...", "appToken": "xapp-...",
"allowFrom": ["YOUR_SLACK_USER_ID"],
"groupPolicy": "mention" "groupPolicy": "mention"
} }
} }
@@ -541,7 +612,7 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl
**2. Configure** **2. Configure**
> - `consentGranted` must be `true` to allow mailbox access. This is a safety gate — set `false` to fully disable. > - `consentGranted` must be `true` to allow mailbox access. This is a safety gate — set `false` to fully disable.
> - `allowFrom`: Leave empty to accept emails from anyone, or restrict to specific senders. > - `allowFrom`: Add your email address. Use `["*"]` to accept emails from anyone.
> - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly. > - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly.
> - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies. > - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies.
@@ -596,27 +667,64 @@ Config file: `~/.nanobot/config.json`
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. > - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config. > - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
| Provider | Purpose | Get API Key | | Provider | Purpose | Get API Key |
|----------|---------|-------------| |----------|---------|-------------|
| `custom` | Any OpenAI-compatible endpoint | — | | `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | | `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
| `minimax` | LLM (MiniMax direct) | [platform.minimax.io](https://platform.minimax.io) | | `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) |
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
| `vllm` | LLM (local, any OpenAI-compatible server) | — | | `vllm` | LLM (local, any OpenAI-compatible server) | — |
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
<details>
<summary><b>OpenAI Codex (OAuth)</b></summary>
Codex uses OAuth instead of API keys. Requires a ChatGPT Plus or Pro account.
**1. Login:**
```bash
nanobot provider login openai-codex
```
**2. Set model** (merge into `~/.nanobot/config.json`):
```json
{
"agents": {
"defaults": {
"model": "openai-codex/gpt-5.1-codex"
}
}
}
```
**3. Chat:**
```bash
nanobot agent -m "Hello!"
```
> Docker users: use `docker run -it` for interactive OAuth login.
</details>
<details> <details>
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary> <summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
If your provider is not listed above but exposes an **OpenAI-compatible API** (e.g. Together AI, Fireworks, Azure OpenAI, self-hosted endpoints), use the `custom` provider: Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Bypasses LiteLLM; model name is passed as-is.
```json ```json
{ {
@@ -634,7 +742,44 @@ If your provider is not listed above but exposes an **OpenAI-compatible API** (e
} }
``` ```
> The `custom` provider routes through LiteLLM's OpenAI-compatible path. It works with any endpoint that follows the OpenAI chat completions API format. The model name is passed directly to the endpoint without any prefix. > For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`).
</details>
<details>
<summary><b>vLLM (local / OpenAI-compatible)</b></summary>
Run your own model with vLLM or any OpenAI-compatible server, then add to config:
**1. Start the server** (example):
```bash
vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
```
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
*Provider (key can be any non-empty string for local):*
```json
{
"providers": {
"vllm": {
"apiKey": "dummy",
"apiBase": "http://localhost:8000/v1"
}
}
}
```
*Model:*
```json
{
"agents": {
"defaults": {
"model": "meta-llama/Llama-3.1-8B-Instruct"
}
}
}
```
</details> </details>
@@ -699,6 +844,12 @@ Add MCP servers to your `config.json`:
"filesystem": { "filesystem": {
"command": "npx", "command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"] "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"]
},
"my-remote-mcp": {
"url": "https://example.com/mcp/",
"headers": {
"Authorization": "Bearer xxxxx"
}
} }
} }
} }
@@ -710,7 +861,22 @@ Two transport modes are supported:
| Mode | Config | Example | | Mode | Config | Example |
|------|--------|---------| |------|--------|---------|
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` | | **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
| **HTTP** | `url` | Remote endpoint (`https://mcp.example.com/sse`) | | **HTTP** | `url` + `headers` (optional) | Remote endpoint (`https://mcp.example.com/sse`) |
Use `toolTimeout` to override the default 30s per-call timeout for slow servers:
```json
{
"tools": {
"mcpServers": {
"my-slow-server": {
"url": "https://example.com/mcp/",
"toolTimeout": 120
}
}
}
}
```
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed. MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
@@ -719,14 +885,44 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
### Security ### Security
> [!TIP]
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent. > For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
> **Change in source / post-`v0.1.4.post3`:** In `v0.1.4.post3` and earlier, an empty `allowFrom` means "allow all senders". In newer versions (including building from source), **empty `allowFrom` denies all access by default**. To allow all senders, set `"allowFrom": ["*"]`.
| Option | Default | Description | | Option | Default | Description |
|--------|---------|-------------| |--------|---------|-------------|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. | | `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. | | `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
## Multiple Instances
Run multiple nanobot instances simultaneously, each with its own workspace and configuration.
```bash
# Instance A - Telegram bot
nanobot gateway -w ~/.nanobot/botA -p 18791
# Instance B - Discord bot
nanobot gateway -w ~/.nanobot/botB -p 18792
# Instance C - Using custom config file
nanobot gateway -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json -p 18793
```
| Option | Short | Description |
|--------|-------|-------------|
| `--workspace` | `-w` | Workspace directory (default: `~/.nanobot/workspace`) |
| `--config` | `-c` | Config file path (default: `~/.nanobot/config.json`) |
| `--port` | `-p` | Gateway port (default: `18790`) |
Each instance has its own:
- Workspace directory (MEMORY.md, HEARTBEAT.md, session files)
- Cron jobs storage (`workspace/cron/jobs.json`)
- Configuration (if using `--config`)
## CLI Reference ## CLI Reference
| Command | Description | | Command | Description |
@@ -738,26 +934,30 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
| `nanobot agent --logs` | Show runtime logs during chat | | `nanobot agent --logs` | Show runtime logs during chat |
| `nanobot gateway` | Start the gateway | | `nanobot gateway` | Start the gateway |
| `nanobot status` | Show status | | `nanobot status` | Show status |
| `nanobot provider login openai-codex` | OAuth login for providers |
| `nanobot channels login` | Link WhatsApp (scan QR) | | `nanobot channels login` | Link WhatsApp (scan QR) |
| `nanobot channels status` | Show channel status | | `nanobot channels status` | Show channel status |
Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`. Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`.
<details> <details>
<summary><b>Scheduled Tasks (Cron)</b></summary> <summary><b>Heartbeat (Periodic Tasks)</b></summary>
```bash The gateway wakes up every 30 minutes and checks `HEARTBEAT.md` in your workspace (`~/.nanobot/workspace/HEARTBEAT.md`). If the file has tasks, the agent executes them and delivers results to your most recently active chat channel.
# Add a job
nanobot cron add --name "daily" --message "Good morning!" --cron "0 9 * * *"
nanobot cron add --name "hourly" --message "Check status" --every 3600
# List jobs **Setup:** edit `~/.nanobot/workspace/HEARTBEAT.md` (created automatically by `nanobot onboard`):
nanobot cron list
# Remove a job ```markdown
nanobot cron remove <job_id> ## Periodic Tasks
- [ ] Check weather forecast and send a summary
- [ ] Scan inbox for urgent emails
``` ```
The agent can also manage this file itself — ask it to "add a periodic task" and it will update `HEARTBEAT.md` for you.
> **Note:** The gateway must be running (`nanobot gateway`) and you must have chatted with the bot at least once so it knows which channel to deliver to.
</details> </details>
## 🐳 Docker ## 🐳 Docker
@@ -765,7 +965,21 @@ nanobot cron remove <job_id>
> [!TIP] > [!TIP]
> The `-v ~/.nanobot:/root/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts. > The `-v ~/.nanobot:/root/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts.
Build and run nanobot in a container: ### Docker Compose
```bash
docker compose run --rm nanobot-cli onboard # first-time setup
vim ~/.nanobot/config.json # add API keys
docker compose up -d nanobot-gateway # start gateway
```
```bash
docker compose run --rm nanobot-cli agent -m "Hello!" # run CLI
docker compose logs -f nanobot-gateway # view logs
docker compose down # stop
```
### Docker
```bash ```bash
# Build the image # Build the image
@@ -785,6 +999,59 @@ docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!"
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
``` ```
## 🐧 Linux Service
Run the gateway as a systemd user service so it starts automatically and restarts on failure.
**1. Find the nanobot binary path:**
```bash
which nanobot # e.g. /home/user/.local/bin/nanobot
```
**2. Create the service file** at `~/.config/systemd/user/nanobot-gateway.service` (replace `ExecStart` path if needed):
```ini
[Unit]
Description=Nanobot Gateway
After=network.target
[Service]
Type=simple
ExecStart=%h/.local/bin/nanobot gateway
Restart=always
RestartSec=10
NoNewPrivileges=yes
ProtectSystem=strict
ReadWritePaths=%h
[Install]
WantedBy=default.target
```
**3. Enable and start:**
```bash
systemctl --user daemon-reload
systemctl --user enable --now nanobot-gateway
```
**Common operations:**
```bash
systemctl --user status nanobot-gateway # check status
systemctl --user restart nanobot-gateway # restart after config changes
journalctl --user -u nanobot-gateway -f # follow logs
```
If you edit the `.service` file itself, run `systemctl --user daemon-reload` before restarting.
> **Note:** User services only run while you are logged in. To keep the gateway running after logout, enable lingering:
>
> ```bash
> loginctl enable-linger $USER
> ```
## 📁 Project Structure ## 📁 Project Structure
``` ```
@@ -813,7 +1080,6 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
**Roadmap** — Pick an item and [open a PR](https://github.com/HKUDS/nanobot/pulls)! **Roadmap** — Pick an item and [open a PR](https://github.com/HKUDS/nanobot/pulls)!
- [x] **Voice Transcription** — Support for Groq Whisper (Issue #13)
- [ ] **Multi-modal** — See and hear (images, voice, video) - [ ] **Multi-modal** — See and hear (images, voice, video)
- [ ] **Long-term memory** — Never forget important context - [ ] **Long-term memory** — Never forget important context
- [ ] **Better reasoning** — Multi-step planning and reflection - [ ] **Better reasoning** — Multi-step planning and reflection

View File

@@ -5,7 +5,7 @@
If you discover a security vulnerability in nanobot, please report it by: If you discover a security vulnerability in nanobot, please report it by:
1. **DO NOT** open a public GitHub issue 1. **DO NOT** open a public GitHub issue
2. Create a private security advisory on GitHub or contact the repository maintainers 2. Create a private security advisory on GitHub or contact the repository maintainers (xubinrencs@gmail.com)
3. Include: 3. Include:
- Description of the vulnerability - Description of the vulnerability
- Steps to reproduce - Steps to reproduce
@@ -55,7 +55,7 @@ chmod 600 ~/.nanobot/config.json
``` ```
**Security Notes:** **Security Notes:**
- Empty `allowFrom` list will **ALLOW ALL** users (open by default for personal use) - In `v0.1.4.post3` and earlier, an empty `allowFrom` allows all users. In newer versions (including source builds), **empty `allowFrom` denies all access** — set `["*"]` to explicitly allow everyone.
- Get your Telegram user ID from `@userinfobot` - Get your Telegram user ID from `@userinfobot`
- Use full phone numbers with country code for WhatsApp - Use full phone numbers with country code for WhatsApp
- Review access logs regularly for unauthorized access attempts - Review access logs regularly for unauthorized access attempts
@@ -212,9 +212,8 @@ If you suspect a security breach:
- Input length limits on HTTP requests - Input length limits on HTTP requests
✅ **Authentication** ✅ **Authentication**
- Allow-list based access control - Allow-list based access control — in `v0.1.4.post3` and earlier empty means allow all; in newer versions empty means deny all (`["*"]` to explicitly allow all)
- Failed authentication attempt logging - Failed authentication attempt logging
- Open by default (configure allowFrom for production use)
✅ **Resource Protection** ✅ **Resource Protection**
- Command execution timeouts (60s default) - Command execution timeouts (60s default)

View File

@@ -9,11 +9,17 @@ import makeWASocket, {
useMultiFileAuthState, useMultiFileAuthState,
fetchLatestBaileysVersion, fetchLatestBaileysVersion,
makeCacheableSignalKeyStore, makeCacheableSignalKeyStore,
downloadMediaMessage,
extractMessageContent as baileysExtractMessageContent,
} from '@whiskeysockets/baileys'; } from '@whiskeysockets/baileys';
import { Boom } from '@hapi/boom'; import { Boom } from '@hapi/boom';
import qrcode from 'qrcode-terminal'; import qrcode from 'qrcode-terminal';
import pino from 'pino'; import pino from 'pino';
import { writeFile, mkdir } from 'fs/promises';
import { join } from 'path';
import { homedir } from 'os';
import { randomBytes } from 'crypto';
const VERSION = '0.1.0'; const VERSION = '0.1.0';
@@ -24,6 +30,7 @@ export interface InboundMessage {
content: string; content: string;
timestamp: number; timestamp: number;
isGroup: boolean; isGroup: boolean;
media?: string[];
} }
export interface WhatsAppClientOptions { export interface WhatsAppClientOptions {
@@ -110,14 +117,33 @@ export class WhatsAppClient {
if (type !== 'notify') return; if (type !== 'notify') return;
for (const msg of messages) { for (const msg of messages) {
// Skip own messages
if (msg.key.fromMe) continue; if (msg.key.fromMe) continue;
// Skip status updates
if (msg.key.remoteJid === 'status@broadcast') continue; if (msg.key.remoteJid === 'status@broadcast') continue;
const content = this.extractMessageContent(msg); const unwrapped = baileysExtractMessageContent(msg.message);
if (!content) continue; if (!unwrapped) continue;
const content = this.getTextContent(unwrapped);
let fallbackContent: string | null = null;
const mediaPaths: string[] = [];
if (unwrapped.imageMessage) {
fallbackContent = '[Image]';
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.documentMessage) {
fallbackContent = '[Document]';
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
unwrapped.documentMessage.fileName ?? undefined);
if (path) mediaPaths.push(path);
} else if (unwrapped.videoMessage) {
fallbackContent = '[Video]';
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
if (path) mediaPaths.push(path);
}
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
if (!finalContent && mediaPaths.length === 0) continue;
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false; const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
@@ -125,18 +151,45 @@ export class WhatsAppClient {
id: msg.key.id || '', id: msg.key.id || '',
sender: msg.key.remoteJid || '', sender: msg.key.remoteJid || '',
pn: msg.key.remoteJidAlt || '', pn: msg.key.remoteJidAlt || '',
content, content: finalContent,
timestamp: msg.messageTimestamp as number, timestamp: msg.messageTimestamp as number,
isGroup, isGroup,
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
}); });
} }
}); });
} }
private extractMessageContent(msg: any): string | null { private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
const message = msg.message; try {
if (!message) return null; const mediaDir = join(homedir(), '.nanobot', 'media');
await mkdir(mediaDir, { recursive: true });
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
let outFilename: string;
if (fileName) {
// Documents have a filename — use it with a unique prefix to avoid collisions
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
outFilename = prefix + fileName;
} else {
const mime = mimetype || 'application/octet-stream';
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
}
const filepath = join(mediaDir, outFilename);
await writeFile(filepath, buffer);
return filepath;
} catch (err) {
console.error('Failed to download media:', err);
return null;
}
}
private getTextContent(message: any): string | null {
// Text message // Text message
if (message.conversation) { if (message.conversation) {
return message.conversation; return message.conversation;
@@ -147,19 +200,19 @@ export class WhatsAppClient {
return message.extendedTextMessage.text; return message.extendedTextMessage.text;
} }
// Image with caption // Image with optional caption
if (message.imageMessage?.caption) { if (message.imageMessage) {
return `[Image] ${message.imageMessage.caption}`; return message.imageMessage.caption || '';
} }
// Video with caption // Video with optional caption
if (message.videoMessage?.caption) { if (message.videoMessage) {
return `[Video] ${message.videoMessage.caption}`; return message.videoMessage.caption || '';
} }
// Document with caption // Document with optional caption
if (message.documentMessage?.caption) { if (message.documentMessage) {
return `[Document] ${message.documentMessage.caption}`; return message.documentMessage.caption || '';
} }
// Voice/Audio message // Voice/Audio message

31
docker-compose.yml Normal file
View File

@@ -0,0 +1,31 @@
x-common-config: &common-config
build:
context: .
dockerfile: Dockerfile
volumes:
- ~/.nanobot:/root/.nanobot
services:
nanobot-gateway:
container_name: nanobot-gateway
<<: *common-config
command: ["gateway"]
restart: unless-stopped
ports:
- 18790:18790
deploy:
resources:
limits:
cpus: '1'
memory: 1G
reservations:
cpus: '0.25'
memory: 256M
nanobot-cli:
<<: *common-config
profiles:
- cli
command: ["status"]
stdin_open: true
tty: true

View File

@@ -2,5 +2,5 @@
nanobot - A lightweight AI agent framework nanobot - A lightweight AI agent framework
""" """
__version__ = "0.1.0" __version__ = "0.1.4.post3"
__logo__ = "🐈" __logo__ = "🐈"

View File

@@ -1,7 +1,7 @@
"""Agent core module.""" """Agent core module."""
from nanobot.agent.loop import AgentLoop
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader

View File

@@ -3,22 +3,21 @@
import base64 import base64
import mimetypes import mimetypes
import platform import platform
import time
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader
from nanobot.utils.helpers import detect_image_mime
class ContextBuilder: class ContextBuilder:
""" """Builds the context (system prompt + messages) for the agent."""
Builds the context (system prompt + messages) for the agent.
Assembles bootstrap files, memory, skills, and conversation history
into a coherent prompt for the LLM.
"""
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"] BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.workspace = workspace self.workspace = workspace
@@ -26,39 +25,23 @@ class ContextBuilder:
self.skills = SkillsLoader(workspace) self.skills = SkillsLoader(workspace)
def build_system_prompt(self, skill_names: list[str] | None = None) -> str: def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
""" """Build the system prompt from identity, bootstrap files, memory, and skills."""
Build the system prompt from bootstrap files, memory, and skills. parts = [self._get_identity()]
Args:
skill_names: Optional list of skills to include.
Returns:
Complete system prompt.
"""
parts = []
# Core identity
parts.append(self._get_identity())
# Bootstrap files
bootstrap = self._load_bootstrap_files() bootstrap = self._load_bootstrap_files()
if bootstrap: if bootstrap:
parts.append(bootstrap) parts.append(bootstrap)
# Memory context
memory = self.memory.get_memory_context() memory = self.memory.get_memory_context()
if memory: if memory:
parts.append(f"# Memory\n\n{memory}") parts.append(f"# Memory\n\n{memory}")
# Skills - progressive loading
# 1. Always-loaded skills: include full content
always_skills = self.skills.get_always_skills() always_skills = self.skills.get_always_skills()
if always_skills: if always_skills:
always_content = self.skills.load_skills_for_context(always_skills) always_content = self.skills.load_skills_for_context(always_skills)
if always_content: if always_content:
parts.append(f"# Active Skills\n\n{always_content}") parts.append(f"# Active Skills\n\n{always_content}")
# 2. Available skills: only show summary (agent uses read_file to load)
skills_summary = self.skills.build_skills_summary() skills_summary = self.skills.build_skills_summary()
if skills_summary: if skills_summary:
parts.append(f"""# Skills parts.append(f"""# Skills
@@ -72,42 +55,41 @@ Skills with available="false" need dependencies installed first - you can try in
def _get_identity(self) -> str: def _get_identity(self) -> str:
"""Get the core identity section.""" """Get the core identity section."""
from datetime import datetime
import time as _time
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
tz = _time.strftime("%Z") or "UTC"
workspace_path = str(self.workspace.expanduser().resolve()) workspace_path = str(self.workspace.expanduser().resolve())
system = platform.system() system = platform.system()
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
return f"""# nanobot 🐈 return f"""# nanobot 🐈
You are nanobot, a helpful AI assistant. You have access to tools that allow you to: You are nanobot, a helpful AI assistant.
- Read, write, and edit files
- Execute shell commands
- Search the web and fetch web pages
- Send messages to users on chat channels
- Spawn subagents for complex background tasks
## Current Time
{now} ({tz})
## Runtime ## Runtime
{runtime} {runtime}
## Workspace ## Workspace
Your workspace is at: {workspace_path} Your workspace is at: {workspace_path}
- Long-term memory: {workspace_path}/memory/MEMORY.md - Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable) - History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
IMPORTANT: When responding to direct questions or conversations, reply directly with your text response. ## nanobot Guidelines
Only use the 'message' tool when you need to send a message to a specific chat channel (like WhatsApp). - State intent before tool calls, but NEVER predict or claim results before receiving them.
For normal conversation, just respond with text - do not call the message tool. - Before modifying a file, read it first. Do not assume files or directories exist.
- After writing or editing a file, re-read it if accuracy matters.
- If a tool call fails, analyze the error before retrying with a different approach.
- Ask for clarification when the request is ambiguous.
Always be helpful, accurate, and concise. When using tools, think step by step: what you know, what you need, and why you chose this tool. Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel."""
When remembering something important, write to {workspace_path}/memory/MEMORY.md
To recall past events, grep {workspace_path}/memory/HISTORY.md""" @staticmethod
def _build_runtime_context(channel: str | None, chat_id: str | None) -> str:
"""Build untrusted runtime metadata block for injection before the user message."""
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
tz = time.strftime("%Z") or "UTC"
lines = [f"Current Time: {now} ({tz})"]
if channel and chat_id:
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
def _load_bootstrap_files(self) -> str: def _load_bootstrap_files(self) -> str:
"""Load all bootstrap files from workspace.""" """Load all bootstrap files from workspace."""
@@ -130,36 +112,22 @@ To recall past events, grep {workspace_path}/memory/HISTORY.md"""
channel: str | None = None, channel: str | None = None,
chat_id: str | None = None, chat_id: str | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Build the complete message list for an LLM call."""
Build the complete message list for an LLM call. runtime_ctx = self._build_runtime_context(channel, chat_id)
Args:
history: Previous conversation messages.
current_message: The new user message.
skill_names: Optional skills to include.
media: Optional list of local file paths for images/media.
channel: Current channel (telegram, feishu, etc.).
chat_id: Current chat/user ID.
Returns:
List of messages including system prompt.
"""
messages = []
# System prompt
system_prompt = self.build_system_prompt(skill_names)
if channel and chat_id:
system_prompt += f"\n\n## Current Session\nChannel: {channel}\nChat ID: {chat_id}"
messages.append({"role": "system", "content": system_prompt})
# History
messages.extend(history)
# Current message (with optional image attachments)
user_content = self._build_user_content(current_message, media) user_content = self._build_user_content(current_message, media)
messages.append({"role": "user", "content": user_content})
return messages # Merge runtime context and user content into a single user message
# to avoid consecutive same-role messages that some providers reject.
if isinstance(user_content, str):
merged = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
return [
{"role": "system", "content": self.build_system_prompt(skill_names)},
*history,
{"role": "user", "content": merged},
]
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
"""Build user message content with optional base64-encoded images.""" """Build user message content with optional base64-encoded images."""
@@ -169,10 +137,14 @@ To recall past events, grep {workspace_path}/memory/HISTORY.md"""
images = [] images = []
for path in media: for path in media:
p = Path(path) p = Path(path)
mime, _ = mimetypes.guess_type(path) if not p.is_file():
if not p.is_file() or not mime or not mime.startswith("image/"):
continue continue
b64 = base64.b64encode(p.read_bytes()).decode() raw = p.read_bytes()
# Detect real MIME type from magic bytes; fallback to filename guess
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
if not mime or not mime.startswith("image/"):
continue
b64 = base64.b64encode(raw).decode()
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
if not images: if not images:
@@ -180,59 +152,27 @@ To recall past events, grep {workspace_path}/memory/HISTORY.md"""
return images + [{"type": "text", "text": text}] return images + [{"type": "text", "text": text}]
def add_tool_result( def add_tool_result(
self, self, messages: list[dict[str, Any]],
messages: list[dict[str, Any]], tool_call_id: str, tool_name: str, result: str,
tool_call_id: str,
tool_name: str,
result: str
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Add a tool result to the message list."""
Add a tool result to the message list. messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
Args:
messages: Current message list.
tool_call_id: ID of the tool call.
tool_name: Name of the tool.
result: Tool execution result.
Returns:
Updated message list.
"""
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": tool_name,
"content": result
})
return messages return messages
def add_assistant_message( def add_assistant_message(
self, self, messages: list[dict[str, Any]],
messages: list[dict[str, Any]],
content: str | None, content: str | None,
tool_calls: list[dict[str, Any]] | None = None, tool_calls: list[dict[str, Any]] | None = None,
reasoning_content: str | None = None, reasoning_content: str | None = None,
thinking_blocks: list[dict] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Add an assistant message to the message list."""
Add an assistant message to the message list. msg: dict[str, Any] = {"role": "assistant", "content": content}
Args:
messages: Current message list.
content: Message content.
tool_calls: Optional tool calls.
reasoning_content: Thinking output (Kimi, DeepSeek-R1, etc.).
Returns:
Updated message list.
"""
msg: dict[str, Any] = {"role": "assistant", "content": content or ""}
if tool_calls: if tool_calls:
msg["tool_calls"] = tool_calls msg["tool_calls"] = tool_calls
if reasoning_content is not None:
# Thinking models reject history without this
if reasoning_content:
msg["reasoning_content"] = reasoning_content msg["reasoning_content"] = reasoning_content
if thinking_blocks:
msg["thinking_blocks"] = thinking_blocks
messages.append(msg) messages.append(msg)
return messages return messages

View File

@@ -1,29 +1,36 @@
"""Agent loop: the core processing engine.""" """Agent loop: the core processing engine."""
from __future__ import annotations
import asyncio import asyncio
from contextlib import AsyncExitStack
import json import json
import json_repair import re
import weakref
from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger from loguru import logger
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.agent.context import ContextBuilder
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.memory import MemoryStore
from nanobot.agent.subagent import SubagentManager
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
if TYPE_CHECKING:
from nanobot.config.schema import ChannelsConfig, ExecToolConfig
from nanobot.cron.service import CronService
class AgentLoop: class AgentLoop:
""" """
@@ -37,26 +44,31 @@ class AgentLoop:
5. Sends responses back 5. Sends responses back
""" """
_TOOL_RESULT_MAX_CHARS = 500
def __init__( def __init__(
self, self,
bus: MessageBus, bus: MessageBus,
provider: LLMProvider, provider: LLMProvider,
workspace: Path, workspace: Path,
model: str | None = None, model: str | None = None,
max_iterations: int = 20, max_iterations: int = 40,
temperature: float = 0.7, temperature: float = 0.1,
max_tokens: int = 4096, max_tokens: int = 4096,
memory_window: int = 50, memory_window: int = 100,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
exec_config: "ExecToolConfig | None" = None, web_proxy: str | None = None,
cron_service: "CronService | None" = None, exec_config: ExecToolConfig | None = None,
cron_service: CronService | None = None,
restrict_to_workspace: bool = False, restrict_to_workspace: bool = False,
session_manager: SessionManager | None = None, session_manager: SessionManager | None = None,
mcp_servers: dict | None = None, mcp_servers: dict | None = None,
channels_config: ChannelsConfig | None = None,
): ):
from nanobot.config.schema import ExecToolConfig from nanobot.config.schema import ExecToolConfig
from nanobot.cron.service import CronService
self.bus = bus self.bus = bus
self.channels_config = channels_config
self.provider = provider self.provider = provider
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
@@ -64,7 +76,9 @@ class AgentLoop:
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.memory_window = memory_window self.memory_window = memory_window
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
self.cron_service = cron_service self.cron_service = cron_service
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
@@ -79,7 +93,9 @@ class AgentLoop:
model=self.model, model=self.model,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
reasoning_effort=reasoning_effort,
brave_api_key=brave_api_key, brave_api_key=brave_api_key,
web_proxy=web_proxy,
exec_config=self.exec_config, exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace, restrict_to_workspace=restrict_to_workspace,
) )
@@ -88,74 +104,85 @@ class AgentLoop:
self._mcp_servers = mcp_servers or {} self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False self._mcp_connected = False
self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._processing_lock = asyncio.Lock()
self._register_default_tools() self._register_default_tools()
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
"""Register the default set of tools.""" """Register the default set of tools."""
# File tools (restrict to workspace if configured)
allowed_dir = self.workspace if self.restrict_to_workspace else None allowed_dir = self.workspace if self.restrict_to_workspace else None
self.tools.register(ReadFileTool(allowed_dir=allowed_dir)) for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(WriteFileTool(allowed_dir=allowed_dir)) self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(EditFileTool(allowed_dir=allowed_dir))
self.tools.register(ListDirTool(allowed_dir=allowed_dir))
# Shell tool
self.tools.register(ExecTool( self.tools.register(ExecTool(
working_dir=str(self.workspace), working_dir=str(self.workspace),
timeout=self.exec_config.timeout, timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace, restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
)) ))
self.tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
# Web tools self.tools.register(WebFetchTool(proxy=self.web_proxy))
self.tools.register(WebSearchTool(api_key=self.brave_api_key)) self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(WebFetchTool()) self.tools.register(SpawnTool(manager=self.subagents))
# Message tool
message_tool = MessageTool(send_callback=self.bus.publish_outbound)
self.tools.register(message_tool)
# Spawn tool (for subagents)
spawn_tool = SpawnTool(manager=self.subagents)
self.tools.register(spawn_tool)
# Cron tool (for scheduling)
if self.cron_service: if self.cron_service:
self.tools.register(CronTool(self.cron_service)) self.tools.register(CronTool(self.cron_service))
async def _connect_mcp(self) -> None: async def _connect_mcp(self) -> None:
"""Connect to configured MCP servers (one-time, lazy).""" """Connect to configured MCP servers (one-time, lazy)."""
if self._mcp_connected or not self._mcp_servers: if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
return return
self._mcp_connected = True self._mcp_connecting = True
from nanobot.agent.tools.mcp import connect_mcp_servers from nanobot.agent.tools.mcp import connect_mcp_servers
try:
self._mcp_stack = AsyncExitStack() self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__() await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True
except Exception as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack:
try:
await self._mcp_stack.aclose()
except Exception:
pass
self._mcp_stack = None
finally:
self._mcp_connecting = False
def _set_tool_context(self, channel: str, chat_id: str) -> None: def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Update context for all tools that need routing info.""" """Update context for all tools that need routing info."""
if message_tool := self.tools.get("message"): for name in ("message", "spawn", "cron"):
if isinstance(message_tool, MessageTool): if tool := self.tools.get(name):
message_tool.set_context(channel, chat_id) if hasattr(tool, "set_context"):
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
if spawn_tool := self.tools.get("spawn"): @staticmethod
if isinstance(spawn_tool, SpawnTool): def _strip_think(text: str | None) -> str | None:
spawn_tool.set_context(channel, chat_id) """Remove <think>…</think> blocks that some models embed in content."""
if not text:
return None
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
if cron_tool := self.tools.get("cron"): @staticmethod
if isinstance(cron_tool, CronTool): def _tool_hint(tool_calls: list) -> str:
cron_tool.set_context(channel, chat_id) """Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc):
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
val = next(iter(args.values()), None) if isinstance(args, dict) else None
if not isinstance(val, str):
return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
return ", ".join(_fmt(tc) for tc in tool_calls)
async def _run_agent_loop(self, initial_messages: list[dict]) -> tuple[str | None, list[str]]: async def _run_agent_loop(
""" self,
Run the agent iteration loop. initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None,
Args: ) -> tuple[str | None, list[str], list[dict]]:
initial_messages: Starting messages for the LLM conversation. """Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
Returns:
Tuple of (final_content, list_of_tools_used).
"""
messages = initial_messages messages = initial_messages
iteration = 0 iteration = 0
final_content = None final_content = None
@@ -170,16 +197,23 @@ class AgentLoop:
model=self.model, model=self.model,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
if on_progress:
thought = self._strip_think(response.content)
if thought:
await on_progress(thought)
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [ tool_call_dicts = [
{ {
"id": tc.id, "id": tc.id,
"type": "function", "type": "function",
"function": { "function": {
"name": tc.name, "name": tc.name,
"arguments": json.dumps(tc.arguments) "arguments": json.dumps(tc.arguments, ensure_ascii=False)
} }
} }
for tc in response.tool_calls for tc in response.tool_calls
@@ -187,49 +221,98 @@ class AgentLoop:
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
messages, response.content, tool_call_dicts, messages, response.content, tool_call_dicts,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
) )
for tool_call in response.tool_calls: for tool_call in response.tool_calls:
tools_used.append(tool_call.name) tools_used.append(tool_call.name)
args_str = json.dumps(tool_call.arguments, ensure_ascii=False) args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})") logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
result = await self.tools.execute(tool_call.name, tool_call.arguments) result = await self.tools.execute(tool_call.name, tool_call.arguments)
messages = self.context.add_tool_result( messages = self.context.add_tool_result(
messages, tool_call.id, tool_call.name, result messages, tool_call.id, tool_call.name, result
) )
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
else: else:
final_content = response.content clean = self._strip_think(response.content)
# Don't persist error responses to session history — they can
# poison the context and cause permanent 400 loops (#1303).
if response.finish_reason == "error":
logger.error("LLM returned error: {}", (clean or "")[:200])
final_content = clean or "Sorry, I encountered an error calling the AI model."
break
messages = self.context.add_assistant_message(
messages, clean, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
)
final_content = clean
break break
return final_content, tools_used if final_content is None and iteration >= self.max_iterations:
logger.warning("Max iterations ({}) reached", self.max_iterations)
final_content = (
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
"without completing the task. You can try breaking the task into smaller steps."
)
return final_content, tools_used, messages
async def run(self) -> None: async def run(self) -> None:
"""Run the agent loop, processing messages from the bus.""" """Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
self._running = True self._running = True
await self._connect_mcp() await self._connect_mcp()
logger.info("Agent loop started") logger.info("Agent loop started")
while self._running: while self._running:
try: try:
msg = await asyncio.wait_for( msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
self.bus.consume_inbound(),
timeout=1.0
)
try:
response = await self._process_message(msg)
if response:
await self.bus.publish_outbound(response)
except Exception as e:
logger.error(f"Error processing message: {e}")
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=f"Sorry, I encountered an error: {str(e)}"
))
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
if msg.content.strip().lower() == "/stop":
await self._handle_stop(msg)
else:
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task)
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
async def _handle_stop(self, msg: InboundMessage) -> None:
"""Cancel all active tasks and subagents for the session."""
tasks = self._active_tasks.pop(msg.session_key, [])
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
for t in tasks:
try:
await t
except (asyncio.CancelledError, Exception):
pass
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
total = cancelled + sub_cancelled
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content,
))
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message under the global lock."""
async with self._processing_lock:
try:
response = await self._process_message(msg)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
))
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
"""Close MCP connections.""" """Close MCP connections."""
if self._mcp_stack: if self._mcp_stack:
@@ -244,206 +327,172 @@ class AgentLoop:
self._running = False self._running = False
logger.info("Agent loop stopping") logger.info("Agent loop stopping")
async def _process_message(self, msg: InboundMessage, session_key: str | None = None) -> OutboundMessage | None: async def _process_message(
""" self,
Process a single inbound message. msg: InboundMessage,
session_key: str | None = None,
Args: on_progress: Callable[[str], Awaitable[None]] | None = None,
msg: The inbound message to process. ) -> OutboundMessage | None:
session_key: Override session key (used by process_direct). """Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
Returns:
The response message, or None if no response needed.
"""
# System messages route back via chat_id ("channel:chat_id")
if msg.channel == "system": if msg.channel == "system":
return await self._process_system_message(msg) channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
else ("cli", msg.chat_id))
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=self.memory_window)
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
)
final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.")
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}") logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
key = session_key or msg.session_key key = session_key or msg.session_key
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
# Handle slash commands # Slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
# Capture messages before clearing (avoid race condition with background task) lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
messages_to_archive = session.messages.copy() self._consolidating.add(session.key)
try:
async with lock:
snapshot = session.messages[session.last_consolidated:]
if snapshot:
temp = Session(key=session.key)
temp.messages = list(snapshot)
if not await self._consolidate_memory(temp, archive_all=True):
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
except Exception:
logger.exception("/new archival failed for {}", session.key)
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
finally:
self._consolidating.discard(session.key)
session.clear() session.clear()
self.sessions.save(session) self.sessions.save(session)
self.sessions.invalidate(session.key) self.sessions.invalidate(session.key)
async def _consolidate_and_cleanup():
temp_session = Session(key=session.key)
temp_session.messages = messages_to_archive
await self._consolidate_memory(temp_session, archive_all=True)
asyncio.create_task(_consolidate_and_cleanup())
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started. Memory consolidation in progress.") content="New session started.")
if cmd == "/help": if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands") content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
if len(session.messages) > self.memory_window: unconsolidated = len(session.messages) - session.last_consolidated
asyncio.create_task(self._consolidate_memory(session)) if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
self._consolidating.add(session.key)
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
self._set_tool_context(msg.channel, msg.chat_id) async def _consolidate_and_unlock():
try:
async with lock:
await self._consolidate_memory(session)
finally:
self._consolidating.discard(session.key)
_task = asyncio.current_task()
if _task is not None:
self._consolidation_tasks.discard(_task)
_task = asyncio.create_task(_consolidate_and_unlock())
self._consolidation_tasks.add(_task)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
history = session.get_history(max_messages=self.memory_window)
initial_messages = self.context.build_messages( initial_messages = self.context.build_messages(
history=session.get_history(max_messages=self.memory_window), history=history,
current_message=msg.content, current_message=msg.content,
media=msg.media if msg.media else None, media=msg.media if msg.media else None,
channel=msg.channel, channel=msg.channel, chat_id=msg.chat_id,
chat_id=msg.chat_id, )
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {})
meta["_progress"] = True
meta["_tool_hint"] = tool_hint
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
final_content, _, all_msgs = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress,
) )
final_content, tools_used = await self._run_agent_loop(initial_messages)
if final_content is None: if final_content is None:
final_content = "I've completed processing but have no response to give." final_content = "I've completed processing but have no response to give."
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}") logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
session.add_message("user", msg.content)
session.add_message("assistant", final_content,
tools_used=tools_used if tools_used else None)
self.sessions.save(session)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel, chat_id=msg.chat_id, content=final_content,
chat_id=msg.chat_id, metadata=msg.metadata or {},
content=final_content,
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
) )
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None: def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
""" """Save new-turn messages into session, truncating large tool results."""
Process a system message (e.g., subagent announce). from datetime import datetime
for m in messages[skip:]:
The chat_id field contains "original_channel:original_chat_id" to route entry = dict(m)
the response back to the correct destination. role, content = entry.get("role"), entry.get("content")
""" if role == "assistant" and not content and not entry.get("tool_calls"):
logger.info(f"Processing system message from {msg.sender_id}") continue # skip empty assistant messages — they poison session context
if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
# Parse origin from chat_id (format: "channel:chat_id") entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
if ":" in msg.chat_id: elif role == "user":
parts = msg.chat_id.split(":", 1) if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
origin_channel = parts[0] # Strip the runtime-context prefix, keep only the user text.
origin_chat_id = parts[1] parts = content.split("\n\n", 1)
if len(parts) > 1 and parts[1].strip():
entry["content"] = parts[1]
else: else:
# Fallback
origin_channel = "cli"
origin_chat_id = msg.chat_id
session_key = f"{origin_channel}:{origin_chat_id}"
session = self.sessions.get_or_create(session_key)
self._set_tool_context(origin_channel, origin_chat_id)
initial_messages = self.context.build_messages(
history=session.get_history(max_messages=self.memory_window),
current_message=msg.content,
channel=origin_channel,
chat_id=origin_chat_id,
)
final_content, _ = await self._run_agent_loop(initial_messages)
if final_content is None:
final_content = "Background task completed."
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
session.add_message("assistant", final_content)
self.sessions.save(session)
return OutboundMessage(
channel=origin_channel,
chat_id=origin_chat_id,
content=final_content
)
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
"""Consolidate old messages into MEMORY.md + HISTORY.md.
Args:
archive_all: If True, clear all messages and reset session (for /new command).
If False, only write to files without modifying session.
"""
memory = MemoryStore(self.workspace)
if archive_all:
old_messages = session.messages
keep_count = 0
logger.info(f"Memory consolidation (archive_all): {len(session.messages)} total messages archived")
else:
keep_count = self.memory_window // 2
if len(session.messages) <= keep_count:
logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})")
return
messages_to_process = len(session.messages) - session.last_consolidated
if messages_to_process <= 0:
logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})")
return
old_messages = session.messages[session.last_consolidated:-keep_count]
if not old_messages:
return
logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep")
lines = []
for m in old_messages:
if not m.get("content"):
continue continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" if isinstance(content, list):
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}") filtered = []
conversation = "\n".join(lines) for c in content:
current_memory = memory.read_long_term() if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
continue # Strip runtime context from multimodal messages
prompt = f"""You are a memory consolidation agent. Process this conversation and return a JSON object with exactly two keys: if (c.get("type") == "image_url"
and c.get("image_url", {}).get("url", "").startswith("data:image/")):
1. "history_entry": A paragraph (2-5 sentences) summarizing the key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later. filtered.append({"type": "text", "text": "[image]"})
2. "memory_update": The updated long-term memory content. Add any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.
## Current Long-term Memory
{current_memory or "(empty)"}
## Conversation to Process
{conversation}
Respond with ONLY valid JSON, no markdown fences."""
try:
response = await self.provider.chat(
messages=[
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."},
{"role": "user", "content": prompt},
],
model=self.model,
)
text = (response.content or "").strip()
if not text:
logger.warning("Memory consolidation: LLM returned empty response, skipping")
return
if text.startswith("```"):
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
result = json_repair.loads(text)
if not isinstance(result, dict):
logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}")
return
if entry := result.get("history_entry"):
memory.append_history(entry)
if update := result.get("memory_update"):
if update != current_memory:
memory.write_long_term(update)
if archive_all:
session.last_consolidated = 0
else: else:
session.last_consolidated = len(session.messages) - keep_count filtered.append(c)
logger.info(f"Memory consolidation done: {len(session.messages)} messages, last_consolidated={session.last_consolidated}") if not filtered:
except Exception as e: continue
logger.error(f"Memory consolidation failed: {e}") entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
session.updated_at = datetime.now()
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
return await MemoryStore(self.workspace).consolidate(
session, self.provider, self.model,
archive_all=archive_all, memory_window=self.memory_window,
)
async def process_direct( async def process_direct(
self, self,
@@ -451,26 +500,10 @@ Respond with ONLY valid JSON, no markdown fences."""
session_key: str = "cli:direct", session_key: str = "cli:direct",
channel: str = "cli", channel: str = "cli",
chat_id: str = "direct", chat_id: str = "direct",
on_progress: Callable[[str], Awaitable[None]] | None = None,
) -> str: ) -> str:
""" """Process a message directly (for CLI or cron usage)."""
Process a message directly (for CLI or cron usage).
Args:
content: The message content.
session_key: Session identifier (overrides channel:chat_id for session lookup).
channel: Source channel (for tool context routing).
chat_id: Source chat ID (for tool context routing).
Returns:
The agent's response.
"""
await self._connect_mcp() await self._connect_mcp()
msg = InboundMessage( msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
channel=channel, response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
sender_id="user",
chat_id=chat_id,
content=content
)
response = await self._process_message(msg, session_key=session_key)
return response.content if response else "" return response.content if response else ""

View File

@@ -1,9 +1,46 @@
"""Memory system for persistent agent memory.""" """Memory system for persistent agent memory."""
from __future__ import annotations
import json
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
from loguru import logger
from nanobot.utils.helpers import ensure_dir from nanobot.utils.helpers import ensure_dir
if TYPE_CHECKING:
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session
_SAVE_MEMORY_TOOL = [
{
"type": "function",
"function": {
"name": "save_memory",
"description": "Save the memory consolidation result to persistent storage.",
"parameters": {
"type": "object",
"properties": {
"history_entry": {
"type": "string",
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
},
"memory_update": {
"type": "string",
"description": "Full updated long-term memory as markdown. Include all existing "
"facts plus new ones. Return unchanged if nothing new.",
},
},
"required": ["history_entry", "memory_update"],
},
},
}
]
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
@@ -28,3 +65,93 @@ class MemoryStore:
def get_memory_context(self) -> str: def get_memory_context(self) -> str:
long_term = self.read_long_term() long_term = self.read_long_term()
return f"## Long-term Memory\n{long_term}" if long_term else "" return f"## Long-term Memory\n{long_term}" if long_term else ""
async def consolidate(
self,
session: Session,
provider: LLMProvider,
model: str,
*,
archive_all: bool = False,
memory_window: int = 50,
) -> bool:
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
Returns True on success (including no-op), False on failure.
"""
if archive_all:
old_messages = session.messages
keep_count = 0
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
else:
keep_count = memory_window // 2
if len(session.messages) <= keep_count:
return True
if len(session.messages) - session.last_consolidated <= 0:
return True
old_messages = session.messages[session.last_consolidated:-keep_count]
if not old_messages:
return True
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
lines = []
for m in old_messages:
if not m.get("content"):
continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
current_memory = self.read_long_term()
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
## Current Long-term Memory
{current_memory or "(empty)"}
## Conversation to Process
{chr(10).join(lines)}"""
try:
response = await provider.chat(
messages=[
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},
],
tools=_SAVE_MEMORY_TOOL,
model=model,
)
if not response.has_tool_calls:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
return False
args = response.tool_calls[0].arguments
# Some providers return arguments as a JSON string instead of dict
if isinstance(args, str):
args = json.loads(args)
# Some providers return arguments as a list (handle edge case)
if isinstance(args, list):
if args and isinstance(args[0], dict):
args = args[0]
else:
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
return False
if not isinstance(args, dict):
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False
if entry := args.get("history_entry"):
if not isinstance(entry, str):
entry = json.dumps(entry, ensure_ascii=False)
self.append_history(entry)
if update := args.get("memory_update"):
if not isinstance(update, str):
update = json.dumps(update, ensure_ascii=False)
if update != current_memory:
self.write_long_term(update)
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
return True
except Exception:
logger.exception("Memory consolidation failed")
return False

View File

@@ -134,7 +134,7 @@ class SkillsLoader:
if missing: if missing:
lines.append(f" <requires>{escape_xml(missing)}</requires>") lines.append(f" <requires>{escape_xml(missing)}</requires>")
lines.append(f" </skill>") lines.append(" </skill>")
lines.append("</skills>") lines.append("</skills>")
return "\n".join(lines) return "\n".join(lines)
@@ -167,10 +167,10 @@ class SkillsLoader:
return content return content
def _parse_nanobot_metadata(self, raw: str) -> dict: def _parse_nanobot_metadata(self, raw: str) -> dict:
"""Parse nanobot metadata JSON from frontmatter.""" """Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
try: try:
data = json.loads(raw) data = json.loads(raw)
return data.get("nanobot", {}) if isinstance(data, dict) else {} return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
return {} return {}

View File

@@ -8,23 +8,18 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ExecToolConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
class SubagentManager: class SubagentManager:
""" """Manages background subagent execution."""
Manages background subagent execution.
Subagents are lightweight agent instances that run in the background
to handle specific tasks. They share the same LLM provider but have
isolated context and a focused system prompt.
"""
def __init__( def __init__(
self, self,
@@ -34,7 +29,9 @@ class SubagentManager:
model: str | None = None, model: str | None = None,
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 4096, max_tokens: int = 4096,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None,
exec_config: "ExecToolConfig | None" = None, exec_config: "ExecToolConfig | None" = None,
restrict_to_workspace: bool = False, restrict_to_workspace: bool = False,
): ):
@@ -45,10 +42,13 @@ class SubagentManager:
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
self._running_tasks: dict[str, asyncio.Task[None]] = {} self._running_tasks: dict[str, asyncio.Task[None]] = {}
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
async def spawn( async def spawn(
self, self,
@@ -56,37 +56,30 @@ class SubagentManager:
label: str | None = None, label: str | None = None,
origin_channel: str = "cli", origin_channel: str = "cli",
origin_chat_id: str = "direct", origin_chat_id: str = "direct",
session_key: str | None = None,
) -> str: ) -> str:
""" """Spawn a subagent to execute a task in the background."""
Spawn a subagent to execute a task in the background.
Args:
task: The task description for the subagent.
label: Optional human-readable label for the task.
origin_channel: The channel to announce results to.
origin_chat_id: The chat ID to announce results to.
Returns:
Status message indicating the subagent was started.
"""
task_id = str(uuid.uuid4())[:8] task_id = str(uuid.uuid4())[:8]
display_label = label or task[:30] + ("..." if len(task) > 30 else "") display_label = label or task[:30] + ("..." if len(task) > 30 else "")
origin = {"channel": origin_channel, "chat_id": origin_chat_id}
origin = {
"channel": origin_channel,
"chat_id": origin_chat_id,
}
# Create background task
bg_task = asyncio.create_task( bg_task = asyncio.create_task(
self._run_subagent(task_id, task, display_label, origin) self._run_subagent(task_id, task, display_label, origin)
) )
self._running_tasks[task_id] = bg_task self._running_tasks[task_id] = bg_task
if session_key:
self._session_tasks.setdefault(session_key, set()).add(task_id)
# Cleanup when done def _cleanup(_: asyncio.Task) -> None:
bg_task.add_done_callback(lambda _: self._running_tasks.pop(task_id, None)) self._running_tasks.pop(task_id, None)
if session_key and (ids := self._session_tasks.get(session_key)):
ids.discard(task_id)
if not ids:
del self._session_tasks[session_key]
logger.info(f"Spawned subagent [{task_id}]: {display_label}") bg_task.add_done_callback(_cleanup)
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes." return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
async def _run_subagent( async def _run_subagent(
@@ -97,26 +90,26 @@ class SubagentManager:
origin: dict[str, str], origin: dict[str, str],
) -> None: ) -> None:
"""Execute the subagent task and announce the result.""" """Execute the subagent task and announce the result."""
logger.info(f"Subagent [{task_id}] starting task: {label}") logger.info("Subagent [{}] starting task: {}", task_id, label)
try: try:
# Build subagent tools (no message tool, no spawn tool) # Build subagent tools (no message tool, no spawn tool)
tools = ToolRegistry() tools = ToolRegistry()
allowed_dir = self.workspace if self.restrict_to_workspace else None allowed_dir = self.workspace if self.restrict_to_workspace else None
tools.register(ReadFileTool(allowed_dir=allowed_dir)) tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(WriteFileTool(allowed_dir=allowed_dir)) tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(EditFileTool(allowed_dir=allowed_dir)) tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ListDirTool(allowed_dir=allowed_dir)) tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ExecTool( tools.register(ExecTool(
working_dir=str(self.workspace), working_dir=str(self.workspace),
timeout=self.exec_config.timeout, timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace, restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
)) ))
tools.register(WebSearchTool(api_key=self.brave_api_key)) tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
tools.register(WebFetchTool()) tools.register(WebFetchTool(proxy=self.web_proxy))
# Build messages with subagent-specific prompt system_prompt = self._build_subagent_prompt()
system_prompt = self._build_subagent_prompt(task)
messages: list[dict[str, Any]] = [ messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": task}, {"role": "user", "content": task},
@@ -136,6 +129,7 @@ class SubagentManager:
model=self.model, model=self.model,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
@@ -146,7 +140,7 @@ class SubagentManager:
"type": "function", "type": "function",
"function": { "function": {
"name": tc.name, "name": tc.name,
"arguments": json.dumps(tc.arguments), "arguments": json.dumps(tc.arguments, ensure_ascii=False),
}, },
} }
for tc in response.tool_calls for tc in response.tool_calls
@@ -159,8 +153,8 @@ class SubagentManager:
# Execute tools # Execute tools
for tool_call in response.tool_calls: for tool_call in response.tool_calls:
args_str = json.dumps(tool_call.arguments) args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug(f"Subagent [{task_id}] executing: {tool_call.name} with arguments: {args_str}") logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
result = await tools.execute(tool_call.name, tool_call.arguments) result = await tools.execute(tool_call.name, tool_call.arguments)
messages.append({ messages.append({
"role": "tool", "role": "tool",
@@ -175,12 +169,12 @@ class SubagentManager:
if final_result is None: if final_result is None:
final_result = "Task completed but no final response was generated." final_result = "Task completed but no final response was generated."
logger.info(f"Subagent [{task_id}] completed successfully") logger.info("Subagent [{}] completed successfully", task_id)
await self._announce_result(task_id, label, task, final_result, origin, "ok") await self._announce_result(task_id, label, task, final_result, origin, "ok")
except Exception as e: except Exception as e:
error_msg = f"Error: {str(e)}" error_msg = f"Error: {str(e)}"
logger.error(f"Subagent [{task_id}] failed: {e}") logger.error("Subagent [{}] failed: {}", task_id, e)
await self._announce_result(task_id, label, task, error_msg, origin, "error") await self._announce_result(task_id, label, task, error_msg, origin, "error")
async def _announce_result( async def _announce_result(
@@ -213,44 +207,39 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
) )
await self.bus.publish_inbound(msg) await self.bus.publish_inbound(msg)
logger.debug(f"Subagent [{task_id}] announced result to {origin['channel']}:{origin['chat_id']}") logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
def _build_subagent_prompt(self, task: str) -> str: def _build_subagent_prompt(self) -> str:
"""Build a focused system prompt for the subagent.""" """Build a focused system prompt for the subagent."""
from datetime import datetime from nanobot.agent.context import ContextBuilder
import time as _time from nanobot.agent.skills import SkillsLoader
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
tz = _time.strftime("%Z") or "UTC"
return f"""# Subagent time_ctx = ContextBuilder._build_runtime_context(None, None)
parts = [f"""# Subagent
## Current Time {time_ctx}
{now} ({tz})
You are a subagent spawned by the main agent to complete a specific task. You are a subagent spawned by the main agent to complete a specific task.
Stay focused on the assigned task. Your final response will be reported back to the main agent.
## Rules
1. Stay focused - complete only the assigned task, nothing else
2. Your final response will be reported back to the main agent
3. Do not initiate conversations or take on side tasks
4. Be concise but informative in your findings
## What You Can Do
- Read and write files in the workspace
- Execute shell commands
- Search the web and fetch web pages
- Complete the task thoroughly
## What You Cannot Do
- Send messages directly to users (no message tool available)
- Spawn other subagents
- Access the main agent's conversation history
## Workspace ## Workspace
Your workspace is at: {self.workspace} {self.workspace}"""]
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
When you have completed the task, provide a clear summary of your findings or actions.""" skills_summary = SkillsLoader(self.workspace).build_skills_summary()
if skills_summary:
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
return "\n\n".join(parts)
async def cancel_by_session(self, session_key: str) -> int:
"""Cancel all subagents for the given session. Returns count cancelled."""
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
if tid in self._running_tasks and not self._running_tasks[tid].done()]
for t in tasks:
t.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
return len(tasks)
def get_running_count(self) -> int: def get_running_count(self) -> int:
"""Return the number of currently running subagents.""" """Return the number of currently running subagents."""

View File

@@ -52,8 +52,79 @@ class Tool(ABC):
""" """
pass pass
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Apply safe schema-driven casts before validation."""
schema = self.parameters or {}
if schema.get("type", "object") != "object":
return params
return self._cast_object(params, schema)
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
"""Cast an object (dict) according to schema."""
if not isinstance(obj, dict):
return obj
props = schema.get("properties", {})
result = {}
for key, value in obj.items():
if key in props:
result[key] = self._cast_value(value, props[key])
else:
result[key] = value
return result
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
"""Cast a single value according to schema."""
target_type = schema.get("type")
if target_type == "boolean" and isinstance(val, bool):
return val
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
return val
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
expected = self._TYPE_MAP[target_type]
if isinstance(val, expected):
return val
if target_type == "integer" and isinstance(val, str):
try:
return int(val)
except ValueError:
return val
if target_type == "number" and isinstance(val, str):
try:
return float(val)
except ValueError:
return val
if target_type == "string":
return val if val is None else str(val)
if target_type == "boolean" and isinstance(val, str):
val_lower = val.lower()
if val_lower in ("true", "1", "yes"):
return True
if val_lower in ("false", "0", "no"):
return False
return val
if target_type == "array" and isinstance(val, list):
item_schema = schema.get("items")
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
if target_type == "object" and isinstance(val, dict):
return self._cast_object(val, schema)
return val
def validate_params(self, params: dict[str, Any]) -> list[str]: def validate_params(self, params: dict[str, Any]) -> list[str]:
"""Validate tool parameters against JSON schema. Returns error list (empty if valid).""" """Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
if not isinstance(params, dict):
return [f"parameters must be an object, got {type(params).__name__}"]
schema = self.parameters or {} schema = self.parameters or {}
if schema.get("type", "object") != "object": if schema.get("type", "object") != "object":
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}") raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
@@ -61,7 +132,13 @@ class Tool(ABC):
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]: def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
t, label = schema.get("type"), path or "parameter" t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]): if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
return [f"{label} should be integer"]
if t == "number" and (
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
):
return [f"{label} should be number"]
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"] return [f"{label} should be {t}"]
errors = [] errors = []
@@ -84,10 +161,12 @@ class Tool(ABC):
errors.append(f"missing required {path + '.' + k if path else k}") errors.append(f"missing required {path + '.' + k if path else k}")
for k, v in val.items(): for k, v in val.items():
if k in props: if k in props:
errors.extend(self._validate(v, props[k], path + '.' + k if path else k)) errors.extend(self._validate(v, props[k], path + "." + k if path else k))
if t == "array" and "items" in schema: if t == "array" and "items" in schema:
for i, item in enumerate(val): for i, item in enumerate(val):
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")) errors.extend(
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
)
return errors return errors
def to_schema(self) -> dict[str, Any]: def to_schema(self) -> dict[str, Any]:
@@ -98,5 +177,5 @@ class Tool(ABC):
"name": self.name, "name": self.name,
"description": self.description, "description": self.description,
"parameters": self.parameters, "parameters": self.parameters,
} },
} }

View File

@@ -1,5 +1,6 @@
"""Cron tool for scheduling reminders and tasks.""" """Cron tool for scheduling reminders and tasks."""
from contextvars import ContextVar
from typing import Any from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -14,12 +15,21 @@ class CronTool(Tool):
self._cron = cron_service self._cron = cron_service
self._channel = "" self._channel = ""
self._chat_id = "" self._chat_id = ""
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the current session context for delivery.""" """Set the current session context for delivery."""
self._channel = channel self._channel = channel
self._chat_id = chat_id self._chat_id = chat_id
def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback."""
return self._in_cron_context.set(active)
def reset_cron_context(self, token) -> None:
"""Restore previous cron context."""
self._in_cron_context.reset(token)
@property @property
def name(self) -> str: def name(self) -> str:
return "cron" return "cron"
@@ -36,30 +46,28 @@ class CronTool(Tool):
"action": { "action": {
"type": "string", "type": "string",
"enum": ["add", "list", "remove"], "enum": ["add", "list", "remove"],
"description": "Action to perform" "description": "Action to perform",
},
"message": {
"type": "string",
"description": "Reminder message (for add)"
}, },
"message": {"type": "string", "description": "Reminder message (for add)"},
"every_seconds": { "every_seconds": {
"type": "integer", "type": "integer",
"description": "Interval in seconds (for recurring tasks)" "description": "Interval in seconds (for recurring tasks)",
}, },
"cron_expr": { "cron_expr": {
"type": "string", "type": "string",
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)" "description": "Cron expression like '0 9 * * *' (for scheduled tasks)",
},
"tz": {
"type": "string",
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')",
}, },
"at": { "at": {
"type": "string", "type": "string",
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')" "description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')",
}, },
"job_id": { "job_id": {"type": "string", "description": "Job ID (for remove)"},
"type": "string",
"description": "Job ID (for remove)"
}
}, },
"required": ["action"] "required": ["action"],
} }
async def execute( async def execute(
@@ -68,33 +76,56 @@ class CronTool(Tool):
message: str = "", message: str = "",
every_seconds: int | None = None, every_seconds: int | None = None,
cron_expr: str | None = None, cron_expr: str | None = None,
tz: str | None = None,
at: str | None = None, at: str | None = None,
job_id: str | None = None, job_id: str | None = None,
**kwargs: Any **kwargs: Any,
) -> str: ) -> str:
if action == "add": if action == "add":
return self._add_job(message, every_seconds, cron_expr, at) if self._in_cron_context.get():
return "Error: cannot schedule new jobs from within a cron job execution"
return self._add_job(message, every_seconds, cron_expr, tz, at)
elif action == "list": elif action == "list":
return self._list_jobs() return self._list_jobs()
elif action == "remove": elif action == "remove":
return self._remove_job(job_id) return self._remove_job(job_id)
return f"Unknown action: {action}" return f"Unknown action: {action}"
def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None, at: str | None) -> str: def _add_job(
self,
message: str,
every_seconds: int | None,
cron_expr: str | None,
tz: str | None,
at: str | None,
) -> str:
if not message: if not message:
return "Error: message is required for add" return "Error: message is required for add"
if not self._channel or not self._chat_id: if not self._channel or not self._chat_id:
return "Error: no session context (channel/chat_id)" return "Error: no session context (channel/chat_id)"
if tz and not cron_expr:
return "Error: tz can only be used with cron_expr"
if tz:
from zoneinfo import ZoneInfo
try:
ZoneInfo(tz)
except (KeyError, Exception):
return f"Error: unknown timezone '{tz}'"
# Build schedule # Build schedule
delete_after = False delete_after = False
if every_seconds: if every_seconds:
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000) schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
elif cron_expr: elif cron_expr:
schedule = CronSchedule(kind="cron", expr=cron_expr) schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
elif at: elif at:
from datetime import datetime from datetime import datetime
try:
dt = datetime.fromisoformat(at) dt = datetime.fromisoformat(at)
except ValueError:
return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS"
at_ms = int(dt.timestamp() * 1000) at_ms = int(dt.timestamp() * 1000)
schedule = CronSchedule(kind="at", at_ms=at_ms) schedule = CronSchedule(kind="at", at_ms=at_ms)
delete_after = True delete_after = True

View File

@@ -1,15 +1,24 @@
"""File system tools: read, write, edit.""" """File system tools: read, write, edit."""
import difflib
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
def _resolve_path(path: str, allowed_dir: Path | None = None) -> Path: def _resolve_path(
"""Resolve path and optionally enforce directory restriction.""" path: str, workspace: Path | None = None, allowed_dir: Path | None = None
resolved = Path(path).expanduser().resolve() ) -> Path:
if allowed_dir and not str(resolved).startswith(str(allowed_dir.resolve())): """Resolve path against workspace (if relative) and enforce directory restriction."""
p = Path(path).expanduser()
if not p.is_absolute() and workspace:
p = workspace / p
resolved = p.resolve()
if allowed_dir:
try:
resolved.relative_to(allowed_dir.resolve())
except ValueError:
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}") raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
return resolved return resolved
@@ -17,7 +26,10 @@ def _resolve_path(path: str, allowed_dir: Path | None = None) -> Path:
class ReadFileTool(Tool): class ReadFileTool(Tool):
"""Tool to read file contents.""" """Tool to read file contents."""
def __init__(self, allowed_dir: Path | None = None): _MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
@property @property
@@ -32,24 +44,28 @@ class ReadFileTool(Tool):
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "The file path to read"}},
"path": { "required": ["path"],
"type": "string",
"description": "The file path to read"
}
},
"required": ["path"]
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
if not file_path.exists(): if not file_path.exists():
return f"Error: File not found: {path}" return f"Error: File not found: {path}"
if not file_path.is_file(): if not file_path.is_file():
return f"Error: Not a file: {path}" return f"Error: Not a file: {path}"
size = file_path.stat().st_size
if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes)
return (
f"Error: File too large ({size:,} bytes). "
f"Use exec tool with head/tail/grep to read portions."
)
content = file_path.read_text(encoding="utf-8") content = file_path.read_text(encoding="utf-8")
if len(content) > self._MAX_CHARS:
return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})"
return content return content
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
@@ -60,7 +76,8 @@ class ReadFileTool(Tool):
class WriteFileTool(Tool): class WriteFileTool(Tool):
"""Tool to write content to a file.""" """Tool to write content to a file."""
def __init__(self, allowed_dir: Path | None = None): def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
@property @property
@@ -76,24 +93,18 @@ class WriteFileTool(Tool):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"path": { "path": {"type": "string", "description": "The file path to write to"},
"type": "string", "content": {"type": "string", "description": "The content to write"},
"description": "The file path to write to"
}, },
"content": { "required": ["path", "content"],
"type": "string",
"description": "The content to write"
}
},
"required": ["path", "content"]
} }
async def execute(self, path: str, content: str, **kwargs: Any) -> str: async def execute(self, path: str, content: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content, encoding="utf-8") file_path.write_text(content, encoding="utf-8")
return f"Successfully wrote {len(content)} bytes to {path}" return f"Successfully wrote {len(content)} bytes to {file_path}"
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
@@ -103,7 +114,8 @@ class WriteFileTool(Tool):
class EditFileTool(Tool): class EditFileTool(Tool):
"""Tool to edit a file by replacing text.""" """Tool to edit a file by replacing text."""
def __init__(self, allowed_dir: Path | None = None): def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
@property @property
@@ -119,32 +131,23 @@ class EditFileTool(Tool):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"path": { "path": {"type": "string", "description": "The file path to edit"},
"type": "string", "old_text": {"type": "string", "description": "The exact text to find and replace"},
"description": "The file path to edit" "new_text": {"type": "string", "description": "The text to replace with"},
}, },
"old_text": { "required": ["path", "old_text", "new_text"],
"type": "string",
"description": "The exact text to find and replace"
},
"new_text": {
"type": "string",
"description": "The text to replace with"
}
},
"required": ["path", "old_text", "new_text"]
} }
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str: async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
if not file_path.exists(): if not file_path.exists():
return f"Error: File not found: {path}" return f"Error: File not found: {path}"
content = file_path.read_text(encoding="utf-8") content = file_path.read_text(encoding="utf-8")
if old_text not in content: if old_text not in content:
return f"Error: old_text not found in file. Make sure it matches exactly." return self._not_found_message(old_text, content, path)
# Count occurrences # Count occurrences
count = content.count(old_text) count = content.count(old_text)
@@ -154,17 +157,46 @@ class EditFileTool(Tool):
new_content = content.replace(old_text, new_text, 1) new_content = content.replace(old_text, new_text, 1)
file_path.write_text(new_content, encoding="utf-8") file_path.write_text(new_content, encoding="utf-8")
return f"Successfully edited {path}" return f"Successfully edited {file_path}"
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error editing file: {str(e)}" return f"Error editing file: {str(e)}"
@staticmethod
def _not_found_message(old_text: str, content: str, path: str) -> str:
"""Build a helpful error when old_text is not found."""
lines = content.splitlines(keepends=True)
old_lines = old_text.splitlines(keepends=True)
window = len(old_lines)
best_ratio, best_start = 0.0, 0
for i in range(max(1, len(lines) - window + 1)):
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
if ratio > best_ratio:
best_ratio, best_start = ratio, i
if best_ratio > 0.5:
diff = "\n".join(
difflib.unified_diff(
old_lines,
lines[best_start : best_start + window],
fromfile="old_text (provided)",
tofile=f"{path} (actual, line {best_start + 1})",
lineterm="",
)
)
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
return (
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
)
class ListDirTool(Tool): class ListDirTool(Tool):
"""Tool to list directory contents.""" """Tool to list directory contents."""
def __init__(self, allowed_dir: Path | None = None): def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
@property @property
@@ -179,18 +211,13 @@ class ListDirTool(Tool):
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "The directory path to list"}},
"path": { "required": ["path"],
"type": "string",
"description": "The directory path to list"
}
},
"required": ["path"]
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, **kwargs: Any) -> str:
try: try:
dir_path = _resolve_path(path, self._allowed_dir) dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
if not dir_path.exists(): if not dir_path.exists():
return f"Error: Directory not found: {path}" return f"Error: Directory not found: {path}"
if not dir_path.is_dir(): if not dir_path.is_dir():

View File

@@ -1,8 +1,10 @@
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools.""" """MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import Any from typing import Any
import httpx
from loguru import logger from loguru import logger
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -12,12 +14,13 @@ from nanobot.agent.tools.registry import ToolRegistry
class MCPToolWrapper(Tool): class MCPToolWrapper(Tool):
"""Wraps a single MCP server tool as a nanobot Tool.""" """Wraps a single MCP server tool as a nanobot Tool."""
def __init__(self, session, server_name: str, tool_def): def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
self._session = session self._session = session
self._original_name = tool_def.name self._original_name = tool_def.name
self._name = f"mcp_{server_name}_{tool_def.name}" self._name = f"mcp_{server_name}_{tool_def.name}"
self._description = tool_def.description or tool_def.name self._description = tool_def.description or tool_def.name
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}} self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
self._tool_timeout = tool_timeout
@property @property
def name(self) -> str: def name(self) -> str:
@@ -33,7 +36,14 @@ class MCPToolWrapper(Tool):
async def execute(self, **kwargs: Any) -> str: async def execute(self, **kwargs: Any) -> str:
from mcp import types from mcp import types
result = await self._session.call_tool(self._original_name, arguments=kwargs) try:
result = await asyncio.wait_for(
self._session.call_tool(self._original_name, arguments=kwargs),
timeout=self._tool_timeout,
)
except asyncio.TimeoutError:
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
return f"(MCP tool call timed out after {self._tool_timeout}s)"
parts = [] parts = []
for block in result.content: for block in result.content:
if isinstance(block, types.TextContent): if isinstance(block, types.TextContent):
@@ -48,22 +58,62 @@ async def connect_mcp_servers(
) -> None: ) -> None:
"""Connect to configured MCP servers and register their tools.""" """Connect to configured MCP servers and register their tools."""
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
for name, cfg in mcp_servers.items(): for name, cfg in mcp_servers.items():
try: try:
transport_type = cfg.type
if not transport_type:
if cfg.command: if cfg.command:
transport_type = "stdio"
elif cfg.url:
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue
if transport_type == "stdio":
params = StdioServerParameters( params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None command=cfg.command, args=cfg.args, env=cfg.env or None
) )
read, write = await stack.enter_async_context(stdio_client(params)) read, write = await stack.enter_async_context(stdio_client(params))
elif cfg.url: elif transport_type == "sse":
from mcp.client.streamable_http import streamable_http_client def httpx_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
merged_headers = {**(cfg.headers or {}), **(headers or {})}
return httpx.AsyncClient(
headers=merged_headers or None,
follow_redirects=True,
timeout=timeout,
auth=auth,
)
read, write = await stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
)
elif transport_type == "streamableHttp":
# Always provide an explicit httpx client so MCP HTTP transport does not
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
http_client = await stack.enter_async_context(
httpx.AsyncClient(
headers=cfg.headers or None,
follow_redirects=True,
timeout=None,
)
)
read, write, _ = await stack.enter_async_context( read, write, _ = await stack.enter_async_context(
streamable_http_client(cfg.url) streamable_http_client(cfg.url, http_client=http_client)
) )
else: else:
logger.warning(f"MCP server '{name}': no command or url configured, skipping") logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
continue continue
session = await stack.enter_async_context(ClientSession(read, write)) session = await stack.enter_async_context(ClientSession(read, write))
@@ -71,10 +121,10 @@ async def connect_mcp_servers(
tools = await session.list_tools() tools = await session.list_tools()
for tool_def in tools.tools: for tool_def in tools.tools:
wrapper = MCPToolWrapper(session, name, tool_def) wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
registry.register(wrapper) registry.register(wrapper)
logger.debug(f"MCP: registered tool '{wrapper.name}' from server '{name}'") logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
logger.info(f"MCP server '{name}': connected, {len(tools.tools)} tools registered") logger.info("MCP server '{}': connected, {} tools registered", name, len(tools.tools))
except Exception as e: except Exception as e:
logger.error(f"MCP server '{name}': failed to connect: {e}") logger.error("MCP server '{}': failed to connect: {}", name, e)

View File

@@ -1,6 +1,6 @@
"""Message tool for sending messages to users.""" """Message tool for sending messages to users."""
from typing import Any, Callable, Awaitable from typing import Any, Awaitable, Callable
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
@@ -13,21 +13,29 @@ class MessageTool(Tool):
self, self,
send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None, send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None,
default_channel: str = "", default_channel: str = "",
default_chat_id: str = "" default_chat_id: str = "",
default_message_id: str | None = None,
): ):
self._send_callback = send_callback self._send_callback = send_callback
self._default_channel = default_channel self._default_channel = default_channel
self._default_chat_id = default_chat_id self._default_chat_id = default_chat_id
self._default_message_id = default_message_id
self._sent_in_turn: bool = False
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Set the current message context.""" """Set the current message context."""
self._default_channel = channel self._default_channel = channel
self._default_chat_id = chat_id self._default_chat_id = chat_id
self._default_message_id = message_id
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None: def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
"""Set the callback for sending messages.""" """Set the callback for sending messages."""
self._send_callback = callback self._send_callback = callback
def start_turn(self) -> None:
"""Reset per-turn send tracking."""
self._sent_in_turn = False
@property @property
def name(self) -> str: def name(self) -> str:
return "message" return "message"
@@ -52,6 +60,11 @@ class MessageTool(Tool):
"chat_id": { "chat_id": {
"type": "string", "type": "string",
"description": "Optional: target chat/user ID" "description": "Optional: target chat/user ID"
},
"media": {
"type": "array",
"items": {"type": "string"},
"description": "Optional: list of file paths to attach (images, audio, documents)"
} }
}, },
"required": ["content"] "required": ["content"]
@@ -62,10 +75,13 @@ class MessageTool(Tool):
content: str, content: str,
channel: str | None = None, channel: str | None = None,
chat_id: str | None = None, chat_id: str | None = None,
message_id: str | None = None,
media: list[str] | None = None,
**kwargs: Any **kwargs: Any
) -> str: ) -> str:
channel = channel or self._default_channel channel = channel or self._default_channel
chat_id = chat_id or self._default_chat_id chat_id = chat_id or self._default_chat_id
message_id = message_id or self._default_message_id
if not channel or not chat_id: if not channel or not chat_id:
return "Error: No target channel/chat specified" return "Error: No target channel/chat specified"
@@ -76,11 +92,18 @@ class MessageTool(Tool):
msg = OutboundMessage( msg = OutboundMessage(
channel=channel, channel=channel,
chat_id=chat_id, chat_id=chat_id,
content=content content=content,
media=media or [],
metadata={
"message_id": message_id,
},
) )
try: try:
await self._send_callback(msg) await self._send_callback(msg)
return f"Message sent to {channel}:{chat_id}" if channel == self._default_channel and chat_id == self._default_chat_id:
self._sent_in_turn = True
media_info = f" with {len(media)} attachments" if media else ""
return f"Message sent to {channel}:{chat_id}{media_info}"
except Exception as e: except Exception as e:
return f"Error sending message: {str(e)}" return f"Error sending message: {str(e)}"

View File

@@ -36,30 +36,27 @@ class ToolRegistry:
return [tool.to_schema() for tool in self._tools.values()] return [tool.to_schema() for tool in self._tools.values()]
async def execute(self, name: str, params: dict[str, Any]) -> str: async def execute(self, name: str, params: dict[str, Any]) -> str:
""" """Execute a tool by name with given parameters."""
Execute a tool by name with given parameters. _HINT = "\n\n[Analyze the error above and try a different approach.]"
Args:
name: Tool name.
params: Tool parameters.
Returns:
Tool execution result as string.
Raises:
KeyError: If tool not found.
"""
tool = self._tools.get(name) tool = self._tools.get(name)
if not tool: if not tool:
return f"Error: Tool '{name}' not found" return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
try: try:
# Attempt to cast parameters to match schema types
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params) errors = tool.validate_params(params)
if errors: if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
return await tool.execute(**params) result = await tool.execute(**params)
if isinstance(result, str) and result.startswith("Error"):
return result + _HINT
return result
except Exception as e: except Exception as e:
return f"Error executing {name}: {str(e)}" return f"Error executing {name}: {str(e)}" + _HINT
@property @property
def tool_names(self) -> list[str]: def tool_names(self) -> list[str]:

View File

@@ -19,6 +19,7 @@ class ExecTool(Tool):
deny_patterns: list[str] | None = None, deny_patterns: list[str] | None = None,
allow_patterns: list[str] | None = None, allow_patterns: list[str] | None = None,
restrict_to_workspace: bool = False, restrict_to_workspace: bool = False,
path_append: str = "",
): ):
self.timeout = timeout self.timeout = timeout
self.working_dir = working_dir self.working_dir = working_dir
@@ -26,7 +27,8 @@ class ExecTool(Tool):
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
r"\bdel\s+/[fq]\b", # del /f, del /q r"\bdel\s+/[fq]\b", # del /f, del /q
r"\brmdir\s+/s\b", # rmdir /s r"\brmdir\s+/s\b", # rmdir /s
r"\b(format|mkfs|diskpart)\b", # disk operations r"(?:^|[;&|]\s*)format\b", # format (as standalone command only)
r"\b(mkfs|diskpart)\b", # disk operations
r"\bdd\s+if=", # dd r"\bdd\s+if=", # dd
r">\s*/dev/sd", # write to disk r">\s*/dev/sd", # write to disk
r"\b(shutdown|reboot|poweroff)\b", # system power r"\b(shutdown|reboot|poweroff)\b", # system power
@@ -34,6 +36,7 @@ class ExecTool(Tool):
] ]
self.allow_patterns = allow_patterns or [] self.allow_patterns = allow_patterns or []
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
self.path_append = path_append
@property @property
def name(self) -> str: def name(self) -> str:
@@ -66,12 +69,17 @@ class ExecTool(Tool):
if guard_error: if guard_error:
return guard_error return guard_error
env = os.environ.copy()
if self.path_append:
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
try: try:
process = await asyncio.create_subprocess_shell( process = await asyncio.create_subprocess_shell(
command, command,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=cwd, cwd=cwd,
env=env,
) )
try: try:
@@ -81,6 +89,12 @@ class ExecTool(Tool):
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
process.kill() process.kill()
# Wait for the process to fully terminate so pipes are
# drained and file descriptors are released.
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
pass
return f"Error: Command timed out after {self.timeout} seconds" return f"Error: Command timed out after {self.timeout} seconds"
output_parts = [] output_parts = []
@@ -127,13 +141,7 @@ class ExecTool(Tool):
cwd_path = Path(cwd).resolve() cwd_path = Path(cwd).resolve()
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd) for raw in self._extract_absolute_paths(cmd):
# Only match absolute paths — avoid false positives on relative
# paths like ".venv/bin/python" where "/bin/python" would be
# incorrectly extracted by the old pattern.
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
for raw in win_paths + posix_paths:
try: try:
p = Path(raw.strip()).resolve() p = Path(raw.strip()).resolve()
except Exception: except Exception:
@@ -142,3 +150,9 @@ class ExecTool(Tool):
return "Error: Command blocked by safety guard (path outside working dir)" return "Error: Command blocked by safety guard (path outside working dir)"
return None return None
@staticmethod
def _extract_absolute_paths(command: str) -> list[str]:
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only
return win_paths + posix_paths

View File

@@ -1,6 +1,6 @@
"""Spawn tool for creating background subagents.""" """Spawn tool for creating background subagents."""
from typing import Any, TYPE_CHECKING from typing import TYPE_CHECKING, Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -9,22 +9,19 @@ if TYPE_CHECKING:
class SpawnTool(Tool): class SpawnTool(Tool):
""" """Tool to spawn a subagent for background task execution."""
Tool to spawn a subagent for background task execution.
The subagent runs asynchronously and announces its result back
to the main agent when complete.
"""
def __init__(self, manager: "SubagentManager"): def __init__(self, manager: "SubagentManager"):
self._manager = manager self._manager = manager
self._origin_channel = "cli" self._origin_channel = "cli"
self._origin_chat_id = "direct" self._origin_chat_id = "direct"
self._session_key = "cli:direct"
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the origin context for subagent announcements.""" """Set the origin context for subagent announcements."""
self._origin_channel = channel self._origin_channel = channel
self._origin_chat_id = chat_id self._origin_chat_id = chat_id
self._session_key = f"{channel}:{chat_id}"
@property @property
def name(self) -> str: def name(self) -> str:
@@ -62,4 +59,5 @@ class SpawnTool(Tool):
label=label, label=label,
origin_channel=self._origin_channel, origin_channel=self._origin_channel,
origin_chat_id=self._origin_chat_id, origin_chat_id=self._origin_chat_id,
session_key=self._session_key,
) )

View File

@@ -8,6 +8,7 @@ from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from loguru import logger
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -57,17 +58,28 @@ class WebSearchTool(Tool):
"required": ["query"] "required": ["query"]
} }
def __init__(self, api_key: str | None = None, max_results: int = 5): def __init__(self, api_key: str | None = None, max_results: int = 5, proxy: str | None = None):
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "") self._init_api_key = api_key
self.max_results = max_results self.max_results = max_results
self.proxy = proxy
@property
def api_key(self) -> str:
"""Resolve API key at call time so env/config changes are picked up."""
return self._init_api_key or os.environ.get("BRAVE_API_KEY", "")
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
if not self.api_key: if not self.api_key:
return "Error: BRAVE_API_KEY not configured" return (
"Error: Brave Search API key not configured. Set it in "
"~/.nanobot/config.json under tools.web.search.apiKey "
"(or export BRAVE_API_KEY), then restart the gateway."
)
try: try:
n = min(max(count or self.max_results, 1), 10) n = min(max(count or self.max_results, 1), 10)
async with httpx.AsyncClient() as client: logger.debug("WebSearch: {}", "proxy enabled" if self.proxy else "direct connection")
async with httpx.AsyncClient(proxy=self.proxy) as client:
r = await client.get( r = await client.get(
"https://api.search.brave.com/res/v1/web/search", "https://api.search.brave.com/res/v1/web/search",
params={"q": query, "count": n}, params={"q": query, "count": n},
@@ -76,17 +88,21 @@ class WebSearchTool(Tool):
) )
r.raise_for_status() r.raise_for_status()
results = r.json().get("web", {}).get("results", []) results = r.json().get("web", {}).get("results", [])[:n]
if not results: if not results:
return f"No results for: {query}" return f"No results for: {query}"
lines = [f"Results for: {query}\n"] lines = [f"Results for: {query}\n"]
for i, item in enumerate(results[:n], 1): for i, item in enumerate(results, 1):
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}") lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
if desc := item.get("description"): if desc := item.get("description"):
lines.append(f" {desc}") lines.append(f" {desc}")
return "\n".join(lines) return "\n".join(lines)
except httpx.ProxyError as e:
logger.error("WebSearch proxy error: {}", e)
return f"Proxy error: {e}"
except Exception as e: except Exception as e:
logger.error("WebSearch error: {}", e)
return f"Error: {e}" return f"Error: {e}"
@@ -105,34 +121,33 @@ class WebFetchTool(Tool):
"required": ["url"] "required": ["url"]
} }
def __init__(self, max_chars: int = 50000): def __init__(self, max_chars: int = 50000, proxy: str | None = None):
self.max_chars = max_chars self.max_chars = max_chars
self.proxy = proxy
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str: async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
from readability import Document from readability import Document
max_chars = maxChars or self.max_chars max_chars = maxChars or self.max_chars
# Validate URL before fetching
is_valid, error_msg = _validate_url(url) is_valid, error_msg = _validate_url(url)
if not is_valid: if not is_valid:
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}) return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
try: try:
logger.debug("WebFetch: {}", "proxy enabled" if self.proxy else "direct connection")
async with httpx.AsyncClient( async with httpx.AsyncClient(
follow_redirects=True, follow_redirects=True,
max_redirects=MAX_REDIRECTS, max_redirects=MAX_REDIRECTS,
timeout=30.0 timeout=30.0,
proxy=self.proxy,
) as client: ) as client:
r = await client.get(url, headers={"User-Agent": USER_AGENT}) r = await client.get(url, headers={"User-Agent": USER_AGENT})
r.raise_for_status() r.raise_for_status()
ctype = r.headers.get("content-type", "") ctype = r.headers.get("content-type", "")
# JSON
if "application/json" in ctype: if "application/json" in ctype:
text, extractor = json.dumps(r.json(), indent=2), "json" text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
# HTML
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")): elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
doc = Document(r.text) doc = Document(r.text)
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary()) content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
@@ -142,13 +157,16 @@ class WebFetchTool(Tool):
text, extractor = r.text, "raw" text, extractor = r.text, "raw"
truncated = len(text) > max_chars truncated = len(text) > max_chars
if truncated: if truncated: text = text[:max_chars]
text = text[:max_chars]
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code, return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}) "extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
except httpx.ProxyError as e:
logger.error("WebFetch proxy error for {}: {}", url, e)
return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False)
except Exception as e: except Exception as e:
return json.dumps({"error": str(e), "url": url}) logger.error("WebFetch error for {}: {}", url, e)
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
def _to_markdown(self, html: str) -> str: def _to_markdown(self, html: str) -> str:
"""Convert HTML to markdown.""" """Convert HTML to markdown."""

View File

@@ -16,11 +16,12 @@ class InboundMessage:
timestamp: datetime = field(default_factory=datetime.now) timestamp: datetime = field(default_factory=datetime.now)
media: list[str] = field(default_factory=list) # Media URLs media: list[str] = field(default_factory=list) # Media URLs
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
session_key_override: str | None = None # Optional override for thread-scoped sessions
@property @property
def session_key(self) -> str: def session_key(self) -> str:
"""Unique key for session identification.""" """Unique key for session identification."""
return f"{self.channel}:{self.chat_id}" return self.session_key_override or f"{self.channel}:{self.chat_id}"
@dataclass @dataclass

View File

@@ -1,9 +1,6 @@
"""Async message queue for decoupled channel-agent communication.""" """Async message queue for decoupled channel-agent communication."""
import asyncio import asyncio
from typing import Callable, Awaitable
from loguru import logger
from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.events import InboundMessage, OutboundMessage
@@ -19,8 +16,6 @@ class MessageBus:
def __init__(self): def __init__(self):
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue() self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue() self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
self._outbound_subscribers: dict[str, list[Callable[[OutboundMessage], Awaitable[None]]]] = {}
self._running = False
async def publish_inbound(self, msg: InboundMessage) -> None: async def publish_inbound(self, msg: InboundMessage) -> None:
"""Publish a message from a channel to the agent.""" """Publish a message from a channel to the agent."""
@@ -38,38 +33,6 @@ class MessageBus:
"""Consume the next outbound message (blocks until available).""" """Consume the next outbound message (blocks until available)."""
return await self.outbound.get() return await self.outbound.get()
def subscribe_outbound(
self,
channel: str,
callback: Callable[[OutboundMessage], Awaitable[None]]
) -> None:
"""Subscribe to outbound messages for a specific channel."""
if channel not in self._outbound_subscribers:
self._outbound_subscribers[channel] = []
self._outbound_subscribers[channel].append(callback)
async def dispatch_outbound(self) -> None:
"""
Dispatch outbound messages to subscribed channels.
Run this as a background task.
"""
self._running = True
while self._running:
try:
msg = await asyncio.wait_for(self.outbound.get(), timeout=1.0)
subscribers = self._outbound_subscribers.get(msg.channel, [])
for callback in subscribers:
try:
await callback(msg)
except Exception as e:
logger.error(f"Error dispatching to {msg.channel}: {e}")
except asyncio.TimeoutError:
continue
def stop(self) -> None:
"""Stop the dispatcher loop."""
self._running = False
@property @property
def inbound_size(self) -> int: def inbound_size(self) -> int:
"""Number of pending inbound messages.""" """Number of pending inbound messages."""

View File

@@ -59,29 +59,14 @@ class BaseChannel(ABC):
pass pass
def is_allowed(self, sender_id: str) -> bool: def is_allowed(self, sender_id: str) -> bool:
""" """Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
Check if a sender is allowed to use this bot.
Args:
sender_id: The sender's identifier.
Returns:
True if allowed, False otherwise.
"""
allow_list = getattr(self.config, "allow_from", []) allow_list = getattr(self.config, "allow_from", [])
# If no allow list, allow everyone
if not allow_list: if not allow_list:
return True logger.warning("{}: allow_from is empty — all access denied", self.name)
sender_str = str(sender_id)
if sender_str in allow_list:
return True
if "|" in sender_str:
for part in sender_str.split("|"):
if part and part in allow_list:
return True
return False return False
if "*" in allow_list:
return True
return str(sender_id) in allow_list
async def _handle_message( async def _handle_message(
self, self,
@@ -89,7 +74,8 @@ class BaseChannel(ABC):
chat_id: str, chat_id: str,
content: str, content: str,
media: list[str] | None = None, media: list[str] | None = None,
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None,
session_key: str | None = None,
) -> None: ) -> None:
""" """
Handle an incoming message from the chat platform. Handle an incoming message from the chat platform.
@@ -102,11 +88,13 @@ class BaseChannel(ABC):
content: Message text content. content: Message text content.
media: Optional list of media URLs. media: Optional list of media URLs.
metadata: Optional channel-specific metadata. metadata: Optional channel-specific metadata.
session_key: Optional session key override (e.g. thread-scoped sessions).
""" """
if not self.is_allowed(sender_id): if not self.is_allowed(sender_id):
logger.warning( logger.warning(
f"Access denied for sender {sender_id} on channel {self.name}. " "Access denied for sender {} on channel {}. "
f"Add them to allowFrom list in config to grant access." "Add them to allowFrom list in config to grant access.",
sender_id, self.name,
) )
return return
@@ -116,7 +104,8 @@ class BaseChannel(ABC):
chat_id=str(chat_id), chat_id=str(chat_id),
content=content, content=content,
media=media or [], media=media or [],
metadata=metadata or {} metadata=metadata or {},
session_key_override=session_key,
) )
await self.bus.publish_inbound(msg) await self.bus.publish_inbound(msg)

View File

@@ -2,11 +2,15 @@
import asyncio import asyncio
import json import json
import mimetypes
import os
import time import time
from pathlib import Path
from typing import Any from typing import Any
from urllib.parse import unquote, urlparse
from loguru import logger
import httpx import httpx
from loguru import logger
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@@ -15,11 +19,11 @@ from nanobot.config.schema import DingTalkConfig
try: try:
from dingtalk_stream import ( from dingtalk_stream import (
DingTalkStreamClient, AckMessage,
Credential,
CallbackHandler, CallbackHandler,
CallbackMessage, CallbackMessage,
AckMessage, Credential,
DingTalkStreamClient,
) )
from dingtalk_stream.chatbot import ChatbotMessage from dingtalk_stream.chatbot import ChatbotMessage
@@ -58,19 +62,32 @@ class NanobotDingTalkHandler(CallbackHandler):
if not content: if not content:
logger.warning( logger.warning(
f"Received empty or unsupported message type: {chatbot_msg.message_type}" "Received empty or unsupported message type: {}",
chatbot_msg.message_type,
) )
return AckMessage.STATUS_OK, "OK" return AckMessage.STATUS_OK, "OK"
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
sender_name = chatbot_msg.sender_nick or "Unknown" sender_name = chatbot_msg.sender_nick or "Unknown"
logger.info(f"Received DingTalk message from {sender_name} ({sender_id}): {content}") conversation_type = message.data.get("conversationType")
conversation_id = (
message.data.get("conversationId")
or message.data.get("openConversationId")
)
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
# Forward to Nanobot via _on_message (non-blocking). # Forward to Nanobot via _on_message (non-blocking).
# Store reference to prevent GC before task completes. # Store reference to prevent GC before task completes.
task = asyncio.create_task( task = asyncio.create_task(
self.channel._on_message(content, sender_id, sender_name) self.channel._on_message(
content,
sender_id,
sender_name,
conversation_type,
conversation_id,
)
) )
self.channel._background_tasks.add(task) self.channel._background_tasks.add(task)
task.add_done_callback(self.channel._background_tasks.discard) task.add_done_callback(self.channel._background_tasks.discard)
@@ -78,7 +95,7 @@ class NanobotDingTalkHandler(CallbackHandler):
return AckMessage.STATUS_OK, "OK" return AckMessage.STATUS_OK, "OK"
except Exception as e: except Exception as e:
logger.error(f"Error processing DingTalk message: {e}") logger.error("Error processing DingTalk message: {}", e)
# Return OK to avoid retry loop from DingTalk server # Return OK to avoid retry loop from DingTalk server
return AckMessage.STATUS_OK, "Error" return AckMessage.STATUS_OK, "Error"
@@ -90,11 +107,14 @@ class DingTalkChannel(BaseChannel):
Uses WebSocket to receive events via `dingtalk-stream` SDK. Uses WebSocket to receive events via `dingtalk-stream` SDK.
Uses direct HTTP API to send messages (SDK is mainly for receiving). Uses direct HTTP API to send messages (SDK is mainly for receiving).
Note: Currently only supports private (1:1) chat. Group messages are Supports both private (1:1) and group chats.
received but replies are sent back as private messages to the sender. Group chat_id is stored with a "group:" prefix to route replies back.
""" """
name = "dingtalk" name = "dingtalk"
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
def __init__(self, config: DingTalkConfig, bus: MessageBus): def __init__(self, config: DingTalkConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
@@ -126,7 +146,8 @@ class DingTalkChannel(BaseChannel):
self._http = httpx.AsyncClient() self._http = httpx.AsyncClient()
logger.info( logger.info(
f"Initializing DingTalk Stream Client with Client ID: {self.config.client_id}..." "Initializing DingTalk Stream Client with Client ID: {}...",
self.config.client_id,
) )
credential = Credential(self.config.client_id, self.config.client_secret) credential = Credential(self.config.client_id, self.config.client_secret)
self._client = DingTalkStreamClient(credential) self._client = DingTalkStreamClient(credential)
@@ -142,13 +163,13 @@ class DingTalkChannel(BaseChannel):
try: try:
await self._client.start() await self._client.start()
except Exception as e: except Exception as e:
logger.warning(f"DingTalk stream error: {e}") logger.warning("DingTalk stream error: {}", e)
if self._running: if self._running:
logger.info("Reconnecting DingTalk stream in 5 seconds...") logger.info("Reconnecting DingTalk stream in 5 seconds...")
await asyncio.sleep(5) await asyncio.sleep(5)
except Exception as e: except Exception as e:
logger.exception(f"Failed to start DingTalk channel: {e}") logger.exception("Failed to start DingTalk channel: {}", e)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the DingTalk bot.""" """Stop the DingTalk bot."""
@@ -186,60 +207,265 @@ class DingTalkChannel(BaseChannel):
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60 self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
return self._access_token return self._access_token
except Exception as e: except Exception as e:
logger.error(f"Failed to get DingTalk access token: {e}") logger.error("Failed to get DingTalk access token: {}", e)
return None return None
@staticmethod
def _is_http_url(value: str) -> bool:
return urlparse(value).scheme in ("http", "https")
def _guess_upload_type(self, media_ref: str) -> str:
ext = Path(urlparse(media_ref).path).suffix.lower()
if ext in self._IMAGE_EXTS: return "image"
if ext in self._AUDIO_EXTS: return "voice"
if ext in self._VIDEO_EXTS: return "video"
return "file"
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
name = os.path.basename(urlparse(media_ref).path)
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
async def _read_media_bytes(
self,
media_ref: str,
) -> tuple[bytes | None, str | None, str | None]:
if not media_ref:
return None, None, None
if self._is_http_url(media_ref):
if not self._http:
return None, None, None
try:
resp = await self._http.get(media_ref, follow_redirects=True)
if resp.status_code >= 400:
logger.warning(
"DingTalk media download failed status={} ref={}",
resp.status_code,
media_ref,
)
return None, None, None
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
return resp.content, filename, content_type or None
except Exception as e:
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
return None, None, None
try:
if media_ref.startswith("file://"):
parsed = urlparse(media_ref)
local_path = Path(unquote(parsed.path))
else:
local_path = Path(os.path.expanduser(media_ref))
if not local_path.is_file():
logger.warning("DingTalk media file not found: {}", local_path)
return None, None, None
data = await asyncio.to_thread(local_path.read_bytes)
content_type = mimetypes.guess_type(local_path.name)[0]
return data, local_path.name, content_type
except Exception as e:
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
return None, None, None
async def _upload_media(
self,
token: str,
data: bytes,
media_type: str,
filename: str,
content_type: str | None,
) -> str | None:
if not self._http:
return None
url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
files = {"media": (filename, data, mime)}
try:
resp = await self._http.post(url, files=files)
text = resp.text
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
if resp.status_code >= 400:
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
return None
errcode = result.get("errcode", 0)
if errcode != 0:
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
return None
sub = result.get("result") or {}
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
if not media_id:
logger.error("DingTalk media upload missing media_id body={}", text[:500])
return None
return str(media_id)
except Exception as e:
logger.error("DingTalk media upload error type={} err={}", media_type, e)
return None
async def _send_batch_message(
self,
token: str,
chat_id: str,
msg_key: str,
msg_param: dict[str, Any],
) -> bool:
if not self._http:
logger.warning("DingTalk HTTP client not initialized, cannot send")
return False
headers = {"x-acs-dingtalk-access-token": token}
if chat_id.startswith("group:"):
# Group chat
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
payload = {
"robotCode": self.config.client_id,
"openConversationId": chat_id[6:], # Remove "group:" prefix,
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
else:
# Private chat
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
payload = {
"robotCode": self.config.client_id,
"userIds": [chat_id],
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
try:
resp = await self._http.post(url, json=payload, headers=headers)
body = resp.text
if resp.status_code != 200:
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
return False
try: result = resp.json()
except Exception: result = {}
errcode = result.get("errcode")
if errcode not in (None, 0):
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
return False
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
return True
except Exception as e:
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
return False
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
return await self._send_batch_message(
token,
chat_id,
"sampleMarkdown",
{"text": content, "title": "Nanobot Reply"},
)
async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
media_ref = (media_ref or "").strip()
if not media_ref:
return True
upload_type = self._guess_upload_type(media_ref)
if upload_type == "image" and self._is_http_url(media_ref):
ok = await self._send_batch_message(
token,
chat_id,
"sampleImageMsg",
{"photoURL": media_ref},
)
if ok:
return True
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
data, filename, content_type = await self._read_media_bytes(media_ref)
if not data:
logger.error("DingTalk media read failed: {}", media_ref)
return False
filename = filename or self._guess_filename(media_ref, upload_type)
file_type = Path(filename).suffix.lower().lstrip(".")
if not file_type:
guessed = mimetypes.guess_extension(content_type or "")
file_type = (guessed or ".bin").lstrip(".")
if file_type == "jpeg":
file_type = "jpg"
media_id = await self._upload_media(
token=token,
data=data,
media_type=upload_type,
filename=filename,
content_type=content_type,
)
if not media_id:
return False
if upload_type == "image":
# Verified in production: sampleImageMsg accepts media_id in photoURL.
ok = await self._send_batch_message(
token,
chat_id,
"sampleImageMsg",
{"photoURL": media_id},
)
if ok:
return True
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
return await self._send_batch_message(
token,
chat_id,
"sampleFile",
{"mediaId": media_id, "fileName": filename, "fileType": file_type},
)
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through DingTalk.""" """Send a message through DingTalk."""
token = await self._get_access_token() token = await self._get_access_token()
if not token: if not token:
return return
# oToMessages/batchSend: sends to individual users (private chat) if msg.content and msg.content.strip():
# https://open.dingtalk.com/document/orgapp/robot-batch-send-messages await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
headers = {"x-acs-dingtalk-access-token": token} for media_ref in msg.media or []:
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
if ok:
continue
logger.error("DingTalk media send failed for {}", media_ref)
# Send visible fallback so failures are observable by the user.
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
await self._send_markdown_text(
token,
msg.chat_id,
f"[Attachment send failed: {filename}]",
)
data = { async def _on_message(
"robotCode": self.config.client_id, self,
"userIds": [msg.chat_id], # chat_id is the user's staffId content: str,
"msgKey": "sampleMarkdown", sender_id: str,
"msgParam": json.dumps({ sender_name: str,
"text": msg.content, conversation_type: str | None = None,
"title": "Nanobot Reply", conversation_id: str | None = None,
}), ) -> None:
}
if not self._http:
logger.warning("DingTalk HTTP client not initialized, cannot send")
return
try:
resp = await self._http.post(url, json=data, headers=headers)
if resp.status_code != 200:
logger.error(f"DingTalk send failed: {resp.text}")
else:
logger.debug(f"DingTalk message sent to {msg.chat_id}")
except Exception as e:
logger.error(f"Error sending DingTalk message: {e}")
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
"""Handle incoming message (called by NanobotDingTalkHandler). """Handle incoming message (called by NanobotDingTalkHandler).
Delegates to BaseChannel._handle_message() which enforces allow_from Delegates to BaseChannel._handle_message() which enforces allow_from
permission checks before publishing to the bus. permission checks before publishing to the bus.
""" """
try: try:
logger.info(f"DingTalk inbound: {content} from {sender_name}") logger.info("DingTalk inbound: {} from {}", content, sender_name)
is_group = conversation_type == "2" and conversation_id
chat_id = f"group:{conversation_id}" if is_group else sender_id
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=sender_id, # For private chat, chat_id == sender_id chat_id=chat_id,
content=str(content), content=str(content),
metadata={ metadata={
"sender_name": sender_name, "sender_name": sender_name,
"platform": "dingtalk", "platform": "dingtalk",
"conversation_type": conversation_type,
}, },
) )
except Exception as e: except Exception as e:
logger.error(f"Error publishing DingTalk message: {e}") logger.error("Error publishing DingTalk message: {}", e)

View File

@@ -13,10 +13,11 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import DiscordConfig from nanobot.config.schema import DiscordConfig
from nanobot.utils.helpers import split_message
DISCORD_API_BASE = "https://discord.com/api/v10" DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit
class DiscordChannel(BaseChannel): class DiscordChannel(BaseChannel):
@@ -32,6 +33,7 @@ class DiscordChannel(BaseChannel):
self._heartbeat_task: asyncio.Task | None = None self._heartbeat_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tasks: dict[str, asyncio.Task] = {}
self._http: httpx.AsyncClient | None = None self._http: httpx.AsyncClient | None = None
self._bot_user_id: str | None = None
async def start(self) -> None: async def start(self) -> None:
"""Start the Discord gateway connection.""" """Start the Discord gateway connection."""
@@ -51,7 +53,7 @@ class DiscordChannel(BaseChannel):
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.warning(f"Discord gateway error: {e}") logger.warning("Discord gateway error: {}", e)
if self._running: if self._running:
logger.info("Reconnecting to Discord gateway in 5 seconds...") logger.info("Reconnecting to Discord gateway in 5 seconds...")
await asyncio.sleep(5) await asyncio.sleep(5)
@@ -73,39 +75,117 @@ class DiscordChannel(BaseChannel):
self._http = None self._http = None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Discord REST API.""" """Send a message through Discord REST API, including file attachments."""
if not self._http: if not self._http:
logger.warning("Discord HTTP client not initialized") logger.warning("Discord HTTP client not initialized")
return return
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages" url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
payload: dict[str, Any] = {"content": msg.content}
if msg.reply_to:
payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False}
headers = {"Authorization": f"Bot {self.config.token}"} headers = {"Authorization": f"Bot {self.config.token}"}
try: try:
sent_media = False
failed_media: list[str] = []
# Send file attachments first
for media_path in msg.media or []:
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
sent_media = True
else:
failed_media.append(Path(media_path).name)
# Send text content
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks and failed_media and not sent_media:
chunks = split_message(
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
MAX_MESSAGE_LEN,
)
if not chunks:
return
for i, chunk in enumerate(chunks):
payload: dict[str, Any] = {"content": chunk}
# Let the first successful attachment carry the reply if present.
if i == 0 and msg.reply_to and not sent_media:
payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False}
if not await self._send_payload(url, headers, payload):
break # Abort remaining chunks on failure
finally:
await self._stop_typing(msg.chat_id)
async def _send_payload(
self, url: str, headers: dict[str, str], payload: dict[str, Any]
) -> bool:
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
for attempt in range(3): for attempt in range(3):
try: try:
response = await self._http.post(url, headers=headers, json=payload) response = await self._http.post(url, headers=headers, json=payload)
if response.status_code == 429: if response.status_code == 429:
data = response.json() data = response.json()
retry_after = float(data.get("retry_after", 1.0)) retry_after = float(data.get("retry_after", 1.0))
logger.warning(f"Discord rate limited, retrying in {retry_after}s") logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after) await asyncio.sleep(retry_after)
continue continue
response.raise_for_status() response.raise_for_status()
return return True
except Exception as e: except Exception as e:
if attempt == 2: if attempt == 2:
logger.error(f"Error sending Discord message: {e}") logger.error("Error sending Discord message: {}", e)
else: else:
await asyncio.sleep(1) await asyncio.sleep(1)
finally: return False
await self._stop_typing(msg.chat_id)
async def _send_file(
self,
url: str,
headers: dict[str, str],
file_path: str,
reply_to: str | None = None,
) -> bool:
"""Send a file attachment via Discord REST API using multipart/form-data."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
payload_json: dict[str, Any] = {}
if reply_to:
payload_json["message_reference"] = {"message_id": reply_to}
payload_json["allowed_mentions"] = {"replied_user": False}
for attempt in range(3):
try:
with open(path, "rb") as f:
files = {"files[0]": (path.name, f, "application/octet-stream")}
data: dict[str, Any] = {}
if payload_json:
data["payload_json"] = json.dumps(payload_json)
response = await self._http.post(
url, headers=headers, files=files, data=data
)
if response.status_code == 429:
resp_data = response.json()
retry_after = float(resp_data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord file {}: {}", path.name, e)
else:
await asyncio.sleep(1)
return False
async def _gateway_loop(self) -> None: async def _gateway_loop(self) -> None:
"""Main gateway loop: identify, heartbeat, dispatch events.""" """Main gateway loop: identify, heartbeat, dispatch events."""
@@ -116,7 +196,7 @@ class DiscordChannel(BaseChannel):
try: try:
data = json.loads(raw) data = json.loads(raw)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Invalid JSON from Discord gateway: {raw[:100]}") logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
continue continue
op = data.get("op") op = data.get("op")
@@ -134,6 +214,10 @@ class DiscordChannel(BaseChannel):
await self._identify() await self._identify()
elif op == 0 and event_type == "READY": elif op == 0 and event_type == "READY":
logger.info("Discord gateway READY") logger.info("Discord gateway READY")
# Capture bot user ID for mention detection
user_data = payload.get("user") or {}
self._bot_user_id = user_data.get("id")
logger.info("Discord bot connected as user {}", self._bot_user_id)
elif op == 0 and event_type == "MESSAGE_CREATE": elif op == 0 and event_type == "MESSAGE_CREATE":
await self._handle_message_create(payload) await self._handle_message_create(payload)
elif op == 7: elif op == 7:
@@ -175,7 +259,7 @@ class DiscordChannel(BaseChannel):
try: try:
await self._ws.send(json.dumps(payload)) await self._ws.send(json.dumps(payload))
except Exception as e: except Exception as e:
logger.warning(f"Discord heartbeat failed: {e}") logger.warning("Discord heartbeat failed: {}", e)
break break
await asyncio.sleep(interval_s) await asyncio.sleep(interval_s)
@@ -190,6 +274,7 @@ class DiscordChannel(BaseChannel):
sender_id = str(author.get("id", "")) sender_id = str(author.get("id", ""))
channel_id = str(payload.get("channel_id", "")) channel_id = str(payload.get("channel_id", ""))
content = payload.get("content") or "" content = payload.get("content") or ""
guild_id = payload.get("guild_id")
if not sender_id or not channel_id: if not sender_id or not channel_id:
return return
@@ -197,6 +282,11 @@ class DiscordChannel(BaseChannel):
if not self.is_allowed(sender_id): if not self.is_allowed(sender_id):
return return
# Check group channel policy (DMs always respond if is_allowed passes)
if guild_id is not None:
if not self._should_respond_in_group(payload, content):
return
content_parts = [content] if content else [] content_parts = [content] if content else []
media_paths: list[str] = [] media_paths: list[str] = []
media_dir = Path.home() / ".nanobot" / "media" media_dir = Path.home() / ".nanobot" / "media"
@@ -219,7 +309,7 @@ class DiscordChannel(BaseChannel):
media_paths.append(str(file_path)) media_paths.append(str(file_path))
content_parts.append(f"[attachment: {file_path}]") content_parts.append(f"[attachment: {file_path}]")
except Exception as e: except Exception as e:
logger.warning(f"Failed to download Discord attachment: {e}") logger.warning("Failed to download Discord attachment: {}", e)
content_parts.append(f"[attachment: {filename} - download failed]") content_parts.append(f"[attachment: {filename} - download failed]")
reply_to = (payload.get("referenced_message") or {}).get("id") reply_to = (payload.get("referenced_message") or {}).get("id")
@@ -233,11 +323,32 @@ class DiscordChannel(BaseChannel):
media=media_paths, media=media_paths,
metadata={ metadata={
"message_id": str(payload.get("id", "")), "message_id": str(payload.get("id", "")),
"guild_id": payload.get("guild_id"), "guild_id": guild_id,
"reply_to": reply_to, "reply_to": reply_to,
}, },
) )
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
"""Check if bot should respond in a group channel based on policy."""
if self.config.group_policy == "open":
return True
if self.config.group_policy == "mention":
# Check if bot was mentioned in the message
if self._bot_user_id:
# Check mentions array
mentions = payload.get("mentions") or []
for mention in mentions:
if str(mention.get("id")) == self._bot_user_id:
return True
# Also check content for mention format <@USER_ID>
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
return False
return True
async def _start_typing(self, channel_id: str) -> None: async def _start_typing(self, channel_id: str) -> None:
"""Start periodic typing indicator for a channel.""" """Start periodic typing indicator for a channel."""
await self._stop_typing(channel_id) await self._stop_typing(channel_id)
@@ -248,8 +359,11 @@ class DiscordChannel(BaseChannel):
while self._running: while self._running:
try: try:
await self._http.post(url, headers=headers) await self._http.post(url, headers=headers)
except Exception: except asyncio.CancelledError:
pass return
except Exception as e:
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
return
await asyncio.sleep(8) await asyncio.sleep(8)
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop()) self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())

View File

@@ -94,7 +94,7 @@ class EmailChannel(BaseChannel):
metadata=item.get("metadata", {}), metadata=item.get("metadata", {}),
) )
except Exception as e: except Exception as e:
logger.error(f"Email polling error: {e}") logger.error("Email polling error: {}", e)
await asyncio.sleep(poll_seconds) await asyncio.sleep(poll_seconds)
@@ -108,11 +108,6 @@ class EmailChannel(BaseChannel):
logger.warning("Skip email send: consent_granted is false") logger.warning("Skip email send: consent_granted is false")
return return
force_send = bool((msg.metadata or {}).get("force_send"))
if not self.config.auto_reply_enabled and not force_send:
logger.info("Skip automatic email reply: auto_reply_enabled is false")
return
if not self.config.smtp_host: if not self.config.smtp_host:
logger.warning("Email channel SMTP host not configured") logger.warning("Email channel SMTP host not configured")
return return
@@ -122,6 +117,15 @@ class EmailChannel(BaseChannel):
logger.warning("Email channel missing recipient address") logger.warning("Email channel missing recipient address")
return return
# Determine if this is a reply (recipient has sent us an email before)
is_reply = to_addr in self._last_subject_by_chat
force_send = bool((msg.metadata or {}).get("force_send"))
# autoReplyEnabled only controls automatic replies, not proactive sends
if is_reply and not self.config.auto_reply_enabled and not force_send:
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
return
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply") base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
subject = self._reply_subject(base_subject) subject = self._reply_subject(base_subject)
if msg.metadata and isinstance(msg.metadata.get("subject"), str): if msg.metadata and isinstance(msg.metadata.get("subject"), str):
@@ -143,7 +147,7 @@ class EmailChannel(BaseChannel):
try: try:
await asyncio.to_thread(self._smtp_send, email_msg) await asyncio.to_thread(self._smtp_send, email_msg)
except Exception as e: except Exception as e:
logger.error(f"Error sending email to {to_addr}: {e}") logger.error("Error sending email to {}: {}", to_addr, e)
raise raise
def _validate_config(self) -> bool: def _validate_config(self) -> bool:
@@ -162,7 +166,7 @@ class EmailChannel(BaseChannel):
missing.append("smtp_password") missing.append("smtp_password")
if missing: if missing:
logger.error(f"Email channel not configured, missing: {', '.join(missing)}") logger.error("Email channel not configured, missing: {}", ', '.join(missing))
return False return False
return True return True
@@ -304,7 +308,8 @@ class EmailChannel(BaseChannel):
self._processed_uids.add(uid) self._processed_uids.add(uid)
# mark_seen is the primary dedup; this set is a safety net # mark_seen is the primary dedup; this set is a safety net
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS: if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
self._processed_uids.clear() # Evict a random half to cap memory; mark_seen is the primary dedup
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
if mark_seen: if mark_seen:
client.store(imap_id, "+FLAGS", "\\Seen") client.store(imap_id, "+FLAGS", "\\Seen")

File diff suppressed because it is too large Load Diff

View File

@@ -45,7 +45,7 @@ class ChannelManager:
) )
logger.info("Telegram channel enabled") logger.info("Telegram channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"Telegram channel not available: {e}") logger.warning("Telegram channel not available: {}", e)
# WhatsApp channel # WhatsApp channel
if self.config.channels.whatsapp.enabled: if self.config.channels.whatsapp.enabled:
@@ -56,7 +56,7 @@ class ChannelManager:
) )
logger.info("WhatsApp channel enabled") logger.info("WhatsApp channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"WhatsApp channel not available: {e}") logger.warning("WhatsApp channel not available: {}", e)
# Discord channel # Discord channel
if self.config.channels.discord.enabled: if self.config.channels.discord.enabled:
@@ -67,18 +67,19 @@ class ChannelManager:
) )
logger.info("Discord channel enabled") logger.info("Discord channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"Discord channel not available: {e}") logger.warning("Discord channel not available: {}", e)
# Feishu channel # Feishu channel
if self.config.channels.feishu.enabled: if self.config.channels.feishu.enabled:
try: try:
from nanobot.channels.feishu import FeishuChannel from nanobot.channels.feishu import FeishuChannel
self.channels["feishu"] = FeishuChannel( self.channels["feishu"] = FeishuChannel(
self.config.channels.feishu, self.bus self.config.channels.feishu, self.bus,
groq_api_key=self.config.providers.groq.api_key,
) )
logger.info("Feishu channel enabled") logger.info("Feishu channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"Feishu channel not available: {e}") logger.warning("Feishu channel not available: {}", e)
# Mochat channel # Mochat channel
if self.config.channels.mochat.enabled: if self.config.channels.mochat.enabled:
@@ -90,7 +91,7 @@ class ChannelManager:
) )
logger.info("Mochat channel enabled") logger.info("Mochat channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"Mochat channel not available: {e}") logger.warning("Mochat channel not available: {}", e)
# DingTalk channel # DingTalk channel
if self.config.channels.dingtalk.enabled: if self.config.channels.dingtalk.enabled:
@@ -101,7 +102,7 @@ class ChannelManager:
) )
logger.info("DingTalk channel enabled") logger.info("DingTalk channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"DingTalk channel not available: {e}") logger.warning("DingTalk channel not available: {}", e)
# Email channel # Email channel
if self.config.channels.email.enabled: if self.config.channels.email.enabled:
@@ -112,7 +113,7 @@ class ChannelManager:
) )
logger.info("Email channel enabled") logger.info("Email channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"Email channel not available: {e}") logger.warning("Email channel not available: {}", e)
# Slack channel # Slack channel
if self.config.channels.slack.enabled: if self.config.channels.slack.enabled:
@@ -123,7 +124,7 @@ class ChannelManager:
) )
logger.info("Slack channel enabled") logger.info("Slack channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"Slack channel not available: {e}") logger.warning("Slack channel not available: {}", e)
# QQ channel # QQ channel
if self.config.channels.qq.enabled: if self.config.channels.qq.enabled:
@@ -135,14 +136,36 @@ class ChannelManager:
) )
logger.info("QQ channel enabled") logger.info("QQ channel enabled")
except ImportError as e: except ImportError as e:
logger.warning(f"QQ channel not available: {e}") logger.warning("QQ channel not available: {}", e)
# Matrix channel
if self.config.channels.matrix.enabled:
try:
from nanobot.channels.matrix import MatrixChannel
self.channels["matrix"] = MatrixChannel(
self.config.channels.matrix,
self.bus,
)
logger.info("Matrix channel enabled")
except ImportError as e:
logger.warning("Matrix channel not available: {}", e)
self._validate_allow_from()
def _validate_allow_from(self) -> None:
for name, ch in self.channels.items():
if getattr(ch.config, "allow_from", None) == []:
raise SystemExit(
f'Error: "{name}" has empty allowFrom (denies all). '
f'Set ["*"] to allow everyone, or add specific user IDs.'
)
async def _start_channel(self, name: str, channel: BaseChannel) -> None: async def _start_channel(self, name: str, channel: BaseChannel) -> None:
"""Start a channel and log any exceptions.""" """Start a channel and log any exceptions."""
try: try:
await channel.start() await channel.start()
except Exception as e: except Exception as e:
logger.error(f"Failed to start channel {name}: {e}") logger.error("Failed to start channel {}: {}", name, e)
async def start_all(self) -> None: async def start_all(self) -> None:
"""Start all channels and the outbound dispatcher.""" """Start all channels and the outbound dispatcher."""
@@ -156,7 +179,7 @@ class ChannelManager:
# Start channels # Start channels
tasks = [] tasks = []
for name, channel in self.channels.items(): for name, channel in self.channels.items():
logger.info(f"Starting {name} channel...") logger.info("Starting {} channel...", name)
tasks.append(asyncio.create_task(self._start_channel(name, channel))) tasks.append(asyncio.create_task(self._start_channel(name, channel)))
# Wait for all to complete (they should run forever) # Wait for all to complete (they should run forever)
@@ -178,9 +201,9 @@ class ChannelManager:
for name, channel in self.channels.items(): for name, channel in self.channels.items():
try: try:
await channel.stop() await channel.stop()
logger.info(f"Stopped {name} channel") logger.info("Stopped {} channel", name)
except Exception as e: except Exception as e:
logger.error(f"Error stopping {name}: {e}") logger.error("Error stopping {}: {}", name, e)
async def _dispatch_outbound(self) -> None: async def _dispatch_outbound(self) -> None:
"""Dispatch outbound messages to the appropriate channel.""" """Dispatch outbound messages to the appropriate channel."""
@@ -193,14 +216,20 @@ class ChannelManager:
timeout=1.0 timeout=1.0
) )
if msg.metadata.get("_progress"):
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
continue
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
continue
channel = self.channels.get(msg.channel) channel = self.channels.get(msg.channel)
if channel: if channel:
try: try:
await channel.send(msg) await channel.send(msg)
except Exception as e: except Exception as e:
logger.error(f"Error sending to {msg.channel}: {e}") logger.error("Error sending to {}: {}", msg.channel, e)
else: else:
logger.warning(f"Unknown channel: {msg.channel}") logger.warning("Unknown channel: {}", msg.channel)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue

699
nanobot/channels/matrix.py Normal file
View File

@@ -0,0 +1,699 @@
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
import asyncio
import logging
import mimetypes
from pathlib import Path
from typing import Any, TypeAlias
from loguru import logger
try:
import nh3
from mistune import create_markdown
from nio import (
AsyncClient,
AsyncClientConfig,
ContentRepositoryConfigError,
DownloadError,
InviteEvent,
JoinError,
MatrixRoom,
MemoryDownloadResponse,
RoomEncryptedMedia,
RoomMessage,
RoomMessageMedia,
RoomMessageText,
RoomSendError,
RoomTypingError,
SyncError,
UploadError,
)
from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError
except ImportError as e:
raise ImportError(
"Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
) from e
from nanobot.bus.events import OutboundMessage
from nanobot.channels.base import BaseChannel
from nanobot.config.loader import get_data_dir
from nanobot.utils.helpers import safe_filename
TYPING_NOTICE_TIMEOUT_MS = 30_000
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
TYPING_KEEPALIVE_INTERVAL_MS = 20_000
MATRIX_HTML_FORMAT = "org.matrix.custom.html"
_ATTACH_MARKER = "[attachment: {}]"
_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
_ATTACH_FAILED = "[attachment: {} - download failed]"
_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
_DEFAULT_ATTACH_NAME = "attachment"
_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
MATRIX_MARKDOWN = create_markdown(
escape=True,
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
)
MATRIX_ALLOWED_HTML_TAGS = {
"p", "a", "strong", "em", "del", "code", "pre", "blockquote",
"ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
"hr", "br", "table", "thead", "tbody", "tr", "th", "td",
"caption", "sup", "sub", "img",
}
MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
"a": {"href"}, "code": {"class"}, "ol": {"start"},
"img": {"src", "alt", "title", "width", "height"},
}
MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
"""Filter attribute values to a safe Matrix-compatible subset."""
if tag == "a" and attr == "href":
return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
if tag == "img" and attr == "src":
return value if value.lower().startswith("mxc://") else None
if tag == "code" and attr == "class":
classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
return " ".join(classes) if classes else None
return value
MATRIX_HTML_CLEANER = nh3.Cleaner(
tags=MATRIX_ALLOWED_HTML_TAGS,
attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
attribute_filter=_filter_matrix_html_attribute,
url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
strip_comments=True,
link_rel="noopener noreferrer",
)
def _render_markdown_html(text: str) -> str | None:
"""Render markdown to sanitized HTML; returns None for plain text."""
try:
formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
except Exception:
return None
if not formatted:
return None
# Skip formatted_body for plain <p>text</p> to keep payload minimal.
if formatted.startswith("<p>") and formatted.endswith("</p>"):
inner = formatted[3:-4]
if "<" not in inner and ">" not in inner:
return None
return formatted
def _build_matrix_text_content(text: str) -> dict[str, object]:
"""Build Matrix m.text payload with optional HTML formatted_body."""
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
if html := _render_markdown_html(text):
content["format"] = MATRIX_HTML_FORMAT
content["formatted_body"] = html
return content
class _NioLoguruHandler(logging.Handler):
"""Route matrix-nio stdlib logs into Loguru."""
def emit(self, record: logging.LogRecord) -> None:
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
frame, depth = logging.currentframe(), 2
while frame and frame.f_code.co_filename == logging.__file__:
frame, depth = frame.f_back, depth + 1
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
def _configure_nio_logging_bridge() -> None:
"""Bridge matrix-nio logs to Loguru (idempotent)."""
nio_logger = logging.getLogger("nio")
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
nio_logger.handlers = [_NioLoguruHandler()]
nio_logger.propagate = False
class MatrixChannel(BaseChannel):
"""Matrix (Element) channel using long-polling sync."""
name = "matrix"
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
workspace: Path | None = None):
super().__init__(config, bus)
self.client: AsyncClient | None = None
self._sync_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
self._restrict_to_workspace = restrict_to_workspace
self._workspace = workspace.expanduser().resolve() if workspace else None
self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False
async def start(self) -> None:
"""Start Matrix client and begin sync loop."""
self._running = True
_configure_nio_logging_bridge()
store_path = get_data_dir() / "matrix-store"
store_path.mkdir(parents=True, exist_ok=True)
self.client = AsyncClient(
homeserver=self.config.homeserver, user=self.config.user_id,
store_path=store_path,
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
)
self.client.user_id = self.config.user_id
self.client.access_token = self.config.access_token
self.client.device_id = self.config.device_id
self._register_event_callbacks()
self._register_response_callbacks()
if not self.config.e2ee_enabled:
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
if self.config.device_id:
try:
self.client.load_store()
except Exception:
logger.exception("Matrix store load failed; restart may replay recent messages.")
else:
logger.warning("Matrix device_id empty; restart may replay recent messages.")
self._sync_task = asyncio.create_task(self._sync_loop())
async def stop(self) -> None:
"""Stop the Matrix channel with graceful sync shutdown."""
self._running = False
for room_id in list(self._typing_tasks):
await self._stop_typing_keepalive(room_id, clear_typing=False)
if self.client:
self.client.stop_sync_forever()
if self._sync_task:
try:
await asyncio.wait_for(asyncio.shield(self._sync_task),
timeout=self.config.sync_stop_grace_seconds)
except (asyncio.TimeoutError, asyncio.CancelledError):
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
if self.client:
await self.client.close()
def _is_workspace_path_allowed(self, path: Path) -> bool:
"""Check path is inside workspace (when restriction enabled)."""
if not self._restrict_to_workspace or not self._workspace:
return True
try:
path.resolve(strict=False).relative_to(self._workspace)
return True
except ValueError:
return False
def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
"""Deduplicate and resolve outbound attachment paths."""
seen: set[str] = set()
candidates: list[Path] = []
for raw in media:
if not isinstance(raw, str) or not raw.strip():
continue
path = Path(raw.strip()).expanduser()
try:
key = str(path.resolve(strict=False))
except OSError:
key = str(path)
if key not in seen:
seen.add(key)
candidates.append(path)
return candidates
@staticmethod
def _build_outbound_attachment_content(
*, filename: str, mime: str, size_bytes: int,
mxc_url: str, encryption_info: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Build Matrix content payload for an uploaded file/image/audio/video."""
prefix = mime.split("/")[0]
msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
content: dict[str, Any] = {
"msgtype": msgtype, "body": filename, "filename": filename,
"info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
}
if encryption_info:
content["file"] = {**encryption_info, "url": mxc_url}
else:
content["url"] = mxc_url
return content
def _is_encrypted_room(self, room_id: str) -> bool:
if not self.client:
return False
room = getattr(self.client, "rooms", {}).get(room_id)
return bool(getattr(room, "encrypted", False))
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
"""Send m.room.message with E2EE options."""
if not self.client:
return
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
if self.config.e2ee_enabled:
kwargs["ignore_unverified_devices"] = True
await self.client.room_send(**kwargs)
async def _resolve_server_upload_limit_bytes(self) -> int | None:
"""Query homeserver upload limit once per channel lifecycle."""
if self._server_upload_limit_checked:
return self._server_upload_limit_bytes
self._server_upload_limit_checked = True
if not self.client:
return None
try:
response = await self.client.content_repository_config()
except Exception:
return None
upload_size = getattr(response, "upload_size", None)
if isinstance(upload_size, int) and upload_size > 0:
self._server_upload_limit_bytes = upload_size
return upload_size
return None
async def _effective_media_limit_bytes(self) -> int:
"""min(local config, server advertised) — 0 blocks all uploads."""
local_limit = max(int(self.config.max_media_bytes), 0)
server_limit = await self._resolve_server_upload_limit_bytes()
if server_limit is None:
return local_limit
return min(local_limit, server_limit) if local_limit else 0
async def _upload_and_send_attachment(
self, room_id: str, path: Path, limit_bytes: int,
relates_to: dict[str, Any] | None = None,
) -> str | None:
"""Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
if not self.client:
return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
resolved = path.expanduser().resolve(strict=False)
filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
fail = _ATTACH_UPLOAD_FAILED.format(filename)
if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
return fail
try:
size_bytes = resolved.stat().st_size
except OSError:
return fail
if limit_bytes <= 0 or size_bytes > limit_bytes:
return _ATTACH_TOO_LARGE.format(filename)
mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
try:
with resolved.open("rb") as f:
upload_result = await self.client.upload(
f, content_type=mime, filename=filename,
encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
filesize=size_bytes,
)
except Exception:
return fail
upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
if isinstance(upload_response, UploadError):
return fail
mxc_url = getattr(upload_response, "content_uri", None)
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
return fail
content = self._build_outbound_attachment_content(
filename=filename, mime=mime, size_bytes=size_bytes,
mxc_url=mxc_url, encryption_info=encryption_info,
)
if relates_to:
content["m.relates_to"] = relates_to
try:
await self._send_room_content(room_id, content)
except Exception:
return fail
return None
async def send(self, msg: OutboundMessage) -> None:
"""Send outbound content; clear typing for non-progress messages."""
if not self.client:
return
text = msg.content or ""
candidates = self._collect_outbound_media_candidates(msg.media)
relates_to = self._build_thread_relates_to(msg.metadata)
is_progress = bool((msg.metadata or {}).get("_progress"))
try:
failures: list[str] = []
if candidates:
limit_bytes = await self._effective_media_limit_bytes()
for path in candidates:
if fail := await self._upload_and_send_attachment(
room_id=msg.chat_id,
path=path,
limit_bytes=limit_bytes,
relates_to=relates_to,
):
failures.append(fail)
if failures:
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
if text or not candidates:
content = _build_matrix_text_content(text)
if relates_to:
content["m.relates_to"] = relates_to
await self._send_room_content(msg.chat_id, content)
finally:
if not is_progress:
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
def _register_event_callbacks(self) -> None:
self.client.add_event_callback(self._on_message, RoomMessageText)
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
self.client.add_event_callback(self._on_room_invite, InviteEvent)
def _register_response_callbacks(self) -> None:
self.client.add_response_callback(self._on_sync_error, SyncError)
self.client.add_response_callback(self._on_join_error, JoinError)
self.client.add_response_callback(self._on_send_error, RoomSendError)
def _log_response_error(self, label: str, response: Any) -> None:
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
code = getattr(response, "status_code", None)
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
is_fatal = is_auth or getattr(response, "soft_logout", False)
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
async def _on_sync_error(self, response: SyncError) -> None:
self._log_response_error("sync", response)
async def _on_join_error(self, response: JoinError) -> None:
self._log_response_error("join", response)
async def _on_send_error(self, response: RoomSendError) -> None:
self._log_response_error("send", response)
async def _set_typing(self, room_id: str, typing: bool) -> None:
"""Best-effort typing indicator update."""
if not self.client:
return
try:
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
timeout=TYPING_NOTICE_TIMEOUT_MS)
if isinstance(response, RoomTypingError):
logger.debug("Matrix typing failed for {}: {}", room_id, response)
except Exception:
pass
async def _start_typing_keepalive(self, room_id: str) -> None:
"""Start periodic typing refresh (spec-recommended keepalive)."""
await self._stop_typing_keepalive(room_id, clear_typing=False)
await self._set_typing(room_id, True)
if not self._running:
return
async def loop() -> None:
try:
while self._running:
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
await self._set_typing(room_id, True)
except asyncio.CancelledError:
pass
self._typing_tasks[room_id] = asyncio.create_task(loop())
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
if task := self._typing_tasks.pop(room_id, None):
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
if clear_typing:
await self._set_typing(room_id, False)
async def _sync_loop(self) -> None:
while self._running:
try:
await self.client.sync_forever(timeout=30000, full_state=True)
except asyncio.CancelledError:
break
except Exception:
await asyncio.sleep(2)
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
if self.is_allowed(event.sender):
await self.client.join(room.room_id)
def _is_direct_room(self, room: MatrixRoom) -> bool:
count = getattr(room, "member_count", None)
return isinstance(count, int) and count <= 2
def _is_bot_mentioned(self, event: RoomMessage) -> bool:
"""Check m.mentions payload for bot mention."""
source = getattr(event, "source", None)
if not isinstance(source, dict):
return False
mentions = (source.get("content") or {}).get("m.mentions")
if not isinstance(mentions, dict):
return False
user_ids = mentions.get("user_ids")
if isinstance(user_ids, list) and self.config.user_id in user_ids:
return True
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
"""Apply sender and room policy checks."""
if not self.is_allowed(event.sender):
return False
if self._is_direct_room(room):
return True
policy = self.config.group_policy
if policy == "open":
return True
if policy == "allowlist":
return room.room_id in (self.config.group_allow_from or [])
if policy == "mention":
return self._is_bot_mentioned(event)
return False
def _media_dir(self) -> Path:
d = get_data_dir() / "media" / "matrix"
d.mkdir(parents=True, exist_ok=True)
return d
@staticmethod
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
source = getattr(event, "source", None)
if not isinstance(source, dict):
return {}
content = source.get("content")
return content if isinstance(content, dict) else {}
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
relates_to = self._event_source_content(event).get("m.relates_to")
if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
return None
root_id = relates_to.get("event_id")
return root_id if isinstance(root_id, str) and root_id else None
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
if not (root_id := self._event_thread_root_id(event)):
return None
meta: dict[str, str] = {"thread_root_event_id": root_id}
if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
meta["thread_reply_to_event_id"] = reply_to
return meta
@staticmethod
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
if not metadata:
return None
root_id = metadata.get("thread_root_event_id")
if not isinstance(root_id, str) or not root_id:
return None
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
if not isinstance(reply_to, str) or not reply_to:
return None
return {"rel_type": "m.thread", "event_id": root_id,
"m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
msgtype = self._event_source_content(event).get("msgtype")
return _MSGTYPE_MAP.get(msgtype, "file")
@staticmethod
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
return (isinstance(getattr(event, "key", None), dict)
and isinstance(getattr(event, "hashes", None), dict)
and isinstance(getattr(event, "iv", None), str))
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
info = self._event_source_content(event).get("info")
size = info.get("size") if isinstance(info, dict) else None
return size if isinstance(size, int) and size >= 0 else None
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
info = self._event_source_content(event).get("info")
if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
return m
m = getattr(event, "mimetype", None)
return m if isinstance(m, str) and m else None
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
body = getattr(event, "body", None)
if isinstance(body, str) and body.strip():
if candidate := safe_filename(Path(body).name):
return candidate
return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
filename: str, mime: str | None) -> Path:
safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
suffix = Path(safe_name).suffix
if not suffix and mime:
if guessed := mimetypes.guess_extension(mime, strict=False):
safe_name, suffix = f"{safe_name}{guessed}", guessed
stem = (Path(safe_name).stem or attachment_type)[:72]
suffix = suffix[:16]
event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
event_prefix = (event_id[:24] or "evt").strip("_")
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
if not self.client:
return None
response = await self.client.download(mxc=mxc_url)
if isinstance(response, DownloadError):
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
return None
body = getattr(response, "body", None)
if isinstance(body, (bytes, bytearray)):
return bytes(body)
if isinstance(response, MemoryDownloadResponse):
return bytes(response.body)
if isinstance(body, (str, Path)):
path = Path(body)
if path.is_file():
try:
return path.read_bytes()
except OSError:
return None
return None
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
key = key_obj.get("k") if isinstance(key_obj, dict) else None
sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
if not all(isinstance(v, str) for v in (key, sha256, iv)):
return None
try:
return decrypt_attachment(ciphertext, key, sha256, iv)
except (EncryptionError, ValueError, TypeError):
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
return None
async def _fetch_media_attachment(
self, room: MatrixRoom, event: MatrixMediaEvent,
) -> tuple[dict[str, Any] | None, str]:
"""Download, decrypt if needed, and persist a Matrix attachment."""
atype = self._event_attachment_type(event)
mime = self._event_mime(event)
filename = self._event_filename(event, atype)
mxc_url = getattr(event, "url", None)
fail = _ATTACH_FAILED.format(filename)
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
return None, fail
limit_bytes = await self._effective_media_limit_bytes()
declared = self._event_declared_size_bytes(event)
if declared is not None and declared > limit_bytes:
return None, _ATTACH_TOO_LARGE.format(filename)
downloaded = await self._download_media_bytes(mxc_url)
if downloaded is None:
return None, fail
encrypted = self._is_encrypted_media_event(event)
data = downloaded
if encrypted:
if (data := self._decrypt_media_bytes(event, downloaded)) is None:
return None, fail
if len(data) > limit_bytes:
return None, _ATTACH_TOO_LARGE.format(filename)
path = self._build_attachment_path(event, atype, filename, mime)
try:
path.write_bytes(data)
except OSError:
return None, fail
attachment = {
"type": atype, "mime": mime, "filename": filename,
"event_id": str(getattr(event, "event_id", "") or ""),
"encrypted": encrypted, "size_bytes": len(data),
"path": str(path), "mxc_url": mxc_url,
}
return attachment, _ATTACH_MARKER.format(path)
def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
"""Build common metadata for text and media handlers."""
meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
if isinstance(eid := getattr(event, "event_id", None), str) and eid:
meta["event_id"] = eid
if thread := self._thread_metadata(event):
meta.update(thread)
return meta
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
if event.sender == self.config.user_id or not self._should_process_message(room, event):
return
await self._start_typing_keepalive(room.room_id)
try:
await self._handle_message(
sender_id=event.sender, chat_id=room.room_id,
content=event.body, metadata=self._base_metadata(room, event),
)
except Exception:
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
raise
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
if event.sender == self.config.user_id or not self._should_process_message(room, event):
return
attachment, marker = await self._fetch_media_attachment(room, event)
parts: list[str] = []
if isinstance(body := getattr(event, "body", None), str) and body.strip():
parts.append(body.strip())
if marker:
parts.append(marker)
await self._start_typing_keepalive(room.room_id)
try:
meta = self._base_metadata(room, event)
meta["attachments"] = []
if attachment:
meta["attachments"] = [attachment]
await self._handle_message(
sender_id=event.sender, chat_id=room.room_id,
content="\n".join(parts),
media=[attachment["path"]] if attachment else [],
metadata=meta,
)
except Exception:
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
raise

View File

@@ -322,7 +322,7 @@ class MochatChannel(BaseChannel):
await self._api_send("/api/claw/sessions/send", "sessionId", target.id, await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
content, msg.reply_to) content, msg.reply_to)
except Exception as e: except Exception as e:
logger.error(f"Failed to send Mochat message: {e}") logger.error("Failed to send Mochat message: {}", e)
# ---- config / init helpers --------------------------------------------- # ---- config / init helpers ---------------------------------------------
@@ -380,7 +380,7 @@ class MochatChannel(BaseChannel):
@client.event @client.event
async def connect_error(data: Any) -> None: async def connect_error(data: Any) -> None:
logger.error(f"Mochat websocket connect error: {data}") logger.error("Mochat websocket connect error: {}", data)
@client.on("claw.session.events") @client.on("claw.session.events")
async def on_session_events(payload: dict[str, Any]) -> None: async def on_session_events(payload: dict[str, Any]) -> None:
@@ -407,7 +407,7 @@ class MochatChannel(BaseChannel):
) )
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to connect Mochat websocket: {e}") logger.error("Failed to connect Mochat websocket: {}", e)
try: try:
await client.disconnect() await client.disconnect()
except Exception: except Exception:
@@ -444,7 +444,7 @@ class MochatChannel(BaseChannel):
"limit": self.config.watch_limit, "limit": self.config.watch_limit,
}) })
if not ack.get("result"): if not ack.get("result"):
logger.error(f"Mochat subscribeSessions failed: {ack.get('message', 'unknown error')}") logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
return False return False
data = ack.get("data") data = ack.get("data")
@@ -466,7 +466,7 @@ class MochatChannel(BaseChannel):
return True return True
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids}) ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
if not ack.get("result"): if not ack.get("result"):
logger.error(f"Mochat subscribePanels failed: {ack.get('message', 'unknown error')}") logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
return False return False
return True return True
@@ -488,7 +488,7 @@ class MochatChannel(BaseChannel):
try: try:
await self._refresh_targets(subscribe_new=self._ws_ready) await self._refresh_targets(subscribe_new=self._ws_ready)
except Exception as e: except Exception as e:
logger.warning(f"Mochat refresh failed: {e}") logger.warning("Mochat refresh failed: {}", e)
if self._fallback_mode: if self._fallback_mode:
await self._ensure_fallback_workers() await self._ensure_fallback_workers()
@@ -502,7 +502,7 @@ class MochatChannel(BaseChannel):
try: try:
response = await self._post_json("/api/claw/sessions/list", {}) response = await self._post_json("/api/claw/sessions/list", {})
except Exception as e: except Exception as e:
logger.warning(f"Mochat listSessions failed: {e}") logger.warning("Mochat listSessions failed: {}", e)
return return
sessions = response.get("sessions") sessions = response.get("sessions")
@@ -536,7 +536,7 @@ class MochatChannel(BaseChannel):
try: try:
response = await self._post_json("/api/claw/groups/get", {}) response = await self._post_json("/api/claw/groups/get", {})
except Exception as e: except Exception as e:
logger.warning(f"Mochat getWorkspaceGroup failed: {e}") logger.warning("Mochat getWorkspaceGroup failed: {}", e)
return return
raw_panels = response.get("panels") raw_panels = response.get("panels")
@@ -598,7 +598,7 @@ class MochatChannel(BaseChannel):
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.warning(f"Mochat watch fallback error ({session_id}): {e}") logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0)) await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
async def _panel_poll_worker(self, panel_id: str) -> None: async def _panel_poll_worker(self, panel_id: str) -> None:
@@ -625,7 +625,7 @@ class MochatChannel(BaseChannel):
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.warning(f"Mochat panel polling error ({panel_id}): {e}") logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
await asyncio.sleep(sleep_s) await asyncio.sleep(sleep_s)
# ---- inbound event processing ------------------------------------------ # ---- inbound event processing ------------------------------------------
@@ -836,7 +836,7 @@ class MochatChannel(BaseChannel):
try: try:
data = json.loads(self._cursor_path.read_text("utf-8")) data = json.loads(self._cursor_path.read_text("utf-8"))
except Exception as e: except Exception as e:
logger.warning(f"Failed to read Mochat cursor file: {e}") logger.warning("Failed to read Mochat cursor file: {}", e)
return return
cursors = data.get("cursors") if isinstance(data, dict) else None cursors = data.get("cursors") if isinstance(data, dict) else None
if isinstance(cursors, dict): if isinstance(cursors, dict):
@@ -852,7 +852,7 @@ class MochatChannel(BaseChannel):
"cursors": self._session_cursor, "cursors": self._session_cursor,
}, ensure_ascii=False, indent=2) + "\n", "utf-8") }, ensure_ascii=False, indent=2) + "\n", "utf-8")
except Exception as e: except Exception as e:
logger.warning(f"Failed to save Mochat cursor file: {e}") logger.warning("Failed to save Mochat cursor file: {}", e)
# ---- HTTP helpers ------------------------------------------------------ # ---- HTTP helpers ------------------------------------------------------

View File

@@ -13,16 +13,17 @@ from nanobot.config.schema import QQConfig
try: try:
import botpy import botpy
from botpy.message import C2CMessage from botpy.message import C2CMessage, GroupMessage
QQ_AVAILABLE = True QQ_AVAILABLE = True
except ImportError: except ImportError:
QQ_AVAILABLE = False QQ_AVAILABLE = False
botpy = None botpy = None
C2CMessage = None C2CMessage = None
GroupMessage = None
if TYPE_CHECKING: if TYPE_CHECKING:
from botpy.message import C2CMessage from botpy.message import C2CMessage, GroupMessage
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
@@ -31,16 +32,20 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
class _Bot(botpy.Client): class _Bot(botpy.Client):
def __init__(self): def __init__(self):
super().__init__(intents=intents) # Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
super().__init__(intents=intents, ext_handlers=False)
async def on_ready(self): async def on_ready(self):
logger.info(f"QQ bot ready: {self.robot.name}") logger.info("QQ bot ready: {}", self.robot.name)
async def on_c2c_message_create(self, message: "C2CMessage"): async def on_c2c_message_create(self, message: "C2CMessage"):
await channel._on_message(message) await channel._on_message(message, is_group=False)
async def on_group_at_message_create(self, message: "GroupMessage"):
await channel._on_message(message, is_group=True)
async def on_direct_message_create(self, message): async def on_direct_message_create(self, message):
await channel._on_message(message) await channel._on_message(message, is_group=False)
return _Bot return _Bot
@@ -55,7 +60,8 @@ class QQChannel(BaseChannel):
self.config: QQConfig = config self.config: QQConfig = config
self._client: "botpy.Client | None" = None self._client: "botpy.Client | None" = None
self._processed_ids: deque = deque(maxlen=1000) self._processed_ids: deque = deque(maxlen=1000)
self._bot_task: asyncio.Task | None = None self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
self._chat_type_cache: dict[str, str] = {}
async def start(self) -> None: async def start(self) -> None:
"""Start the QQ bot.""" """Start the QQ bot."""
@@ -70,9 +76,8 @@ class QQChannel(BaseChannel):
self._running = True self._running = True
BotClass = _make_bot_class(self) BotClass = _make_bot_class(self)
self._client = BotClass() self._client = BotClass()
logger.info("QQ bot started (C2C & Group supported)")
self._bot_task = asyncio.create_task(self._run_bot()) await self._run_bot()
logger.info("QQ bot started (C2C private message)")
async def _run_bot(self) -> None: async def _run_bot(self) -> None:
"""Run the bot connection with auto-reconnect.""" """Run the bot connection with auto-reconnect."""
@@ -80,7 +85,7 @@ class QQChannel(BaseChannel):
try: try:
await self._client.start(appid=self.config.app_id, secret=self.config.secret) await self._client.start(appid=self.config.app_id, secret=self.config.secret)
except Exception as e: except Exception as e:
logger.warning(f"QQ bot error: {e}") logger.warning("QQ bot error: {}", e)
if self._running: if self._running:
logger.info("Reconnecting QQ bot in 5 seconds...") logger.info("Reconnecting QQ bot in 5 seconds...")
await asyncio.sleep(5) await asyncio.sleep(5)
@@ -88,11 +93,10 @@ class QQChannel(BaseChannel):
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the QQ bot.""" """Stop the QQ bot."""
self._running = False self._running = False
if self._bot_task: if self._client:
self._bot_task.cancel()
try: try:
await self._bot_task await self._client.close()
except asyncio.CancelledError: except Exception:
pass pass
logger.info("QQ bot stopped") logger.info("QQ bot stopped")
@@ -101,16 +105,31 @@ class QQChannel(BaseChannel):
if not self._client: if not self._client:
logger.warning("QQ client not initialized") logger.warning("QQ client not initialized")
return return
try: try:
msg_id = msg.metadata.get("message_id")
self._msg_seq += 1
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
if msg_type == "group":
await self._client.api.post_group_message(
group_openid=msg.chat_id,
msg_type=0,
content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq,
)
else:
await self._client.api.post_c2c_message( await self._client.api.post_c2c_message(
openid=msg.chat_id, openid=msg.chat_id,
msg_type=0, msg_type=0,
content=msg.content, content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq,
) )
except Exception as e: except Exception as e:
logger.error(f"Error sending QQ message: {e}") logger.error("Error sending QQ message: {}", e)
async def _on_message(self, data: "C2CMessage") -> None: async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
"""Handle incoming message from QQ.""" """Handle incoming message from QQ."""
try: try:
# Dedup by message ID # Dedup by message ID
@@ -118,17 +137,24 @@ class QQChannel(BaseChannel):
return return
self._processed_ids.append(data.id) self._processed_ids.append(data.id)
author = data.author
user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
content = (data.content or "").strip() content = (data.content or "").strip()
if not content: if not content:
return return
if is_group:
chat_id = data.group_openid
user_id = data.author.member_openid
self._chat_type_cache[chat_id] = "group"
else:
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
user_id = chat_id
self._chat_type_cache[chat_id] = "c2c"
await self._handle_message( await self._handle_message(
sender_id=user_id, sender_id=user_id,
chat_id=user_id, chat_id=chat_id,
content=content, content=content,
metadata={"message_id": data.id}, metadata={"message_id": data.id},
) )
except Exception as e: except Exception:
logger.error(f"Error handling QQ message: {e}") logger.exception("Error handling QQ message")

View File

@@ -5,10 +5,11 @@ import re
from typing import Any from typing import Any
from loguru import logger from loguru import logger
from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
from slackify_markdown import slackify_markdown
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@@ -34,7 +35,7 @@ class SlackChannel(BaseChannel):
logger.error("Slack bot/app token not configured") logger.error("Slack bot/app token not configured")
return return
if self.config.mode != "socket": if self.config.mode != "socket":
logger.error(f"Unsupported Slack mode: {self.config.mode}") logger.error("Unsupported Slack mode: {}", self.config.mode)
return return
self._running = True self._running = True
@@ -51,9 +52,9 @@ class SlackChannel(BaseChannel):
try: try:
auth = await self._web_client.auth_test() auth = await self._web_client.auth_test()
self._bot_user_id = auth.get("user_id") self._bot_user_id = auth.get("user_id")
logger.info(f"Slack bot connected as {self._bot_user_id}") logger.info("Slack bot connected as {}", self._bot_user_id)
except Exception as e: except Exception as e:
logger.warning(f"Slack auth_test failed: {e}") logger.warning("Slack auth_test failed: {}", e)
logger.info("Starting Slack Socket Mode client...") logger.info("Starting Slack Socket Mode client...")
await self._socket_client.connect() await self._socket_client.connect()
@@ -68,7 +69,7 @@ class SlackChannel(BaseChannel):
try: try:
await self._socket_client.close() await self._socket_client.close()
except Exception as e: except Exception as e:
logger.warning(f"Slack socket close failed: {e}") logger.warning("Slack socket close failed: {}", e)
self._socket_client = None self._socket_client = None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
@@ -81,14 +82,28 @@ class SlackChannel(BaseChannel):
thread_ts = slack_meta.get("thread_ts") thread_ts = slack_meta.get("thread_ts")
channel_type = slack_meta.get("channel_type") channel_type = slack_meta.get("channel_type")
# Only reply in thread for channel/group messages; DMs don't use threads # Only reply in thread for channel/group messages; DMs don't use threads
use_thread = thread_ts and channel_type != "im" thread_ts_param = thread_ts if use_thread else None
# Slack rejects empty text payloads. Keep media-only messages media-only,
# but send a single blank message when the bot has no text or files to send.
if msg.content or not (msg.media or []):
await self._web_client.chat_postMessage( await self._web_client.chat_postMessage(
channel=msg.chat_id, channel=msg.chat_id,
text=msg.content or "<empty_response_from_the_bot>", text=self._to_mrkdwn(msg.content) if msg.content else " ",
thread_ts=thread_ts if use_thread else None, thread_ts=thread_ts_param,
)
for media_path in msg.media or []:
try:
await self._web_client.files_upload_v2(
channel=msg.chat_id,
file=media_path,
thread_ts=thread_ts_param,
) )
except Exception as e: except Exception as e:
logger.error(f"Error sending Slack message: {e}") logger.error("Failed to upload file {}: {}", media_path, e)
except Exception as e:
logger.error("Error sending Slack message: {}", e)
async def _on_socket_request( async def _on_socket_request(
self, self,
@@ -150,18 +165,24 @@ class SlackChannel(BaseChannel):
text = self._strip_bot_mention(text) text = self._strip_bot_mention(text)
thread_ts = event.get("thread_ts") or event.get("ts") thread_ts = event.get("thread_ts")
if self.config.reply_in_thread and not thread_ts:
thread_ts = event.get("ts")
# Add :eyes: reaction to the triggering message (best-effort) # Add :eyes: reaction to the triggering message (best-effort)
try: try:
if self._web_client and event.get("ts"): if self._web_client and event.get("ts"):
await self._web_client.reactions_add( await self._web_client.reactions_add(
channel=chat_id, channel=chat_id,
name="eyes", name=self.config.react_emoji,
timestamp=event.get("ts"), timestamp=event.get("ts"),
) )
except Exception as e: except Exception as e:
logger.debug(f"Slack reactions_add failed: {e}") logger.debug("Slack reactions_add failed: {}", e)
# Thread-scoped session key for channel/group messages
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
try:
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=chat_id, chat_id=chat_id,
@@ -171,9 +192,12 @@ class SlackChannel(BaseChannel):
"event": event, "event": event,
"thread_ts": thread_ts, "thread_ts": thread_ts,
"channel_type": channel_type, "channel_type": channel_type,
}
}, },
},
session_key=session_key,
) )
except Exception:
logger.exception("Error handling Slack message from {}", sender_id)
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
if channel_type == "im": if channel_type == "im":
@@ -203,3 +227,55 @@ class SlackChannel(BaseChannel):
if not text or not self._bot_user_id: if not text or not self._bot_user_id:
return text return text
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip() return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
_CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
_INLINE_CODE_RE = re.compile(r"`[^`]+`")
_LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
_LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
_BARE_URL_RE = re.compile(r"(?<![|<])(https?://\S+)")
@classmethod
def _to_mrkdwn(cls, text: str) -> str:
"""Convert Markdown to Slack mrkdwn, including tables."""
if not text:
return ""
text = cls._TABLE_RE.sub(cls._convert_table, text)
return cls._fixup_mrkdwn(slackify_markdown(text))
@classmethod
def _fixup_mrkdwn(cls, text: str) -> str:
"""Fix markdown artifacts that slackify_markdown misses."""
code_blocks: list[str] = []
def _save_code(m: re.Match) -> str:
code_blocks.append(m.group(0))
return f"\x00CB{len(code_blocks) - 1}\x00"
text = cls._CODE_FENCE_RE.sub(_save_code, text)
text = cls._INLINE_CODE_RE.sub(_save_code, text)
text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&amp;", "&"), text)
for i, block in enumerate(code_blocks):
text = text.replace(f"\x00CB{i}\x00", block)
return text
@staticmethod
def _convert_table(match: re.Match) -> str:
"""Convert a Markdown table to a Slack-readable list."""
lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()]
if len(lines) < 2:
return match.group(0)
headers = [h.strip() for h in lines[0].strip("|").split("|")]
start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1
rows: list[str] = []
for line in lines[start:]:
cells = [c.strip() for c in line.strip("|").split("|")]
cells = (cells + [""] * len(headers))[: len(headers)]
parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]]
if parts:
rows.append(" · ".join(parts))
return "\n".join(rows)

View File

@@ -4,15 +4,62 @@ from __future__ import annotations
import asyncio import asyncio
import re import re
import time
import unicodedata
from loguru import logger from loguru import logger
from telegram import BotCommand, Update from telegram import BotCommand, ReplyParameters, Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import TelegramConfig from nanobot.config.schema import TelegramConfig
from nanobot.utils.helpers import split_message
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
def _strip_md(s: str) -> str:
"""Strip markdown inline formatting from text."""
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
s = re.sub(r'__(.+?)__', r'\1', s)
s = re.sub(r'~~(.+?)~~', r'\1', s)
s = re.sub(r'`([^`]+)`', r'\1', s)
return s.strip()
def _render_table_box(table_lines: list[str]) -> str:
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
def dw(s: str) -> int:
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
rows: list[list[str]] = []
has_sep = False
for line in table_lines:
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
has_sep = True
continue
rows.append(cells)
if not rows or not has_sep:
return '\n'.join(table_lines)
ncols = max(len(r) for r in rows)
for r in rows:
r.extend([''] * (ncols - len(r)))
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
def dr(cells: list[str]) -> str:
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
out = [dr(rows[0])]
out.append(' '.join('' * w for w in widths))
for row in rows[1:]:
out.append(dr(row))
return '\n'.join(out)
def _markdown_to_telegram_html(text: str) -> str: def _markdown_to_telegram_html(text: str) -> str:
@@ -30,6 +77,27 @@ def _markdown_to_telegram_html(text: str) -> str:
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text) text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
lines = text.split('\n')
rebuilt: list[str] = []
li = 0
while li < len(lines):
if re.match(r'^\s*\|.+\|', lines[li]):
tbl: list[str] = []
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
tbl.append(lines[li])
li += 1
box = _render_table_box(tbl)
if box != '\n'.join(tbl):
code_blocks.append(box)
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
else:
rebuilt.extend(tbl)
else:
rebuilt.append(lines[li])
li += 1
text = '\n'.join(rebuilt)
# 2. Extract and protect inline code # 2. Extract and protect inline code
inline_codes: list[str] = [] inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str: def save_inline_code(m: re.Match) -> str:
@@ -91,6 +159,7 @@ class TelegramChannel(BaseChannel):
BOT_COMMANDS = [ BOT_COMMANDS = [
BotCommand("start", "Start the bot"), BotCommand("start", "Start the bot"),
BotCommand("new", "Start a new conversation"), BotCommand("new", "Start a new conversation"),
BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"), BotCommand("help", "Show available commands"),
] ]
@@ -106,6 +175,28 @@ class TelegramChannel(BaseChannel):
self._app: Application | None = None self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {}
self._message_threads: dict[tuple[str, int], int] = {}
def is_allowed(self, sender_id: str) -> bool:
"""Preserve Telegram's legacy id|username allowlist matching."""
if super().is_allowed(sender_id):
return True
allow_list = getattr(self.config, "allow_from", [])
if not allow_list or "*" in allow_list:
return False
sender_str = str(sender_id)
if sender_str.count("|") != 1:
return False
sid, username = sender_str.split("|", 1)
if not sid.isdigit() or not username:
return False
return sid in allow_list or username in allow_list
async def start(self) -> None: async def start(self) -> None:
"""Start the Telegram bot with long polling.""" """Start the Telegram bot with long polling."""
@@ -116,17 +207,22 @@ class TelegramChannel(BaseChannel):
self._running = True self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs # Build the application with larger connection pool to avoid pool-timeout on long runs
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) req = HTTPXRequest(
connection_pool_size=16,
pool_timeout=5.0,
connect_timeout=30.0,
read_timeout=30.0,
proxy=self.config.proxy if self.config.proxy else None,
)
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
if self.config.proxy:
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build() self._app = builder.build()
self._app.add_error_handler(self._on_error) self._app.add_error_handler(self._on_error)
# Add command handlers # Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command)) self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("help", self._forward_command)) self._app.add_handler(CommandHandler("stop", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents # Add message handler for text, photos, voice, documents
self._app.add_handler( self._app.add_handler(
@@ -145,13 +241,13 @@ class TelegramChannel(BaseChannel):
# Get bot info and register command menu # Get bot info and register command menu
bot_info = await self._app.bot.get_me() bot_info = await self._app.bot.get_me()
logger.info(f"Telegram bot @{bot_info.username} connected") logger.info("Telegram bot @{} connected", bot_info.username)
try: try:
await self._app.bot.set_my_commands(self.BOT_COMMANDS) await self._app.bot.set_my_commands(self.BOT_COMMANDS)
logger.debug("Telegram bot commands registered") logger.debug("Telegram bot commands registered")
except Exception as e: except Exception as e:
logger.warning(f"Failed to register bot commands: {e}") logger.warning("Failed to register bot commands: {}", e)
# Start polling (this runs until stopped) # Start polling (this runs until stopped)
await self._app.updater.start_polling( await self._app.updater.start_polling(
@@ -171,6 +267,11 @@ class TelegramChannel(BaseChannel):
for chat_id in list(self._typing_tasks): for chat_id in list(self._typing_tasks):
self._stop_typing(chat_id) self._stop_typing(chat_id)
for task in self._media_group_tasks.values():
task.cancel()
self._media_group_tasks.clear()
self._media_group_buffers.clear()
if self._app: if self._app:
logger.info("Stopping Telegram bot...") logger.info("Stopping Telegram bot...")
await self._app.updater.stop() await self._app.updater.stop()
@@ -178,37 +279,137 @@ class TelegramChannel(BaseChannel):
await self._app.shutdown() await self._app.shutdown()
self._app = None self._app = None
@staticmethod
def _get_media_type(path: str) -> str:
"""Guess media type from file extension."""
ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
if ext in ("jpg", "jpeg", "png", "gif", "webp"):
return "photo"
if ext == "ogg":
return "voice"
if ext in ("mp3", "m4a", "wav", "aac"):
return "audio"
return "document"
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Telegram.""" """Send a message through Telegram."""
if not self._app: if not self._app:
logger.warning("Telegram bot not running") logger.warning("Telegram bot not running")
return return
# Stop typing indicator for this chat # Only stop typing indicator for final responses
if not msg.metadata.get("_progress", False):
self._stop_typing(msg.chat_id) self._stop_typing(msg.chat_id)
try: try:
# chat_id should be the Telegram chat ID (integer)
chat_id = int(msg.chat_id) chat_id = int(msg.chat_id)
# Convert markdown to Telegram HTML except ValueError:
html_content = _markdown_to_telegram_html(msg.content) logger.error("Invalid chat_id: {}", msg.chat_id)
return
reply_to_message_id = msg.metadata.get("message_id")
message_thread_id = msg.metadata.get("message_thread_id")
if message_thread_id is None and reply_to_message_id is not None:
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
thread_kwargs = {}
if message_thread_id is not None:
thread_kwargs["message_thread_id"] = message_thread_id
reply_params = None
if self.config.reply_to_message:
if reply_to_message_id:
reply_params = ReplyParameters(
message_id=reply_to_message_id,
allow_sending_without_reply=True
)
# Send media files
for media_path in (msg.media or []):
try:
media_type = self._get_media_type(media_path)
sender = {
"photo": self._app.bot.send_photo,
"voice": self._app.bot.send_voice,
"audio": self._app.bot.send_audio,
}.get(media_type, self._app.bot.send_document)
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
with open(media_path, 'rb') as f:
await sender(
chat_id=chat_id,
**{param: f},
reply_parameters=reply_params,
**thread_kwargs,
)
except Exception as e:
filename = media_path.rsplit("/", 1)[-1]
logger.error("Failed to send media {}: {}", media_path, e)
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=html_content, text=f"[Failed to send: {filename}]",
parse_mode="HTML" reply_parameters=reply_params,
**thread_kwargs,
)
# Send text content
if msg.content and msg.content != "[empty message]":
is_progress = msg.metadata.get("_progress", False)
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
# Final response: simulate streaming via draft, then persist
if not is_progress:
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
else:
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
async def _send_text(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Send a plain text message with HTML fallback."""
try:
html = _markdown_to_telegram_html(text)
await self._app.bot.send_message(
chat_id=chat_id, text=html, parse_mode="HTML",
reply_parameters=reply_params,
**(thread_kwargs or {}),
) )
except ValueError:
logger.error(f"Invalid chat_id: {msg.chat_id}")
except Exception as e: except Exception as e:
# Fallback to plain text if HTML parsing fails logger.warning("HTML parse failed, falling back to plain text: {}", e)
logger.warning(f"HTML parse failed, falling back to plain text: {e}")
try: try:
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=int(msg.chat_id), chat_id=chat_id,
text=msg.content text=text,
reply_parameters=reply_params,
**(thread_kwargs or {}),
) )
except Exception as e2: except Exception as e2:
logger.error(f"Error sending Telegram message: {e2}") logger.error("Error sending Telegram message: {}", e2)
async def _send_with_streaming(
self,
chat_id: int,
text: str,
reply_params=None,
thread_kwargs: dict | None = None,
) -> None:
"""Simulate streaming via send_message_draft, then persist with send_message."""
draft_id = int(time.time() * 1000) % (2**31)
try:
step = max(len(text) // 8, 40)
for i in range(step, len(text), step):
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text[:i],
)
await asyncio.sleep(0.04)
await self._app.bot.send_message_draft(
chat_id=chat_id, draft_id=draft_id, text=text,
)
await asyncio.sleep(0.15)
except Exception:
pass
await self._send_text(chat_id, text, reply_params, thread_kwargs)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command.""" """Handle /start command."""
@@ -222,14 +423,67 @@ class TelegramChannel(BaseChannel):
"Type /help to see available commands." "Type /help to see available commands."
) )
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /help command, bypassing ACL so all users can access it."""
if not update.message:
return
await update.message.reply_text(
"🐈 nanobot commands:\n"
"/new — Start a new conversation\n"
"/stop — Stop the current task\n"
"/help — Show available commands"
)
@staticmethod
def _sender_id(user) -> str:
"""Build sender_id with username for allowlist matching."""
sid = str(user.id)
return f"{sid}|{user.username}" if user.username else sid
@staticmethod
def _derive_topic_session_key(message) -> str | None:
"""Derive topic-scoped session key for non-private Telegram chats."""
message_thread_id = getattr(message, "message_thread_id", None)
if message.chat.type == "private" or message_thread_id is None:
return None
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
@staticmethod
def _build_message_metadata(message, user) -> dict:
"""Build common Telegram inbound metadata payload."""
return {
"message_id": message.message_id,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private",
"message_thread_id": getattr(message, "message_thread_id", None),
"is_forum": bool(getattr(message.chat, "is_forum", False)),
}
def _remember_thread_context(self, message) -> None:
"""Cache topic thread id by chat/message id for follow-up replies."""
message_thread_id = getattr(message, "message_thread_id", None)
if message_thread_id is None:
return
key = (str(message.chat_id), message.message_id)
self._message_threads[key] = message_thread_id
if len(self._message_threads) > 1000:
self._message_threads.pop(next(iter(self._message_threads)))
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Forward slash commands to the bus for unified handling in AgentLoop.""" """Forward slash commands to the bus for unified handling in AgentLoop."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
return return
message = update.message
user = update.effective_user
self._remember_thread_context(message)
await self._handle_message( await self._handle_message(
sender_id=str(update.effective_user.id), sender_id=self._sender_id(user),
chat_id=str(update.message.chat_id), chat_id=str(message.chat_id),
content=update.message.text, content=message.text,
metadata=self._build_message_metadata(message, user),
session_key=self._derive_topic_session_key(message),
) )
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
@@ -240,11 +494,8 @@ class TelegramChannel(BaseChannel):
message = update.message message = update.message
user = update.effective_user user = update.effective_user
chat_id = message.chat_id chat_id = message.chat_id
sender_id = self._sender_id(user)
# Use stable numeric ID, but keep username for allowlist compatibility self._remember_thread_context(message)
sender_id = str(user.id)
if user.username:
sender_id = f"{sender_id}|{user.username}"
# Store chat_id for replies # Store chat_id for replies
self._chat_ids[sender_id] = chat_id self._chat_ids[sender_id] = chat_id
@@ -280,8 +531,11 @@ class TelegramChannel(BaseChannel):
if media_file and self._app: if media_file and self._app:
try: try:
file = await self._app.bot.get_file(media_file.file_id) file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None)) ext = self._get_extension(
media_type,
getattr(media_file, 'mime_type', None),
getattr(media_file, 'file_name', None),
)
# Save to workspace/media/ # Save to workspace/media/
from pathlib import Path from pathlib import Path
media_dir = Path.home() / ".nanobot" / "media" media_dir = Path.home() / ".nanobot" / "media"
@@ -298,23 +552,44 @@ class TelegramChannel(BaseChannel):
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key) transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
transcription = await transcriber.transcribe(file_path) transcription = await transcriber.transcribe(file_path)
if transcription: if transcription:
logger.info(f"Transcribed {media_type}: {transcription[:50]}...") logger.info("Transcribed {}: {}...", media_type, transcription[:50])
content_parts.append(f"[transcription: {transcription}]") content_parts.append(f"[transcription: {transcription}]")
else: else:
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
else: else:
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
logger.debug(f"Downloaded {media_type} to {file_path}") logger.debug("Downloaded {} to {}", media_type, file_path)
except Exception as e: except Exception as e:
logger.error(f"Failed to download media: {e}") logger.error("Failed to download media: {}", e)
content_parts.append(f"[{media_type}: download failed]") content_parts.append(f"[{media_type}: download failed]")
content = "\n".join(content_parts) if content_parts else "[empty message]" content = "\n".join(content_parts) if content_parts else "[empty message]"
logger.debug(f"Telegram message from {sender_id}: {content[:50]}...") logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id) str_chat_id = str(chat_id)
metadata = self._build_message_metadata(message, user)
session_key = self._derive_topic_session_key(message)
# Telegram media groups: buffer briefly, forward as one aggregated turn.
if media_group_id := getattr(message, "media_group_id", None):
key = f"{str_chat_id}:{media_group_id}"
if key not in self._media_group_buffers:
self._media_group_buffers[key] = {
"sender_id": sender_id, "chat_id": str_chat_id,
"contents": [], "media": [],
"metadata": metadata,
"session_key": session_key,
}
self._start_typing(str_chat_id)
buf = self._media_group_buffers[key]
if content and content != "[empty message]":
buf["contents"].append(content)
buf["media"].extend(media_paths)
if key not in self._media_group_tasks:
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
return
# Start typing indicator before processing # Start typing indicator before processing
self._start_typing(str_chat_id) self._start_typing(str_chat_id)
@@ -325,15 +600,26 @@ class TelegramChannel(BaseChannel):
chat_id=str_chat_id, chat_id=str_chat_id,
content=content, content=content,
media=media_paths, media=media_paths,
metadata={ metadata=metadata,
"message_id": message.message_id, session_key=session_key,
"user_id": user.id,
"username": user.username,
"first_name": user.first_name,
"is_group": message.chat.type != "private"
}
) )
async def _flush_media_group(self, key: str) -> None:
"""Wait briefly, then forward buffered media-group as one turn."""
try:
await asyncio.sleep(0.6)
if not (buf := self._media_group_buffers.pop(key, None)):
return
content = "\n".join(buf["contents"]) or "[empty message]"
await self._handle_message(
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
content=content, media=list(dict.fromkeys(buf["media"])),
metadata=buf["metadata"],
session_key=buf.get("session_key"),
)
finally:
self._media_group_tasks.pop(key, None)
def _start_typing(self, chat_id: str) -> None: def _start_typing(self, chat_id: str) -> None:
"""Start sending 'typing...' indicator for a chat.""" """Start sending 'typing...' indicator for a chat."""
# Cancel any existing typing task for this chat # Cancel any existing typing task for this chat
@@ -355,14 +641,19 @@ class TelegramChannel(BaseChannel):
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
logger.debug(f"Typing indicator stopped for {chat_id}: {e}") logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Log polling / handler errors instead of silently swallowing them.""" """Log polling / handler errors instead of silently swallowing them."""
logger.error(f"Telegram error: {context.error}") logger.error("Telegram error: {}", context.error)
def _get_extension(self, media_type: str, mime_type: str | None) -> str: def _get_extension(
"""Get file extension based on media type.""" self,
media_type: str,
mime_type: str | None,
filename: str | None = None,
) -> str:
"""Get file extension based on media type or original filename."""
if mime_type: if mime_type:
ext_map = { ext_map = {
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", "image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
@@ -372,4 +663,12 @@ class TelegramChannel(BaseChannel):
return ext_map[mime_type] return ext_map[mime_type]
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""} type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
return type_map.get(media_type, "") if ext := type_map.get(media_type, ""):
return ext
if filename:
from pathlib import Path
return "".join(Path(filename).suffixes)
return ""

View File

@@ -2,7 +2,8 @@
import asyncio import asyncio
import json import json
from typing import Any import mimetypes
from collections import OrderedDict
from loguru import logger from loguru import logger
@@ -27,6 +28,7 @@ class WhatsAppChannel(BaseChannel):
self.config: WhatsAppConfig = config self.config: WhatsAppConfig = config
self._ws = None self._ws = None
self._connected = False self._connected = False
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
async def start(self) -> None: async def start(self) -> None:
"""Start the WhatsApp channel by connecting to the bridge.""" """Start the WhatsApp channel by connecting to the bridge."""
@@ -34,7 +36,7 @@ class WhatsAppChannel(BaseChannel):
bridge_url = self.config.bridge_url bridge_url = self.config.bridge_url
logger.info(f"Connecting to WhatsApp bridge at {bridge_url}...") logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
self._running = True self._running = True
@@ -53,14 +55,14 @@ class WhatsAppChannel(BaseChannel):
try: try:
await self._handle_bridge_message(message) await self._handle_bridge_message(message)
except Exception as e: except Exception as e:
logger.error(f"Error handling bridge message: {e}") logger.error("Error handling bridge message: {}", e)
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
self._connected = False self._connected = False
self._ws = None self._ws = None
logger.warning(f"WhatsApp bridge connection error: {e}") logger.warning("WhatsApp bridge connection error: {}", e)
if self._running: if self._running:
logger.info("Reconnecting in 5 seconds...") logger.info("Reconnecting in 5 seconds...")
@@ -87,16 +89,16 @@ class WhatsAppChannel(BaseChannel):
"to": msg.chat_id, "to": msg.chat_id,
"text": msg.content "text": msg.content
} }
await self._ws.send(json.dumps(payload)) await self._ws.send(json.dumps(payload, ensure_ascii=False))
except Exception as e: except Exception as e:
logger.error(f"Error sending WhatsApp message: {e}") logger.error("Error sending WhatsApp message: {}", e)
async def _handle_bridge_message(self, raw: str) -> None: async def _handle_bridge_message(self, raw: str) -> None:
"""Handle a message from the bridge.""" """Handle a message from the bridge."""
try: try:
data = json.loads(raw) data = json.loads(raw)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Invalid JSON from bridge: {raw[:100]}") logger.warning("Invalid JSON from bridge: {}", raw[:100])
return return
msg_type = data.get("type") msg_type = data.get("type")
@@ -108,23 +110,43 @@ class WhatsAppChannel(BaseChannel):
# New LID sytle typically: # New LID sytle typically:
sender = data.get("sender", "") sender = data.get("sender", "")
content = data.get("content", "") content = data.get("content", "")
message_id = data.get("id", "")
if message_id:
if message_id in self._processed_message_ids:
return
self._processed_message_ids[message_id] = None
while len(self._processed_message_ids) > 1000:
self._processed_message_ids.popitem(last=False)
# Extract just the phone number or lid as chat_id # Extract just the phone number or lid as chat_id
user_id = pn if pn else sender user_id = pn if pn else sender
sender_id = user_id.split("@")[0] if "@" in user_id else user_id sender_id = user_id.split("@")[0] if "@" in user_id else user_id
logger.info(f"Sender {sender}") logger.info("Sender {}", sender)
# Handle voice transcription if it's a voice message # Handle voice transcription if it's a voice message
if content == "[Voice Message]": if content == "[Voice Message]":
logger.info(f"Voice message received from {sender_id}, but direct download from bridge is not yet supported.") logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]" content = "[Voice Message: Transcription not available for WhatsApp yet]"
# Extract media paths (images/documents/videos downloaded by the bridge)
media_paths = data.get("media") or []
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
if media_paths:
for p in media_paths:
mime, _ = mimetypes.guess_type(p)
media_type = "image" if mime and mime.startswith("image/") else "file"
media_tag = f"[{media_type}: {p}]"
content = f"{content}\n{media_tag}" if content else media_tag
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=sender, # Use full LID for replies chat_id=sender, # Use full LID for replies
content=content, content=content,
media=media_paths,
metadata={ metadata={
"message_id": data.get("id"), "message_id": message_id,
"timestamp": data.get("timestamp"), "timestamp": data.get("timestamp"),
"is_group": data.get("isGroup", False) "is_group": data.get("isGroup", False)
} }
@@ -133,7 +155,7 @@ class WhatsAppChannel(BaseChannel):
elif msg_type == "status": elif msg_type == "status":
# Connection status update # Connection status update
status = data.get("status") status = data.get("status")
logger.info(f"WhatsApp status: {status}") logger.info("WhatsApp status: {}", status)
if status == "connected": if status == "connected":
self._connected = True self._connected = True
@@ -145,4 +167,4 @@ class WhatsAppChannel(BaseChannel):
logger.info("Scan QR code in the bridge terminal to connect WhatsApp") logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
elif msg_type == "error": elif msg_type == "error":
logger.error(f"WhatsApp bridge error: {data.get('error')}") logger.error("WhatsApp bridge error: {}", data.get('error'))

View File

@@ -2,23 +2,36 @@
import asyncio import asyncio
import os import os
import signal
from pathlib import Path
import select import select
import signal
import sys import sys
from pathlib import Path
# Force UTF-8 encoding for Windows console
if sys.platform == "win32":
import locale
if sys.stdout.encoding != "utf-8":
os.environ["PYTHONIOENCODING"] = "utf-8"
# Re-open stdout/stderr with UTF-8 encoding
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
import typer import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
from rich.text import Text from rich.text import Text
from prompt_toolkit import PromptSession from nanobot import __logo__, __version__
from prompt_toolkit.formatted_text import HTML from nanobot.config.schema import Config
from prompt_toolkit.history import FileHistory from nanobot.utils.helpers import sync_workspace_templates
from prompt_toolkit.patch_stdout import patch_stdout
from nanobot import __version__, __logo__
app = typer.Typer( app = typer.Typer(
name="nanobot", name="nanobot",
@@ -184,8 +197,7 @@ def onboard():
workspace.mkdir(parents=True, exist_ok=True) workspace.mkdir(parents=True, exist_ok=True)
console.print(f"[green]✓[/green] Created workspace at {workspace}") console.print(f"[green]✓[/green] Created workspace at {workspace}")
# Create default bootstrap files sync_workspace_templates(workspace)
_create_workspace_templates(workspace)
console.print(f"\n{__logo__} nanobot is ready!") console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:") console.print("\nNext steps:")
@@ -197,102 +209,57 @@ def onboard():
def _create_workspace_templates(workspace: Path):
"""Create default workspace template files."""
templates = {
"AGENTS.md": """# Agent Instructions
You are a helpful AI assistant. Be concise, accurate, and friendly. def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
## Guidelines
- Always explain what you're doing before taking actions
- Ask for clarification when the request is ambiguous
- Use tools to help accomplish tasks
- Remember important information in memory/MEMORY.md; past events are logged in memory/HISTORY.md
""",
"SOUL.md": """# Soul
I am nanobot, a lightweight AI assistant.
## Personality
- Helpful and friendly
- Concise and to the point
- Curious and eager to learn
## Values
- Accuracy over speed
- User privacy and safety
- Transparency in actions
""",
"USER.md": """# User
Information about the user goes here.
## Preferences
- Communication style: (casual/formal)
- Timezone: (your timezone)
- Language: (your preferred language)
""",
}
for filename, content in templates.items():
file_path = workspace / filename
if not file_path.exists():
file_path.write_text(content)
console.print(f" [dim]Created {filename}[/dim]")
# Create memory directory and MEMORY.md
memory_dir = workspace / "memory"
memory_dir.mkdir(exist_ok=True)
memory_file = memory_dir / "MEMORY.md"
if not memory_file.exists():
memory_file.write_text("""# Long-term Memory
This file stores important information that should persist across sessions.
## User Information
(Important facts about the user)
## Preferences
(User preferences learned over time)
## Important Notes
(Things to remember)
""")
console.print(" [dim]Created memory/MEMORY.md[/dim]")
history_file = memory_dir / "HISTORY.md"
if not history_file.exists():
history_file.write_text("")
console.print(" [dim]Created memory/HISTORY.md[/dim]")
# Create skills directory for custom user skills
skills_dir = workspace / "skills"
skills_dir.mkdir(exist_ok=True)
def _make_provider(config):
"""Create LiteLLMProvider from config. Exits if no API key found."""
from nanobot.providers.litellm_provider import LiteLLMProvider
p = config.get_provider()
model = config.agents.defaults.model model = config.agents.defaults.model
if not (p and p.api_key) and not model.startswith("bedrock/"): provider_name = config.get_provider_name(model)
p = config.get_provider(model)
# OpenAI Codex (OAuth)
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
return OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
from nanobot.providers.custom_provider import CustomProvider
if provider_name == "custom":
return CustomProvider(
api_key=p.api_key if p else "no-key",
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
default_model=model,
)
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
if provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1)
return AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name)
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
console.print("[red]Error: No API key configured.[/red]") console.print("[red]Error: No API key configured.[/red]")
console.print("Set one in ~/.nanobot/config.json under providers section") console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1) raise typer.Exit(1)
return LiteLLMProvider( return LiteLLMProvider(
api_key=p.api_key if p else None, api_key=p.api_key if p else None,
api_base=config.get_api_base(), api_base=config.get_api_base(model),
default_model=model, default_model=model,
extra_headers=p.extra_headers if p else None, extra_headers=p.extra_headers if p else None,
provider_name=config.get_provider_name(), provider_name=provider_name,
) )
@@ -304,31 +271,38 @@ def _make_provider(config):
@app.command() @app.command()
def gateway( def gateway(
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"), port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
): ):
"""Start the nanobot gateway.""" """Start the nanobot gateway."""
from nanobot.config.loader import load_config, get_data_dir
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.session.manager import SessionManager from nanobot.config.loader import load_config
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
from nanobot.session.manager import SessionManager
if verbose: if verbose:
import logging import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
console.print(f"{__logo__} Starting nanobot gateway on port {port}...") config_path = Path(config) if config else None
config = load_config(config_path)
if workspace:
config.agents.defaults.workspace = workspace
config = load_config() console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path) session_manager = SessionManager(config.workspace_path)
# Create cron service first (callback set after agent creation) # Create cron service first (callback set after agent creation)
cron_store_path = get_data_dir() / "cron" / "jobs.json" # Use workspace path for per-instance cron store
cron_store_path = config.workspace_path / "cron" / "jobs.json"
cron = CronService(cron_store_path) cron = CronService(cron_store_path)
# Create agent with cron service # Create agent with cron service
@@ -341,48 +315,112 @@ def gateway(
max_tokens=config.agents.defaults.max_tokens, max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, memory_window=config.agents.defaults.memory_window,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
cron_service=cron, cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace, restrict_to_workspace=config.tools.restrict_to_workspace,
session_manager=session_manager, session_manager=session_manager,
mcp_servers=config.tools.mcp_servers, mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
) )
# Set cron callback (needs agent) # Set cron callback (needs agent)
async def on_cron_job(job: CronJob) -> str | None: async def on_cron_job(job: CronJob) -> str | None:
"""Execute a cron job through the agent.""" """Execute a cron job through the agent."""
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.message import MessageTool
reminder_note = (
"[Scheduled Task] Timer finished.\n\n"
f"Task '{job.name}' has been triggered.\n"
f"Scheduled instruction: {job.payload.message}"
)
# Prevent the agent from scheduling new cron jobs during execution
cron_tool = agent.tools.get("cron")
cron_token = None
if isinstance(cron_tool, CronTool):
cron_token = cron_tool.set_cron_context(True)
try:
response = await agent.process_direct( response = await agent.process_direct(
job.payload.message, reminder_note,
session_key=f"cron:{job.id}", session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli", channel=job.payload.channel or "cli",
chat_id=job.payload.to or "direct", chat_id=job.payload.to or "direct",
) )
if job.payload.deliver and job.payload.to: finally:
if isinstance(cron_tool, CronTool) and cron_token is not None:
cron_tool.reset_cron_context(cron_token)
message_tool = agent.tools.get("message")
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
return response
if job.payload.deliver and job.payload.to and response:
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
await bus.publish_outbound(OutboundMessage( await bus.publish_outbound(OutboundMessage(
channel=job.payload.channel or "cli", channel=job.payload.channel or "cli",
chat_id=job.payload.to, chat_id=job.payload.to,
content=response or "" content=response
)) ))
return response return response
cron.on_job = on_cron_job cron.on_job = on_cron_job
# Create heartbeat service
async def on_heartbeat(prompt: str) -> str:
"""Execute heartbeat through the agent."""
return await agent.process_direct(prompt, session_key="heartbeat")
heartbeat = HeartbeatService(
workspace=config.workspace_path,
on_heartbeat=on_heartbeat,
interval_s=30 * 60, # 30 minutes
enabled=True
)
# Create channel manager # Create channel manager
channels = ChannelManager(config, bus) channels = ChannelManager(config, bus)
def _pick_heartbeat_target() -> tuple[str, str]:
"""Pick a routable channel/chat target for heartbeat-triggered messages."""
enabled = set(channels.enabled_channels)
# Prefer the most recently updated non-internal session on an enabled channel.
for item in session_manager.list_sessions():
key = item.get("key") or ""
if ":" not in key:
continue
channel, chat_id = key.split(":", 1)
if channel in {"cli", "system"}:
continue
if channel in enabled and chat_id:
return channel, chat_id
# Fallback keeps prior behavior but remains explicit.
return "cli", "direct"
# Create heartbeat service
async def on_heartbeat_execute(tasks: str) -> str:
"""Phase 2: execute heartbeat tasks through the full agent loop."""
channel, chat_id = _pick_heartbeat_target()
async def _silent(*_args, **_kwargs):
pass
return await agent.process_direct(
tasks,
session_key="heartbeat",
channel=channel,
chat_id=chat_id,
on_progress=_silent,
)
async def on_heartbeat_notify(response: str) -> None:
"""Deliver a heartbeat response to the user's channel."""
from nanobot.bus.events import OutboundMessage
channel, chat_id = _pick_heartbeat_target()
if channel == "cli":
return # No external channel available to deliver to
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
hb_cfg = config.gateway.heartbeat
heartbeat = HeartbeatService(
workspace=config.workspace_path,
provider=provider,
model=agent.model,
on_execute=on_heartbeat_execute,
on_notify=on_heartbeat_notify,
interval_s=hb_cfg.interval_s,
enabled=hb_cfg.enabled,
)
if channels.enabled_channels: if channels.enabled_channels:
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}") console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
else: else:
@@ -392,7 +430,7 @@ def gateway(
if cron_status["jobs"] > 0: if cron_status["jobs"] > 0:
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs") console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
console.print(f"[green]✓[/green] Heartbeat: every 30m") console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
async def run(): async def run():
try: try:
@@ -429,16 +467,23 @@ def agent(
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"), logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
): ):
"""Interact with the agent directly.""" """Interact with the agent directly."""
from nanobot.config.loader import load_config
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop
from loguru import logger from loguru import logger
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import get_data_dir, load_config
from nanobot.cron.service import CronService
config = load_config() config = load_config()
sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
# Create cron service for tool usage (no callback needed for CLI unless running)
cron_store_path = get_data_dir() / "cron" / "jobs.json"
cron = CronService(cron_store_path)
if logs: if logs:
logger.enable("nanobot") logger.enable("nanobot")
else: else:
@@ -453,10 +498,14 @@ def agent(
max_tokens=config.agents.defaults.max_tokens, max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, memory_window=config.agents.defaults.memory_window,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace, restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers, mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
) )
# Show spinner when logs are off (no output to miss); skip when logs are on # Show spinner when logs are off (no output to miss); skip when logs are on
@@ -467,28 +516,83 @@ def agent(
# Animated spinner is safe to use with prompt_toolkit input handling # Animated spinner is safe to use with prompt_toolkit input handling
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots") return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
ch = agent_loop.channels_config
if ch and tool_hint and not ch.send_tool_hints:
return
if ch and not tool_hint and not ch.send_progress:
return
console.print(f" [dim]↳ {content}[/dim]")
if message: if message:
# Single message mode # Single message mode — direct call, no bus needed
async def run_once(): async def run_once():
with _thinking_ctx(): with _thinking_ctx():
response = await agent_loop.process_direct(message, session_id) response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
_print_agent_response(response, render_markdown=markdown) _print_agent_response(response, render_markdown=markdown)
await agent_loop.close_mcp() await agent_loop.close_mcp()
asyncio.run(run_once()) asyncio.run(run_once())
else: else:
# Interactive mode # Interactive mode — route through bus like other channels
from nanobot.bus.events import InboundMessage
_init_prompt_session() _init_prompt_session()
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n") console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
def _exit_on_sigint(signum, frame): if ":" in session_id:
_restore_terminal() cli_channel, cli_chat_id = session_id.split(":", 1)
console.print("\nGoodbye!") else:
os._exit(0) cli_channel, cli_chat_id = "cli", session_id
signal.signal(signal.SIGINT, _exit_on_sigint) def _handle_signal(signum, frame):
sig_name = signal.Signals(signum).name
_restore_terminal()
console.print(f"\nReceived {sig_name}, goodbye!")
sys.exit(0)
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
# SIGHUP is not available on Windows
if hasattr(signal, 'SIGHUP'):
signal.signal(signal.SIGHUP, _handle_signal)
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
# SIGPIPE is not available on Windows
if hasattr(signal, 'SIGPIPE'):
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
async def run_interactive(): async def run_interactive():
bus_task = asyncio.create_task(agent_loop.run())
turn_done = asyncio.Event()
turn_done.set()
turn_response: list[str] = []
async def _consume_outbound():
while True:
try:
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
if msg.metadata.get("_progress"):
is_tool_hint = msg.metadata.get("_tool_hint", False)
ch = agent_loop.channels_config
if ch and is_tool_hint and not ch.send_tool_hints:
pass
elif ch and not is_tool_hint and not ch.send_progress:
pass
else:
console.print(f" [dim]↳ {msg.content}[/dim]")
elif not turn_done.is_set():
if msg.content:
turn_response.append(msg.content)
turn_done.set()
elif msg.content:
console.print()
_print_agent_response(msg.content, render_markdown=markdown)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
outbound_task = asyncio.create_task(_consume_outbound())
try: try:
while True: while True:
try: try:
@@ -503,9 +607,21 @@ def agent(
console.print("\nGoodbye!") console.print("\nGoodbye!")
break break
turn_done.clear()
turn_response.clear()
await bus.publish_inbound(InboundMessage(
channel=cli_channel,
sender_id="user",
chat_id=cli_chat_id,
content=user_input,
))
with _thinking_ctx(): with _thinking_ctx():
response = await agent_loop.process_direct(user_input, session_id) await turn_done.wait()
_print_agent_response(response, render_markdown=markdown)
if turn_response:
_print_agent_response(turn_response[0], render_markdown=markdown)
except KeyboardInterrupt: except KeyboardInterrupt:
_restore_terminal() _restore_terminal()
console.print("\nGoodbye!") console.print("\nGoodbye!")
@@ -515,6 +631,9 @@ def agent(
console.print("\nGoodbye!") console.print("\nGoodbye!")
break break
finally: finally:
agent_loop.stop()
outbound_task.cancel()
await asyncio.gather(bus_task, outbound_task, return_exceptions=True)
await agent_loop.close_mcp() await agent_loop.close_mcp()
asyncio.run(run_interactive()) asyncio.run(run_interactive())
@@ -592,6 +711,33 @@ def channels_status():
slack_config slack_config
) )
# DingTalk
dt = config.channels.dingtalk
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
table.add_row(
"DingTalk",
"" if dt.enabled else "",
dt_config
)
# QQ
qq = config.channels.qq
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
table.add_row(
"QQ",
"" if qq.enabled else "",
qq_config
)
# Email
em = config.channels.email
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
table.add_row(
"Email",
"" if em.enabled else "",
em_config
)
console.print(table) console.print(table)
@@ -657,6 +803,7 @@ def _get_bridge_dir() -> Path:
def channels_login(): def channels_login():
"""Link device via QR code.""" """Link device via QR code."""
import subprocess import subprocess
from nanobot.config.loader import load_config from nanobot.config.loader import load_config
config = load_config() config = load_config()
@@ -677,163 +824,6 @@ def channels_login():
console.print("[red]npm not found. Please install Node.js.[/red]") console.print("[red]npm not found. Please install Node.js.[/red]")
# ============================================================================
# Cron Commands
# ============================================================================
cron_app = typer.Typer(help="Manage scheduled tasks")
app.add_typer(cron_app, name="cron")
@cron_app.command("list")
def cron_list(
all: bool = typer.Option(False, "--all", "-a", help="Include disabled jobs"),
):
"""List scheduled jobs."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
jobs = service.list_jobs(include_disabled=all)
if not jobs:
console.print("No scheduled jobs.")
return
table = Table(title="Scheduled Jobs")
table.add_column("ID", style="cyan")
table.add_column("Name")
table.add_column("Schedule")
table.add_column("Status")
table.add_column("Next Run")
import time
for job in jobs:
# Format schedule
if job.schedule.kind == "every":
sched = f"every {(job.schedule.every_ms or 0) // 1000}s"
elif job.schedule.kind == "cron":
sched = job.schedule.expr or ""
else:
sched = "one-time"
# Format next run
next_run = ""
if job.state.next_run_at_ms:
next_time = time.strftime("%Y-%m-%d %H:%M", time.localtime(job.state.next_run_at_ms / 1000))
next_run = next_time
status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
table.add_row(job.id, job.name, sched, status, next_run)
console.print(table)
@cron_app.command("add")
def cron_add(
name: str = typer.Option(..., "--name", "-n", help="Job name"),
message: str = typer.Option(..., "--message", "-m", help="Message for agent"),
every: int = typer.Option(None, "--every", "-e", help="Run every N seconds"),
cron_expr: str = typer.Option(None, "--cron", "-c", help="Cron expression (e.g. '0 9 * * *')"),
at: str = typer.Option(None, "--at", help="Run once at time (ISO format)"),
deliver: bool = typer.Option(False, "--deliver", "-d", help="Deliver response to channel"),
to: str = typer.Option(None, "--to", help="Recipient for delivery"),
channel: str = typer.Option(None, "--channel", help="Channel for delivery (e.g. 'telegram', 'whatsapp')"),
):
"""Add a scheduled job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
from nanobot.cron.types import CronSchedule
# Determine schedule type
if every:
schedule = CronSchedule(kind="every", every_ms=every * 1000)
elif cron_expr:
schedule = CronSchedule(kind="cron", expr=cron_expr)
elif at:
import datetime
dt = datetime.datetime.fromisoformat(at)
schedule = CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000))
else:
console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
raise typer.Exit(1)
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
job = service.add_job(
name=name,
schedule=schedule,
message=message,
deliver=deliver,
to=to,
channel=channel,
)
console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
@cron_app.command("remove")
def cron_remove(
job_id: str = typer.Argument(..., help="Job ID to remove"),
):
"""Remove a scheduled job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
if service.remove_job(job_id):
console.print(f"[green]✓[/green] Removed job {job_id}")
else:
console.print(f"[red]Job {job_id} not found[/red]")
@cron_app.command("enable")
def cron_enable(
job_id: str = typer.Argument(..., help="Job ID"),
disable: bool = typer.Option(False, "--disable", help="Disable instead of enable"),
):
"""Enable or disable a job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
job = service.enable_job(job_id, enabled=not disable)
if job:
status = "disabled" if disable else "enabled"
console.print(f"[green]✓[/green] Job '{job.name}' {status}")
else:
console.print(f"[red]Job {job_id} not found[/red]")
@cron_app.command("run")
def cron_run(
job_id: str = typer.Argument(..., help="Job ID to run"),
force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
):
"""Manually run a job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
async def run():
return await service.run_job(job_id, force=force)
if asyncio.run(run()):
console.print(f"[green]✓[/green] Job executed")
else:
console.print(f"[red]Failed to run job {job_id}[/red]")
# ============================================================================ # ============================================================================
# Status Commands # Status Commands
# ============================================================================ # ============================================================================
@@ -842,7 +832,7 @@ def cron_run(
@app.command() @app.command()
def status(): def status():
"""Show nanobot status.""" """Show nanobot status."""
from nanobot.config.loader import load_config, get_config_path from nanobot.config.loader import get_config_path, load_config
config_path = get_config_path() config_path = get_config_path()
config = load_config() config = load_config()
@@ -863,7 +853,9 @@ def status():
p = getattr(config.providers, spec.name, None) p = getattr(config.providers, spec.name, None)
if p is None: if p is None:
continue continue
if spec.is_local: if spec.is_oauth:
console.print(f"{spec.label}: [green]✓ (OAuth)[/green]")
elif spec.is_local:
# Local deployments show api_base instead of api_key # Local deployments show api_base instead of api_key
if p.api_base: if p.api_base:
console.print(f"{spec.label}: [green]✓ {p.api_base}[/green]") console.print(f"{spec.label}: [green]✓ {p.api_base}[/green]")
@@ -874,5 +866,88 @@ def status():
console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}") console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}")
# ============================================================================
# OAuth Login
# ============================================================================
provider_app = typer.Typer(help="Manage providers")
app.add_typer(provider_app, name="provider")
_LOGIN_HANDLERS: dict[str, callable] = {}
def _register_login(name: str):
def decorator(fn):
_LOGIN_HANDLERS[name] = fn
return fn
return decorator
@provider_app.command("login")
def provider_login(
provider: str = typer.Argument(..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"),
):
"""Authenticate with an OAuth provider."""
from nanobot.providers.registry import PROVIDERS
key = provider.replace("-", "_")
spec = next((s for s in PROVIDERS if s.name == key and s.is_oauth), None)
if not spec:
names = ", ".join(s.name.replace("_", "-") for s in PROVIDERS if s.is_oauth)
console.print(f"[red]Unknown OAuth provider: {provider}[/red] Supported: {names}")
raise typer.Exit(1)
handler = _LOGIN_HANDLERS.get(spec.name)
if not handler:
console.print(f"[red]Login not implemented for {spec.label}[/red]")
raise typer.Exit(1)
console.print(f"{__logo__} OAuth Login - {spec.label}\n")
handler()
@_register_login("openai_codex")
def _login_openai_codex() -> None:
try:
from oauth_cli_kit import get_token, login_oauth_interactive
token = None
try:
token = get_token()
except Exception:
pass
if not (token and token.access):
console.print("[cyan]Starting interactive OAuth login...[/cyan]\n")
token = login_oauth_interactive(
print_fn=lambda s: console.print(s),
prompt_fn=lambda s: typer.prompt(s),
)
if not (token and token.access):
console.print("[red]✗ Authentication failed[/red]")
raise typer.Exit(1)
console.print(f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]")
except ImportError:
console.print("[red]oauth_cli_kit not installed. Run: pip install oauth-cli-kit[/red]")
raise typer.Exit(1)
@_register_login("github_copilot")
def _login_github_copilot() -> None:
import asyncio
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
async def _trigger():
from litellm import acompletion
await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1)
try:
asyncio.run(_trigger())
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
except Exception as e:
console.print(f"[red]Authentication error: {e}[/red]")
raise typer.Exit(1)
if __name__ == "__main__": if __name__ == "__main__":
app() app()

View File

@@ -1,6 +1,6 @@
"""Configuration module for nanobot.""" """Configuration module for nanobot."""
from nanobot.config.loader import load_config, get_config_path from nanobot.config.loader import get_config_path, load_config
from nanobot.config.schema import Config from nanobot.config.schema import Config
__all__ = ["Config", "load_config", "get_config_path"] __all__ = ["Config", "load_config", "get_config_path"]

View File

@@ -2,7 +2,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Any
from nanobot.config.schema import Config from nanobot.config.schema import Config
@@ -32,10 +31,10 @@ def load_config(config_path: Path | None = None) -> Config:
if path.exists(): if path.exists():
try: try:
with open(path) as f: with open(path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
data = _migrate_config(data) data = _migrate_config(data)
return Config.model_validate(convert_keys(data)) return Config.model_validate(data)
except (json.JSONDecodeError, ValueError) as e: except (json.JSONDecodeError, ValueError) as e:
print(f"Warning: Failed to load config from {path}: {e}") print(f"Warning: Failed to load config from {path}: {e}")
print("Using default configuration.") print("Using default configuration.")
@@ -54,12 +53,10 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
path = config_path or get_config_path() path = config_path or get_config_path()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
# Convert to camelCase format data = config.model_dump(by_alias=True)
data = config.model_dump()
data = convert_to_camel(data)
with open(path, "w") as f: with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2, ensure_ascii=False)
def _migrate_config(data: dict) -> dict: def _migrate_config(data: dict) -> dict:
@@ -70,37 +67,3 @@ def _migrate_config(data: dict) -> dict:
if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools: if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools:
tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace") tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace")
return data return data
def convert_keys(data: Any) -> Any:
"""Convert camelCase keys to snake_case for Pydantic."""
if isinstance(data, dict):
return {camel_to_snake(k): convert_keys(v) for k, v in data.items()}
if isinstance(data, list):
return [convert_keys(item) for item in data]
return data
def convert_to_camel(data: Any) -> Any:
"""Convert snake_case keys to camelCase."""
if isinstance(data, dict):
return {snake_to_camel(k): convert_to_camel(v) for k, v in data.items()}
if isinstance(data, list):
return [convert_to_camel(item) for item in data]
return data
def camel_to_snake(name: str) -> str:
"""Convert camelCase to snake_case."""
result = []
for i, char in enumerate(name):
if char.isupper() and i > 0:
result.append("_")
result.append(char.lower())
return "".join(result)
def snake_to_camel(name: str) -> str:
"""Convert snake_case to camelCase."""
components = name.split("_")
return components[0] + "".join(x.title() for x in components[1:])

View File

@@ -1,54 +1,98 @@
"""Configuration schema using Pydantic.""" """Configuration schema using Pydantic."""
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, Field, ConfigDict from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_camel
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
class WhatsAppConfig(BaseModel): class Base(BaseModel):
"""Base model that accepts both camelCase and snake_case keys."""
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
class WhatsAppConfig(Base):
"""WhatsApp channel configuration.""" """WhatsApp channel configuration."""
enabled: bool = False enabled: bool = False
bridge_url: str = "ws://localhost:3001" bridge_url: str = "ws://localhost:3001"
bridge_token: str = "" # Shared token for bridge auth (optional, recommended) bridge_token: str = "" # Shared token for bridge auth (optional, recommended)
allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
class TelegramConfig(BaseModel): class TelegramConfig(Base):
"""Telegram channel configuration.""" """Telegram channel configuration."""
enabled: bool = False enabled: bool = False
token: str = "" # Bot token from @BotFather token: str = "" # Bot token from @BotFather
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" proxy: str | None = (
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
)
reply_to_message: bool = False # If true, bot replies quote the original message
class FeishuConfig(BaseModel): class FeishuConfig(Base):
"""Feishu/Lark channel configuration using WebSocket long connection.""" """Feishu/Lark channel configuration using WebSocket long connection."""
enabled: bool = False enabled: bool = False
app_id: str = "" # App ID from Feishu Open Platform app_id: str = "" # App ID from Feishu Open Platform
app_secret: str = "" # App Secret from Feishu Open Platform app_secret: str = "" # App Secret from Feishu Open Platform
encrypt_key: str = "" # Encrypt Key for event subscription (optional) encrypt_key: str = "" # Encrypt Key for event subscription (optional)
verification_token: str = "" # Verification Token for event subscription (optional) verification_token: str = "" # Verification Token for event subscription (optional)
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
react_emoji: str = (
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
)
class DingTalkConfig(BaseModel): class DingTalkConfig(Base):
"""DingTalk channel configuration using Stream mode.""" """DingTalk channel configuration using Stream mode."""
enabled: bool = False enabled: bool = False
client_id: str = "" # AppKey client_id: str = "" # AppKey
client_secret: str = "" # AppSecret client_secret: str = "" # AppSecret
allow_from: list[str] = Field(default_factory=list) # Allowed staff_ids allow_from: list[str] = Field(default_factory=list) # Allowed staff_ids
class DiscordConfig(BaseModel): class DiscordConfig(Base):
"""Discord channel configuration.""" """Discord channel configuration."""
enabled: bool = False enabled: bool = False
token: str = "" # Bot token from Discord Developer Portal token: str = "" # Bot token from Discord Developer Portal
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json" gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
group_policy: Literal["mention", "open"] = "mention"
class EmailConfig(BaseModel):
class MatrixConfig(Base):
"""Matrix (Element) channel configuration."""
enabled: bool = False
homeserver: str = "https://matrix.org"
access_token: str = ""
user_id: str = "" # @bot:matrix.org
device_id: str = ""
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
sync_stop_grace_seconds: int = (
2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
)
max_media_bytes: int = (
20 * 1024 * 1024
) # Max attachment size accepted for Matrix media handling (inbound + outbound).
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
class EmailConfig(Base):
"""Email channel configuration (IMAP inbound + SMTP outbound).""" """Email channel configuration (IMAP inbound + SMTP outbound)."""
enabled: bool = False enabled: bool = False
consent_granted: bool = False # Explicit owner permission to access mailbox data consent_granted: bool = False # Explicit owner permission to access mailbox data
@@ -70,7 +114,9 @@ class EmailConfig(BaseModel):
from_address: str = "" from_address: str = ""
# Behavior # Behavior
auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent auto_reply_enabled: bool = (
True # If false, inbound email is read but no automatic reply is sent
)
poll_interval_seconds: int = 30 poll_interval_seconds: int = 30
mark_seen: bool = True mark_seen: bool = True
max_body_chars: int = 12000 max_body_chars: int = 12000
@@ -78,18 +124,21 @@ class EmailConfig(BaseModel):
allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses
class MochatMentionConfig(BaseModel): class MochatMentionConfig(Base):
"""Mochat mention behavior configuration.""" """Mochat mention behavior configuration."""
require_in_groups: bool = False require_in_groups: bool = False
class MochatGroupRule(BaseModel): class MochatGroupRule(Base):
"""Mochat per-group mention requirement.""" """Mochat per-group mention requirement."""
require_mention: bool = False require_mention: bool = False
class MochatConfig(BaseModel): class MochatConfig(Base):
"""Mochat channel configuration.""" """Mochat channel configuration."""
enabled: bool = False enabled: bool = False
base_url: str = "https://mochat.io" base_url: str = "https://mochat.io"
socket_url: str = "" socket_url: str = ""
@@ -114,36 +163,49 @@ class MochatConfig(BaseModel):
reply_delay_ms: int = 120000 reply_delay_ms: int = 120000
class SlackDMConfig(BaseModel): class SlackDMConfig(Base):
"""Slack DM policy configuration.""" """Slack DM policy configuration."""
enabled: bool = True enabled: bool = True
policy: str = "open" # "open" or "allowlist" policy: str = "open" # "open" or "allowlist"
allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs
class SlackConfig(BaseModel): class SlackConfig(Base):
"""Slack channel configuration.""" """Slack channel configuration."""
enabled: bool = False enabled: bool = False
mode: str = "socket" # "socket" supported mode: str = "socket" # "socket" supported
webhook_path: str = "/slack/events" webhook_path: str = "/slack/events"
bot_token: str = "" # xoxb-... bot_token: str = "" # xoxb-...
app_token: str = "" # xapp-... app_token: str = "" # xapp-...
user_token_read_only: bool = True user_token_read_only: bool = True
reply_in_thread: bool = True
react_emoji: str = "eyes"
allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs (sender-level)
group_policy: str = "mention" # "mention", "open", "allowlist" group_policy: str = "mention" # "mention", "open", "allowlist"
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
dm: SlackDMConfig = Field(default_factory=SlackDMConfig) dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
class QQConfig(BaseModel): class QQConfig(Base):
"""QQ channel configuration using botpy SDK.""" """QQ channel configuration using botpy SDK."""
enabled: bool = False enabled: bool = False
app_id: str = "" # 机器人 ID (AppID) from q.qq.com app_id: str = "" # 机器人 ID (AppID) from q.qq.com
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access) allow_from: list[str] = Field(
default_factory=list
) # Allowed user openids (empty = public access)
class ChannelsConfig(BaseModel):
class ChannelsConfig(Base):
"""Configuration for chat channels.""" """Configuration for chat channels."""
send_progress: bool = True # stream agent's text progress to the channel
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig) whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
telegram: TelegramConfig = Field(default_factory=TelegramConfig) telegram: TelegramConfig = Field(default_factory=TelegramConfig)
discord: DiscordConfig = Field(default_factory=DiscordConfig) discord: DiscordConfig = Field(default_factory=DiscordConfig)
@@ -153,33 +215,43 @@ class ChannelsConfig(BaseModel):
email: EmailConfig = Field(default_factory=EmailConfig) email: EmailConfig = Field(default_factory=EmailConfig)
slack: SlackConfig = Field(default_factory=SlackConfig) slack: SlackConfig = Field(default_factory=SlackConfig)
qq: QQConfig = Field(default_factory=QQConfig) qq: QQConfig = Field(default_factory=QQConfig)
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
class AgentDefaults(BaseModel): class AgentDefaults(Base):
"""Default agent configuration.""" """Default agent configuration."""
workspace: str = "~/.nanobot/workspace" workspace: str = "~/.nanobot/workspace"
model: str = "anthropic/claude-opus-4-5" model: str = "anthropic/claude-opus-4-5"
provider: str = (
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
)
max_tokens: int = 8192 max_tokens: int = 8192
temperature: float = 0.7 temperature: float = 0.1
max_tool_iterations: int = 20 max_tool_iterations: int = 40
memory_window: int = 50 memory_window: int = 100
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
class AgentsConfig(BaseModel): class AgentsConfig(Base):
"""Agent configuration.""" """Agent configuration."""
defaults: AgentDefaults = Field(default_factory=AgentDefaults) defaults: AgentDefaults = Field(default_factory=AgentDefaults)
class ProviderConfig(BaseModel): class ProviderConfig(Base):
"""LLM provider configuration.""" """LLM provider configuration."""
api_key: str = "" api_key: str = ""
api_base: str | None = None api_base: str | None = None
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix) extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
class ProvidersConfig(BaseModel): class ProvidersConfig(Base):
"""Configuration for LLM providers.""" """Configuration for LLM providers."""
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig) anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig) openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig) openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
@@ -192,40 +264,65 @@ class ProvidersConfig(BaseModel):
moonshot: ProviderConfig = Field(default_factory=ProviderConfig) moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
class GatewayConfig(BaseModel): class HeartbeatConfig(Base):
"""Heartbeat service configuration."""
enabled: bool = True
interval_s: int = 30 * 60 # 30 minutes
class GatewayConfig(Base):
"""Gateway/server configuration.""" """Gateway/server configuration."""
host: str = "0.0.0.0" host: str = "0.0.0.0"
port: int = 18790 port: int = 18790
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
class WebSearchConfig(BaseModel): class WebSearchConfig(Base):
"""Web search tool configuration.""" """Web search tool configuration."""
api_key: str = "" # Brave Search API key api_key: str = "" # Brave Search API key
max_results: int = 5 max_results: int = 5
class WebToolsConfig(BaseModel): class WebToolsConfig(Base):
"""Web tools configuration.""" """Web tools configuration."""
proxy: str | None = (
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
)
search: WebSearchConfig = Field(default_factory=WebSearchConfig) search: WebSearchConfig = Field(default_factory=WebSearchConfig)
class ExecToolConfig(BaseModel): class ExecToolConfig(Base):
"""Shell exec tool configuration.""" """Shell exec tool configuration."""
timeout: int = 60 timeout: int = 60
path_append: str = ""
class MCPServerConfig(BaseModel): class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP).""" """MCP server connection configuration (stdio or HTTP)."""
type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
command: str = "" # Stdio: command to run (e.g. "npx") command: str = "" # Stdio: command to run (e.g. "npx")
args: list[str] = Field(default_factory=list) # Stdio: command arguments args: list[str] = Field(default_factory=list) # Stdio: command arguments
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
url: str = "" # HTTP: streamable HTTP endpoint URL url: str = "" # HTTP/SSE: endpoint URL
headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
tool_timeout: int = 30 # seconds before a tool call is cancelled
class ToolsConfig(BaseModel): class ToolsConfig(Base):
"""Tools configuration.""" """Tools configuration."""
web: WebToolsConfig = Field(default_factory=WebToolsConfig) web: WebToolsConfig = Field(default_factory=WebToolsConfig)
exec: ExecToolConfig = Field(default_factory=ExecToolConfig) exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
@@ -234,6 +331,7 @@ class ToolsConfig(BaseModel):
class Config(BaseSettings): class Config(BaseSettings):
"""Root configuration for nanobot.""" """Root configuration for nanobot."""
agents: AgentsConfig = Field(default_factory=AgentsConfig) agents: AgentsConfig = Field(default_factory=AgentsConfig)
channels: ChannelsConfig = Field(default_factory=ChannelsConfig) channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
providers: ProvidersConfig = Field(default_factory=ProvidersConfig) providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
@@ -245,19 +343,45 @@ class Config(BaseSettings):
"""Get expanded workspace path.""" """Get expanded workspace path."""
return Path(self.agents.defaults.workspace).expanduser() return Path(self.agents.defaults.workspace).expanduser()
def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]: def _match_provider(
self, model: str | None = None
) -> tuple["ProviderConfig | None", str | None]:
"""Match provider config and its registry name. Returns (config, spec_name).""" """Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS from nanobot.providers.registry import PROVIDERS
forced = self.agents.defaults.provider
if forced != "auto":
p = getattr(self.providers, forced, None)
return (p, forced) if p else (None, None)
model_lower = (model or self.agents.defaults.model).lower() model_lower = (model or self.agents.defaults.model).lower()
model_normalized = model_lower.replace("-", "_")
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
normalized_prefix = model_prefix.replace("-", "_")
def _kw_matches(kw: str) -> bool:
kw = kw.lower()
return kw in model_lower or kw.replace("-", "_") in model_normalized
# Explicit provider prefix wins — prevents `github-copilot/...codex` matching openai_codex.
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
if p and model_prefix and normalized_prefix == spec.name:
if spec.is_oauth or p.api_key:
return p, spec.name
# Match by keyword (order follows PROVIDERS registry) # Match by keyword (order follows PROVIDERS registry)
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and any(kw in model_lower for kw in spec.keywords) and p.api_key: if p and any(_kw_matches(kw) for kw in spec.keywords):
if spec.is_oauth or p.api_key:
return p, spec.name return p, spec.name
# Fallback: gateways first, then others (follows registry order) # Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks — they require explicit model selection
for spec in PROVIDERS: for spec in PROVIDERS:
if spec.is_oauth:
continue
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and p.api_key: if p and p.api_key:
return p, spec.name return p, spec.name
@@ -281,6 +405,7 @@ class Config(BaseSettings):
def get_api_base(self, model: str | None = None) -> str | None: def get_api_base(self, model: str | None = None) -> str | None:
"""Get API base URL for the given model. Applies default URLs for known gateways.""" """Get API base URL for the given model. Applies default URLs for known gateways."""
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
p, name = self._match_provider(model) p, name = self._match_provider(model)
if p and p.api_base: if p and p.api_base:
return p.api_base return p.api_base
@@ -293,7 +418,4 @@ class Config(BaseSettings):
return spec.default_api_base return spec.default_api_base
return None return None
model_config = ConfigDict( model_config = ConfigDict(env_prefix="NANOBOT_", env_nested_delimiter="__")
env_prefix="NANOBOT_",
env_nested_delimiter="__"
)

View File

@@ -30,9 +30,11 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
if schedule.kind == "cron" and schedule.expr: if schedule.kind == "cron" and schedule.expr:
try: try:
from croniter import croniter
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
base_time = time.time()
from croniter import croniter
# Use caller-provided reference time for deterministic scheduling
base_time = now_ms / 1000
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
base_dt = datetime.fromtimestamp(base_time, tz=tz) base_dt = datetime.fromtimestamp(base_time, tz=tz)
cron = croniter(schedule.expr, base_dt) cron = croniter(schedule.expr, base_dt)
@@ -44,6 +46,20 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
return None return None
def _validate_schedule_for_add(schedule: CronSchedule) -> None:
"""Validate schedule fields that would otherwise create non-runnable jobs."""
if schedule.tz and schedule.kind != "cron":
raise ValueError("tz can only be used with cron schedules")
if schedule.kind == "cron" and schedule.tz:
try:
from zoneinfo import ZoneInfo
ZoneInfo(schedule.tz)
except Exception:
raise ValueError(f"unknown timezone '{schedule.tz}'") from None
class CronService: class CronService:
"""Service for managing and executing scheduled jobs.""" """Service for managing and executing scheduled jobs."""
@@ -53,19 +69,25 @@ class CronService:
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
): ):
self.store_path = store_path self.store_path = store_path
self.on_job = on_job # Callback to execute job, returns response text self.on_job = on_job
self._store: CronStore | None = None self._store: CronStore | None = None
self._last_mtime: float = 0.0
self._timer_task: asyncio.Task | None = None self._timer_task: asyncio.Task | None = None
self._running = False self._running = False
def _load_store(self) -> CronStore: def _load_store(self) -> CronStore:
"""Load jobs from disk.""" """Load jobs from disk. Reloads automatically if file was modified externally."""
if self._store and self.store_path.exists():
mtime = self.store_path.stat().st_mtime
if mtime != self._last_mtime:
logger.info("Cron: jobs.json modified externally, reloading")
self._store = None
if self._store: if self._store:
return self._store return self._store
if self.store_path.exists(): if self.store_path.exists():
try: try:
data = json.loads(self.store_path.read_text()) data = json.loads(self.store_path.read_text(encoding="utf-8"))
jobs = [] jobs = []
for j in data.get("jobs", []): for j in data.get("jobs", []):
jobs.append(CronJob( jobs.append(CronJob(
@@ -98,7 +120,7 @@ class CronService:
)) ))
self._store = CronStore(jobs=jobs) self._store = CronStore(jobs=jobs)
except Exception as e: except Exception as e:
logger.warning(f"Failed to load cron store: {e}") logger.warning("Failed to load cron store: {}", e)
self._store = CronStore() self._store = CronStore()
else: else:
self._store = CronStore() self._store = CronStore()
@@ -147,7 +169,8 @@ class CronService:
] ]
} }
self.store_path.write_text(json.dumps(data, indent=2)) self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
self._last_mtime = self.store_path.stat().st_mtime
async def start(self) -> None: async def start(self) -> None:
"""Start the cron service.""" """Start the cron service."""
@@ -156,7 +179,7 @@ class CronService:
self._recompute_next_runs() self._recompute_next_runs()
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info(f"Cron service started with {len(self._store.jobs if self._store else [])} jobs") logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
def stop(self) -> None: def stop(self) -> None:
"""Stop the cron service.""" """Stop the cron service."""
@@ -203,6 +226,7 @@ class CronService:
async def _on_timer(self) -> None: async def _on_timer(self) -> None:
"""Handle timer tick - run due jobs.""" """Handle timer tick - run due jobs."""
self._load_store()
if not self._store: if not self._store:
return return
@@ -221,7 +245,7 @@ class CronService:
async def _execute_job(self, job: CronJob) -> None: async def _execute_job(self, job: CronJob) -> None:
"""Execute a single job.""" """Execute a single job."""
start_ms = _now_ms() start_ms = _now_ms()
logger.info(f"Cron: executing job '{job.name}' ({job.id})") logger.info("Cron: executing job '{}' ({})", job.name, job.id)
try: try:
response = None response = None
@@ -230,12 +254,12 @@ class CronService:
job.state.last_status = "ok" job.state.last_status = "ok"
job.state.last_error = None job.state.last_error = None
logger.info(f"Cron: job '{job.name}' completed") logger.info("Cron: job '{}' completed", job.name)
except Exception as e: except Exception as e:
job.state.last_status = "error" job.state.last_status = "error"
job.state.last_error = str(e) job.state.last_error = str(e)
logger.error(f"Cron: job '{job.name}' failed: {e}") logger.error("Cron: job '{}' failed: {}", job.name, e)
job.state.last_run_at_ms = start_ms job.state.last_run_at_ms = start_ms
job.updated_at_ms = _now_ms() job.updated_at_ms = _now_ms()
@@ -271,6 +295,7 @@ class CronService:
) -> CronJob: ) -> CronJob:
"""Add a new job.""" """Add a new job."""
store = self._load_store() store = self._load_store()
_validate_schedule_for_add(schedule)
now = _now_ms() now = _now_ms()
job = CronJob( job = CronJob(
@@ -295,7 +320,7 @@ class CronService:
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info(f"Cron: added job '{name}' ({job.id})") logger.info("Cron: added job '{}' ({})", name, job.id)
return job return job
def remove_job(self, job_id: str) -> bool: def remove_job(self, job_id: str) -> bool:
@@ -308,7 +333,7 @@ class CronService:
if removed: if removed:
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info(f"Cron: removed job {job_id}") logger.info("Cron: removed job {}", job_id)
return removed return removed

View File

@@ -1,57 +1,70 @@
"""Heartbeat service - periodic agent wake-up to check for tasks.""" """Heartbeat service - periodic agent wake-up to check for tasks."""
from __future__ import annotations
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Coroutine from typing import TYPE_CHECKING, Any, Callable, Coroutine
from loguru import logger from loguru import logger
# Default interval: 30 minutes if TYPE_CHECKING:
DEFAULT_HEARTBEAT_INTERVAL_S = 30 * 60 from nanobot.providers.base import LLMProvider
# The prompt sent to agent during heartbeat _HEARTBEAT_TOOL = [
HEARTBEAT_PROMPT = """Read HEARTBEAT.md in your workspace (if it exists). {
Follow any instructions or tasks listed there. "type": "function",
If nothing needs attention, reply with just: HEARTBEAT_OK""" "function": {
"name": "heartbeat",
# Token that indicates "nothing to do" "description": "Report heartbeat decision after reviewing tasks.",
HEARTBEAT_OK_TOKEN = "HEARTBEAT_OK" "parameters": {
"type": "object",
"properties": {
def _is_heartbeat_empty(content: str | None) -> bool: "action": {
"""Check if HEARTBEAT.md has no actionable content.""" "type": "string",
if not content: "enum": ["skip", "run"],
return True "description": "skip = nothing to do, run = has active tasks",
},
# Lines to skip: empty, headers, HTML comments, empty checkboxes "tasks": {
skip_patterns = {"- [ ]", "* [ ]", "- [x]", "* [x]"} "type": "string",
"description": "Natural-language summary of active tasks (required for run)",
for line in content.split("\n"): },
line = line.strip() },
if not line or line.startswith("#") or line.startswith("<!--") or line in skip_patterns: "required": ["action"],
continue },
return False # Found actionable content },
}
return True ]
class HeartbeatService: class HeartbeatService:
""" """
Periodic heartbeat service that wakes the agent to check for tasks. Periodic heartbeat service that wakes the agent to check for tasks.
The agent reads HEARTBEAT.md from the workspace and executes any Phase 1 (decision): reads HEARTBEAT.md and asks the LLM — via a virtual
tasks listed there. If nothing needs attention, it replies HEARTBEAT_OK. tool call — whether there are active tasks. This avoids free-text parsing
and the unreliable HEARTBEAT_OK token.
Phase 2 (execution): only triggered when Phase 1 returns ``run``. The
``on_execute`` callback runs the task through the full agent loop and
returns the result to deliver.
""" """
def __init__( def __init__(
self, self,
workspace: Path, workspace: Path,
on_heartbeat: Callable[[str], Coroutine[Any, Any, str]] | None = None, provider: LLMProvider,
interval_s: int = DEFAULT_HEARTBEAT_INTERVAL_S, model: str,
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
interval_s: int = 30 * 60,
enabled: bool = True, enabled: bool = True,
): ):
self.workspace = workspace self.workspace = workspace
self.on_heartbeat = on_heartbeat self.provider = provider
self.model = model
self.on_execute = on_execute
self.on_notify = on_notify
self.interval_s = interval_s self.interval_s = interval_s
self.enabled = enabled self.enabled = enabled
self._running = False self._running = False
@@ -62,23 +75,48 @@ class HeartbeatService:
return self.workspace / "HEARTBEAT.md" return self.workspace / "HEARTBEAT.md"
def _read_heartbeat_file(self) -> str | None: def _read_heartbeat_file(self) -> str | None:
"""Read HEARTBEAT.md content."""
if self.heartbeat_file.exists(): if self.heartbeat_file.exists():
try: try:
return self.heartbeat_file.read_text() return self.heartbeat_file.read_text(encoding="utf-8")
except Exception: except Exception:
return None return None
return None return None
async def _decide(self, content: str) -> tuple[str, str]:
"""Phase 1: ask LLM to decide skip/run via virtual tool call.
Returns (action, tasks) where action is 'skip' or 'run'.
"""
response = await self.provider.chat(
messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": (
"Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n"
f"{content}"
)},
],
tools=_HEARTBEAT_TOOL,
model=self.model,
)
if not response.has_tool_calls:
return "skip", ""
args = response.tool_calls[0].arguments
return args.get("action", "skip"), args.get("tasks", "")
async def start(self) -> None: async def start(self) -> None:
"""Start the heartbeat service.""" """Start the heartbeat service."""
if not self.enabled: if not self.enabled:
logger.info("Heartbeat disabled") logger.info("Heartbeat disabled")
return return
if self._running:
logger.warning("Heartbeat already running")
return
self._running = True self._running = True
self._task = asyncio.create_task(self._run_loop()) self._task = asyncio.create_task(self._run_loop())
logger.info(f"Heartbeat started (every {self.interval_s}s)") logger.info("Heartbeat started (every {}s)", self.interval_s)
def stop(self) -> None: def stop(self) -> None:
"""Stop the heartbeat service.""" """Stop the heartbeat service."""
@@ -97,34 +135,39 @@ class HeartbeatService:
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"Heartbeat error: {e}") logger.error("Heartbeat error: {}", e)
async def _tick(self) -> None: async def _tick(self) -> None:
"""Execute a single heartbeat tick.""" """Execute a single heartbeat tick."""
content = self._read_heartbeat_file() content = self._read_heartbeat_file()
if not content:
# Skip if HEARTBEAT.md is empty or doesn't exist logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
if _is_heartbeat_empty(content):
logger.debug("Heartbeat: no tasks (HEARTBEAT.md empty)")
return return
logger.info("Heartbeat: checking for tasks...") logger.info("Heartbeat: checking for tasks...")
if self.on_heartbeat:
try: try:
response = await self.on_heartbeat(HEARTBEAT_PROMPT) action, tasks = await self._decide(content)
# Check if agent said "nothing to do" if action != "run":
if HEARTBEAT_OK_TOKEN.replace("_", "") in response.upper().replace("_", ""): logger.info("Heartbeat: OK (nothing to report)")
logger.info("Heartbeat: OK (no action needed)") return
else:
logger.info(f"Heartbeat: completed task")
except Exception as e: logger.info("Heartbeat: tasks found, executing...")
logger.error(f"Heartbeat execution failed: {e}") if self.on_execute:
response = await self.on_execute(tasks)
if response and self.on_notify:
logger.info("Heartbeat: completed, delivering response")
await self.on_notify(response)
except Exception:
logger.exception("Heartbeat execution failed")
async def trigger_now(self) -> str | None: async def trigger_now(self) -> str | None:
"""Manually trigger a heartbeat.""" """Manually trigger a heartbeat."""
if self.on_heartbeat: content = self._read_heartbeat_file()
return await self.on_heartbeat(HEARTBEAT_PROMPT) if not content:
return None return None
action, tasks = await self._decide(content)
if action != "run" or not self.on_execute:
return None
return await self.on_execute(tasks)

View File

@@ -2,5 +2,7 @@
from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"] __all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]

View File

@@ -0,0 +1,210 @@
"""Azure OpenAI provider implementation with API version 2024-10-21."""
from __future__ import annotations
import uuid
from typing import Any
from urllib.parse import urljoin
import httpx
import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider):
"""
Azure OpenAI provider with API version 2024-10-21 compliance.
Features:
- Hardcoded API version 2024-10-21
- Uses model field as Azure deployment name in URL path
- Uses api-key header instead of Authorization Bearer
- Uses max_completion_tokens instead of max_tokens
- Direct HTTP calls, bypasses LiteLLM
"""
def __init__(
self,
api_key: str = "",
api_base: str = "",
default_model: str = "gpt-5.2-chat",
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.api_version = "2024-10-21"
# Validate required parameters
if not api_key:
raise ValueError("Azure OpenAI api_key is required")
if not api_base:
raise ValueError("Azure OpenAI api_base is required")
# Ensure api_base ends with /
if not api_base.endswith('/'):
api_base += '/'
self.api_base = api_base
def _build_chat_url(self, deployment_name: str) -> str:
"""Build the Azure OpenAI chat completions URL."""
# Azure OpenAI URL format:
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
base_url = self.api_base
if not base_url.endswith('/'):
base_url += '/'
url = urljoin(
base_url,
f"openai/deployments/{deployment_name}/chat/completions"
)
return f"{url}?api-version={self.api_version}"
def _build_headers(self) -> dict[str, str]:
"""Build headers for Azure OpenAI API with api-key header."""
return {
"Content-Type": "application/json",
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
return payload
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (used as deployment name).
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
temperature: Sampling temperature.
reasoning_effort: Optional reasoning effort parameter.
Returns:
LLMResponse with content and/or tool calls.
"""
deployment_name = model or self.default_model
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
)
try:
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code != 200:
return LLMResponse(
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
finish_reason="error",
)
response_data = response.json()
return self._parse_response(response_data)
except Exception as e:
return LLMResponse(
content=f"Error calling Azure OpenAI: {repr(e)}",
finish_reason="error",
)
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
"""Parse Azure OpenAI response into our standard format."""
try:
choice = response["choices"][0]
message = choice["message"]
tool_calls = []
if message.get("tool_calls"):
for tc in message["tool_calls"]:
# Parse arguments from JSON string if needed
args = tc["function"]["arguments"]
if isinstance(args, str):
args = json_repair.loads(args)
tool_calls.append(
ToolCallRequest(
id=tc["id"],
name=tc["function"]["name"],
arguments=args,
)
)
usage = {}
if response.get("usage"):
usage_data = response["usage"]
usage = {
"prompt_tokens": usage_data.get("prompt_tokens", 0),
"completion_tokens": usage_data.get("completion_tokens", 0),
"total_tokens": usage_data.get("total_tokens", 0),
}
reasoning_content = message.get("reasoning_content") or None
return LLMResponse(
content=message.get("content"),
tool_calls=tool_calls,
finish_reason=choice.get("finish_reason", "stop"),
usage=usage,
reasoning_content=reasoning_content,
)
except (KeyError, IndexError) as e:
return LLMResponse(
content=f"Error parsing Azure OpenAI response: {str(e)}",
finish_reason="error",
)
def get_default_model(self) -> str:
"""Get the default model (also used as default deployment name)."""
return self.default_model

View File

@@ -21,6 +21,7 @@ class LLMResponse:
finish_reason: str = "stop" finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict) usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property @property
def has_tool_calls(self) -> bool: def has_tool_calls(self) -> bool:
@@ -40,6 +41,66 @@ class LLMProvider(ABC):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
@staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Replace empty text content that causes provider 400 errors.
Empty content can appear when MCP tools return nothing. Most providers
reject empty-string content or empty text blocks in list content.
"""
result: list[dict[str, Any]] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str) and not content:
clean = dict(msg)
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
result.append(clean)
continue
if isinstance(content, list):
filtered = [
item for item in content
if not (
isinstance(item, dict)
and item.get("type") in ("text", "input_text", "output_text")
and not item.get("text")
)
]
if len(filtered) != len(content):
clean = dict(msg)
if filtered:
clean["content"] = filtered
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
clean["content"] = None
else:
clean["content"] = "(empty)"
result.append(clean)
continue
if isinstance(content, dict):
clean = dict(msg)
clean["content"] = [content]
result.append(clean)
continue
result.append(msg)
return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
@@ -48,6 +109,7 @@ class LLMProvider(ABC):
model: str | None = None, model: str | None = None,
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request. Send a chat completion request.

View File

@@ -0,0 +1,61 @@
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
from __future__ import annotations
import uuid
from typing import Any
import json_repair
from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class CustomProvider(LLMProvider):
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
super().__init__(api_key, api_base)
self.default_model = default_model
# Keep affinity stable for this provider instance to improve backend cache locality.
self._client = AsyncOpenAI(
api_key=api_key,
base_url=api_base,
default_headers={"x-session-affinity": uuid.uuid4().hex},
)
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
reasoning_effort: str | None = None) -> LLMResponse:
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"messages": self._sanitize_empty_content(messages),
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
if tools:
kwargs.update(tools=tools, tool_choice="auto")
try:
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
return LLMResponse(content=f"Error: {e}", finish_reason="error")
def _parse(self, response: Any) -> LLMResponse:
choice = response.choices[0]
msg = choice.message
tool_calls = [
ToolCallRequest(id=tc.id, name=tc.function.name,
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments)
for tc in (msg.tool_calls or [])
]
u = response.usage
return LLMResponse(
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
reasoning_content=getattr(msg, "reasoning_content", None) or None,
)
def get_default_model(self) -> str:
return self.default_model

View File

@@ -1,16 +1,28 @@
"""LiteLLM provider implementation for multi-provider support.""" """LiteLLM provider implementation for multi-provider support."""
import json import hashlib
import json_repair
import os import os
import secrets
import string
from typing import Any from typing import Any
import json_repair
import litellm import litellm
from litellm import acompletion from litellm import acompletion
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway from nanobot.providers.registry import find_by_model, find_gateway
# Standard chat-completion message keys.
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
_ALNUM = string.ascii_letters + string.digits
def _short_tool_id() -> str:
"""Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
return "".join(secrets.choice(_ALNUM) for _ in range(9))
class LiteLLMProvider(LLMProvider): class LiteLLMProvider(LLMProvider):
""" """
@@ -55,6 +67,9 @@ class LiteLLMProvider(LLMProvider):
spec = self._gateway or find_by_model(model) spec = self._gateway or find_by_model(model)
if not spec: if not spec:
return return
if not spec.env_key:
# OAuth/provider-only specs (for example: openai_codex)
return
# Gateway/local overrides existing env; standard provider doesn't # Gateway/local overrides existing env; standard provider doesn't
if self._gateway: if self._gateway:
@@ -85,11 +100,55 @@ class LiteLLMProvider(LLMProvider):
# Standard mode: auto-prefix for known providers # Standard mode: auto-prefix for known providers
spec = find_by_model(model) spec = find_by_model(model)
if spec and spec.litellm_prefix: if spec and spec.litellm_prefix:
model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix)
if not any(model.startswith(s) for s in spec.skip_prefixes): if not any(model.startswith(s) for s in spec.skip_prefixes):
model = f"{spec.litellm_prefix}/{model}" model = f"{spec.litellm_prefix}/{model}"
return model return model
@staticmethod
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
"""Normalize explicit provider prefixes like `github-copilot/...`."""
if "/" not in model:
return model
prefix, remainder = model.split("/", 1)
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"
def _supports_cache_control(self, model: str) -> bool:
"""Return True when the provider supports cache_control on content blocks."""
if self._gateway is not None:
return self._gateway.supports_prompt_caching
spec = find_by_model(model)
return spec is not None and spec.supports_prompt_caching
def _apply_cache_control(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
"""Return copies of messages and tools with cache_control injected."""
new_messages = []
for msg in messages:
if msg.get("role") == "system":
content = msg["content"]
if isinstance(content, str):
new_content = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
else:
new_content = list(content)
new_content[-1] = {**new_content[-1], "cache_control": {"type": "ephemeral"}}
new_messages.append({**msg, "content": new_content})
else:
new_messages.append(msg)
new_tools = tools
if tools:
new_tools = list(tools)
new_tools[-1] = {**new_tools[-1], "cache_control": {"type": "ephemeral"}}
return new_messages, new_tools
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None: def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
"""Apply model-specific parameter overrides from the registry.""" """Apply model-specific parameter overrides from the registry."""
model_lower = model.lower() model_lower = model.lower()
@@ -100,6 +159,53 @@ class LiteLLMProvider(LLMProvider):
kwargs.update(overrides) kwargs.update(overrides)
return return
@staticmethod
def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
"""Return provider-specific extra keys to preserve in request messages."""
spec = find_by_model(original_model) or find_by_model(resolved_model)
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
return _ANTHROPIC_EXTRA_KEYS
return frozenset()
@staticmethod
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
if not isinstance(tool_call_id, str):
return tool_call_id
if len(tool_call_id) == 9 and tool_call_id.isalnum():
return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
id_map: dict[str, str] = {}
def map_id(value: Any) -> Any:
if not isinstance(value, str):
return value
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for clean in sanitized:
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list):
normalized_tool_calls = []
for tc in clean["tool_calls"]:
if not isinstance(tc, dict):
normalized_tool_calls.append(tc)
continue
tc_clean = dict(tc)
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized_tool_calls.append(tc_clean)
clean["tool_calls"] = normalized_tool_calls
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
async def chat( async def chat(
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
@@ -107,6 +213,7 @@ class LiteLLMProvider(LLMProvider):
model: str | None = None, model: str | None = None,
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request via LiteLLM. Send a chat completion request via LiteLLM.
@@ -121,7 +228,12 @@ class LiteLLMProvider(LLMProvider):
Returns: Returns:
LLMResponse with content and/or tool calls. LLMResponse with content and/or tool calls.
""" """
model = self._resolve_model(model or self.default_model) original_model = model or self.default_model
model = self._resolve_model(original_model)
extra_msg_keys = self._extra_msg_keys(original_model, model)
if self._supports_cache_control(original_model):
messages, tools = self._apply_cache_control(messages, tools)
# Clamp max_tokens to at least 1 — negative or zero values cause # Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1". # LiteLLM to reject the request with "max_tokens must be at least 1".
@@ -129,7 +241,7 @@ class LiteLLMProvider(LLMProvider):
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model, "model": model,
"messages": messages, "messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
"max_tokens": max_tokens, "max_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
} }
@@ -149,6 +261,10 @@ class LiteLLMProvider(LLMProvider):
if self.extra_headers: if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers kwargs["extra_headers"] = self.extra_headers
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
kwargs["drop_params"] = True
if tools: if tools:
kwargs["tools"] = tools kwargs["tools"] = tools
kwargs["tool_choice"] = "auto" kwargs["tool_choice"] = "auto"
@@ -167,17 +283,34 @@ class LiteLLMProvider(LLMProvider):
"""Parse LiteLLM response into our standard format.""" """Parse LiteLLM response into our standard format."""
choice = response.choices[0] choice = response.choices[0]
message = choice.message message = choice.message
content = message.content
finish_reason = choice.finish_reason
# Some providers (e.g. GitHub Copilot) split content and tool_calls
# across multiple choices. Merge them so tool_calls are not lost.
raw_tool_calls = []
for ch in response.choices:
msg = ch.message
if hasattr(msg, "tool_calls") and msg.tool_calls:
raw_tool_calls.extend(msg.tool_calls)
if ch.finish_reason in ("tool_calls", "stop"):
finish_reason = ch.finish_reason
if not content and msg.content:
content = msg.content
if len(response.choices) > 1:
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
len(response.choices), len(raw_tool_calls))
tool_calls = [] tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls: for tc in raw_tool_calls:
for tc in message.tool_calls:
# Parse arguments from JSON string if needed # Parse arguments from JSON string if needed
args = tc.function.arguments args = tc.function.arguments
if isinstance(args, str): if isinstance(args, str):
args = json_repair.loads(args) args = json_repair.loads(args)
tool_calls.append(ToolCallRequest( tool_calls.append(ToolCallRequest(
id=tc.id, id=_short_tool_id(),
name=tc.function.name, name=tc.function.name,
arguments=args, arguments=args,
)) ))
@@ -190,14 +323,16 @@ class LiteLLMProvider(LLMProvider):
"total_tokens": response.usage.total_tokens, "total_tokens": response.usage.total_tokens,
} }
reasoning_content = getattr(message, "reasoning_content", None) reasoning_content = getattr(message, "reasoning_content", None) or None
thinking_blocks = getattr(message, "thinking_blocks", None) or None
return LLMResponse( return LLMResponse(
content=message.content, content=content,
tool_calls=tool_calls, tool_calls=tool_calls,
finish_reason=choice.finish_reason or "stop", finish_reason=finish_reason or "stop",
usage=usage, usage=usage,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
) )
def get_default_model(self) -> str: def get_default_model(self) -> str:

View File

@@ -0,0 +1,316 @@
"""OpenAI Codex Responses Provider."""
from __future__ import annotations
import asyncio
import hashlib
import json
from typing import Any, AsyncGenerator
import httpx
from loguru import logger
from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
DEFAULT_ORIGINATOR = "nanobot"
class OpenAICodexProvider(LLMProvider):
"""Use Codex OAuth to call the Responses API."""
def __init__(self, default_model: str = "openai-codex/gpt-5.1-codex"):
super().__init__(api_key=None, api_base=None)
self.default_model = default_model
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
token = await asyncio.to_thread(get_codex_token)
headers = _build_headers(token.account_id, token.access)
body: dict[str, Any] = {
"model": _strip_model_prefix(model),
"store": False,
"stream": True,
"instructions": system_prompt,
"input": input_items,
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"prompt_cache_key": _prompt_cache_key(messages),
"tool_choice": "auto",
"parallel_tool_calls": True,
}
if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort}
if tools:
body["tools"] = _convert_tools(tools)
url = DEFAULT_CODEX_URL
try:
try:
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True)
except Exception as e:
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
raise
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False")
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False)
return LLMResponse(
content=content,
tool_calls=tool_calls,
finish_reason=finish_reason,
)
except Exception as e:
return LLMResponse(
content=f"Error calling Codex: {str(e)}",
finish_reason="error",
)
def get_default_model(self) -> str:
return self.default_model
def _strip_model_prefix(model: str) -> str:
if model.startswith("openai-codex/") or model.startswith("openai_codex/"):
return model.split("/", 1)[1]
return model
def _build_headers(account_id: str, token: str) -> dict[str, str]:
return {
"Authorization": f"Bearer {token}",
"chatgpt-account-id": account_id,
"OpenAI-Beta": "responses=experimental",
"originator": DEFAULT_ORIGINATOR,
"User-Agent": "nanobot (python)",
"accept": "text/event-stream",
"content-type": "application/json",
}
async def _request_codex(
url: str,
headers: dict[str, str],
body: dict[str, Any],
verify: bool,
) -> tuple[str, list[ToolCallRequest], str]:
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
async with client.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200:
text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
return await _consume_sse(response)
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert OpenAI function-calling schema to Codex flat format."""
converted: list[dict[str, Any]] = []
for tool in tools:
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
name = fn.get("name")
if not name:
continue
params = fn.get("parameters") or {}
converted.append({
"type": "function",
"name": name,
"description": fn.get("description") or "",
"parameters": params if isinstance(params, dict) else {},
})
return converted
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
system_prompt = ""
input_items: list[dict[str, Any]] = []
for idx, msg in enumerate(messages):
role = msg.get("role")
content = msg.get("content")
if role == "system":
system_prompt = content if isinstance(content, str) else ""
continue
if role == "user":
input_items.append(_convert_user_message(content))
continue
if role == "assistant":
# Handle text first.
if isinstance(content, str) and content:
input_items.append(
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": content}],
"status": "completed",
"id": f"msg_{idx}",
}
)
# Then handle tool calls.
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
call_id = call_id or f"call_{idx}"
item_id = item_id or f"fc_{idx}"
input_items.append(
{
"type": "function_call",
"id": item_id,
"call_id": call_id,
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
}
)
continue
if role == "tool":
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
input_items.append(
{
"type": "function_call_output",
"call_id": call_id,
"output": output_text,
}
)
continue
return system_prompt, input_items
def _convert_user_message(content: Any) -> dict[str, Any]:
if isinstance(content, str):
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
if isinstance(content, list):
converted: list[dict[str, Any]] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
converted.append({"type": "input_text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url")
if url:
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
if converted:
return {"role": "user", "content": converted}
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
if isinstance(tool_call_id, str) and tool_call_id:
if "|" in tool_call_id:
call_id, item_id = tool_call_id.split("|", 1)
return call_id, item_id or None
return tool_call_id, None
return "call_0", None
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
buffer: list[str] = []
async for line in response.aiter_lines():
if line == "":
if buffer:
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
buffer = []
if not data_lines:
continue
data = "\n".join(data_lines).strip()
if not data or data == "[DONE]":
continue
try:
yield json.loads(data)
except Exception:
continue
continue
buffer.append(line)
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
content = ""
tool_calls: list[ToolCallRequest] = []
tool_call_buffers: dict[str, dict[str, Any]] = {}
finish_reason = "stop"
async for event in _iter_sse(response):
event_type = event.get("type")
if event_type == "response.output_item.added":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
tool_call_buffers[call_id] = {
"id": item.get("id") or "fc_0",
"name": item.get("name"),
"arguments": item.get("arguments") or "",
}
elif event_type == "response.output_text.delta":
content += event.get("delta") or ""
elif event_type == "response.function_call_arguments.delta":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
elif event_type == "response.function_call_arguments.done":
call_id = event.get("call_id")
if call_id and call_id in tool_call_buffers:
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
elif event_type == "response.output_item.done":
item = event.get("item") or {}
if item.get("type") == "function_call":
call_id = item.get("call_id")
if not call_id:
continue
buf = tool_call_buffers.get(call_id) or {}
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
try:
args = json.loads(args_raw)
except Exception:
args = {"raw": args_raw}
tool_calls.append(
ToolCallRequest(
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
name=buf.get("name") or item.get("name"),
arguments=args,
)
)
elif event_type == "response.completed":
status = (event.get("response") or {}).get("status")
finish_reason = _map_finish_reason(status)
elif event_type in {"error", "response.failed"}:
raise RuntimeError("Codex response failed")
return content, tool_calls, finish_reason
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
def _map_finish_reason(status: str | None) -> str:
return _FINISH_REASON_MAP.get(status or "completed", "stop")
def _friendly_error(status_code: int, raw: str) -> str:
if status_code == 429:
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
return f"HTTP {status_code}: {raw}"

View File

@@ -51,6 +51,15 @@ class ProviderSpec:
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),) # per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = () model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
is_oauth: bool = False # if True, uses OAuth flow instead of API key
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
is_direct: bool = False
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
supports_prompt_caching: bool = False
@property @property
def label(self) -> str: def label(self) -> str:
return self.display_name or self.name.title() return self.display_name or self.name.title()
@@ -61,24 +70,27 @@ class ProviderSpec:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
PROVIDERS: tuple[ProviderSpec, ...] = ( PROVIDERS: tuple[ProviderSpec, ...] = (
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
# === Custom (user-provided OpenAI-compatible endpoint) =================
# No auto-detection — only activates when user explicitly configures "custom".
ProviderSpec( ProviderSpec(
name="custom", name="custom",
keywords=(), keywords=(),
env_key="OPENAI_API_KEY", env_key="",
display_name="Custom", display_name="Custom",
litellm_prefix="openai", litellm_prefix="",
skip_prefixes=("openai/",), is_direct=True,
is_gateway=True,
strip_model_prefix=True,
), ),
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
ProviderSpec(
name="azure_openai",
keywords=("azure", "azure-openai"),
env_key="",
display_name="Azure OpenAI",
litellm_prefix="",
is_direct=True,
),
# === Gateways (detected by api_key / api_base, not model name) ========= # === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback. # Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-" # OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec( ProviderSpec(
name="openrouter", name="openrouter",
@@ -95,8 +107,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://openrouter.ai/api/v1", default_api_base="https://openrouter.ai/api/v1",
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
supports_prompt_caching=True,
), ),
# AiHubMix: global gateway, OpenAI-compatible interface. # AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3", # strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3". # so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
@@ -116,9 +128,41 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3 strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
model_overrides=(), model_overrides=(),
), ),
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
ProviderSpec(
name="siliconflow",
keywords=("siliconflow",),
env_key="OPENAI_API_KEY",
display_name="SiliconFlow",
litellm_prefix="openai",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="siliconflow",
default_api_base="https://api.siliconflow.cn/v1",
strip_model_prefix=False,
model_overrides=(),
),
# VolcEngine (火山引擎): OpenAI-compatible gateway
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
env_key="OPENAI_API_KEY",
display_name="VolcEngine",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="volces",
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
strip_model_prefix=False,
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) =============== # === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(
name="anthropic", name="anthropic",
@@ -135,8 +179,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="", default_api_base="",
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
supports_prompt_caching=True,
), ),
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed. # OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(
name="openai", name="openai",
@@ -154,7 +198,42 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# OpenAI Codex: uses OAuth, not API key.
ProviderSpec(
name="openai_codex",
keywords=("openai-codex",),
env_key="", # OAuth-based, no API key
display_name="OpenAI Codex",
litellm_prefix="", # Not routed through LiteLLM
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="codex",
default_api_base="https://chatgpt.com/backend-api",
strip_model_prefix=False,
model_overrides=(),
is_oauth=True, # OAuth-based authentication
),
# Github Copilot: uses OAuth, not API key.
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
env_key="", # OAuth-based, no API key
display_name="Github Copilot",
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
skip_prefixes=("github_copilot/",),
env_extras=(),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
is_oauth=True, # OAuth-based authentication
),
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing. # DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
ProviderSpec( ProviderSpec(
name="deepseek", name="deepseek",
@@ -172,7 +251,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Gemini: needs "gemini/" prefix for LiteLLM. # Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec( ProviderSpec(
name="gemini", name="gemini",
@@ -190,7 +268,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Zhipu: LiteLLM uses "zai/" prefix. # Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that). # Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway. # skip_prefixes: don't add "zai/" when already routed via gateway.
@@ -201,9 +278,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
display_name="Zhipu AI", display_name="Zhipu AI",
litellm_prefix="zai", # glm-4 → zai/glm-4 litellm_prefix="zai", # glm-4 → zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"), skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=( env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
("ZHIPUAI_API_KEY", "{api_key}"),
),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
detect_by_key_prefix="", detect_by_key_prefix="",
@@ -212,7 +287,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# DashScope: Qwen models, needs "dashscope/" prefix. # DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec( ProviderSpec(
name="dashscope", name="dashscope",
@@ -230,7 +304,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# Moonshot: Kimi models, needs "moonshot/" prefix. # Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint. # LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0. # Kimi K2.5 API enforces temperature >= 1.0.
@@ -241,20 +314,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
display_name="Moonshot", display_name="Moonshot",
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5 litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"), skip_prefixes=("moonshot/", "openrouter/"),
env_extras=( env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
("MOONSHOT_API_BASE", "{api_base}"),
),
is_gateway=False, is_gateway=False,
is_local=False, is_local=False,
detect_by_key_prefix="", detect_by_key_prefix="",
detect_by_base_keyword="", detect_by_base_keyword="",
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=( model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
("kimi-k2.5", {"temperature": 1.0}),
), ),
),
# MiniMax: needs "minimax/" prefix for LiteLLM routing. # MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1. # Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec( ProviderSpec(
@@ -273,9 +341,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Local deployment (matched by config key, NOT by api_base) ========= # === Local deployment (matched by config key, NOT by api_base) =========
# vLLM / any OpenAI-compatible local server. # vLLM / any OpenAI-compatible local server.
# Detected when config key is "vllm" (provider_name="vllm"). # Detected when config key is "vllm" (provider_name="vllm").
ProviderSpec( ProviderSpec(
@@ -294,9 +360,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# === Auxiliary (not a primary LLM provider) ============================ # === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM. # Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback. # Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
ProviderSpec( ProviderSpec(
@@ -322,14 +386,25 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# Lookup helpers # Lookup helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def find_by_model(model: str) -> ProviderSpec | None: def find_by_model(model: str) -> ProviderSpec | None:
"""Match a standard provider by model-name keyword (case-insensitive). """Match a standard provider by model-name keyword (case-insensitive).
Skips gateways/local — those are matched by api_key/api_base instead.""" Skips gateways/local — those are matched by api_key/api_base instead."""
model_lower = model.lower() model_lower = model.lower()
for spec in PROVIDERS: model_normalized = model_lower.replace("-", "_")
if spec.is_gateway or spec.is_local: model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
continue normalized_prefix = model_prefix.replace("-", "_")
if any(kw in model_lower for kw in spec.keywords): std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local]
# Prefer explicit provider prefix — prevents `github-copilot/...codex` matching openai_codex.
for spec in std_specs:
if model_prefix and normalized_prefix == spec.name:
return spec
for spec in std_specs:
if any(
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
):
return spec return spec
return None return None

View File

@@ -2,7 +2,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Any
import httpx import httpx
from loguru import logger from loguru import logger
@@ -35,7 +34,7 @@ class GroqTranscriptionProvider:
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
logger.error(f"Audio file not found: {file_path}") logger.error("Audio file not found: {}", file_path)
return "" return ""
try: try:
@@ -61,5 +60,5 @@ class GroqTranscriptionProvider:
return data.get("text", "") return data.get("text", "")
except Exception as e: except Exception as e:
logger.error(f"Groq transcription error: {e}") logger.error("Groq transcription error: {}", e)
return "" return ""

View File

@@ -1,5 +1,5 @@
"""Session management module.""" """Session management module."""
from nanobot.session.manager import SessionManager, Session from nanobot.session.manager import Session, SessionManager
__all__ = ["SessionManager", "Session"] __all__ = ["SessionManager", "Session"]

View File

@@ -1,9 +1,10 @@
"""Session management for conversation history.""" """Session management for conversation history."""
import json import json
from pathlib import Path import shutil
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@@ -42,8 +43,24 @@ class Session:
self.updated_at = datetime.now() self.updated_at = datetime.now()
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Get recent messages in LLM format (role + content only).""" """Return unconsolidated messages for LLM input, aligned to a user turn."""
return [{"role": m["role"], "content": m["content"]} for m in self.messages[-max_messages:]] unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
# Drop leading non-user messages to avoid orphaned tool_result blocks
for i, m in enumerate(sliced):
if m.get("role") == "user":
sliced = sliced[i:]
break
out: list[dict[str, Any]] = []
for m in sliced:
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
for k in ("tool_calls", "tool_call_id", "name"):
if k in m:
entry[k] = m[k]
out.append(entry)
return out
def clear(self) -> None: def clear(self) -> None:
"""Clear all messages and reset session to initial state.""" """Clear all messages and reset session to initial state."""
@@ -61,7 +78,8 @@ class SessionManager:
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.workspace = workspace self.workspace = workspace
self.sessions_dir = ensure_dir(Path.home() / ".nanobot" / "sessions") self.sessions_dir = ensure_dir(self.workspace / "sessions")
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
self._cache: dict[str, Session] = {} self._cache: dict[str, Session] = {}
def _get_session_path(self, key: str) -> Path: def _get_session_path(self, key: str) -> Path:
@@ -69,6 +87,11 @@ class SessionManager:
safe_key = safe_filename(key.replace(":", "_")) safe_key = safe_filename(key.replace(":", "_"))
return self.sessions_dir / f"{safe_key}.jsonl" return self.sessions_dir / f"{safe_key}.jsonl"
def _get_legacy_session_path(self, key: str) -> Path:
"""Legacy global session path (~/.nanobot/sessions/)."""
safe_key = safe_filename(key.replace(":", "_"))
return self.legacy_sessions_dir / f"{safe_key}.jsonl"
def get_or_create(self, key: str) -> Session: def get_or_create(self, key: str) -> Session:
""" """
Get an existing session or create a new one. Get an existing session or create a new one.
@@ -92,6 +115,14 @@ class SessionManager:
def _load(self, key: str) -> Session | None: def _load(self, key: str) -> Session | None:
"""Load a session from disk.""" """Load a session from disk."""
path = self._get_session_path(key) path = self._get_session_path(key)
if not path.exists():
legacy_path = self._get_legacy_session_path(key)
if legacy_path.exists():
try:
shutil.move(str(legacy_path), str(path))
logger.info("Migrated session {} from legacy path", key)
except Exception:
logger.exception("Failed to migrate session {}", key)
if not path.exists(): if not path.exists():
return None return None
@@ -102,7 +133,7 @@ class SessionManager:
created_at = None created_at = None
last_consolidated = 0 last_consolidated = 0
with open(path) as f: with open(path, encoding="utf-8") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: if not line:
@@ -125,24 +156,25 @@ class SessionManager:
last_consolidated=last_consolidated last_consolidated=last_consolidated
) )
except Exception as e: except Exception as e:
logger.warning(f"Failed to load session {key}: {e}") logger.warning("Failed to load session {}: {}", key, e)
return None return None
def save(self, session: Session) -> None: def save(self, session: Session) -> None:
"""Save a session to disk.""" """Save a session to disk."""
path = self._get_session_path(session.key) path = self._get_session_path(session.key)
with open(path, "w") as f: with open(path, "w", encoding="utf-8") as f:
metadata_line = { metadata_line = {
"_type": "metadata", "_type": "metadata",
"key": session.key,
"created_at": session.created_at.isoformat(), "created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(), "updated_at": session.updated_at.isoformat(),
"metadata": session.metadata, "metadata": session.metadata,
"last_consolidated": session.last_consolidated "last_consolidated": session.last_consolidated
} }
f.write(json.dumps(metadata_line) + "\n") f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
for msg in session.messages: for msg in session.messages:
f.write(json.dumps(msg) + "\n") f.write(json.dumps(msg, ensure_ascii=False) + "\n")
self._cache[session.key] = session self._cache[session.key] = session
@@ -162,13 +194,14 @@ class SessionManager:
for path in self.sessions_dir.glob("*.jsonl"): for path in self.sessions_dir.glob("*.jsonl"):
try: try:
# Read just the metadata line # Read just the metadata line
with open(path) as f: with open(path, encoding="utf-8") as f:
first_line = f.readline().strip() first_line = f.readline().strip()
if first_line: if first_line:
data = json.loads(first_line) data = json.loads(first_line)
if data.get("_type") == "metadata": if data.get("_type") == "metadata":
key = data.get("key") or path.stem.replace("_", ":", 1)
sessions.append({ sessions.append({
"key": path.stem.replace("_", ":"), "key": key,
"created_at": data.get("created_at"), "created_at": data.get("created_at"),
"updated_at": data.get("updated_at"), "updated_at": data.get("updated_at"),
"path": str(path) "path": str(path)

View File

@@ -21,4 +21,5 @@ The skill format and metadata structure follow OpenClaw's conventions to maintai
| `weather` | Get weather info using wttr.in and Open-Meteo | | `weather` | Get weather info using wttr.in and Open-Meteo |
| `summarize` | Summarize URLs, files, and YouTube videos | | `summarize` | Summarize URLs, files, and YouTube videos |
| `tmux` | Remote-control tmux sessions | | `tmux` | Remote-control tmux sessions |
| `clawhub` | Search and install skills from ClawHub registry |
| `skill-creator` | Create new skills | | `skill-creator` | Create new skills |

View File

@@ -0,0 +1,53 @@
---
name: clawhub
description: Search and install agent skills from ClawHub, the public skill registry.
homepage: https://clawhub.ai
metadata: {"nanobot":{"emoji":"🦞"}}
---
# ClawHub
Public skill registry for AI agents. Search by natural language (vector search).
## When to use
Use this skill when the user asks any of:
- "find a skill for …"
- "search for skills"
- "install a skill"
- "what skills are available?"
- "update my skills"
## Search
```bash
npx --yes clawhub@latest search "web scraping" --limit 5
```
## Install
```bash
npx --yes clawhub@latest install <slug> --workdir ~/.nanobot/workspace
```
Replace `<slug>` with the skill name from search results. This places the skill into `~/.nanobot/workspace/skills/`, where nanobot loads workspace skills from. Always include `--workdir`.
## Update
```bash
npx --yes clawhub@latest update --all --workdir ~/.nanobot/workspace
```
## List installed
```bash
npx --yes clawhub@latest list --workdir ~/.nanobot/workspace
```
## Notes
- Requires Node.js (`npx` comes with it).
- No API key needed for search and install.
- Login (`npx --yes clawhub@latest login`) is only required for publishing.
- `--workdir ~/.nanobot/workspace` is critical — without it, skills install to the current directory instead of the nanobot workspace.
- After install, remind the user to start a new session to load the skill.

View File

@@ -30,6 +30,11 @@ One-time scheduled task (compute ISO datetime from current time):
cron(action="add", message="Remind me about the meeting", at="<ISO datetime>") cron(action="add", message="Remind me about the meeting", at="<ISO datetime>")
``` ```
Timezone-aware cron:
```
cron(action="add", message="Morning standup", cron_expr="0 9 * * 1-5", tz="America/Vancouver")
```
List/remove: List/remove:
``` ```
cron(action="list") cron(action="list")
@@ -44,4 +49,9 @@ cron(action="remove", job_id="abc123")
| every hour | every_seconds: 3600 | | every hour | every_seconds: 3600 |
| every day at 8am | cron_expr: "0 8 * * *" | | every day at 8am | cron_expr: "0 8 * * *" |
| weekdays at 5pm | cron_expr: "0 17 * * 1-5" | | weekdays at 5pm | cron_expr: "0 17 * * 1-5" |
| 9am Vancouver time daily | cron_expr: "0 9 * * *", tz: "America/Vancouver" |
| at a specific time | at: ISO datetime string (compute from current time) | | at a specific time | at: ISO datetime string (compute from current time) |
## Timezone
Use `tz` with `cron_expr` to schedule in a specific IANA timezone. Without `tz`, the server's local timezone is used.

View File

@@ -9,7 +9,7 @@ always: true
## Structure ## Structure
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context. - `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. - `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM].
## Search Past Events ## Search Past Events

View File

@@ -0,0 +1,21 @@
# Agent Instructions
You are a helpful AI assistant. Be concise, accurate, and friendly.
## Scheduled Reminders
Before scheduling reminders, check available skills and follow skill guidance first.
Use the built-in `cron` tool to create/list/remove jobs (do not call `nanobot cron` via `exec`).
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
## Heartbeat Tasks
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
- **Add**: `edit_file` to append new tasks
- **Remove**: `edit_file` to delete completed tasks
- **Rewrite**: `write_file` to replace all tasks
When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder.

View File

@@ -0,0 +1,15 @@
# Tool Usage Notes
Tool signatures are provided automatically via function calling.
This file documents non-obvious constraints and usage patterns.
## exec — Safety Limits
- Commands have a configurable timeout (default 60s)
- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
- Output is truncated at 10,000 characters
- `restrictToWorkspace` config can limit file access to the workspace
## cron — Scheduled Reminders
- Please refer to cron skill for usage.

View File

View File

View File

@@ -1,5 +1,5 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
from nanobot.utils.helpers import ensure_dir, get_workspace_path, get_data_path from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"] __all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]

View File

@@ -1,80 +1,112 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
from pathlib import Path import re
from datetime import datetime from datetime import datetime
from pathlib import Path
def detect_image_mime(data: bytes) -> str | None:
"""Detect image MIME type from magic bytes, ignoring file extension."""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:3] == b"\xff\xd8\xff":
return "image/jpeg"
if data[:6] in (b"GIF87a", b"GIF89a"):
return "image/gif"
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
return None
def ensure_dir(path: Path) -> Path: def ensure_dir(path: Path) -> Path:
"""Ensure a directory exists, creating it if necessary.""" """Ensure directory exists, return it."""
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
return path return path
def get_data_path() -> Path: def get_data_path() -> Path:
"""Get the nanobot data directory (~/.nanobot).""" """~/.nanobot data directory."""
return ensure_dir(Path.home() / ".nanobot") return ensure_dir(Path.home() / ".nanobot")
def get_workspace_path(workspace: str | None = None) -> Path: def get_workspace_path(workspace: str | None = None) -> Path:
""" """Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace."""
Get the workspace path. path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
Args:
workspace: Optional workspace path. Defaults to ~/.nanobot/workspace.
Returns:
Expanded and ensured workspace path.
"""
if workspace:
path = Path(workspace).expanduser()
else:
path = Path.home() / ".nanobot" / "workspace"
return ensure_dir(path) return ensure_dir(path)
def get_sessions_path() -> Path:
"""Get the sessions storage directory."""
return ensure_dir(get_data_path() / "sessions")
def get_skills_path(workspace: Path | None = None) -> Path:
"""Get the skills directory within the workspace."""
ws = workspace or get_workspace_path()
return ensure_dir(ws / "skills")
def timestamp() -> str: def timestamp() -> str:
"""Get current timestamp in ISO format.""" """Current ISO timestamp."""
return datetime.now().isoformat() return datetime.now().isoformat()
def truncate_string(s: str, max_len: int = 100, suffix: str = "...") -> str: _UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
"""Truncate a string to max length, adding suffix if truncated."""
if len(s) <= max_len:
return s
return s[: max_len - len(suffix)] + suffix
def safe_filename(name: str) -> str: def safe_filename(name: str) -> str:
"""Convert a string to a safe filename.""" """Replace unsafe path characters with underscores."""
# Replace unsafe characters return _UNSAFE_CHARS.sub("_", name).strip()
unsafe = '<>:"/\\|?*'
for char in unsafe:
name = name.replace(char, "_")
return name.strip()
def parse_session_key(key: str) -> tuple[str, str]: def split_message(content: str, max_len: int = 2000) -> list[str]:
""" """
Parse a session key into channel and chat_id. Split content into chunks within max_len, preferring line breaks.
Args: Args:
key: Session key in format "channel:chat_id" content: The text content to split.
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
Returns: Returns:
Tuple of (channel, chat_id) List of message chunks, each within max_len.
""" """
parts = key.split(":", 1) if not content:
if len(parts) != 2: return []
raise ValueError(f"Invalid session key: {key}") if len(content) <= max_len:
return parts[0], parts[1] return [content]
chunks: list[str] = []
while content:
if len(content) <= max_len:
chunks.append(content)
break
cut = content[:max_len]
# Try to break at newline first, then space, then hard break
pos = cut.rfind('\n')
if pos <= 0:
pos = cut.rfind(' ')
if pos <= 0:
pos = max_len
chunks.append(content[:pos])
content = content[pos:].lstrip()
return chunks
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
"""Sync bundled templates to workspace. Only creates missing files."""
from importlib.resources import files as pkg_files
try:
tpl = pkg_files("nanobot") / "templates"
except Exception:
return []
if not tpl.is_dir():
return []
added: list[str] = []
def _write(src, dest: Path):
if dest.exists():
return
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
added.append(str(dest.relative_to(workspace)))
for item in tpl.iterdir():
if item.name.endswith(".md"):
_write(item, workspace / item.name)
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
_write(None, workspace / "memory" / "HISTORY.md")
(workspace / "skills").mkdir(exist_ok=True)
if added and not silent:
from rich.console import Console
for name in added:
Console().print(f" [dim]Created {name}[/dim]")
return added

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "nanobot-ai" name = "nanobot-ai"
version = "0.1.3.post7" version = "0.1.4.post3"
description = "A lightweight personal AI assistant framework" description = "A lightweight personal AI assistant framework"
requires-python = ">=3.11" requires-python = ">=3.11"
license = {text = "MIT"} license = {text = "MIT"}
@@ -17,36 +17,48 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"typer>=0.9.0", "typer>=0.20.0,<1.0.0",
"litellm>=1.0.0", "litellm>=1.81.5,<2.0.0",
"pydantic>=2.0.0", "pydantic>=2.12.0,<3.0.0",
"pydantic-settings>=2.0.0", "pydantic-settings>=2.12.0,<3.0.0",
"websockets>=12.0", "websockets>=16.0,<17.0",
"websocket-client>=1.6.0", "websocket-client>=1.9.0,<2.0.0",
"httpx[socks]>=0.25.0", "httpx>=0.28.0,<1.0.0",
"loguru>=0.7.0", "oauth-cli-kit>=0.1.3,<1.0.0",
"readability-lxml>=0.8.0", "loguru>=0.7.3,<1.0.0",
"rich>=13.0.0", "readability-lxml>=0.8.4,<1.0.0",
"croniter>=2.0.0", "rich>=14.0.0,<15.0.0",
"dingtalk-stream>=0.4.0", "croniter>=6.0.0,<7.0.0",
"python-telegram-bot[socks]>=21.0", "dingtalk-stream>=0.24.0,<1.0.0",
"lark-oapi>=1.0.0", "python-telegram-bot[socks]>=22.6,<23.0",
"socksio>=1.0.0", "lark-oapi>=1.5.0,<2.0.0",
"python-socketio>=5.11.0", "socksio>=1.0.0,<2.0.0",
"msgpack>=1.0.8", "python-socketio>=5.16.0,<6.0.0",
"slack-sdk>=3.26.0", "msgpack>=1.1.0,<2.0.0",
"qq-botpy>=1.0.0", "slack-sdk>=3.39.0,<4.0.0",
"python-socks[asyncio]>=2.4.0", "slackify-markdown>=0.2.0,<1.0.0",
"prompt-toolkit>=3.0.0", "qq-botpy>=1.2.0,<2.0.0",
"mcp>=1.0.0", "python-socks[asyncio]>=2.8.0,<3.0.0",
"json-repair>=0.30.0", "prompt-toolkit>=3.0.50,<4.0.0",
"mcp>=1.26.0,<2.0.0",
"json-repair>=0.57.0,<1.0.0",
"chardet>=3.0.2,<6.0.0",
"openai>=2.8.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
matrix = [
"matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
]
dev = [ dev = [
"pytest>=7.0.0", "pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=0.21.0", "pytest-asyncio>=1.3.0,<2.0.0",
"ruff>=0.1.0", "ruff>=0.1.0",
"matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
] ]
[project.scripts] [project.scripts]
@@ -62,10 +74,11 @@ packages = ["nanobot"]
[tool.hatch.build.targets.wheel.sources] [tool.hatch.build.targets.wheel.sources]
"nanobot" = "nanobot" "nanobot" = "nanobot"
# Include non-Python files in skills # Include non-Python files in skills and templates
[tool.hatch.build] [tool.hatch.build]
include = [ include = [
"nanobot/**/*.py", "nanobot/**/*.py",
"nanobot/templates/**/*.md",
"nanobot/skills/**/*.md", "nanobot/skills/**/*.md",
"nanobot/skills/**/*.sh", "nanobot/skills/**/*.sh",
] ]

View File

@@ -0,0 +1,399 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse
def test_azure_openai_provider_init():
"""Test AzureOpenAIProvider initialization without deployment_name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
assert provider.api_key == "test-key"
assert provider.api_base == "https://test-resource.openai.azure.com/"
assert provider.default_model == "gpt-4o-deployment"
assert provider.api_version == "2024-10-21"
def test_azure_openai_provider_init_validation():
"""Test AzureOpenAIProvider initialization validation."""
# Missing api_key
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
AzureOpenAIProvider(api_key="", api_base="https://test.com")
# Missing api_base
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
AzureOpenAIProvider(api_key="test", api_base="")
def test_build_chat_url():
"""Test Azure OpenAI URL building with different deployment names."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test various deployment names
test_cases = [
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
]
for deployment_name, expected_url in test_cases:
url = provider._build_chat_url(deployment_name)
assert url == expected_url
def test_build_chat_url_api_base_without_slash():
"""Test URL building when api_base doesn't end with slash."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com", # No trailing slash
default_model="gpt-4o",
)
url = provider._build_chat_url("test-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
def test_build_headers():
"""Test Azure OpenAI header building with api-key authentication."""
provider = AzureOpenAIProvider(
api_key="test-api-key-123",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
headers = provider._build_headers()
assert headers["Content-Type"] == "application/json"
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
assert "x-session-affinity" in headers
def test_prepare_request_payload():
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert payload["temperature"] == 0.8
assert "tools" not in payload
# Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
)
assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio
async def test_chat_success():
"""Test successful chat request using model as deployment name."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
# Mock response data
mock_response_data = {
"choices": [{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 18,
"total_tokens": 30
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
# Test with specific model (deployment name)
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages, model="custom-deployment")
assert isinstance(result, LLMResponse)
assert result.content == "Hello! How can I help you today?"
assert result.finish_reason == "stop"
assert result.usage["prompt_tokens"] == 12
assert result.usage["completion_tokens"] == 18
assert result.usage["total_tokens"] == 30
# Verify URL was built with the provided model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_uses_default_model_when_no_model_provided():
"""Test that chat uses default_model when no model is specified."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="default-deployment",
)
mock_response_data = {
"choices": [{
"message": {"content": "Response", "role": "assistant"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Test"}]
await provider.chat(messages) # No model specified
# Verify URL was built with default model as deployment name
call_args = mock_context.post.call_args
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
assert call_args[0][0] == expected_url
@pytest.mark.asyncio
async def test_chat_with_tool_calls():
"""Test chat request with tool calls in response."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Mock response with tool calls
mock_response_data = {
"choices": [{
"message": {
"content": None,
"role": "assistant",
"tool_calls": [{
"id": "call_12345",
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}'
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 15,
"total_tokens": 35
}
}
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = Mock(return_value=mock_response_data)
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "What's the weather?"}]
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
result = await provider.chat(messages, tools=tools, model="weather-model")
assert isinstance(result, LLMResponse)
assert result.content is None
assert result.finish_reason == "tool_calls"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
@pytest.mark.asyncio
async def test_chat_api_error():
"""Test chat request API error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_response = AsyncMock()
mock_response.status_code = 401
mock_response.text = "Invalid authentication credentials"
mock_context = AsyncMock()
mock_context.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Azure OpenAI API Error 401" in result.content
assert "Invalid authentication credentials" in result.content
assert result.finish_reason == "error"
@pytest.mark.asyncio
async def test_chat_connection_error():
"""Test chat request connection error handling."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
with patch("httpx.AsyncClient") as mock_client:
mock_context = AsyncMock()
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
mock_client.return_value.__aenter__.return_value = mock_context
messages = [{"role": "user", "content": "Hello"}]
result = await provider.chat(messages)
assert isinstance(result, LLMResponse)
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
assert result.finish_reason == "error"
def test_parse_response_malformed():
"""Test response parsing with malformed data."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
# Test with missing choices
malformed_response = {"usage": {"prompt_tokens": 10}}
result = provider._parse_response(malformed_response)
assert isinstance(result, LLMResponse)
assert "Error parsing Azure OpenAI response" in result.content
assert result.finish_reason == "error"
def test_get_default_model():
"""Test get_default_model method."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="my-custom-deployment",
)
assert provider.get_default_model() == "my-custom-deployment"
if __name__ == "__main__":
# Run basic tests
print("Running basic Azure OpenAI provider tests...")
# Test initialization
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o-deployment",
)
print("✅ Provider initialization successful")
# Test URL building
url = provider._build_chat_url("my-deployment")
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
assert url == expected
print("✅ URL building works correctly")
# Test headers
headers = provider._build_headers()
assert headers["api-key"] == "test-key"
assert headers["Content-Type"] == "application/json"
print("✅ Header building works correctly")
# Test payload preparation
messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly")
print("✅ All basic tests passed! Updated test file is working correctly.")

View File

@@ -0,0 +1,25 @@
from types import SimpleNamespace
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
class _DummyChannel(BaseChannel):
name = "dummy"
async def start(self) -> None:
return None
async def stop(self) -> None:
return None
async def send(self, msg: OutboundMessage) -> None:
return None
def test_is_allowed_requires_exact_match() -> None:
channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
assert channel.is_allowed("allow@email.com") is True
assert channel.is_allowed("attacker|allow@email.com") is False

View File

@@ -6,6 +6,10 @@ import pytest
from typer.testing import CliRunner from typer.testing import CliRunner
from nanobot.cli.commands import app from nanobot.cli.commands import app
from nanobot.config.schema import Config
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model
runner = CliRunner() runner = CliRunner()
@@ -90,3 +94,37 @@ def test_onboard_existing_workspace_safe_create(mock_paths):
assert "Created workspace" not in result.stdout assert "Created workspace" not in result.stdout
assert "Created AGENTS.md" in result.stdout assert "Created AGENTS.md" in result.stdout
assert (workspace_dir / "AGENTS.md").exists() assert (workspace_dir / "AGENTS.md").exists()
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
config = Config()
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
assert config.get_provider_name() == "github_copilot"
def test_config_matches_openai_codex_with_hyphen_prefix():
config = Config()
config.agents.defaults.model = "openai-codex/gpt-5.1-codex"
assert config.get_provider_name() == "openai_codex"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex")
assert spec is not None
assert spec.name == "github_copilot"
def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex")
resolved = provider._resolve_model("github-copilot/gpt-5.3-codex")
assert resolved == "github_copilot/gpt-5.3-codex"
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"

View File

@@ -1,5 +1,8 @@
"""Test session management with cache-friendly message handling.""" """Test session management with cache-friendly message handling."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from pathlib import Path from pathlib import Path
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
@@ -475,3 +478,344 @@ class TestEmptyAndBoundarySessions:
expected_count = 60 - KEEP_COUNT - 10 expected_count = 60 - KEEP_COUNT - 10
assert len(old_messages) == expected_count assert len(old_messages) == expected_count
assert_messages_content(old_messages, 10, 34) assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard:
"""Test that consolidation tasks are deduplicated and serialized."""
@pytest.mark.asyncio
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
"""Concurrent messages above memory_window spawn only one consolidation task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
@pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
if archive_all:
return False
return True
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)
@pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
self, tmp_path: Path
) -> None:
"""/new should archive only messages not yet consolidated by prior task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
sess.last_consolidated = len(sess.messages) - 3
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done()
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count == 3, (
f"Expected only unconsolidated tail to archive, got {archived_count}"
)
@pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
"""/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
assert loop.sessions.get_or_create("cli:test").messages == []

View File

@@ -0,0 +1,65 @@
"""Tests for cache-friendly prompt construction."""
from __future__ import annotations
from datetime import datetime as real_datetime
from pathlib import Path
import datetime as datetime_module
from nanobot.agent.context import ContextBuilder
class _FakeDatetime(real_datetime):
current = real_datetime(2026, 2, 24, 13, 59)
@classmethod
def now(cls, tz=None): # type: ignore[override]
return cls.current
def _make_workspace(tmp_path: Path) -> Path:
workspace = tmp_path / "workspace"
workspace.mkdir(parents=True)
return workspace
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
"""System prompt should not change just because wall clock minute changes."""
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
prompt1 = builder.build_system_prompt()
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
prompt2 = builder.build_system_prompt()
assert prompt1 == prompt2
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
"""Runtime metadata should be merged with the user message."""
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
messages = builder.build_messages(
history=[],
current_message="Return exactly: OK",
channel="cli",
chat_id="direct",
)
assert messages[0]["role"] == "system"
assert "## Current Session" not in messages[0]["content"]
# Runtime context is now merged with user message into a single message
assert messages[-1]["role"] == "user"
user_content = messages[-1]["content"]
assert isinstance(user_content, str)
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
assert "Current Time:" in user_content
assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content
assert "Return exactly: OK" in user_content

View File

@@ -0,0 +1,61 @@
import asyncio
import pytest
from nanobot.cron.service import CronService
from nanobot.cron.types import CronSchedule
def test_add_job_rejects_unknown_timezone(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
with pytest.raises(ValueError, match="unknown timezone 'America/Vancovuer'"):
service.add_job(
name="tz typo",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancovuer"),
message="hello",
)
assert service.list_jobs(include_disabled=True) == []
def test_add_job_accepts_valid_timezone(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="tz ok",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancouver"),
message="hello",
)
assert job.schedule.tz == "America/Vancouver"
assert job.state.next_run_at_ms is not None
@pytest.mark.asyncio
async def test_running_service_honors_external_disable(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
called: list[str] = []
async def on_job(job) -> None:
called.append(job.id)
service = CronService(store_path, on_job=on_job)
job = service.add_job(
name="external-disable",
schedule=CronSchedule(kind="every", every_ms=200),
message="hello",
)
await service.start()
try:
# Wait slightly to ensure file mtime is definitively different
await asyncio.sleep(0.05)
external = CronService(store_path)
updated = external.enable_job(job.id, enabled=False)
assert updated is not None
assert updated.enabled is False
await asyncio.sleep(0.35)
assert called == []
finally:
service.stop()

View File

@@ -0,0 +1,66 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.dingtalk import DingTalkChannel
from nanobot.config.schema import DingTalkConfig
class _FakeResponse:
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
self.status_code = status_code
self._json_body = json_body or {}
self.text = "{}"
def json(self) -> dict:
return self._json_body
class _FakeHttp:
def __init__(self) -> None:
self.calls: list[dict] = []
async def post(self, url: str, json=None, headers=None):
self.calls.append({"url": url, "json": json, "headers": headers})
return _FakeResponse()
@pytest.mark.asyncio
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
bus = MessageBus()
channel = DingTalkChannel(config, bus)
await channel._on_message(
"hello",
sender_id="user1",
sender_name="Alice",
conversation_type="2",
conversation_id="conv123",
)
msg = await bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"
assert msg.metadata["conversation_type"] == "2"
@pytest.mark.asyncio
async def test_group_send_uses_group_messages_api() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
channel._http = _FakeHttp()
ok = await channel._send_batch_message(
"token",
"group:conv123",
"sampleMarkdown",
{"text": "hello", "title": "Nanobot Reply"},
)
assert ok is True
call = channel._http.calls[0]
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
assert call["json"]["openConversationId"] == "conv123"
assert call["json"]["msgKey"] == "sampleMarkdown"

View File

@@ -169,7 +169,8 @@ async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None: async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
class FakeSMTP: class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None: def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.sent_messages: list[EmailMessage] = [] self.sent_messages: list[EmailMessage] = []
@@ -201,6 +202,11 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
cfg = _make_config() cfg = _make_config()
cfg.auto_reply_enabled = False cfg.auto_reply_enabled = False
channel = EmailChannel(cfg, MessageBus()) channel = EmailChannel(cfg, MessageBus())
# Mark alice as someone who sent us an email (making this a "reply")
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
# Reply should be skipped (auto_reply_enabled=False)
await channel.send( await channel.send(
OutboundMessage( OutboundMessage(
channel="email", channel="email",
@@ -210,6 +216,7 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
) )
assert fake_instances == [] assert fake_instances == []
# Reply with force_send=True should be sent
await channel.send( await channel.send(
OutboundMessage( OutboundMessage(
channel="email", channel="email",
@@ -222,6 +229,56 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
assert len(fake_instances[0].sent_messages) == 1 assert len(fake_instances[0].sent_messages) == 1
@pytest.mark.asyncio
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.sent_messages: list[EmailMessage] = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
return None
def login(self, _user: str, _pw: str):
return None
def send_message(self, msg: EmailMessage):
self.sent_messages.append(msg)
fake_instances: list[FakeSMTP] = []
def _smtp_factory(host: str, port: int, timeout: int = 30):
instance = FakeSMTP(host, port, timeout=timeout)
fake_instances.append(instance)
return instance
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
cfg = _make_config()
cfg.auto_reply_enabled = False
channel = EmailChannel(cfg, MessageBus())
# bob@example.com has never sent us an email (proactive send)
# This should be sent even with auto_reply_enabled=False
await channel.send(
OutboundMessage(
channel="email",
chat_id="bob@example.com",
content="Hello, this is a proactive email.",
)
)
assert len(fake_instances) == 1
assert len(fake_instances[0].sent_messages) == 1
sent = fake_instances[0].sent_messages[0]
assert sent["To"] == "bob@example.com"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None: async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
class FakeSMTP: class FakeSMTP:

View File

@@ -0,0 +1,65 @@
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
def test_extract_post_content_supports_post_wrapper_shape() -> None:
payload = {
"post": {
"zh_cn": {
"title": "日报",
"content": [
[
{"tag": "text", "text": "完成"},
{"tag": "img", "image_key": "img_1"},
]
],
}
}
}
text, image_keys = _extract_post_content(payload)
assert text == "日报 完成"
assert image_keys == ["img_1"]
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
payload = {
"title": "Daily",
"content": [
[
{"tag": "text", "text": "report"},
{"tag": "img", "image_key": "img_a"},
{"tag": "img", "image_key": "img_b"},
]
],
}
text, image_keys = _extract_post_content(payload)
assert text == "Daily report"
assert image_keys == ["img_a", "img_b"]
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
class Builder:
pass
builder = Builder()
same = FeishuChannel._register_optional_event(builder, "missing", object())
assert same is builder
def test_register_optional_event_calls_supported_method() -> None:
called = []
class Builder:
def register_event(self, handler):
called.append(handler)
return self
builder = Builder()
handler = object()
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
assert same is builder
assert called == [handler]

View File

@@ -0,0 +1,104 @@
"""Tests for FeishuChannel._split_elements_by_table_limit.
Feishu cards reject messages that contain more than one table element
(API error 11310: card table number over limit). The helper splits a flat
list of card elements into groups so that each group contains at most one
table, allowing nanobot to send multiple cards instead of failing.
"""
from nanobot.channels.feishu import FeishuChannel
def _md(text: str) -> dict:
return {"tag": "markdown", "content": text}
def _table() -> dict:
return {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
"rows": [{"c0": "v"}],
"page_size": 2,
}
split = FeishuChannel._split_elements_by_table_limit
def test_empty_list_returns_single_empty_group() -> None:
assert split([]) == [[]]
def test_no_tables_returns_single_group() -> None:
els = [_md("hello"), _md("world")]
result = split(els)
assert result == [els]
def test_single_table_stays_in_one_group() -> None:
els = [_md("intro"), _table(), _md("outro")]
result = split(els)
assert len(result) == 1
assert result[0] == els
def test_two_tables_split_into_two_groups() -> None:
# Use different row values so the two tables are not equal
t1 = {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
"rows": [{"c0": "table-one"}],
"page_size": 2,
}
t2 = {
"tag": "table",
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
"rows": [{"c0": "table-two"}],
"page_size": 2,
}
els = [_md("before"), t1, _md("between"), t2, _md("after")]
result = split(els)
assert len(result) == 2
# First group: text before table-1 + table-1
assert t1 in result[0]
assert t2 not in result[0]
# Second group: text between tables + table-2 + text after
assert t2 in result[1]
assert t1 not in result[1]
def test_three_tables_split_into_three_groups() -> None:
tables = [
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
for i in range(3)
]
els = tables[:]
result = split(els)
assert len(result) == 3
for i, group in enumerate(result):
assert tables[i] in group
def test_leading_markdown_stays_with_first_table() -> None:
intro = _md("intro")
t = _table()
result = split([intro, t])
assert len(result) == 1
assert result[0] == [intro, t]
def test_trailing_markdown_after_second_table() -> None:
t1, t2 = _table(), _table()
tail = _md("end")
result = split([t1, t2, tail])
assert len(result) == 2
assert result[1] == [t2, tail]
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
head = _md("head")
t1, t2 = _table(), _table()
result = split([head, t1, t2])
# head + t1 in group 0; t2 in group 1
assert result[0] == [head, t1]
assert result[1] == [t2]

View File

@@ -0,0 +1,117 @@
import asyncio
import pytest
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMResponse, ToolCallRequest
class DummyProvider:
def __init__(self, responses: list[LLMResponse]):
self._responses = list(responses)
async def chat(self, *args, **kwargs) -> LLMResponse:
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None:
provider = DummyProvider([])
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
interval_s=9999,
enabled=True,
)
await service.start()
first_task = service._task
await service.start()
assert service._task is first_task
service.stop()
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "skip"
assert tasks == ""
@pytest.mark.asyncio
async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
)
])
called_with: list[str] = []
async def _on_execute(tasks: str) -> str:
called_with.append(tasks)
return "done"
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
)
result = await service.trigger_now()
assert result == "done"
assert called_with == ["check open tasks"]
@pytest.mark.asyncio
async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "skip"},
)
],
)
])
async def _on_execute(tasks: str) -> str:
return tasks
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
)
assert await service.trigger_now() is None

View File

@@ -0,0 +1,41 @@
from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = 500
return loop
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
loop = _mk_loop()
session = Session(key="test:runtime-only")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
skip=0,
)
assert session.messages == []
def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
loop = _mk_loop()
session = Session(key="test:image")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{
"role": "user",
"content": [
{"type": "text", "text": runtime},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
}],
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]

1318
tests/test_matrix_channel.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,222 @@
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
Regression test for https://github.com/HKUDS/nanobot/issues/1042
When memory consolidation receives dict values instead of strings from the LLM
tool call response, it should serialize them to JSON instead of raising TypeError.
"""
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50):
"""Create a mock session with messages."""
session = MagicMock()
session.messages = [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count)
]
session.last_consolidated = 0
return session
def _make_tool_response(history_entry, memory_update):
"""Create an LLMResponse with a save_memory tool call."""
return LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={
"history_entry": history_entry,
"memory_update": memory_update,
},
)
],
)
class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly."""
@pytest.mark.asyncio
async def test_string_arguments_work(self, tmp_path: Path) -> None:
"""Normal case: LLM returns string arguments."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert store.history_file.exists()
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
assert "User likes testing." in store.memory_file.read_text()
@pytest.mark.asyncio
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert store.history_file.exists()
history_content = store.history_file.read_text()
parsed = json.loads(history_content.strip())
assert parsed["summary"] == "User discussed testing."
memory_content = store.memory_file.read_text()
parsed_mem = json.loads(memory_content)
assert "User likes testing" in parsed_mem["facts"]
@pytest.mark.asyncio
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
"""Some providers return arguments as a JSON string instead of parsed dict."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=json.dumps({
"history_entry": "[2026-01-01] User discussed testing.",
"memory_update": "# Memory\nUser likes testing.",
}),
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
"""When LLM doesn't use the save_memory tool, return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when messages < keep_count."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
session = _make_session(message_count=10)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
provider.chat.assert_not_called()
@pytest.mark.asyncio
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
"""Some providers return arguments as a list - extract first element if it's a dict."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[{
"history_entry": "[2026-01-01] User discussed testing.",
"memory_update": "# Memory\nUser likes testing.",
}],
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert "User discussed testing." in store.history_file.read_text()
assert "User likes testing." in store.memory_file.read_text()
@pytest.mark.asyncio
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
"""Empty list arguments should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=[],
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
@pytest.mark.asyncio
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
"""List with non-dict content should return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=["string", "content"],
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False

View File

@@ -0,0 +1,10 @@
import pytest
from nanobot.agent.tools.message import MessageTool
@pytest.mark.asyncio
async def test_message_tool_returns_error_when_no_target_context() -> None:
tool = MessageTool()
result = await tool.execute(content="test")
assert result == "Error: No target channel/chat specified"

View File

@@ -0,0 +1,132 @@
"""Test message tool suppress logic for final replies."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.message import MessageTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
class TestMessageToolSuppressLogic:
"""Final reply suppressed only when message tool sends to the same target."""
@pytest.mark.asyncio
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
result = await loop._process_message(msg)
assert len(sent) == 1
assert result is None # suppressed
@pytest.mark.asyncio
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
result = await loop._process_message(msg)
assert len(sent) == 1
assert sent[0].channel == "email"
assert result is not None # not suppressed
assert result.channel == "feishu"
@pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
result = await loop._process_message(msg)
assert result is not None
assert "Hello" in result.content
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
calls = iter([
LLMResponse(
content="Visible<think>hidden</think>",
tool_calls=[tool_call],
reasoning_content="secret reasoning",
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
progress: list[tuple[str, bool]] = []
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [
("Visible", False),
('read_file("foo.txt")', True),
]
class TestMessageToolTurnTracking:
def test_sent_in_turn_tracks_same_target(self) -> None:
tool = MessageTool()
tool.set_context("feishu", "chat1")
assert not tool._sent_in_turn
tool._sent_in_turn = True
assert tool._sent_in_turn
def test_start_turn_resets(self) -> None:
tool = MessageTool()
tool._sent_in_turn = True
tool.start_turn()
assert not tool._sent_in_turn

66
tests/test_qq_channel.py Normal file
View File

@@ -0,0 +1,66 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel
from nanobot.config.schema import QQConfig
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeClient:
def __init__(self) -> None:
self.api = _FakeApi()
@pytest.mark.asyncio
async def test_on_group_message_routes_to_group_chat_id() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["user1"]), MessageBus())
data = SimpleNamespace(
id="msg1",
content="hello",
group_openid="group123",
author=SimpleNamespace(member_openid="user1"),
)
await channel._on_message(data, is_group=True)
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group123"
@pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
await channel.send(
OutboundMessage(
channel="qq",
chat_id="group123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123"
assert call["msg_id"] == "msg1"
assert call["msg_seq"] == 2
assert not channel._client.api.c2c_calls

167
tests/test_task_cancel.py Normal file
View File

@@ -0,0 +1,167 @@
"""Tests for /stop task cancellation."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
def _make_loop():
"""Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
workspace = MagicMock()
workspace.__truediv__ = MagicMock(return_value=MagicMock())
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
return loop, bus
class TestHandleStop:
@pytest.mark.asyncio
async def test_stop_no_active_task(self):
from nanobot.bus.events import InboundMessage
loop, bus = _make_loop()
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
await loop._handle_stop(msg)
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "No active task" in out.content
@pytest.mark.asyncio
async def test_stop_cancels_active_task(self):
from nanobot.bus.events import InboundMessage
loop, bus = _make_loop()
cancelled = asyncio.Event()
async def slow_task():
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
cancelled.set()
raise
task = asyncio.create_task(slow_task())
await asyncio.sleep(0)
loop._active_tasks["test:c1"] = [task]
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
await loop._handle_stop(msg)
assert cancelled.is_set()
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "stopped" in out.content.lower()
@pytest.mark.asyncio
async def test_stop_cancels_multiple_tasks(self):
from nanobot.bus.events import InboundMessage
loop, bus = _make_loop()
events = [asyncio.Event(), asyncio.Event()]
async def slow(idx):
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
events[idx].set()
raise
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
await asyncio.sleep(0)
loop._active_tasks["test:c1"] = tasks
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
await loop._handle_stop(msg)
assert all(e.is_set() for e in events)
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "2 task" in out.content
class TestDispatch:
@pytest.mark.asyncio
async def test_dispatch_processes_and_publishes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage
loop, bus = _make_loop()
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
loop._process_message = AsyncMock(
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
)
await loop._dispatch(msg)
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert out.content == "hi"
@pytest.mark.asyncio
async def test_processing_lock_serializes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage
loop, bus = _make_loop()
order = []
async def mock_process(m, **kwargs):
order.append(f"start-{m.content}")
await asyncio.sleep(0.05)
order.append(f"end-{m.content}")
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
loop._process_message = mock_process
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
t1 = asyncio.create_task(loop._dispatch(msg1))
t2 = asyncio.create_task(loop._dispatch(msg2))
await asyncio.gather(t1, t2)
assert order == ["start-a", "end-a", "start-b", "end-b"]
class TestSubagentCancellation:
@pytest.mark.asyncio
async def test_cancel_by_session(self):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
cancelled = asyncio.Event()
async def slow():
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
cancelled.set()
raise
task = asyncio.create_task(slow())
await asyncio.sleep(0)
mgr._running_tasks["sub-1"] = task
mgr._session_tasks["test:c1"] = {"sub-1"}
count = await mgr.cancel_by_session("test:c1")
assert count == 1
assert cancelled.is_set()
@pytest.mark.asyncio
async def test_cancel_by_session_no_tasks(self):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
assert await mgr.cancel_by_session("nonexistent") == 0

View File

@@ -0,0 +1,184 @@
from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TelegramChannel
from nanobot.config.schema import TelegramConfig
class _FakeHTTPXRequest:
instances: list["_FakeHTTPXRequest"] = []
def __init__(self, **kwargs) -> None:
self.kwargs = kwargs
self.__class__.instances.append(self)
class _FakeUpdater:
def __init__(self, on_start_polling) -> None:
self._on_start_polling = on_start_polling
async def start_polling(self, **kwargs) -> None:
self._on_start_polling()
class _FakeBot:
def __init__(self) -> None:
self.sent_messages: list[dict] = []
async def get_me(self):
return SimpleNamespace(username="nanobot_test")
async def set_my_commands(self, commands) -> None:
self.commands = commands
async def send_message(self, **kwargs) -> None:
self.sent_messages.append(kwargs)
class _FakeApp:
def __init__(self, on_start_polling) -> None:
self.bot = _FakeBot()
self.updater = _FakeUpdater(on_start_polling)
self.handlers = []
self.error_handlers = []
def add_error_handler(self, handler) -> None:
self.error_handlers.append(handler)
def add_handler(self, handler) -> None:
self.handlers.append(handler)
async def initialize(self) -> None:
pass
async def start(self) -> None:
pass
class _FakeBuilder:
def __init__(self, app: _FakeApp) -> None:
self.app = app
self.token_value = None
self.request_value = None
self.get_updates_request_value = None
def token(self, token: str):
self.token_value = token
return self
def request(self, request):
self.request_value = request
return self
def get_updates_request(self, request):
self.get_updates_request_value = request
return self
def proxy(self, _proxy):
raise AssertionError("builder.proxy should not be called when request is set")
def get_updates_proxy(self, _proxy):
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
def build(self):
return self.app
@pytest.mark.asyncio
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
config = TelegramConfig(
enabled=True,
token="123:abc",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
)
bus = MessageBus()
channel = TelegramChannel(config, bus)
app = _FakeApp(lambda: setattr(channel, "_running", False))
builder = _FakeBuilder(app)
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
monkeypatch.setattr(
"nanobot.channels.telegram.Application",
SimpleNamespace(builder=lambda: builder),
)
await channel.start()
assert len(_FakeHTTPXRequest.instances) == 1
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
assert builder.request_value is _FakeHTTPXRequest.instances[0]
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
def test_derive_topic_session_key_uses_thread_id() -> None:
message = SimpleNamespace(
chat=SimpleNamespace(type="supergroup"),
chat_id=-100123,
message_thread_id=42,
)
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
def test_get_extension_falls_back_to_original_filename() -> None:
channel = TelegramChannel(TelegramConfig(), MessageBus())
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
assert channel.is_allowed("12345|carol") is True
assert channel.is_allowed("99999|alice") is True
assert channel.is_allowed("67890|bob") is True
def test_is_allowed_rejects_invalid_legacy_telegram_sender_shapes() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["alice"]), MessageBus())
assert channel.is_allowed("attacker|alice|extra") is False
assert channel.is_allowed("not-a-number|alice") is False
@pytest.mark.asyncio
async def test_send_progress_keeps_message_in_topic() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"_progress": True, "message_thread_id": 42},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
@pytest.mark.asyncio
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
channel = TelegramChannel(config, MessageBus())
channel._app = _FakeApp(lambda: None)
channel._message_threads[("123", 10)] = 42
await channel.send(
OutboundMessage(
channel="telegram",
chat_id="123",
content="hello",
metadata={"message_id": 10},
)
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10

View File

@@ -2,6 +2,7 @@ from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
class SampleTool(Tool): class SampleTool(Tool):
@@ -86,3 +87,253 @@ async def test_registry_returns_validation_error() -> None:
reg.register(SampleTool()) reg.register(SampleTool())
result = await reg.execute("sample", {"query": "hi"}) result = await reg.execute("sample", {"query": "hi"})
assert "Invalid parameters" in result assert "Invalid parameters" in result
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
cmd = r"type C:\user\workspace\txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert paths == [r"C:\user\workspace\txt"]
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
cmd = ".venv/bin/python script.py"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/bin/python" not in paths
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
cmd = "cat /tmp/data.txt > /tmp/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "/tmp/out.txt" in paths
# --- cast_params tests ---
class CastTestTool(Tool):
"""Minimal tool for testing cast_params."""
def __init__(self, schema: dict[str, Any]) -> None:
self._schema = schema
@property
def name(self) -> str:
return "cast_test"
@property
def description(self) -> str:
return "test tool for casting"
@property
def parameters(self) -> dict[str, Any]:
return self._schema
async def execute(self, **kwargs: Any) -> str:
return "ok"
def test_cast_params_string_to_int() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "42"})
assert result["count"] == 42
assert isinstance(result["count"], int)
def test_cast_params_string_to_number() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "3.14"})
assert result["rate"] == 3.14
assert isinstance(result["rate"], float)
def test_cast_params_string_to_bool() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"enabled": {"type": "boolean"}},
}
)
assert tool.cast_params({"enabled": "true"})["enabled"] is True
assert tool.cast_params({"enabled": "false"})["enabled"] is False
assert tool.cast_params({"enabled": "1"})["enabled"] is True
def test_cast_params_array_items() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"nums": {"type": "array", "items": {"type": "integer"}},
},
}
)
result = tool.cast_params({"nums": ["1", "2", "3"]})
assert result["nums"] == [1, 2, 3]
def test_cast_params_nested_object() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"port": {"type": "integer"},
"debug": {"type": "boolean"},
},
},
},
}
)
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
assert result["config"]["port"] == 8080
assert result["config"]["debug"] is True
def test_cast_params_bool_not_cast_to_int() -> None:
"""Booleans should not be silently cast to integers."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": True})
assert result["count"] is True
errors = tool.validate_params(result)
assert any("count should be integer" in e for e in errors)
def test_cast_params_preserves_empty_string() -> None:
"""Empty strings should be preserved for string type."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string"}},
}
)
result = tool.cast_params({"name": ""})
assert result["name"] == ""
def test_cast_params_bool_string_false() -> None:
"""Test that 'false', '0', 'no' strings convert to False."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
assert tool.cast_params({"flag": "false"})["flag"] is False
assert tool.cast_params({"flag": "False"})["flag"] is False
assert tool.cast_params({"flag": "0"})["flag"] is False
assert tool.cast_params({"flag": "no"})["flag"] is False
assert tool.cast_params({"flag": "NO"})["flag"] is False
def test_cast_params_bool_string_invalid() -> None:
"""Invalid boolean strings should not be cast."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
# Invalid strings should be preserved (validation will catch them)
result = tool.cast_params({"flag": "random"})
assert result["flag"] == "random"
result = tool.cast_params({"flag": "maybe"})
assert result["flag"] == "maybe"
def test_cast_params_invalid_string_to_int() -> None:
"""Invalid strings should not be cast to integer."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "abc"})
assert result["count"] == "abc" # Original value preserved
result = tool.cast_params({"count": "12.5.7"})
assert result["count"] == "12.5.7"
def test_cast_params_invalid_string_to_number() -> None:
"""Invalid strings should not be cast to number."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "not_a_number"})
assert result["rate"] == "not_a_number"
def test_validate_params_bool_not_accepted_as_number() -> None:
"""Booleans should not pass number validation."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
errors = tool.validate_params({"rate": False})
assert any("rate should be number" in e for e in errors)
def test_cast_params_none_values() -> None:
"""Test None handling for different types."""
tool = CastTestTool(
{
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "integer"},
"items": {"type": "array"},
"config": {"type": "object"},
},
}
)
result = tool.cast_params(
{
"name": None,
"count": None,
"items": None,
"config": None,
}
)
# None should be preserved for all types
assert result["name"] is None
assert result["count"] is None
assert result["items"] is None
assert result["config"] is None
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
"""Single values should NOT be automatically wrapped into arrays."""
tool = CastTestTool(
{
"type": "object",
"properties": {"items": {"type": "array"}},
}
)
# Non-array values should be preserved (validation will catch them)
result = tool.cast_params({"items": 5})
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]

View File

@@ -1,51 +0,0 @@
# Agent Instructions
You are a helpful AI assistant. Be concise, accurate, and friendly.
## Guidelines
- Always explain what you're doing before taking actions
- Ask for clarification when the request is ambiguous
- Use tools to help accomplish tasks
- Remember important information in your memory files
## Tools Available
You have access to:
- File operations (read, write, edit, list)
- Shell commands (exec)
- Web access (search, fetch)
- Messaging (message)
- Background tasks (spawn)
## Memory
- `memory/MEMORY.md` — long-term facts (preferences, context, relationships)
- `memory/HISTORY.md` — append-only event log, search with grep to recall past events
## Scheduled Reminders
When user asks for a reminder at a specific time, use `exec` to run:
```
nanobot cron add --name "reminder" --message "Your message" --at "YYYY-MM-DDTHH:MM:SS" --deliver --to "USER_ID" --channel "CHANNEL"
```
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
## Heartbeat Tasks
`HEARTBEAT.md` is checked every 30 minutes. You can manage periodic tasks by editing this file:
- **Add a task**: Use `edit_file` to append new tasks to `HEARTBEAT.md`
- **Remove a task**: Use `edit_file` to remove completed or obsolete tasks
- **Rewrite tasks**: Use `write_file` to completely rewrite the task list
Task format examples:
```
- [ ] Check calendar and remind of upcoming events
- [ ] Scan inbox for urgent emails
- [ ] Check weather forecast for today
```
When the user asks you to add a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time reminder. Keep the file small to minimize token usage.

View File

@@ -1,150 +0,0 @@
# Available Tools
This document describes the tools available to nanobot.
## File Operations
### read_file
Read the contents of a file.
```
read_file(path: str) -> str
```
### write_file
Write content to a file (creates parent directories if needed).
```
write_file(path: str, content: str) -> str
```
### edit_file
Edit a file by replacing specific text.
```
edit_file(path: str, old_text: str, new_text: str) -> str
```
### list_dir
List contents of a directory.
```
list_dir(path: str) -> str
```
## Shell Execution
### exec
Execute a shell command and return output.
```
exec(command: str, working_dir: str = None) -> str
```
**Safety Notes:**
- Commands have a configurable timeout (default 60s)
- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
- Output is truncated at 10,000 characters
- Optional `restrictToWorkspace` config to limit paths
## Web Access
### web_search
Search the web using Brave Search API.
```
web_search(query: str, count: int = 5) -> str
```
Returns search results with titles, URLs, and snippets. Requires `tools.web.search.apiKey` in config.
### web_fetch
Fetch and extract main content from a URL.
```
web_fetch(url: str, extractMode: str = "markdown", maxChars: int = 50000) -> str
```
**Notes:**
- Content is extracted using readability
- Supports markdown or plain text extraction
- Output is truncated at 50,000 characters by default
## Communication
### message
Send a message to the user (used internally).
```
message(content: str, channel: str = None, chat_id: str = None) -> str
```
## Background Tasks
### spawn
Spawn a subagent to handle a task in the background.
```
spawn(task: str, label: str = None) -> str
```
Use for complex or time-consuming tasks that can run independently. The subagent will complete the task and report back when done.
## Scheduled Reminders (Cron)
Use the `exec` tool to create scheduled reminders with `nanobot cron add`:
### Set a recurring reminder
```bash
# Every day at 9am
nanobot cron add --name "morning" --message "Good morning! ☀️" --cron "0 9 * * *"
# Every 2 hours
nanobot cron add --name "water" --message "Drink water! 💧" --every 7200
```
### Set a one-time reminder
```bash
# At a specific time (ISO format)
nanobot cron add --name "meeting" --message "Meeting starts now!" --at "2025-01-31T15:00:00"
```
### Manage reminders
```bash
nanobot cron list # List all jobs
nanobot cron remove <job_id> # Remove a job
```
## Heartbeat Task Management
The `HEARTBEAT.md` file in the workspace is checked every 30 minutes.
Use file operations to manage periodic tasks:
### Add a heartbeat task
```python
# Append a new task
edit_file(
path="HEARTBEAT.md",
old_text="## Example Tasks",
new_text="- [ ] New periodic task here\n\n## Example Tasks"
)
```
### Remove a heartbeat task
```python
# Remove a specific task
edit_file(
path="HEARTBEAT.md",
old_text="- [ ] Task to remove\n",
new_text=""
)
```
### Rewrite all tasks
```python
# Replace the entire file
write_file(
path="HEARTBEAT.md",
content="# Heartbeat Tasks\n\n- [ ] Task 1\n- [ ] Task 2\n"
)
```
---
## Adding Custom Tools
To add custom tools:
1. Create a class that extends `Tool` in `nanobot/agent/tools/`
2. Implement `name`, `description`, `parameters`, and `execute`
3. Register it in `AgentLoop._register_default_tools()`