diff --git a/.gitignore b/.gitignore index 36dbfc2..374875a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.worktrees/ .assets .env *.pyc @@ -14,8 +15,9 @@ docs/ *.pywz *.pyzz .venv/ +venv/ __pycache__/ poetry.lock .pytest_cache/ -tests/ botpy.log + diff --git a/README.md b/README.md index fed25c8..03f042a 100644 --- a/README.md +++ b/README.md @@ -12,24 +12,48 @@

-🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw) +🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw). -⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines. +⚡️ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw. -📏 Real-time line count: **3,510 lines** (run `bash core_agent_lines.sh` to verify anytime) +📏 Real-time line count: run `bash core_agent_lines.sh` to verify anytime. ## 📢 News -- **2026-02-10** 🎉 Released v0.1.3.post6 with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431). +- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details. +- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes. +- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility. +- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync. +- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details. +- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes. +- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements. +- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details. +- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood. +- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode. + +
+Earlier news + +- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching. +- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details. +- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills. +- **2026-02-15** 🔑 nanobot now supports OpenAI Codex provider with OAuth login support. +- **2026-02-14** 🔌 nanobot now supports MCP! See [MCP section](#mcp-model-context-protocol) for details. +- **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details. +- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it! +- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support! +- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431). - **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms! - **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers). -- **2026-02-07** 🚀 Released v0.1.3.post5 with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details. +- **2026-02-07** 🚀 Released **v0.1.3.post5** with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details. - **2026-02-06** ✨ Added Moonshot/Kimi provider, Discord integration, and enhanced security hardening! - **2026-02-05** ✨ Added Feishu channel, DeepSeek provider, and enhanced scheduled tasks support! -- **2026-02-04** 🚀 Released v0.1.3.post4 with multi-provider & Docker support! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post4) for details. +- **2026-02-04** 🚀 Released **v0.1.3.post4** with multi-provider & Docker support! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post4) for details. - **2026-02-03** ⚡ Integrated vLLM for local LLM support and improved natural language task scheduling! - **2026-02-02** 🎉 nanobot officially launched! Welcome to try 🐈 nanobot! +
+ ## Key Features of nanobot: 🪶 **Ultra-Lightweight**: Just ~4,000 lines of core agent code — 99% smaller than Clawdbot. @@ -105,17 +129,26 @@ nanobot onboard **2. Configure** (`~/.nanobot/config.json`) -For OpenRouter - recommended for global users: +Add or merge these **two parts** into your config (other options have defaults). + +*Set your API key* (e.g. OpenRouter, recommended for global users): ```json { "providers": { "openrouter": { "apiKey": "sk-or-v1-xxx" } - }, + } +} +``` + +*Set your model* (optionally pin a provider — defaults to auto-detection): +```json +{ "agents": { "defaults": { - "model": "anthropic/claude-opus-4-5" + "model": "anthropic/claude-opus-4-5", + "provider": "openrouter" } } } @@ -124,63 +157,26 @@ For OpenRouter - recommended for global users: **3. Chat** ```bash -nanobot agent -m "What is 2+2?" +nanobot agent ``` That's it! You have a working AI assistant in 2 minutes. -## 🖥️ Local Models (vLLM) - -Run nanobot with your own local models using vLLM or any OpenAI-compatible server. - -**1. Start your vLLM server** - -```bash -vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000 -``` - -**2. Configure** (`~/.nanobot/config.json`) - -```json -{ - "providers": { - "vllm": { - "apiKey": "dummy", - "apiBase": "http://localhost:8000/v1" - } - }, - "agents": { - "defaults": { - "model": "meta-llama/Llama-3.1-8B-Instruct" - } - } -} -``` - -**3. Chat** - -```bash -nanobot agent -m "Hello from my local LLM!" -``` - -> [!TIP] -> The `apiKey` can be any non-empty string for local servers that don't require authentication. - ## 💬 Chat Apps -Talk to your nanobot through Telegram, Discord, WhatsApp, Feishu, Mochat, DingTalk, Slack, Email, or QQ — anytime, anywhere. +Connect nanobot to your favorite chat platform. -| Channel | Setup | -|---------|-------| -| **Telegram** | Easy (just a token) | -| **Discord** | Easy (bot token + intents) | -| **WhatsApp** | Medium (scan QR) | -| **Feishu** | Medium (app credentials) | -| **Mochat** | Medium (claw token + websocket) | -| **DingTalk** | Medium (app credentials) | -| **Slack** | Medium (bot + app tokens) | -| **Email** | Medium (IMAP/SMTP credentials) | -| **QQ** | Easy (app credentials) | +| Channel | What you need | +|---------|---------------| +| **Telegram** | Bot token from @BotFather | +| **Discord** | Bot token + Message Content intent | +| **WhatsApp** | QR code scan | +| **Feishu** | App ID + App Secret | +| **Mochat** | Claw token (auto-setup available) | +| **DingTalk** | App Key + App Secret | +| **Slack** | Bot token + App-Level token | +| **Email** | IMAP/SMTP credentials | +| **QQ** | App ID + App Secret |
Telegram (Recommended) @@ -297,12 +293,18 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso "discord": { "enabled": true, "token": "YOUR_BOT_TOKEN", - "allowFrom": ["YOUR_USER_ID"] + "allowFrom": ["YOUR_USER_ID"], + "groupPolicy": "mention" } } } ``` +> `groupPolicy` controls how the bot responds in group channels: +> - `"mention"` (default) — Only respond when @mentioned +> - `"open"` — Respond to all messages +> DMs always respond when the sender is in `allowFrom`. + **5. Invite the bot** - OAuth2 → URL Generator - Scopes: `bot` @@ -317,6 +319,72 @@ nanobot gateway
+
+Matrix (Element) + +Install Matrix dependencies first: + +```bash +pip install nanobot-ai[matrix] +``` + +**1. Create/choose a Matrix account** + +- Create or reuse a Matrix account on your homeserver (for example `matrix.org`). +- Confirm you can log in with Element. + +**2. Get credentials** + +- You need: + - `userId` (example: `@nanobot:matrix.org`) + - `accessToken` + - `deviceId` (recommended so sync tokens can be restored across restarts) +- You can obtain these from your homeserver login API (`/_matrix/client/v3/login`) or from your client's advanced session settings. + +**3. Configure** + +```json +{ + "channels": { + "matrix": { + "enabled": true, + "homeserver": "https://matrix.org", + "userId": "@nanobot:matrix.org", + "accessToken": "syt_xxx", + "deviceId": "NANOBOT01", + "e2eeEnabled": true, + "allowFrom": ["@your_user:matrix.org"], + "groupPolicy": "open", + "groupAllowFrom": [], + "allowRoomMentions": false, + "maxMediaBytes": 20971520 + } + } +} +``` + +> Keep a persistent `matrix-store` and stable `deviceId` — encrypted session state is lost if these change across restarts. + +| Option | Description | +|--------|-------------| +| `allowFrom` | User IDs allowed to interact. Empty = all senders. | +| `groupPolicy` | `open` (default), `mention`, or `allowlist`. | +| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). | +| `allowRoomMentions` | Accept `@room` mentions in mention mode. | +| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. | +| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. | + + + + +**4. Run** + +```bash +nanobot gateway +``` + +
+
WhatsApp @@ -352,6 +420,10 @@ nanobot channels login nanobot gateway ``` +> WhatsApp bridge updates are not applied automatically for existing installations. +> If you upgrade nanobot and need the latest WhatsApp bridge, run: +> `rm -rf ~/.nanobot/bridge && nanobot channels login` +
@@ -362,7 +434,7 @@ Uses **WebSocket** long connection — no public IP required. **1. Create a Feishu bot** - Visit [Feishu Open Platform](https://open.feishu.cn/app) - Create a new app → Enable **Bot** capability -- **Permissions**: Add `im:message` (send messages) +- **Permissions**: Add `im:message` (send messages) and `im:message.p2p_msg:readonly` (receive messages) - **Events**: Add `im.message.receive_v1` (receive messages) - Select **Long Connection** mode (requires running nanobot first to establish connection) - Get **App ID** and **App Secret** from "Credentials & Basic Info" @@ -379,14 +451,14 @@ Uses **WebSocket** long connection — no public IP required. "appSecret": "xxx", "encryptKey": "", "verificationToken": "", - "allowFrom": [] + "allowFrom": ["ou_YOUR_OPEN_ID"] } } } ``` > `encryptKey` and `verificationToken` are optional for Long Connection mode. -> `allowFrom`: Leave empty to allow all users, or add `["ou_xxx"]` to restrict access. +> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users. **3. Run** @@ -416,7 +488,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports **3. Configure** -> - `allowFrom`: Leave empty for public access, or add user openids to restrict. You can find openids in the nanobot logs when a user messages the bot. +> - `allowFrom`: Add your openid (find it in nanobot logs when you message the bot). Use `["*"]` for public access. > - For production: submit a review in the bot console and publish. See [QQ Bot Docs](https://bot.q.qq.com/wiki/) for the full publishing flow. ```json @@ -426,7 +498,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports "enabled": true, "appId": "YOUR_APP_ID", "secret": "YOUR_APP_SECRET", - "allowFrom": [] + "allowFrom": ["YOUR_OPENID"] } } } @@ -465,13 +537,13 @@ Uses **Stream Mode** — no public IP required. "enabled": true, "clientId": "YOUR_APP_KEY", "clientSecret": "YOUR_APP_SECRET", - "allowFrom": [] + "allowFrom": ["YOUR_STAFF_ID"] } } } ``` -> `allowFrom`: Leave empty to allow all users, or add `["staffId"]` to restrict access. +> `allowFrom`: Add your staff ID. Use `["*"]` to allow all users. **3. Run** @@ -506,6 +578,7 @@ Uses **Socket Mode** — no public URL required. "enabled": true, "botToken": "xoxb-...", "appToken": "xapp-...", + "allowFrom": ["YOUR_SLACK_USER_ID"], "groupPolicy": "mention" } } @@ -539,7 +612,7 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl **2. Configure** > - `consentGranted` must be `true` to allow mailbox access. This is a safety gate — set `false` to fully disable. -> - `allowFrom`: Leave empty to accept emails from anyone, or restrict to specific senders. +> - `allowFrom`: Add your email address. Use `["*"]` to accept emails from anyone. > - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly. > - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies. @@ -594,21 +667,121 @@ Config file: `~/.nanobot/config.json` > - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config. +> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config. +> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config. | Provider | Purpose | Get API Key | |----------|---------|-------------| +| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | +| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | | `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | | `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | -| `minimax` | LLM (MiniMax direct) | [platform.minimax.io](https://platform.minimax.io) | +| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | +| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) | +| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `vllm` | LLM (local, any OpenAI-compatible server) | — | +| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` | +| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` | + +
+OpenAI Codex (OAuth) + +Codex uses OAuth instead of API keys. Requires a ChatGPT Plus or Pro account. + +**1. Login:** +```bash +nanobot provider login openai-codex +``` + +**2. Set model** (merge into `~/.nanobot/config.json`): +```json +{ + "agents": { + "defaults": { + "model": "openai-codex/gpt-5.1-codex" + } + } +} +``` + +**3. Chat:** +```bash +nanobot agent -m "Hello!" +``` + +> Docker users: use `docker run -it` for interactive OAuth login. + +
+ +
+Custom Provider (Any OpenAI-compatible API) + +Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Bypasses LiteLLM; model name is passed as-is. + +```json +{ + "providers": { + "custom": { + "apiKey": "your-api-key", + "apiBase": "https://api.your-provider.com/v1" + } + }, + "agents": { + "defaults": { + "model": "your-model-name" + } + } +} +``` + +> For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`). + +
+ +
+vLLM (local / OpenAI-compatible) + +Run your own model with vLLM or any OpenAI-compatible server, then add to config: + +**1. Start the server** (example): +```bash +vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000 +``` + +**2. Add to config** (partial — merge into `~/.nanobot/config.json`): + +*Provider (key can be any non-empty string for local):* +```json +{ + "providers": { + "vllm": { + "apiKey": "dummy", + "apiBase": "http://localhost:8000/v1" + } + } +} +``` + +*Model:* +```json +{ + "agents": { + "defaults": { + "model": "meta-llama/Llama-3.1-8B-Instruct" + } + } +} +``` + +
Adding a New Provider (Developer Guide) @@ -655,16 +828,101 @@ That's it! Environment variables, model prefixing, config matching, and `nanobot
+### MCP (Model Context Protocol) + +> [!TIP] +> The config format is compatible with Claude Desktop / Cursor. You can copy MCP server configs directly from any MCP server's README. + +nanobot supports [MCP](https://modelcontextprotocol.io/) — connect external tool servers and use them as native agent tools. + +Add MCP servers to your `config.json`: + +```json +{ + "tools": { + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"] + }, + "my-remote-mcp": { + "url": "https://example.com/mcp/", + "headers": { + "Authorization": "Bearer xxxxx" + } + } + } + } +} +``` + +Two transport modes are supported: + +| Mode | Config | Example | +|------|--------|---------| +| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` | +| **HTTP** | `url` + `headers` (optional) | Remote endpoint (`https://mcp.example.com/sse`) | + +Use `toolTimeout` to override the default 30s per-call timeout for slow servers: + +```json +{ + "tools": { + "mcpServers": { + "my-slow-server": { + "url": "https://example.com/mcp/", + "toolTimeout": 120 + } + } + } +} +``` + +MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed. + + + + ### Security +> [!TIP] > For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent. +> **Change in source / post-`v0.1.4.post3`:** In `v0.1.4.post3` and earlier, an empty `allowFrom` means "allow all senders". In newer versions (including building from source), **empty `allowFrom` denies all access by default**. To allow all senders, set `"allowFrom": ["*"]`. | Option | Default | Description | |--------|---------|-------------| | `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. | +| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). | | `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. | +## Multiple Instances + +Run multiple nanobot instances simultaneously, each with its own workspace and configuration. + +```bash +# Instance A - Telegram bot +nanobot gateway -w ~/.nanobot/botA -p 18791 + +# Instance B - Discord bot +nanobot gateway -w ~/.nanobot/botB -p 18792 + +# Instance C - Using custom config file +nanobot gateway -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json -p 18793 +``` + +| Option | Short | Description | +|--------|-------|-------------| +| `--workspace` | `-w` | Workspace directory (default: `~/.nanobot/workspace`) | +| `--config` | `-c` | Config file path (default: `~/.nanobot/config.json`) | +| `--port` | `-p` | Gateway port (default: `18790`) | + +Each instance has its own: +- Workspace directory (MEMORY.md, HEARTBEAT.md, session files) +- Cron jobs storage (`workspace/cron/jobs.json`) +- Configuration (if using `--config`) + + ## CLI Reference | Command | Description | @@ -676,26 +934,30 @@ That's it! Environment variables, model prefixing, config matching, and `nanobot | `nanobot agent --logs` | Show runtime logs during chat | | `nanobot gateway` | Start the gateway | | `nanobot status` | Show status | +| `nanobot provider login openai-codex` | OAuth login for providers | | `nanobot channels login` | Link WhatsApp (scan QR) | | `nanobot channels status` | Show channel status | Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`.
-Scheduled Tasks (Cron) +Heartbeat (Periodic Tasks) -```bash -# Add a job -nanobot cron add --name "daily" --message "Good morning!" --cron "0 9 * * *" -nanobot cron add --name "hourly" --message "Check status" --every 3600 +The gateway wakes up every 30 minutes and checks `HEARTBEAT.md` in your workspace (`~/.nanobot/workspace/HEARTBEAT.md`). If the file has tasks, the agent executes them and delivers results to your most recently active chat channel. -# List jobs -nanobot cron list +**Setup:** edit `~/.nanobot/workspace/HEARTBEAT.md` (created automatically by `nanobot onboard`): -# Remove a job -nanobot cron remove +```markdown +## Periodic Tasks + +- [ ] Check weather forecast and send a summary +- [ ] Scan inbox for urgent emails ``` +The agent can also manage this file itself — ask it to "add a periodic task" and it will update `HEARTBEAT.md` for you. + +> **Note:** The gateway must be running (`nanobot gateway`) and you must have chatted with the bot at least once so it knows which channel to deliver to. +
## 🐳 Docker @@ -703,7 +965,21 @@ nanobot cron remove > [!TIP] > The `-v ~/.nanobot:/root/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts. -Build and run nanobot in a container: +### Docker Compose + +```bash +docker compose run --rm nanobot-cli onboard # first-time setup +vim ~/.nanobot/config.json # add API keys +docker compose up -d nanobot-gateway # start gateway +``` + +```bash +docker compose run --rm nanobot-cli agent -m "Hello!" # run CLI +docker compose logs -f nanobot-gateway # view logs +docker compose down # stop +``` + +### Docker ```bash # Build the image @@ -723,6 +999,59 @@ docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!" docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status ``` +## 🐧 Linux Service + +Run the gateway as a systemd user service so it starts automatically and restarts on failure. + +**1. Find the nanobot binary path:** + +```bash +which nanobot # e.g. /home/user/.local/bin/nanobot +``` + +**2. Create the service file** at `~/.config/systemd/user/nanobot-gateway.service` (replace `ExecStart` path if needed): + +```ini +[Unit] +Description=Nanobot Gateway +After=network.target + +[Service] +Type=simple +ExecStart=%h/.local/bin/nanobot gateway +Restart=always +RestartSec=10 +NoNewPrivileges=yes +ProtectSystem=strict +ReadWritePaths=%h + +[Install] +WantedBy=default.target +``` + +**3. Enable and start:** + +```bash +systemctl --user daemon-reload +systemctl --user enable --now nanobot-gateway +``` + +**Common operations:** + +```bash +systemctl --user status nanobot-gateway # check status +systemctl --user restart nanobot-gateway # restart after config changes +journalctl --user -u nanobot-gateway -f # follow logs +``` + +If you edit the `.service` file itself, run `systemctl --user daemon-reload` before restarting. + +> **Note:** User services only run while you are logged in. To keep the gateway running after logout, enable lingering: +> +> ```bash +> loginctl enable-linger $USER +> ``` + ## 📁 Project Structure ``` @@ -751,7 +1080,6 @@ PRs welcome! The codebase is intentionally small and readable. 🤗 **Roadmap** — Pick an item and [open a PR](https://github.com/HKUDS/nanobot/pulls)! -- [x] **Voice Transcription** — Support for Groq Whisper (Issue #13) - [ ] **Multi-modal** — See and hear (images, voice, video) - [ ] **Long-term memory** — Never forget important context - [ ] **Better reasoning** — Multi-step planning and reflection diff --git a/SECURITY.md b/SECURITY.md index ac15ba4..af4da71 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -5,7 +5,7 @@ If you discover a security vulnerability in nanobot, please report it by: 1. **DO NOT** open a public GitHub issue -2. Create a private security advisory on GitHub or contact the repository maintainers +2. Create a private security advisory on GitHub or contact the repository maintainers (xubinrencs@gmail.com) 3. Include: - Description of the vulnerability - Steps to reproduce @@ -55,7 +55,7 @@ chmod 600 ~/.nanobot/config.json ``` **Security Notes:** -- Empty `allowFrom` list will **ALLOW ALL** users (open by default for personal use) +- In `v0.1.4.post3` and earlier, an empty `allowFrom` allows all users. In newer versions (including source builds), **empty `allowFrom` denies all access** — set `["*"]` to explicitly allow everyone. - Get your Telegram user ID from `@userinfobot` - Use full phone numbers with country code for WhatsApp - Review access logs regularly for unauthorized access attempts @@ -95,8 +95,8 @@ File operations have path traversal protection, but: - Consider using a firewall to restrict outbound connections if needed **WhatsApp Bridge:** -- The bridge runs on `localhost:3001` by default -- If exposing to network, use proper authentication and TLS +- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network) +- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js - Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700) ### 6. Dependency Security @@ -212,9 +212,8 @@ If you suspect a security breach: - Input length limits on HTTP requests ✅ **Authentication** -- Allow-list based access control +- Allow-list based access control — in `v0.1.4.post3` and earlier empty means allow all; in newer versions empty means deny all (`["*"]` to explicitly allow all) - Failed authentication attempt logging -- Open by default (configure allowFrom for production use) ✅ **Resource Protection** - Command execution timeouts (60s default) @@ -224,7 +223,7 @@ If you suspect a security breach: ✅ **Secure Communication** - HTTPS for all external API calls - TLS for Telegram API -- WebSocket security for WhatsApp bridge +- WhatsApp bridge: localhost-only binding + optional token auth ## Known Limitations diff --git a/bridge/src/index.ts b/bridge/src/index.ts index 8db63ef..e8f3db9 100644 --- a/bridge/src/index.ts +++ b/bridge/src/index.ts @@ -25,11 +25,12 @@ import { join } from 'path'; const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10); const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth'); +const TOKEN = process.env.BRIDGE_TOKEN || undefined; console.log('🐈 nanobot WhatsApp Bridge'); console.log('========================\n'); -const server = new BridgeServer(PORT, AUTH_DIR); +const server = new BridgeServer(PORT, AUTH_DIR, TOKEN); // Handle graceful shutdown process.on('SIGINT', async () => { diff --git a/bridge/src/server.ts b/bridge/src/server.ts index c6fd599..7d48f5e 100644 --- a/bridge/src/server.ts +++ b/bridge/src/server.ts @@ -1,5 +1,6 @@ /** * WebSocket server for Python-Node.js bridge communication. + * Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth. */ import { WebSocketServer, WebSocket } from 'ws'; @@ -21,12 +22,13 @@ export class BridgeServer { private wa: WhatsAppClient | null = null; private clients: Set = new Set(); - constructor(private port: number, private authDir: string) {} + constructor(private port: number, private authDir: string, private token?: string) {} async start(): Promise { - // Create WebSocket server - this.wss = new WebSocketServer({ port: this.port }); - console.log(`🌉 Bridge server listening on ws://localhost:${this.port}`); + // Bind to localhost only — never expose to external network + this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port }); + console.log(`🌉 Bridge server listening on ws://127.0.0.1:${this.port}`); + if (this.token) console.log('🔒 Token authentication enabled'); // Initialize WhatsApp client this.wa = new WhatsAppClient({ @@ -38,35 +40,58 @@ export class BridgeServer { // Handle WebSocket connections this.wss.on('connection', (ws) => { - console.log('🔗 Python client connected'); - this.clients.add(ws); - - ws.on('message', async (data) => { - try { - const cmd = JSON.parse(data.toString()) as SendCommand; - await this.handleCommand(cmd); - ws.send(JSON.stringify({ type: 'sent', to: cmd.to })); - } catch (error) { - console.error('Error handling command:', error); - ws.send(JSON.stringify({ type: 'error', error: String(error) })); - } - }); - - ws.on('close', () => { - console.log('🔌 Python client disconnected'); - this.clients.delete(ws); - }); - - ws.on('error', (error) => { - console.error('WebSocket error:', error); - this.clients.delete(ws); - }); + if (this.token) { + // Require auth handshake as first message + const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000); + ws.once('message', (data) => { + clearTimeout(timeout); + try { + const msg = JSON.parse(data.toString()); + if (msg.type === 'auth' && msg.token === this.token) { + console.log('🔗 Python client authenticated'); + this.setupClient(ws); + } else { + ws.close(4003, 'Invalid token'); + } + } catch { + ws.close(4003, 'Invalid auth message'); + } + }); + } else { + console.log('🔗 Python client connected'); + this.setupClient(ws); + } }); // Connect to WhatsApp await this.wa.connect(); } + private setupClient(ws: WebSocket): void { + this.clients.add(ws); + + ws.on('message', async (data) => { + try { + const cmd = JSON.parse(data.toString()) as SendCommand; + await this.handleCommand(cmd); + ws.send(JSON.stringify({ type: 'sent', to: cmd.to })); + } catch (error) { + console.error('Error handling command:', error); + ws.send(JSON.stringify({ type: 'error', error: String(error) })); + } + }); + + ws.on('close', () => { + console.log('🔌 Python client disconnected'); + this.clients.delete(ws); + }); + + ws.on('error', (error) => { + console.error('WebSocket error:', error); + this.clients.delete(ws); + }); + } + private async handleCommand(cmd: SendCommand): Promise { if (cmd.type === 'send' && this.wa) { await this.wa.sendMessage(cmd.to, cmd.text); diff --git a/bridge/src/whatsapp.ts b/bridge/src/whatsapp.ts index 069d72b..b91bacc 100644 --- a/bridge/src/whatsapp.ts +++ b/bridge/src/whatsapp.ts @@ -9,11 +9,17 @@ import makeWASocket, { useMultiFileAuthState, fetchLatestBaileysVersion, makeCacheableSignalKeyStore, + downloadMediaMessage, + extractMessageContent as baileysExtractMessageContent, } from '@whiskeysockets/baileys'; import { Boom } from '@hapi/boom'; import qrcode from 'qrcode-terminal'; import pino from 'pino'; +import { writeFile, mkdir } from 'fs/promises'; +import { join } from 'path'; +import { homedir } from 'os'; +import { randomBytes } from 'crypto'; const VERSION = '0.1.0'; @@ -24,6 +30,7 @@ export interface InboundMessage { content: string; timestamp: number; isGroup: boolean; + media?: string[]; } export interface WhatsAppClientOptions { @@ -110,14 +117,33 @@ export class WhatsAppClient { if (type !== 'notify') return; for (const msg of messages) { - // Skip own messages if (msg.key.fromMe) continue; - - // Skip status updates if (msg.key.remoteJid === 'status@broadcast') continue; - const content = this.extractMessageContent(msg); - if (!content) continue; + const unwrapped = baileysExtractMessageContent(msg.message); + if (!unwrapped) continue; + + const content = this.getTextContent(unwrapped); + let fallbackContent: string | null = null; + const mediaPaths: string[] = []; + + if (unwrapped.imageMessage) { + fallbackContent = '[Image]'; + const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined); + if (path) mediaPaths.push(path); + } else if (unwrapped.documentMessage) { + fallbackContent = '[Document]'; + const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined, + unwrapped.documentMessage.fileName ?? undefined); + if (path) mediaPaths.push(path); + } else if (unwrapped.videoMessage) { + fallbackContent = '[Video]'; + const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined); + if (path) mediaPaths.push(path); + } + + const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || ''; + if (!finalContent && mediaPaths.length === 0) continue; const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false; @@ -125,18 +151,45 @@ export class WhatsAppClient { id: msg.key.id || '', sender: msg.key.remoteJid || '', pn: msg.key.remoteJidAlt || '', - content, + content: finalContent, timestamp: msg.messageTimestamp as number, isGroup, + ...(mediaPaths.length > 0 ? { media: mediaPaths } : {}), }); } }); } - private extractMessageContent(msg: any): string | null { - const message = msg.message; - if (!message) return null; + private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise { + try { + const mediaDir = join(homedir(), '.nanobot', 'media'); + await mkdir(mediaDir, { recursive: true }); + const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer; + + let outFilename: string; + if (fileName) { + // Documents have a filename — use it with a unique prefix to avoid collisions + const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`; + outFilename = prefix + fileName; + } else { + const mime = mimetype || 'application/octet-stream'; + // Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf") + const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin'); + outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`; + } + + const filepath = join(mediaDir, outFilename); + await writeFile(filepath, buffer); + + return filepath; + } catch (err) { + console.error('Failed to download media:', err); + return null; + } + } + + private getTextContent(message: any): string | null { // Text message if (message.conversation) { return message.conversation; @@ -147,19 +200,19 @@ export class WhatsAppClient { return message.extendedTextMessage.text; } - // Image with caption - if (message.imageMessage?.caption) { - return `[Image] ${message.imageMessage.caption}`; + // Image with optional caption + if (message.imageMessage) { + return message.imageMessage.caption || ''; } - // Video with caption - if (message.videoMessage?.caption) { - return `[Video] ${message.videoMessage.caption}`; + // Video with optional caption + if (message.videoMessage) { + return message.videoMessage.caption || ''; } - // Document with caption - if (message.documentMessage?.caption) { - return `[Document] ${message.documentMessage.caption}`; + // Document with optional caption + if (message.documentMessage) { + return message.documentMessage.caption || ''; } // Voice/Audio message diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..5c27f81 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,31 @@ +x-common-config: &common-config + build: + context: . + dockerfile: Dockerfile + volumes: + - ~/.nanobot:/root/.nanobot + +services: + nanobot-gateway: + container_name: nanobot-gateway + <<: *common-config + command: ["gateway"] + restart: unless-stopped + ports: + - 18790:18790 + deploy: + resources: + limits: + cpus: '1' + memory: 1G + reservations: + cpus: '0.25' + memory: 256M + + nanobot-cli: + <<: *common-config + profiles: + - cli + command: ["status"] + stdin_open: true + tty: true diff --git a/nanobot/__init__.py b/nanobot/__init__.py index ee0445b..4dba5f4 100644 --- a/nanobot/__init__.py +++ b/nanobot/__init__.py @@ -2,5 +2,5 @@ nanobot - A lightweight AI agent framework """ -__version__ = "0.1.0" +__version__ = "0.1.4.post3" __logo__ = "🐈" diff --git a/nanobot/agent/__init__.py b/nanobot/agent/__init__.py index c3fc97b..f9ba8b8 100644 --- a/nanobot/agent/__init__.py +++ b/nanobot/agent/__init__.py @@ -1,7 +1,7 @@ """Agent core module.""" -from nanobot.agent.loop import AgentLoop from nanobot.agent.context import ContextBuilder +from nanobot.agent.loop import AgentLoop from nanobot.agent.memory import MemoryStore from nanobot.agent.skills import SkillsLoader diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index d807854..27511fa 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -3,62 +3,45 @@ import base64 import mimetypes import platform +import time +from datetime import datetime from pathlib import Path from typing import Any from nanobot.agent.memory import MemoryStore from nanobot.agent.skills import SkillsLoader +from nanobot.utils.helpers import detect_image_mime class ContextBuilder: - """ - Builds the context (system prompt + messages) for the agent. - - Assembles bootstrap files, memory, skills, and conversation history - into a coherent prompt for the LLM. - """ - + """Builds the context (system prompt + messages) for the agent.""" + BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"] - + _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" + def __init__(self, workspace: Path): self.workspace = workspace self.memory = MemoryStore(workspace) self.skills = SkillsLoader(workspace) - + def build_system_prompt(self, skill_names: list[str] | None = None) -> str: - """ - Build the system prompt from bootstrap files, memory, and skills. - - Args: - skill_names: Optional list of skills to include. - - Returns: - Complete system prompt. - """ - parts = [] - - # Core identity - parts.append(self._get_identity()) - - # Bootstrap files + """Build the system prompt from identity, bootstrap files, memory, and skills.""" + parts = [self._get_identity()] + bootstrap = self._load_bootstrap_files() if bootstrap: parts.append(bootstrap) - - # Memory context + memory = self.memory.get_memory_context() if memory: parts.append(f"# Memory\n\n{memory}") - - # Skills - progressive loading - # 1. Always-loaded skills: include full content + always_skills = self.skills.get_always_skills() if always_skills: always_content = self.skills.load_skills_for_context(always_skills) if always_content: parts.append(f"# Active Skills\n\n{always_content}") - - # 2. Available skills: only show summary (agent uses read_file to load) + skills_summary = self.skills.build_skills_summary() if skills_summary: parts.append(f"""# Skills @@ -67,57 +50,59 @@ The following skills extend your capabilities. To use a skill, read its SKILL.md Skills with available="false" need dependencies installed first - you can try installing them with apt/brew. {skills_summary}""") - + return "\n\n---\n\n".join(parts) - + def _get_identity(self) -> str: """Get the core identity section.""" - from datetime import datetime - now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") workspace_path = str(self.workspace.expanduser().resolve()) system = platform.system() runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" - + return f"""# nanobot 🐈 -You are nanobot, a helpful AI assistant. You have access to tools that allow you to: -- Read, write, and edit files -- Execute shell commands -- Search the web and fetch web pages -- Send messages to users on chat channels -- Spawn subagents for complex background tasks - -## Current Time -{now} +You are nanobot, a helpful AI assistant. ## Runtime {runtime} ## Workspace Your workspace is at: {workspace_path} -- Memory files: {workspace_path}/memory/MEMORY.md -- Daily notes: {workspace_path}/memory/YYYY-MM-DD.md +- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here) +- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md -IMPORTANT: When responding to direct questions or conversations, reply directly with your text response. -Only use the 'message' tool when you need to send a message to a specific chat channel (like WhatsApp). -For normal conversation, just respond with text - do not call the message tool. +## nanobot Guidelines +- State intent before tool calls, but NEVER predict or claim results before receiving them. +- Before modifying a file, read it first. Do not assume files or directories exist. +- After writing or editing a file, re-read it if accuracy matters. +- If a tool call fails, analyze the error before retrying with a different approach. +- Ask for clarification when the request is ambiguous. + +Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel.""" + + @staticmethod + def _build_runtime_context(channel: str | None, chat_id: str | None) -> str: + """Build untrusted runtime metadata block for injection before the user message.""" + now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") + tz = time.strftime("%Z") or "UTC" + lines = [f"Current Time: {now} ({tz})"] + if channel and chat_id: + lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] + return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) -Always be helpful, accurate, and concise. When using tools, explain what you're doing. -When remembering something, write to {workspace_path}/memory/MEMORY.md""" - def _load_bootstrap_files(self) -> str: """Load all bootstrap files from workspace.""" parts = [] - + for filename in self.BOOTSTRAP_FILES: file_path = self.workspace / filename if file_path.exists(): content = file_path.read_text(encoding="utf-8") parts.append(f"## {filename}\n\n{content}") - + return "\n\n".join(parts) if parts else "" - + def build_messages( self, history: list[dict[str, Any]], @@ -127,109 +112,67 @@ When remembering something, write to {workspace_path}/memory/MEMORY.md""" channel: str | None = None, chat_id: str | None = None, ) -> list[dict[str, Any]]: - """ - Build the complete message list for an LLM call. - - Args: - history: Previous conversation messages. - current_message: The new user message. - skill_names: Optional skills to include. - media: Optional list of local file paths for images/media. - channel: Current channel (telegram, feishu, etc.). - chat_id: Current chat/user ID. - - Returns: - List of messages including system prompt. - """ - messages = [] - - # System prompt - system_prompt = self.build_system_prompt(skill_names) - if channel and chat_id: - system_prompt += f"\n\n## Current Session\nChannel: {channel}\nChat ID: {chat_id}" - messages.append({"role": "system", "content": system_prompt}) - - # History - messages.extend(history) - - # Current message (with optional image attachments) + """Build the complete message list for an LLM call.""" + runtime_ctx = self._build_runtime_context(channel, chat_id) user_content = self._build_user_content(current_message, media) - messages.append({"role": "user", "content": user_content}) - return messages + # Merge runtime context and user content into a single user message + # to avoid consecutive same-role messages that some providers reject. + if isinstance(user_content, str): + merged = f"{runtime_ctx}\n\n{user_content}" + else: + merged = [{"type": "text", "text": runtime_ctx}] + user_content + + return [ + {"role": "system", "content": self.build_system_prompt(skill_names)}, + *history, + {"role": "user", "content": merged}, + ] def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: """Build user message content with optional base64-encoded images.""" if not media: return text - + images = [] for path in media: p = Path(path) - mime, _ = mimetypes.guess_type(path) - if not p.is_file() or not mime or not mime.startswith("image/"): + if not p.is_file(): continue - b64 = base64.b64encode(p.read_bytes()).decode() + raw = p.read_bytes() + # Detect real MIME type from magic bytes; fallback to filename guess + mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0] + if not mime or not mime.startswith("image/"): + continue + b64 = base64.b64encode(raw).decode() images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) - + if not images: return text return images + [{"type": "text", "text": text}] - + def add_tool_result( - self, - messages: list[dict[str, Any]], - tool_call_id: str, - tool_name: str, - result: str + self, messages: list[dict[str, Any]], + tool_call_id: str, tool_name: str, result: str, ) -> list[dict[str, Any]]: - """ - Add a tool result to the message list. - - Args: - messages: Current message list. - tool_call_id: ID of the tool call. - tool_name: Name of the tool. - result: Tool execution result. - - Returns: - Updated message list. - """ - messages.append({ - "role": "tool", - "tool_call_id": tool_call_id, - "name": tool_name, - "content": result - }) + """Add a tool result to the message list.""" + messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result}) return messages - + def add_assistant_message( - self, - messages: list[dict[str, Any]], + self, messages: list[dict[str, Any]], content: str | None, tool_calls: list[dict[str, Any]] | None = None, reasoning_content: str | None = None, + thinking_blocks: list[dict] | None = None, ) -> list[dict[str, Any]]: - """ - Add an assistant message to the message list. - - Args: - messages: Current message list. - content: Message content. - tool_calls: Optional tool calls. - reasoning_content: Thinking output (Kimi, DeepSeek-R1, etc.). - - Returns: - Updated message list. - """ - msg: dict[str, Any] = {"role": "assistant", "content": content or ""} - + """Add an assistant message to the message list.""" + msg: dict[str, Any] = {"role": "assistant", "content": content} if tool_calls: msg["tool_calls"] = tool_calls - - # Thinking models reject history without this - if reasoning_content: + if reasoning_content is not None: msg["reasoning_content"] = reasoning_content - + if thinking_blocks: + msg["thinking_blocks"] = thinking_blocks messages.append(msg) return messages diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b764c3d..ca9a06e 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -1,31 +1,41 @@ """Agent loop: the core processing engine.""" +from __future__ import annotations + import asyncio import json +import re +import weakref +from contextlib import AsyncExitStack from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger +from nanobot.agent.context import ContextBuilder +from nanobot.agent.memory import MemoryStore +from nanobot.agent.subagent import SubagentManager +from nanobot.agent.tools.cron import CronTool +from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.registry import ToolRegistry +from nanobot.agent.tools.shell import ExecTool +from nanobot.agent.tools.spawn import SpawnTool +from nanobot.agent.tools.web import WebFetchTool, WebSearchTool from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMProvider -from nanobot.agent.context import ContextBuilder -from nanobot.agent.tools.registry import ToolRegistry -from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool -from nanobot.agent.tools.shell import ExecTool -from nanobot.agent.tools.web import WebSearchTool, WebFetchTool -from nanobot.agent.tools.message import MessageTool -from nanobot.agent.tools.spawn import SpawnTool -from nanobot.agent.tools.cron import CronTool -from nanobot.agent.subagent import SubagentManager -from nanobot.session.manager import SessionManager +from nanobot.session.manager import Session, SessionManager + +if TYPE_CHECKING: + from nanobot.config.schema import ChannelsConfig, ExecToolConfig + from nanobot.cron.service import CronService class AgentLoop: """ The agent loop is the core processing engine. - + It: 1. Receives messages from the bus 2. Builds context with history, memory, skills @@ -33,32 +43,46 @@ class AgentLoop: 4. Executes tool calls 5. Sends responses back """ - + + _TOOL_RESULT_MAX_CHARS = 500 + def __init__( self, bus: MessageBus, provider: LLMProvider, workspace: Path, model: str | None = None, - max_iterations: int = 20, + max_iterations: int = 40, + temperature: float = 0.1, + max_tokens: int = 4096, + memory_window: int = 100, + reasoning_effort: str | None = None, brave_api_key: str | None = None, - exec_config: "ExecToolConfig | None" = None, - cron_service: "CronService | None" = None, + web_proxy: str | None = None, + exec_config: ExecToolConfig | None = None, + cron_service: CronService | None = None, restrict_to_workspace: bool = False, session_manager: SessionManager | None = None, + mcp_servers: dict | None = None, + channels_config: ChannelsConfig | None = None, ): from nanobot.config.schema import ExecToolConfig - from nanobot.cron.service import CronService self.bus = bus + self.channels_config = channels_config self.provider = provider self.workspace = workspace self.model = model or provider.get_default_model() self.max_iterations = max_iterations + self.temperature = temperature + self.max_tokens = max_tokens + self.memory_window = memory_window + self.reasoning_effort = reasoning_effort self.brave_api_key = brave_api_key + self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() self.cron_service = cron_service self.restrict_to_workspace = restrict_to_workspace - + self.context = ContextBuilder(workspace) self.sessions = session_manager or SessionManager(workspace) self.tools = ToolRegistry() @@ -67,312 +91,419 @@ class AgentLoop: workspace=workspace, bus=bus, model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + reasoning_effort=reasoning_effort, brave_api_key=brave_api_key, + web_proxy=web_proxy, exec_config=self.exec_config, restrict_to_workspace=restrict_to_workspace, ) - + self._running = False + self._mcp_servers = mcp_servers or {} + self._mcp_stack: AsyncExitStack | None = None + self._mcp_connected = False + self._mcp_connecting = False + self._consolidating: set[str] = set() # Session keys with consolidation in progress + self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks + self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() + self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks + self._processing_lock = asyncio.Lock() self._register_default_tools() - + def _register_default_tools(self) -> None: """Register the default set of tools.""" - # File tools (restrict to workspace if configured) allowed_dir = self.workspace if self.restrict_to_workspace else None - self.tools.register(ReadFileTool(allowed_dir=allowed_dir)) - self.tools.register(WriteFileTool(allowed_dir=allowed_dir)) - self.tools.register(EditFileTool(allowed_dir=allowed_dir)) - self.tools.register(ListDirTool(allowed_dir=allowed_dir)) - - # Shell tool + for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool): + self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(ExecTool( working_dir=str(self.workspace), timeout=self.exec_config.timeout, restrict_to_workspace=self.restrict_to_workspace, + path_append=self.exec_config.path_append, )) - - # Web tools - self.tools.register(WebSearchTool(api_key=self.brave_api_key)) - self.tools.register(WebFetchTool()) - - # Message tool - message_tool = MessageTool(send_callback=self.bus.publish_outbound) - self.tools.register(message_tool) - - # Spawn tool (for subagents) - spawn_tool = SpawnTool(manager=self.subagents) - self.tools.register(spawn_tool) - - # Cron tool (for scheduling) + self.tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy)) + self.tools.register(WebFetchTool(proxy=self.web_proxy)) + self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) + self.tools.register(SpawnTool(manager=self.subagents)) if self.cron_service: self.tools.register(CronTool(self.cron_service)) - + + async def _connect_mcp(self) -> None: + """Connect to configured MCP servers (one-time, lazy).""" + if self._mcp_connected or self._mcp_connecting or not self._mcp_servers: + return + self._mcp_connecting = True + from nanobot.agent.tools.mcp import connect_mcp_servers + try: + self._mcp_stack = AsyncExitStack() + await self._mcp_stack.__aenter__() + await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) + self._mcp_connected = True + except Exception as e: + logger.error("Failed to connect MCP servers (will retry next message): {}", e) + if self._mcp_stack: + try: + await self._mcp_stack.aclose() + except Exception: + pass + self._mcp_stack = None + finally: + self._mcp_connecting = False + + def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: + """Update context for all tools that need routing info.""" + for name in ("message", "spawn", "cron"): + if tool := self.tools.get(name): + if hasattr(tool, "set_context"): + tool.set_context(channel, chat_id, *([message_id] if name == "message" else [])) + + @staticmethod + def _strip_think(text: str | None) -> str | None: + """Remove blocks that some models embed in content.""" + if not text: + return None + return re.sub(r"[\s\S]*?", "", text).strip() or None + + @staticmethod + def _tool_hint(tool_calls: list) -> str: + """Format tool calls as concise hint, e.g. 'web_search("query")'.""" + def _fmt(tc): + args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {} + val = next(iter(args.values()), None) if isinstance(args, dict) else None + if not isinstance(val, str): + return tc.name + return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")' + return ", ".join(_fmt(tc) for tc in tool_calls) + + async def _run_agent_loop( + self, + initial_messages: list[dict], + on_progress: Callable[..., Awaitable[None]] | None = None, + ) -> tuple[str | None, list[str], list[dict]]: + """Run the agent iteration loop. Returns (final_content, tools_used, messages).""" + messages = initial_messages + iteration = 0 + final_content = None + tools_used: list[str] = [] + + while iteration < self.max_iterations: + iteration += 1 + + response = await self.provider.chat( + messages=messages, + tools=self.tools.get_definitions(), + model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + reasoning_effort=self.reasoning_effort, + ) + + if response.has_tool_calls: + if on_progress: + thought = self._strip_think(response.content) + if thought: + await on_progress(thought) + await on_progress(self._tool_hint(response.tool_calls), tool_hint=True) + + tool_call_dicts = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.arguments, ensure_ascii=False) + } + } + for tc in response.tool_calls + ] + messages = self.context.add_assistant_message( + messages, response.content, tool_call_dicts, + reasoning_content=response.reasoning_content, + thinking_blocks=response.thinking_blocks, + ) + + for tool_call in response.tool_calls: + tools_used.append(tool_call.name) + args_str = json.dumps(tool_call.arguments, ensure_ascii=False) + logger.info("Tool call: {}({})", tool_call.name, args_str[:200]) + result = await self.tools.execute(tool_call.name, tool_call.arguments) + messages = self.context.add_tool_result( + messages, tool_call.id, tool_call.name, result + ) + else: + clean = self._strip_think(response.content) + # Don't persist error responses to session history — they can + # poison the context and cause permanent 400 loops (#1303). + if response.finish_reason == "error": + logger.error("LLM returned error: {}", (clean or "")[:200]) + final_content = clean or "Sorry, I encountered an error calling the AI model." + break + messages = self.context.add_assistant_message( + messages, clean, reasoning_content=response.reasoning_content, + thinking_blocks=response.thinking_blocks, + ) + final_content = clean + break + + if final_content is None and iteration >= self.max_iterations: + logger.warning("Max iterations ({}) reached", self.max_iterations) + final_content = ( + f"I reached the maximum number of tool call iterations ({self.max_iterations}) " + "without completing the task. You can try breaking the task into smaller steps." + ) + + return final_content, tools_used, messages + async def run(self) -> None: - """Run the agent loop, processing messages from the bus.""" + """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" self._running = True + await self._connect_mcp() logger.info("Agent loop started") - + while self._running: try: - # Wait for next message - msg = await asyncio.wait_for( - self.bus.consume_inbound(), - timeout=1.0 - ) - - # Process it - try: - response = await self._process_message(msg) - if response: - await self.bus.publish_outbound(response) - except Exception as e: - logger.error(f"Error processing message: {e}") - # Send error response - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content=f"Sorry, I encountered an error: {str(e)}" - )) + msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0) except asyncio.TimeoutError: continue - + + if msg.content.strip().lower() == "/stop": + await self._handle_stop(msg) + else: + task = asyncio.create_task(self._dispatch(msg)) + self._active_tasks.setdefault(msg.session_key, []).append(task) + task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) + + async def _handle_stop(self, msg: InboundMessage) -> None: + """Cancel all active tasks and subagents for the session.""" + tasks = self._active_tasks.pop(msg.session_key, []) + cancelled = sum(1 for t in tasks if not t.done() and t.cancel()) + for t in tasks: + try: + await t + except (asyncio.CancelledError, Exception): + pass + sub_cancelled = await self.subagents.cancel_by_session(msg.session_key) + total = cancelled + sub_cancelled + content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop." + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content=content, + )) + + async def _dispatch(self, msg: InboundMessage) -> None: + """Process a message under the global lock.""" + async with self._processing_lock: + try: + response = await self._process_message(msg) + if response is not None: + await self.bus.publish_outbound(response) + elif msg.channel == "cli": + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="", metadata=msg.metadata or {}, + )) + except asyncio.CancelledError: + logger.info("Task cancelled for session {}", msg.session_key) + raise + except Exception: + logger.exception("Error processing message for session {}", msg.session_key) + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="Sorry, I encountered an error.", + )) + + async def close_mcp(self) -> None: + """Close MCP connections.""" + if self._mcp_stack: + try: + await self._mcp_stack.aclose() + except (RuntimeError, BaseExceptionGroup): + pass # MCP SDK cancel scope cleanup is noisy but harmless + self._mcp_stack = None + def stop(self) -> None: """Stop the agent loop.""" self._running = False logger.info("Agent loop stopping") - - async def _process_message(self, msg: InboundMessage) -> OutboundMessage | None: - """ - Process a single inbound message. - - Args: - msg: The inbound message to process. - - Returns: - The response message, or None if no response needed. - """ - # Handle system messages (subagent announces) - # The chat_id contains the original "channel:chat_id" to route back to + + async def _process_message( + self, + msg: InboundMessage, + session_key: str | None = None, + on_progress: Callable[[str], Awaitable[None]] | None = None, + ) -> OutboundMessage | None: + """Process a single inbound message and return the response.""" + # System messages: parse origin from chat_id ("channel:chat_id") if msg.channel == "system": - return await self._process_system_message(msg) - + channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id + else ("cli", msg.chat_id)) + logger.info("Processing system message from {}", msg.sender_id) + key = f"{channel}:{chat_id}" + session = self.sessions.get_or_create(key) + self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) + history = session.get_history(max_messages=self.memory_window) + messages = self.context.build_messages( + history=history, + current_message=msg.content, channel=channel, chat_id=chat_id, + ) + final_content, _, all_msgs = await self._run_agent_loop(messages) + self._save_turn(session, all_msgs, 1 + len(history)) + self.sessions.save(session) + return OutboundMessage(channel=channel, chat_id=chat_id, + content=final_content or "Background task completed.") + preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content - logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}") - - # Get or create session - session = self.sessions.get_or_create(msg.session_key) - - # Update tool contexts - message_tool = self.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_context(msg.channel, msg.chat_id) - - spawn_tool = self.tools.get("spawn") - if isinstance(spawn_tool, SpawnTool): - spawn_tool.set_context(msg.channel, msg.chat_id) - - cron_tool = self.tools.get("cron") - if isinstance(cron_tool, CronTool): - cron_tool.set_context(msg.channel, msg.chat_id) - - # Build initial messages (use get_history for LLM-formatted messages) - messages = self.context.build_messages( - history=session.get_history(), + logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) + + key = session_key or msg.session_key + session = self.sessions.get_or_create(key) + + # Slash commands + cmd = msg.content.strip().lower() + if cmd == "/new": + lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) + self._consolidating.add(session.key) + try: + async with lock: + snapshot = session.messages[session.last_consolidated:] + if snapshot: + temp = Session(key=session.key) + temp.messages = list(snapshot) + if not await self._consolidate_memory(temp, archive_all=True): + return OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="Memory archival failed, session not cleared. Please try again.", + ) + except Exception: + logger.exception("/new archival failed for {}", session.key) + return OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="Memory archival failed, session not cleared. Please try again.", + ) + finally: + self._consolidating.discard(session.key) + + session.clear() + self.sessions.save(session) + self.sessions.invalidate(session.key) + return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, + content="New session started.") + if cmd == "/help": + return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, + content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands") + + unconsolidated = len(session.messages) - session.last_consolidated + if (unconsolidated >= self.memory_window and session.key not in self._consolidating): + self._consolidating.add(session.key) + lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) + + async def _consolidate_and_unlock(): + try: + async with lock: + await self._consolidate_memory(session) + finally: + self._consolidating.discard(session.key) + _task = asyncio.current_task() + if _task is not None: + self._consolidation_tasks.discard(_task) + + _task = asyncio.create_task(_consolidate_and_unlock()) + self._consolidation_tasks.add(_task) + + self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) + if message_tool := self.tools.get("message"): + if isinstance(message_tool, MessageTool): + message_tool.start_turn() + + history = session.get_history(max_messages=self.memory_window) + initial_messages = self.context.build_messages( + history=history, current_message=msg.content, media=msg.media if msg.media else None, - channel=msg.channel, - chat_id=msg.chat_id, + channel=msg.channel, chat_id=msg.chat_id, ) - - # Agent loop - iteration = 0 - final_content = None - - while iteration < self.max_iterations: - iteration += 1 - - # Call LLM - response = await self.provider.chat( - messages=messages, - tools=self.tools.get_definitions(), - model=self.model - ) - - # Handle tool calls - if response.has_tool_calls: - # Add assistant message with tool calls - tool_call_dicts = [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": json.dumps(tc.arguments) # Must be JSON string - } - } - for tc in response.tool_calls - ] - messages = self.context.add_assistant_message( - messages, response.content, tool_call_dicts, - reasoning_content=response.reasoning_content, - ) - - # Execute tools - for tool_call in response.tool_calls: - args_str = json.dumps(tool_call.arguments, ensure_ascii=False) - logger.info(f"Tool call: {tool_call.name}({args_str[:200]})") - result = await self.tools.execute(tool_call.name, tool_call.arguments) - messages = self.context.add_tool_result( - messages, tool_call.id, tool_call.name, result - ) - else: - # No tool calls, we're done - final_content = response.content - break - + + async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: + meta = dict(msg.metadata or {}) + meta["_progress"] = True + meta["_tool_hint"] = tool_hint + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, + )) + + final_content, _, all_msgs = await self._run_agent_loop( + initial_messages, on_progress=on_progress or _bus_progress, + ) + if final_content is None: final_content = "I've completed processing but have no response to give." - - # Log response preview + + self._save_turn(session, all_msgs, 1 + len(history)) + self.sessions.save(session) + + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + return None + preview = final_content[:120] + "..." if len(final_content) > 120 else final_content - logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}") - - # Save to session - session.add_message("user", msg.content) - session.add_message("assistant", final_content) - self.sessions.save(session) - + logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) return OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content=final_content, - metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts) + channel=msg.channel, chat_id=msg.chat_id, content=final_content, + metadata=msg.metadata or {}, ) - - async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None: - """ - Process a system message (e.g., subagent announce). - - The chat_id field contains "original_channel:original_chat_id" to route - the response back to the correct destination. - """ - logger.info(f"Processing system message from {msg.sender_id}") - - # Parse origin from chat_id (format: "channel:chat_id") - if ":" in msg.chat_id: - parts = msg.chat_id.split(":", 1) - origin_channel = parts[0] - origin_chat_id = parts[1] - else: - # Fallback - origin_channel = "cli" - origin_chat_id = msg.chat_id - - # Use the origin session for context - session_key = f"{origin_channel}:{origin_chat_id}" - session = self.sessions.get_or_create(session_key) - - # Update tool contexts - message_tool = self.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_context(origin_channel, origin_chat_id) - - spawn_tool = self.tools.get("spawn") - if isinstance(spawn_tool, SpawnTool): - spawn_tool.set_context(origin_channel, origin_chat_id) - - cron_tool = self.tools.get("cron") - if isinstance(cron_tool, CronTool): - cron_tool.set_context(origin_channel, origin_chat_id) - - # Build messages with the announce content - messages = self.context.build_messages( - history=session.get_history(), - current_message=msg.content, - channel=origin_channel, - chat_id=origin_chat_id, + + def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: + """Save new-turn messages into session, truncating large tool results.""" + from datetime import datetime + for m in messages[skip:]: + entry = dict(m) + role, content = entry.get("role"), entry.get("content") + if role == "assistant" and not content and not entry.get("tool_calls"): + continue # skip empty assistant messages — they poison session context + if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS: + entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + elif role == "user": + if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): + # Strip the runtime-context prefix, keep only the user text. + parts = content.split("\n\n", 1) + if len(parts) > 1 and parts[1].strip(): + entry["content"] = parts[1] + else: + continue + if isinstance(content, list): + filtered = [] + for c in content: + if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): + continue # Strip runtime context from multimodal messages + if (c.get("type") == "image_url" + and c.get("image_url", {}).get("url", "").startswith("data:image/")): + filtered.append({"type": "text", "text": "[image]"}) + else: + filtered.append(c) + if not filtered: + continue + entry["content"] = filtered + entry.setdefault("timestamp", datetime.now().isoformat()) + session.messages.append(entry) + session.updated_at = datetime.now() + + async def _consolidate_memory(self, session, archive_all: bool = False) -> bool: + """Delegate to MemoryStore.consolidate(). Returns True on success.""" + return await MemoryStore(self.workspace).consolidate( + session, self.provider, self.model, + archive_all=archive_all, memory_window=self.memory_window, ) - - # Agent loop (limited for announce handling) - iteration = 0 - final_content = None - - while iteration < self.max_iterations: - iteration += 1 - - response = await self.provider.chat( - messages=messages, - tools=self.tools.get_definitions(), - model=self.model - ) - - if response.has_tool_calls: - tool_call_dicts = [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": json.dumps(tc.arguments) - } - } - for tc in response.tool_calls - ] - messages = self.context.add_assistant_message( - messages, response.content, tool_call_dicts, - reasoning_content=response.reasoning_content, - ) - - for tool_call in response.tool_calls: - args_str = json.dumps(tool_call.arguments, ensure_ascii=False) - logger.info(f"Tool call: {tool_call.name}({args_str[:200]})") - result = await self.tools.execute(tool_call.name, tool_call.arguments) - messages = self.context.add_tool_result( - messages, tool_call.id, tool_call.name, result - ) - else: - final_content = response.content - break - - if final_content is None: - final_content = "Background task completed." - - # Save to session (mark as system message in history) - session.add_message("user", f"[System: {msg.sender_id}] {msg.content}") - session.add_message("assistant", final_content) - self.sessions.save(session) - - return OutboundMessage( - channel=origin_channel, - chat_id=origin_chat_id, - content=final_content - ) - + async def process_direct( self, content: str, session_key: str = "cli:direct", channel: str = "cli", chat_id: str = "direct", + on_progress: Callable[[str], Awaitable[None]] | None = None, ) -> str: - """ - Process a message directly (for CLI or cron usage). - - Args: - content: The message content. - session_key: Session identifier. - channel: Source channel (for context). - chat_id: Source chat ID (for context). - - Returns: - The agent's response. - """ - msg = InboundMessage( - channel=channel, - sender_id="user", - chat_id=chat_id, - content=content - ) - - response = await self._process_message(msg) + """Process a message directly (for CLI or cron usage).""" + await self._connect_mcp() + msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) + response = await self._process_message(msg, session_key=session_key, on_progress=on_progress) return response.content if response else "" diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 453407e..21fe77d 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -1,109 +1,157 @@ """Memory system for persistent agent memory.""" -from pathlib import Path -from datetime import datetime +from __future__ import annotations -from nanobot.utils.helpers import ensure_dir, today_date +import json +from pathlib import Path +from typing import TYPE_CHECKING + +from loguru import logger + +from nanobot.utils.helpers import ensure_dir + +if TYPE_CHECKING: + from nanobot.providers.base import LLMProvider + from nanobot.session.manager import Session + + +_SAVE_MEMORY_TOOL = [ + { + "type": "function", + "function": { + "name": "save_memory", + "description": "Save the memory consolidation result to persistent storage.", + "parameters": { + "type": "object", + "properties": { + "history_entry": { + "type": "string", + "description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. " + "Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.", + }, + "memory_update": { + "type": "string", + "description": "Full updated long-term memory as markdown. Include all existing " + "facts plus new ones. Return unchanged if nothing new.", + }, + }, + "required": ["history_entry", "memory_update"], + }, + }, + } +] class MemoryStore: - """ - Memory system for the agent. - - Supports daily notes (memory/YYYY-MM-DD.md) and long-term memory (MEMORY.md). - """ - + """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" + def __init__(self, workspace: Path): - self.workspace = workspace self.memory_dir = ensure_dir(workspace / "memory") self.memory_file = self.memory_dir / "MEMORY.md" - - def get_today_file(self) -> Path: - """Get path to today's memory file.""" - return self.memory_dir / f"{today_date()}.md" - - def read_today(self) -> str: - """Read today's memory notes.""" - today_file = self.get_today_file() - if today_file.exists(): - return today_file.read_text(encoding="utf-8") - return "" - - def append_today(self, content: str) -> None: - """Append content to today's memory notes.""" - today_file = self.get_today_file() - - if today_file.exists(): - existing = today_file.read_text(encoding="utf-8") - content = existing + "\n" + content - else: - # Add header for new day - header = f"# {today_date()}\n\n" - content = header + content - - today_file.write_text(content, encoding="utf-8") - + self.history_file = self.memory_dir / "HISTORY.md" + def read_long_term(self) -> str: - """Read long-term memory (MEMORY.md).""" if self.memory_file.exists(): return self.memory_file.read_text(encoding="utf-8") return "" - + def write_long_term(self, content: str) -> None: - """Write to long-term memory (MEMORY.md).""" self.memory_file.write_text(content, encoding="utf-8") - - def get_recent_memories(self, days: int = 7) -> str: - """ - Get memories from the last N days. - - Args: - days: Number of days to look back. - - Returns: - Combined memory content. - """ - from datetime import timedelta - - memories = [] - today = datetime.now().date() - - for i in range(days): - date = today - timedelta(days=i) - date_str = date.strftime("%Y-%m-%d") - file_path = self.memory_dir / f"{date_str}.md" - - if file_path.exists(): - content = file_path.read_text(encoding="utf-8") - memories.append(content) - - return "\n\n---\n\n".join(memories) - - def list_memory_files(self) -> list[Path]: - """List all memory files sorted by date (newest first).""" - if not self.memory_dir.exists(): - return [] - - files = list(self.memory_dir.glob("????-??-??.md")) - return sorted(files, reverse=True) - + + def append_history(self, entry: str) -> None: + with open(self.history_file, "a", encoding="utf-8") as f: + f.write(entry.rstrip() + "\n\n") + def get_memory_context(self) -> str: - """ - Get memory context for the agent. - - Returns: - Formatted memory context including long-term and recent memories. - """ - parts = [] - - # Long-term memory long_term = self.read_long_term() - if long_term: - parts.append("## Long-term Memory\n" + long_term) - - # Today's notes - today = self.read_today() - if today: - parts.append("## Today's Notes\n" + today) - - return "\n\n".join(parts) if parts else "" + return f"## Long-term Memory\n{long_term}" if long_term else "" + + async def consolidate( + self, + session: Session, + provider: LLMProvider, + model: str, + *, + archive_all: bool = False, + memory_window: int = 50, + ) -> bool: + """Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call. + + Returns True on success (including no-op), False on failure. + """ + if archive_all: + old_messages = session.messages + keep_count = 0 + logger.info("Memory consolidation (archive_all): {} messages", len(session.messages)) + else: + keep_count = memory_window // 2 + if len(session.messages) <= keep_count: + return True + if len(session.messages) - session.last_consolidated <= 0: + return True + old_messages = session.messages[session.last_consolidated:-keep_count] + if not old_messages: + return True + logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count) + + lines = [] + for m in old_messages: + if not m.get("content"): + continue + tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" + lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}") + + current_memory = self.read_long_term() + prompt = f"""Process this conversation and call the save_memory tool with your consolidation. + +## Current Long-term Memory +{current_memory or "(empty)"} + +## Conversation to Process +{chr(10).join(lines)}""" + + try: + response = await provider.chat( + messages=[ + {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, + {"role": "user", "content": prompt}, + ], + tools=_SAVE_MEMORY_TOOL, + model=model, + ) + + if not response.has_tool_calls: + logger.warning("Memory consolidation: LLM did not call save_memory, skipping") + return False + + args = response.tool_calls[0].arguments + # Some providers return arguments as a JSON string instead of dict + if isinstance(args, str): + args = json.loads(args) + # Some providers return arguments as a list (handle edge case) + if isinstance(args, list): + if args and isinstance(args[0], dict): + args = args[0] + else: + logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list") + return False + if not isinstance(args, dict): + logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) + return False + + if entry := args.get("history_entry"): + if not isinstance(entry, str): + entry = json.dumps(entry, ensure_ascii=False) + self.append_history(entry) + if update := args.get("memory_update"): + if not isinstance(update, str): + update = json.dumps(update, ensure_ascii=False) + if update != current_memory: + self.write_long_term(update) + + session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count + logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) + return True + except Exception: + logger.exception("Memory consolidation failed") + return False diff --git a/nanobot/agent/skills.py b/nanobot/agent/skills.py index ead9f5b..9afee82 100644 --- a/nanobot/agent/skills.py +++ b/nanobot/agent/skills.py @@ -13,28 +13,28 @@ BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills" class SkillsLoader: """ Loader for agent skills. - + Skills are markdown files (SKILL.md) that teach the agent how to use specific tools or perform certain tasks. """ - + def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None): self.workspace = workspace self.workspace_skills = workspace / "skills" self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR - + def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]: """ List all available skills. - + Args: filter_unavailable: If True, filter out skills with unmet requirements. - + Returns: List of skill info dicts with 'name', 'path', 'source'. """ skills = [] - + # Workspace skills (highest priority) if self.workspace_skills.exists(): for skill_dir in self.workspace_skills.iterdir(): @@ -42,7 +42,7 @@ class SkillsLoader: skill_file = skill_dir / "SKILL.md" if skill_file.exists(): skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"}) - + # Built-in skills if self.builtin_skills and self.builtin_skills.exists(): for skill_dir in self.builtin_skills.iterdir(): @@ -50,19 +50,19 @@ class SkillsLoader: skill_file = skill_dir / "SKILL.md" if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills): skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"}) - + # Filter by requirements if filter_unavailable: return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))] return skills - + def load_skill(self, name: str) -> str | None: """ Load a skill by name. - + Args: name: Skill name (directory name). - + Returns: Skill content or None if not found. """ @@ -70,22 +70,22 @@ class SkillsLoader: workspace_skill = self.workspace_skills / name / "SKILL.md" if workspace_skill.exists(): return workspace_skill.read_text(encoding="utf-8") - + # Check built-in if self.builtin_skills: builtin_skill = self.builtin_skills / name / "SKILL.md" if builtin_skill.exists(): return builtin_skill.read_text(encoding="utf-8") - + return None - + def load_skills_for_context(self, skill_names: list[str]) -> str: """ Load specific skills for inclusion in agent context. - + Args: skill_names: List of skill names to load. - + Returns: Formatted skills content. """ @@ -95,26 +95,26 @@ class SkillsLoader: if content: content = self._strip_frontmatter(content) parts.append(f"### Skill: {name}\n\n{content}") - + return "\n\n---\n\n".join(parts) if parts else "" - + def build_skills_summary(self) -> str: """ Build a summary of all skills (name, description, path, availability). - + This is used for progressive loading - the agent can read the full skill content using read_file when needed. - + Returns: XML-formatted skills summary. """ all_skills = self.list_skills(filter_unavailable=False) if not all_skills: return "" - + def escape_xml(s: str) -> str: return s.replace("&", "&").replace("<", "<").replace(">", ">") - + lines = [""] for s in all_skills: name = escape_xml(s["name"]) @@ -122,23 +122,23 @@ class SkillsLoader: desc = escape_xml(self._get_skill_description(s["name"])) skill_meta = self._get_skill_meta(s["name"]) available = self._check_requirements(skill_meta) - + lines.append(f" ") lines.append(f" {name}") lines.append(f" {desc}") lines.append(f" {path}") - + # Show missing requirements for unavailable skills if not available: missing = self._get_missing_requirements(skill_meta) if missing: lines.append(f" {escape_xml(missing)}") - - lines.append(f" ") + + lines.append(" ") lines.append("") - + return "\n".join(lines) - + def _get_missing_requirements(self, skill_meta: dict) -> str: """Get a description of missing requirements.""" missing = [] @@ -150,14 +150,14 @@ class SkillsLoader: if not os.environ.get(env): missing.append(f"ENV: {env}") return ", ".join(missing) - + def _get_skill_description(self, name: str) -> str: """Get the description of a skill from its frontmatter.""" meta = self.get_skill_metadata(name) if meta and meta.get("description"): return meta["description"] return name # Fallback to skill name - + def _strip_frontmatter(self, content: str) -> str: """Remove YAML frontmatter from markdown content.""" if content.startswith("---"): @@ -165,15 +165,15 @@ class SkillsLoader: if match: return content[match.end():].strip() return content - + def _parse_nanobot_metadata(self, raw: str) -> dict: - """Parse nanobot metadata JSON from frontmatter.""" + """Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys).""" try: data = json.loads(raw) - return data.get("nanobot", {}) if isinstance(data, dict) else {} + return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {} except (json.JSONDecodeError, TypeError): return {} - + def _check_requirements(self, skill_meta: dict) -> bool: """Check if skill requirements are met (bins, env vars).""" requires = skill_meta.get("requires", {}) @@ -184,12 +184,12 @@ class SkillsLoader: if not os.environ.get(env): return False return True - + def _get_skill_meta(self, name: str) -> dict: """Get nanobot metadata for a skill (cached in frontmatter).""" meta = self.get_skill_metadata(name) or {} return self._parse_nanobot_metadata(meta.get("metadata", "")) - + def get_always_skills(self) -> list[str]: """Get skills marked as always=true that meet requirements.""" result = [] @@ -199,21 +199,21 @@ class SkillsLoader: if skill_meta.get("always") or meta.get("always"): result.append(s["name"]) return result - + def get_skill_metadata(self, name: str) -> dict | None: """ Get metadata from a skill's frontmatter. - + Args: name: Skill name. - + Returns: Metadata dict or None. """ content = self.load_skill(name) if not content: return None - + if content.startswith("---"): match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) if match: @@ -224,5 +224,5 @@ class SkillsLoader: key, value = line.split(":", 1) metadata[key.strip()] = value.strip().strip('"\'') return metadata - + return None diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 6113efb..f2d6ee5 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -8,31 +8,30 @@ from typing import Any from loguru import logger +from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools.registry import ToolRegistry +from nanobot.agent.tools.shell import ExecTool +from nanobot.agent.tools.web import WebFetchTool, WebSearchTool from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus +from nanobot.config.schema import ExecToolConfig from nanobot.providers.base import LLMProvider -from nanobot.agent.tools.registry import ToolRegistry -from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, ListDirTool -from nanobot.agent.tools.shell import ExecTool -from nanobot.agent.tools.web import WebSearchTool, WebFetchTool class SubagentManager: - """ - Manages background subagent execution. - - Subagents are lightweight agent instances that run in the background - to handle specific tasks. They share the same LLM provider but have - isolated context and a focused system prompt. - """ - + """Manages background subagent execution.""" + def __init__( self, provider: LLMProvider, workspace: Path, bus: MessageBus, model: str | None = None, + temperature: float = 0.7, + max_tokens: int = 4096, + reasoning_effort: str | None = None, brave_api_key: str | None = None, + web_proxy: str | None = None, exec_config: "ExecToolConfig | None" = None, restrict_to_workspace: bool = False, ): @@ -41,50 +40,48 @@ class SubagentManager: self.workspace = workspace self.bus = bus self.model = model or provider.get_default_model() + self.temperature = temperature + self.max_tokens = max_tokens + self.reasoning_effort = reasoning_effort self.brave_api_key = brave_api_key + self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() self.restrict_to_workspace = restrict_to_workspace self._running_tasks: dict[str, asyncio.Task[None]] = {} - + self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...} + async def spawn( self, task: str, label: str | None = None, origin_channel: str = "cli", origin_chat_id: str = "direct", + session_key: str | None = None, ) -> str: - """ - Spawn a subagent to execute a task in the background. - - Args: - task: The task description for the subagent. - label: Optional human-readable label for the task. - origin_channel: The channel to announce results to. - origin_chat_id: The chat ID to announce results to. - - Returns: - Status message indicating the subagent was started. - """ + """Spawn a subagent to execute a task in the background.""" task_id = str(uuid.uuid4())[:8] display_label = label or task[:30] + ("..." if len(task) > 30 else "") - - origin = { - "channel": origin_channel, - "chat_id": origin_chat_id, - } - - # Create background task + origin = {"channel": origin_channel, "chat_id": origin_chat_id} + bg_task = asyncio.create_task( self._run_subagent(task_id, task, display_label, origin) ) self._running_tasks[task_id] = bg_task - - # Cleanup when done - bg_task.add_done_callback(lambda _: self._running_tasks.pop(task_id, None)) - - logger.info(f"Spawned subagent [{task_id}]: {display_label}") + if session_key: + self._session_tasks.setdefault(session_key, set()).add(task_id) + + def _cleanup(_: asyncio.Task) -> None: + self._running_tasks.pop(task_id, None) + if session_key and (ids := self._session_tasks.get(session_key)): + ids.discard(task_id) + if not ids: + del self._session_tasks[session_key] + + bg_task.add_done_callback(_cleanup) + + logger.info("Spawned subagent [{}]: {}", task_id, display_label) return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes." - + async def _run_subagent( self, task_id: str, @@ -93,44 +90,48 @@ class SubagentManager: origin: dict[str, str], ) -> None: """Execute the subagent task and announce the result.""" - logger.info(f"Subagent [{task_id}] starting task: {label}") - + logger.info("Subagent [{}] starting task: {}", task_id, label) + try: # Build subagent tools (no message tool, no spawn tool) tools = ToolRegistry() allowed_dir = self.workspace if self.restrict_to_workspace else None - tools.register(ReadFileTool(allowed_dir=allowed_dir)) - tools.register(WriteFileTool(allowed_dir=allowed_dir)) - tools.register(ListDirTool(allowed_dir=allowed_dir)) + tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) + tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) + tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) + tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(ExecTool( working_dir=str(self.workspace), timeout=self.exec_config.timeout, restrict_to_workspace=self.restrict_to_workspace, + path_append=self.exec_config.path_append, )) - tools.register(WebSearchTool(api_key=self.brave_api_key)) - tools.register(WebFetchTool()) + tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy)) + tools.register(WebFetchTool(proxy=self.web_proxy)) - # Build messages with subagent-specific prompt - system_prompt = self._build_subagent_prompt(task) + system_prompt = self._build_subagent_prompt() messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": task}, ] - + # Run agent loop (limited iterations) max_iterations = 15 iteration = 0 final_result: str | None = None - + while iteration < max_iterations: iteration += 1 - + response = await self.provider.chat( messages=messages, tools=tools.get_definitions(), model=self.model, + temperature=self.temperature, + max_tokens=self.max_tokens, + reasoning_effort=self.reasoning_effort, ) - + if response.has_tool_calls: # Add assistant message with tool calls tool_call_dicts = [ @@ -139,7 +140,7 @@ class SubagentManager: "type": "function", "function": { "name": tc.name, - "arguments": json.dumps(tc.arguments), + "arguments": json.dumps(tc.arguments, ensure_ascii=False), }, } for tc in response.tool_calls @@ -149,11 +150,11 @@ class SubagentManager: "content": response.content or "", "tool_calls": tool_call_dicts, }) - + # Execute tools for tool_call in response.tool_calls: - args_str = json.dumps(tool_call.arguments) - logger.debug(f"Subagent [{task_id}] executing: {tool_call.name} with arguments: {args_str}") + args_str = json.dumps(tool_call.arguments, ensure_ascii=False) + logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str) result = await tools.execute(tool_call.name, tool_call.arguments) messages.append({ "role": "tool", @@ -164,18 +165,18 @@ class SubagentManager: else: final_result = response.content break - + if final_result is None: final_result = "Task completed but no final response was generated." - - logger.info(f"Subagent [{task_id}] completed successfully") + + logger.info("Subagent [{}] completed successfully", task_id) await self._announce_result(task_id, label, task, final_result, origin, "ok") - + except Exception as e: error_msg = f"Error: {str(e)}" - logger.error(f"Subagent [{task_id}] failed: {e}") + logger.error("Subagent [{}] failed: {}", task_id, e) await self._announce_result(task_id, label, task, error_msg, origin, "error") - + async def _announce_result( self, task_id: str, @@ -187,7 +188,7 @@ class SubagentManager: ) -> None: """Announce the subagent result to the main agent via the message bus.""" status_text = "completed successfully" if status == "ok" else "failed" - + announce_content = f"""[Subagent '{label}' {status_text}] Task: {task} @@ -196,7 +197,7 @@ Result: {result} Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs.""" - + # Inject as system message to trigger main agent msg = InboundMessage( channel="system", @@ -204,41 +205,42 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men chat_id=f"{origin['channel']}:{origin['chat_id']}", content=announce_content, ) - + await self.bus.publish_inbound(msg) - logger.debug(f"Subagent [{task_id}] announced result to {origin['channel']}:{origin['chat_id']}") + logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id']) - def _build_subagent_prompt(self, task: str) -> str: + def _build_subagent_prompt(self) -> str: """Build a focused system prompt for the subagent.""" - return f"""# Subagent + from nanobot.agent.context import ContextBuilder + from nanobot.agent.skills import SkillsLoader + + time_ctx = ContextBuilder._build_runtime_context(None, None) + parts = [f"""# Subagent + +{time_ctx} You are a subagent spawned by the main agent to complete a specific task. - -## Your Task -{task} - -## Rules -1. Stay focused - complete only the assigned task, nothing else -2. Your final response will be reported back to the main agent -3. Do not initiate conversations or take on side tasks -4. Be concise but informative in your findings - -## What You Can Do -- Read and write files in the workspace -- Execute shell commands -- Search the web and fetch web pages -- Complete the task thoroughly - -## What You Cannot Do -- Send messages directly to users (no message tool available) -- Spawn other subagents -- Access the main agent's conversation history +Stay focused on the assigned task. Your final response will be reported back to the main agent. ## Workspace -Your workspace is at: {self.workspace} +{self.workspace}"""] -When you have completed the task, provide a clear summary of your findings or actions.""" + skills_summary = SkillsLoader(self.workspace).build_skills_summary() + if skills_summary: + parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}") + + return "\n\n".join(parts) + async def cancel_by_session(self, session_key: str) -> int: + """Cancel all subagents for the given session. Returns count cancelled.""" + tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, []) + if tid in self._running_tasks and not self._running_tasks[tid].done()] + for t in tasks: + t.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + return len(tasks) + def get_running_count(self) -> int: """Return the number of currently running subagents.""" return len(self._running_tasks) diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index ca9bcc2..06f5bdd 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -7,11 +7,11 @@ from typing import Any class Tool(ABC): """ Abstract base class for agent tools. - + Tools are capabilities that the agent can use to interact with the environment, such as reading files, executing commands, etc. """ - + _TYPE_MAP = { "string": str, "integer": int, @@ -20,40 +20,111 @@ class Tool(ABC): "array": list, "object": dict, } - + @property @abstractmethod def name(self) -> str: """Tool name used in function calls.""" pass - + @property @abstractmethod def description(self) -> str: """Description of what the tool does.""" pass - + @property @abstractmethod def parameters(self) -> dict[str, Any]: """JSON Schema for tool parameters.""" pass - + @abstractmethod async def execute(self, **kwargs: Any) -> str: """ Execute the tool with given parameters. - + Args: **kwargs: Tool-specific parameters. - + Returns: String result of the tool execution. """ pass + def cast_params(self, params: dict[str, Any]) -> dict[str, Any]: + """Apply safe schema-driven casts before validation.""" + schema = self.parameters or {} + if schema.get("type", "object") != "object": + return params + + return self._cast_object(params, schema) + + def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]: + """Cast an object (dict) according to schema.""" + if not isinstance(obj, dict): + return obj + + props = schema.get("properties", {}) + result = {} + + for key, value in obj.items(): + if key in props: + result[key] = self._cast_value(value, props[key]) + else: + result[key] = value + + return result + + def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any: + """Cast a single value according to schema.""" + target_type = schema.get("type") + + if target_type == "boolean" and isinstance(val, bool): + return val + if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool): + return val + if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"): + expected = self._TYPE_MAP[target_type] + if isinstance(val, expected): + return val + + if target_type == "integer" and isinstance(val, str): + try: + return int(val) + except ValueError: + return val + + if target_type == "number" and isinstance(val, str): + try: + return float(val) + except ValueError: + return val + + if target_type == "string": + return val if val is None else str(val) + + if target_type == "boolean" and isinstance(val, str): + val_lower = val.lower() + if val_lower in ("true", "1", "yes"): + return True + if val_lower in ("false", "0", "no"): + return False + return val + + if target_type == "array" and isinstance(val, list): + item_schema = schema.get("items") + return [self._cast_value(item, item_schema) for item in val] if item_schema else val + + if target_type == "object" and isinstance(val, dict): + return self._cast_object(val, schema) + + return val + def validate_params(self, params: dict[str, Any]) -> list[str]: """Validate tool parameters against JSON schema. Returns error list (empty if valid).""" + if not isinstance(params, dict): + return [f"parameters must be an object, got {type(params).__name__}"] schema = self.parameters or {} if schema.get("type", "object") != "object": raise ValueError(f"Schema must be object type, got {schema.get('type')!r}") @@ -61,9 +132,15 @@ class Tool(ABC): def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]: t, label = schema.get("type"), path or "parameter" - if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]): + if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)): + return [f"{label} should be integer"] + if t == "number" and ( + not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool) + ): + return [f"{label} should be number"] + if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]): return [f"{label} should be {t}"] - + errors = [] if "enum" in schema and val not in schema["enum"]: errors.append(f"{label} must be one of {schema['enum']}") @@ -84,12 +161,14 @@ class Tool(ABC): errors.append(f"missing required {path + '.' + k if path else k}") for k, v in val.items(): if k in props: - errors.extend(self._validate(v, props[k], path + '.' + k if path else k)) + errors.extend(self._validate(v, props[k], path + "." + k if path else k)) if t == "array" and "items" in schema: for i, item in enumerate(val): - errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")) + errors.extend( + self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]") + ) return errors - + def to_schema(self) -> dict[str, Any]: """Convert tool to OpenAI function schema format.""" return { @@ -98,5 +177,5 @@ class Tool(ABC): "name": self.name, "description": self.description, "parameters": self.parameters, - } + }, } diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index ec0d2cd..f8e737b 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -1,5 +1,6 @@ """Cron tool for scheduling reminders and tasks.""" +from contextvars import ContextVar from typing import Any from nanobot.agent.tools.base import Tool @@ -9,25 +10,34 @@ from nanobot.cron.types import CronSchedule class CronTool(Tool): """Tool to schedule reminders and recurring tasks.""" - + def __init__(self, cron_service: CronService): self._cron = cron_service self._channel = "" self._chat_id = "" - + self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False) + def set_context(self, channel: str, chat_id: str) -> None: """Set the current session context for delivery.""" self._channel = channel self._chat_id = chat_id - + + def set_cron_context(self, active: bool): + """Mark whether the tool is executing inside a cron job callback.""" + return self._in_cron_context.set(active) + + def reset_cron_context(self, token) -> None: + """Restore previous cron context.""" + self._in_cron_context.reset(token) + @property def name(self) -> str: return "cron" - + @property def description(self) -> str: return "Schedule reminders and recurring tasks. Actions: add, list, remove." - + @property def parameters(self) -> dict[str, Any]: return { @@ -36,59 +46,92 @@ class CronTool(Tool): "action": { "type": "string", "enum": ["add", "list", "remove"], - "description": "Action to perform" - }, - "message": { - "type": "string", - "description": "Reminder message (for add)" + "description": "Action to perform", }, + "message": {"type": "string", "description": "Reminder message (for add)"}, "every_seconds": { "type": "integer", - "description": "Interval in seconds (for recurring tasks)" + "description": "Interval in seconds (for recurring tasks)", }, "cron_expr": { "type": "string", - "description": "Cron expression like '0 9 * * *' (for scheduled tasks)" + "description": "Cron expression like '0 9 * * *' (for scheduled tasks)", }, - "job_id": { + "tz": { "type": "string", - "description": "Job ID (for remove)" - } + "description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')", + }, + "at": { + "type": "string", + "description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')", + }, + "job_id": {"type": "string", "description": "Job ID (for remove)"}, }, - "required": ["action"] + "required": ["action"], } - + async def execute( self, action: str, message: str = "", every_seconds: int | None = None, cron_expr: str | None = None, + tz: str | None = None, + at: str | None = None, job_id: str | None = None, - **kwargs: Any + **kwargs: Any, ) -> str: if action == "add": - return self._add_job(message, every_seconds, cron_expr) + if self._in_cron_context.get(): + return "Error: cannot schedule new jobs from within a cron job execution" + return self._add_job(message, every_seconds, cron_expr, tz, at) elif action == "list": return self._list_jobs() elif action == "remove": return self._remove_job(job_id) return f"Unknown action: {action}" - - def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None) -> str: + + def _add_job( + self, + message: str, + every_seconds: int | None, + cron_expr: str | None, + tz: str | None, + at: str | None, + ) -> str: if not message: return "Error: message is required for add" if not self._channel or not self._chat_id: return "Error: no session context (channel/chat_id)" - + if tz and not cron_expr: + return "Error: tz can only be used with cron_expr" + if tz: + from zoneinfo import ZoneInfo + + try: + ZoneInfo(tz) + except (KeyError, Exception): + return f"Error: unknown timezone '{tz}'" + # Build schedule + delete_after = False if every_seconds: schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000) elif cron_expr: - schedule = CronSchedule(kind="cron", expr=cron_expr) + schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz) + elif at: + from datetime import datetime + + try: + dt = datetime.fromisoformat(at) + except ValueError: + return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS" + at_ms = int(dt.timestamp() * 1000) + schedule = CronSchedule(kind="at", at_ms=at_ms) + delete_after = True else: - return "Error: either every_seconds or cron_expr is required" - + return "Error: either every_seconds, cron_expr, or at is required" + job = self._cron.add_job( name=message[:30], schedule=schedule, @@ -96,16 +139,17 @@ class CronTool(Tool): deliver=True, channel=self._channel, to=self._chat_id, + delete_after_run=delete_after, ) return f"Created job '{job.name}' (id: {job.id})" - + def _list_jobs(self) -> str: jobs = self._cron.list_jobs() if not jobs: return "No scheduled jobs." lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs] return "Scheduled jobs:\n" + "\n".join(lines) - + def _remove_job(self, job_id: str | None) -> str: if not job_id: return "Error: job_id is required for remove" diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 6b3254a..7b0b867 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -1,55 +1,71 @@ """File system tools: read, write, edit.""" +import difflib from pathlib import Path from typing import Any from nanobot.agent.tools.base import Tool -def _resolve_path(path: str, allowed_dir: Path | None = None) -> Path: - """Resolve path and optionally enforce directory restriction.""" - resolved = Path(path).expanduser().resolve() - if allowed_dir and not str(resolved).startswith(str(allowed_dir.resolve())): - raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}") +def _resolve_path( + path: str, workspace: Path | None = None, allowed_dir: Path | None = None +) -> Path: + """Resolve path against workspace (if relative) and enforce directory restriction.""" + p = Path(path).expanduser() + if not p.is_absolute() and workspace: + p = workspace / p + resolved = p.resolve() + if allowed_dir: + try: + resolved.relative_to(allowed_dir.resolve()) + except ValueError: + raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}") return resolved class ReadFileTool(Tool): """Tool to read file contents.""" - - def __init__(self, allowed_dir: Path | None = None): + + _MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context + + def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): + self._workspace = workspace self._allowed_dir = allowed_dir @property def name(self) -> str: return "read_file" - + @property def description(self) -> str: return "Read the contents of a file at the given path." - + @property def parameters(self) -> dict[str, Any]: return { "type": "object", - "properties": { - "path": { - "type": "string", - "description": "The file path to read" - } - }, - "required": ["path"] + "properties": {"path": {"type": "string", "description": "The file path to read"}}, + "required": ["path"], } - + async def execute(self, path: str, **kwargs: Any) -> str: try: - file_path = _resolve_path(path, self._allowed_dir) + file_path = _resolve_path(path, self._workspace, self._allowed_dir) if not file_path.exists(): return f"Error: File not found: {path}" if not file_path.is_file(): return f"Error: Not a file: {path}" - + + size = file_path.stat().st_size + if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes) + return ( + f"Error: File too large ({size:,} bytes). " + f"Use exec tool with head/tail/grep to read portions." + ) + content = file_path.read_text(encoding="utf-8") + if len(content) > self._MAX_CHARS: + return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})" return content except PermissionError as e: return f"Error: {e}" @@ -59,41 +75,36 @@ class ReadFileTool(Tool): class WriteFileTool(Tool): """Tool to write content to a file.""" - - def __init__(self, allowed_dir: Path | None = None): + + def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): + self._workspace = workspace self._allowed_dir = allowed_dir @property def name(self) -> str: return "write_file" - + @property def description(self) -> str: return "Write content to a file at the given path. Creates parent directories if needed." - + @property def parameters(self) -> dict[str, Any]: return { "type": "object", "properties": { - "path": { - "type": "string", - "description": "The file path to write to" - }, - "content": { - "type": "string", - "description": "The content to write" - } + "path": {"type": "string", "description": "The file path to write to"}, + "content": {"type": "string", "description": "The content to write"}, }, - "required": ["path", "content"] + "required": ["path", "content"], } - + async def execute(self, path: str, content: str, **kwargs: Any) -> str: try: - file_path = _resolve_path(path, self._allowed_dir) + file_path = _resolve_path(path, self._workspace, self._allowed_dir) file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_text(content, encoding="utf-8") - return f"Successfully wrote {len(content)} bytes to {path}" + return f"Successfully wrote {len(content)} bytes to {file_path}" except PermissionError as e: return f"Error: {e}" except Exception as e: @@ -102,108 +113,124 @@ class WriteFileTool(Tool): class EditFileTool(Tool): """Tool to edit a file by replacing text.""" - - def __init__(self, allowed_dir: Path | None = None): + + def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): + self._workspace = workspace self._allowed_dir = allowed_dir @property def name(self) -> str: return "edit_file" - + @property def description(self) -> str: return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file." - + @property def parameters(self) -> dict[str, Any]: return { "type": "object", "properties": { - "path": { - "type": "string", - "description": "The file path to edit" - }, - "old_text": { - "type": "string", - "description": "The exact text to find and replace" - }, - "new_text": { - "type": "string", - "description": "The text to replace with" - } + "path": {"type": "string", "description": "The file path to edit"}, + "old_text": {"type": "string", "description": "The exact text to find and replace"}, + "new_text": {"type": "string", "description": "The text to replace with"}, }, - "required": ["path", "old_text", "new_text"] + "required": ["path", "old_text", "new_text"], } - + async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str: try: - file_path = _resolve_path(path, self._allowed_dir) + file_path = _resolve_path(path, self._workspace, self._allowed_dir) if not file_path.exists(): return f"Error: File not found: {path}" - + content = file_path.read_text(encoding="utf-8") - + if old_text not in content: - return f"Error: old_text not found in file. Make sure it matches exactly." - + return self._not_found_message(old_text, content, path) + # Count occurrences count = content.count(old_text) if count > 1: return f"Warning: old_text appears {count} times. Please provide more context to make it unique." - + new_content = content.replace(old_text, new_text, 1) file_path.write_text(new_content, encoding="utf-8") - - return f"Successfully edited {path}" + + return f"Successfully edited {file_path}" except PermissionError as e: return f"Error: {e}" except Exception as e: return f"Error editing file: {str(e)}" + @staticmethod + def _not_found_message(old_text: str, content: str, path: str) -> str: + """Build a helpful error when old_text is not found.""" + lines = content.splitlines(keepends=True) + old_lines = old_text.splitlines(keepends=True) + window = len(old_lines) + + best_ratio, best_start = 0.0, 0 + for i in range(max(1, len(lines) - window + 1)): + ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio() + if ratio > best_ratio: + best_ratio, best_start = ratio, i + + if best_ratio > 0.5: + diff = "\n".join( + difflib.unified_diff( + old_lines, + lines[best_start : best_start + window], + fromfile="old_text (provided)", + tofile=f"{path} (actual, line {best_start + 1})", + lineterm="", + ) + ) + return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" + return ( + f"Error: old_text not found in {path}. No similar text found. Verify the file content." + ) + class ListDirTool(Tool): """Tool to list directory contents.""" - - def __init__(self, allowed_dir: Path | None = None): + + def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): + self._workspace = workspace self._allowed_dir = allowed_dir @property def name(self) -> str: return "list_dir" - + @property def description(self) -> str: return "List the contents of a directory." - + @property def parameters(self) -> dict[str, Any]: return { "type": "object", - "properties": { - "path": { - "type": "string", - "description": "The directory path to list" - } - }, - "required": ["path"] + "properties": {"path": {"type": "string", "description": "The directory path to list"}}, + "required": ["path"], } - + async def execute(self, path: str, **kwargs: Any) -> str: try: - dir_path = _resolve_path(path, self._allowed_dir) + dir_path = _resolve_path(path, self._workspace, self._allowed_dir) if not dir_path.exists(): return f"Error: Directory not found: {path}" if not dir_path.is_dir(): return f"Error: Not a directory: {path}" - + items = [] for item in sorted(dir_path.iterdir()): prefix = "📁 " if item.is_dir() else "📄 " items.append(f"{prefix}{item.name}") - + if not items: return f"Directory {path} is empty" - + return "\n".join(items) except PermissionError as e: return f"Error: {e}" diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py new file mode 100644 index 0000000..2cbffd0 --- /dev/null +++ b/nanobot/agent/tools/mcp.py @@ -0,0 +1,130 @@ +"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools.""" + +import asyncio +from contextlib import AsyncExitStack +from typing import Any + +import httpx +from loguru import logger + +from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.registry import ToolRegistry + + +class MCPToolWrapper(Tool): + """Wraps a single MCP server tool as a nanobot Tool.""" + + def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30): + self._session = session + self._original_name = tool_def.name + self._name = f"mcp_{server_name}_{tool_def.name}" + self._description = tool_def.description or tool_def.name + self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}} + self._tool_timeout = tool_timeout + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + @property + def parameters(self) -> dict[str, Any]: + return self._parameters + + async def execute(self, **kwargs: Any) -> str: + from mcp import types + try: + result = await asyncio.wait_for( + self._session.call_tool(self._original_name, arguments=kwargs), + timeout=self._tool_timeout, + ) + except asyncio.TimeoutError: + logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout) + return f"(MCP tool call timed out after {self._tool_timeout}s)" + parts = [] + for block in result.content: + if isinstance(block, types.TextContent): + parts.append(block.text) + else: + parts.append(str(block)) + return "\n".join(parts) or "(no output)" + + +async def connect_mcp_servers( + mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack +) -> None: + """Connect to configured MCP servers and register their tools.""" + from mcp import ClientSession, StdioServerParameters + from mcp.client.sse import sse_client + from mcp.client.stdio import stdio_client + from mcp.client.streamable_http import streamable_http_client + + for name, cfg in mcp_servers.items(): + try: + transport_type = cfg.type + if not transport_type: + if cfg.command: + transport_type = "stdio" + elif cfg.url: + # Convention: URLs ending with /sse use SSE transport; others use streamableHttp + transport_type = ( + "sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp" + ) + else: + logger.warning("MCP server '{}': no command or url configured, skipping", name) + continue + + if transport_type == "stdio": + params = StdioServerParameters( + command=cfg.command, args=cfg.args, env=cfg.env or None + ) + read, write = await stack.enter_async_context(stdio_client(params)) + elif transport_type == "sse": + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + merged_headers = {**(cfg.headers or {}), **(headers or {})} + return httpx.AsyncClient( + headers=merged_headers or None, + follow_redirects=True, + timeout=timeout, + auth=auth, + ) + + read, write = await stack.enter_async_context( + sse_client(cfg.url, httpx_client_factory=httpx_client_factory) + ) + elif transport_type == "streamableHttp": + # Always provide an explicit httpx client so MCP HTTP transport does not + # inherit httpx's default 5s timeout and preempt the higher-level tool timeout. + http_client = await stack.enter_async_context( + httpx.AsyncClient( + headers=cfg.headers or None, + follow_redirects=True, + timeout=None, + ) + ) + read, write, _ = await stack.enter_async_context( + streamable_http_client(cfg.url, http_client=http_client) + ) + else: + logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type) + continue + + session = await stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + + tools = await session.list_tools() + for tool_def in tools.tools: + wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout) + registry.register(wrapper) + logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name) + + logger.info("MCP server '{}': connected, {} tools registered", name, len(tools.tools)) + except Exception as e: + logger.error("MCP server '{}': failed to connect: {}", name, e) diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 347830f..0a52427 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -1,6 +1,6 @@ """Message tool for sending messages to users.""" -from typing import Any, Callable, Awaitable +from typing import Any, Awaitable, Callable from nanobot.agent.tools.base import Tool from nanobot.bus.events import OutboundMessage @@ -8,34 +8,42 @@ from nanobot.bus.events import OutboundMessage class MessageTool(Tool): """Tool to send messages to users on chat channels.""" - + def __init__( - self, + self, send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None, default_channel: str = "", - default_chat_id: str = "" + default_chat_id: str = "", + default_message_id: str | None = None, ): self._send_callback = send_callback self._default_channel = default_channel self._default_chat_id = default_chat_id - - def set_context(self, channel: str, chat_id: str) -> None: + self._default_message_id = default_message_id + self._sent_in_turn: bool = False + + def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Set the current message context.""" self._default_channel = channel self._default_chat_id = chat_id - + self._default_message_id = message_id + def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None: """Set the callback for sending messages.""" self._send_callback = callback - + + def start_turn(self) -> None: + """Reset per-turn send tracking.""" + self._sent_in_turn = False + @property def name(self) -> str: return "message" - + @property def description(self) -> str: return "Send a message to the user. Use this when you want to communicate something." - + @property def parameters(self) -> dict[str, Any]: return { @@ -52,35 +60,50 @@ class MessageTool(Tool): "chat_id": { "type": "string", "description": "Optional: target chat/user ID" + }, + "media": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional: list of file paths to attach (images, audio, documents)" } }, "required": ["content"] } - + async def execute( - self, - content: str, - channel: str | None = None, + self, + content: str, + channel: str | None = None, chat_id: str | None = None, + message_id: str | None = None, + media: list[str] | None = None, **kwargs: Any ) -> str: channel = channel or self._default_channel chat_id = chat_id or self._default_chat_id - + message_id = message_id or self._default_message_id + if not channel or not chat_id: return "Error: No target channel/chat specified" - + if not self._send_callback: return "Error: Message sending not configured" - + msg = OutboundMessage( channel=channel, chat_id=chat_id, - content=content + content=content, + media=media or [], + metadata={ + "message_id": message_id, + }, ) - + try: await self._send_callback(msg) - return f"Message sent to {channel}:{chat_id}" + if channel == self._default_channel and chat_id == self._default_chat_id: + self._sent_in_turn = True + media_info = f" with {len(media)} attachments" if media else "" + return f"Message sent to {channel}:{chat_id}{media_info}" except Exception as e: return f"Error sending message: {str(e)}" diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index d9b33ff..896491f 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -8,66 +8,63 @@ from nanobot.agent.tools.base import Tool class ToolRegistry: """ Registry for agent tools. - + Allows dynamic registration and execution of tools. """ - + def __init__(self): self._tools: dict[str, Tool] = {} - + def register(self, tool: Tool) -> None: """Register a tool.""" self._tools[tool.name] = tool - + def unregister(self, name: str) -> None: """Unregister a tool by name.""" self._tools.pop(name, None) - + def get(self, name: str) -> Tool | None: """Get a tool by name.""" return self._tools.get(name) - + def has(self, name: str) -> bool: """Check if a tool is registered.""" return name in self._tools - + def get_definitions(self) -> list[dict[str, Any]]: """Get all tool definitions in OpenAI format.""" return [tool.to_schema() for tool in self._tools.values()] - + async def execute(self, name: str, params: dict[str, Any]) -> str: - """ - Execute a tool by name with given parameters. - - Args: - name: Tool name. - params: Tool parameters. - - Returns: - Tool execution result as string. - - Raises: - KeyError: If tool not found. - """ + """Execute a tool by name with given parameters.""" + _HINT = "\n\n[Analyze the error above and try a different approach.]" + tool = self._tools.get(name) if not tool: - return f"Error: Tool '{name}' not found" + return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" try: + # Attempt to cast parameters to match schema types + params = tool.cast_params(params) + + # Validate parameters errors = tool.validate_params(params) if errors: - return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) - return await tool.execute(**params) + return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT + result = await tool.execute(**params) + if isinstance(result, str) and result.startswith("Error"): + return result + _HINT + return result except Exception as e: - return f"Error executing {name}: {str(e)}" - + return f"Error executing {name}: {str(e)}" + _HINT + @property def tool_names(self) -> list[str]: """Get list of registered tool names.""" return list(self._tools.keys()) - + def __len__(self) -> int: return len(self._tools) - + def __contains__(self, name: str) -> bool: return name in self._tools diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 18eff64..ce19920 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -11,7 +11,7 @@ from nanobot.agent.tools.base import Tool class ExecTool(Tool): """Tool to execute shell commands.""" - + def __init__( self, timeout: int = 60, @@ -19,6 +19,7 @@ class ExecTool(Tool): deny_patterns: list[str] | None = None, allow_patterns: list[str] | None = None, restrict_to_workspace: bool = False, + path_append: str = "", ): self.timeout = timeout self.working_dir = working_dir @@ -26,7 +27,8 @@ class ExecTool(Tool): r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr r"\bdel\s+/[fq]\b", # del /f, del /q r"\brmdir\s+/s\b", # rmdir /s - r"\b(format|mkfs|diskpart)\b", # disk operations + r"(?:^|[;&|]\s*)format\b", # format (as standalone command only) + r"\b(mkfs|diskpart)\b", # disk operations r"\bdd\s+if=", # dd r">\s*/dev/sd", # write to disk r"\b(shutdown|reboot|poweroff)\b", # system power @@ -34,15 +36,16 @@ class ExecTool(Tool): ] self.allow_patterns = allow_patterns or [] self.restrict_to_workspace = restrict_to_workspace - + self.path_append = path_append + @property def name(self) -> str: return "exec" - + @property def description(self) -> str: return "Execute a shell command and return its output. Use with caution." - + @property def parameters(self) -> dict[str, Any]: return { @@ -66,12 +69,17 @@ class ExecTool(Tool): if guard_error: return guard_error + env = os.environ.copy() + if self.path_append: + env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append + try: process = await asyncio.create_subprocess_shell( command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=cwd, + env=env, ) try: @@ -81,6 +89,12 @@ class ExecTool(Tool): ) except asyncio.TimeoutError: process.kill() + # Wait for the process to fully terminate so pipes are + # drained and file descriptors are released. + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + pass return f"Error: Command timed out after {self.timeout} seconds" output_parts = [] @@ -127,13 +141,7 @@ class ExecTool(Tool): cwd_path = Path(cwd).resolve() - win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd) - # Only match absolute paths — avoid false positives on relative - # paths like ".venv/bin/python" where "/bin/python" would be - # incorrectly extracted by the old pattern. - posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd) - - for raw in win_paths + posix_paths: + for raw in self._extract_absolute_paths(cmd): try: p = Path(raw.strip()).resolve() except Exception: @@ -142,3 +150,9 @@ class ExecTool(Tool): return "Error: Command blocked by safety guard (path outside working dir)" return None + + @staticmethod + def _extract_absolute_paths(command: str) -> list[str]: + win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\... + posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only + return win_paths + posix_paths diff --git a/nanobot/agent/tools/spawn.py b/nanobot/agent/tools/spawn.py index 5884a07..fc62bf8 100644 --- a/nanobot/agent/tools/spawn.py +++ b/nanobot/agent/tools/spawn.py @@ -1,6 +1,6 @@ """Spawn tool for creating background subagents.""" -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any from nanobot.agent.tools.base import Tool @@ -9,27 +9,24 @@ if TYPE_CHECKING: class SpawnTool(Tool): - """ - Tool to spawn a subagent for background task execution. - - The subagent runs asynchronously and announces its result back - to the main agent when complete. - """ - + """Tool to spawn a subagent for background task execution.""" + def __init__(self, manager: "SubagentManager"): self._manager = manager self._origin_channel = "cli" self._origin_chat_id = "direct" - + self._session_key = "cli:direct" + def set_context(self, channel: str, chat_id: str) -> None: """Set the origin context for subagent announcements.""" self._origin_channel = channel self._origin_chat_id = chat_id - + self._session_key = f"{channel}:{chat_id}" + @property def name(self) -> str: return "spawn" - + @property def description(self) -> str: return ( @@ -37,7 +34,7 @@ class SpawnTool(Tool): "Use this for complex or time-consuming tasks that can run independently. " "The subagent will complete the task and report back when done." ) - + @property def parameters(self) -> dict[str, Any]: return { @@ -54,7 +51,7 @@ class SpawnTool(Tool): }, "required": ["task"], } - + async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str: """Spawn a subagent to execute the given task.""" return await self._manager.spawn( @@ -62,4 +59,5 @@ class SpawnTool(Tool): label=label, origin_channel=self._origin_channel, origin_chat_id=self._origin_chat_id, + session_key=self._session_key, ) diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 9de1d3c..0d8f4d1 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -8,6 +8,7 @@ from typing import Any from urllib.parse import urlparse import httpx +from loguru import logger from nanobot.agent.tools.base import Tool @@ -45,7 +46,7 @@ def _validate_url(url: str) -> tuple[bool, str]: class WebSearchTool(Tool): """Search the web using Brave Search API.""" - + name = "web_search" description = "Search the web. Returns titles, URLs, and snippets." parameters = { @@ -56,18 +57,29 @@ class WebSearchTool(Tool): }, "required": ["query"] } - - def __init__(self, api_key: str | None = None, max_results: int = 5): - self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "") + + def __init__(self, api_key: str | None = None, max_results: int = 5, proxy: str | None = None): + self._init_api_key = api_key self.max_results = max_results - + self.proxy = proxy + + @property + def api_key(self) -> str: + """Resolve API key at call time so env/config changes are picked up.""" + return self._init_api_key or os.environ.get("BRAVE_API_KEY", "") + async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: if not self.api_key: - return "Error: BRAVE_API_KEY not configured" - + return ( + "Error: Brave Search API key not configured. Set it in " + "~/.nanobot/config.json under tools.web.search.apiKey " + "(or export BRAVE_API_KEY), then restart the gateway." + ) + try: n = min(max(count or self.max_results, 1), 10) - async with httpx.AsyncClient() as client: + logger.debug("WebSearch: {}", "proxy enabled" if self.proxy else "direct connection") + async with httpx.AsyncClient(proxy=self.proxy) as client: r = await client.get( "https://api.search.brave.com/res/v1/web/search", params={"q": query, "count": n}, @@ -75,24 +87,28 @@ class WebSearchTool(Tool): timeout=10.0 ) r.raise_for_status() - - results = r.json().get("web", {}).get("results", []) + + results = r.json().get("web", {}).get("results", [])[:n] if not results: return f"No results for: {query}" - + lines = [f"Results for: {query}\n"] - for i, item in enumerate(results[:n], 1): + for i, item in enumerate(results, 1): lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}") if desc := item.get("description"): lines.append(f" {desc}") return "\n".join(lines) + except httpx.ProxyError as e: + logger.error("WebSearch proxy error: {}", e) + return f"Proxy error: {e}" except Exception as e: + logger.error("WebSearch error: {}", e) return f"Error: {e}" class WebFetchTool(Tool): """Fetch and extract content from a URL using Readability.""" - + name = "web_fetch" description = "Fetch URL and extract readable content (HTML → markdown/text)." parameters = { @@ -104,35 +120,34 @@ class WebFetchTool(Tool): }, "required": ["url"] } - - def __init__(self, max_chars: int = 50000): + + def __init__(self, max_chars: int = 50000, proxy: str | None = None): self.max_chars = max_chars - + self.proxy = proxy + async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str: from readability import Document max_chars = maxChars or self.max_chars - - # Validate URL before fetching is_valid, error_msg = _validate_url(url) if not is_valid: - return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}) + return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False) try: + logger.debug("WebFetch: {}", "proxy enabled" if self.proxy else "direct connection") async with httpx.AsyncClient( follow_redirects=True, max_redirects=MAX_REDIRECTS, - timeout=30.0 + timeout=30.0, + proxy=self.proxy, ) as client: r = await client.get(url, headers={"User-Agent": USER_AGENT}) r.raise_for_status() - + ctype = r.headers.get("content-type", "") - - # JSON + if "application/json" in ctype: - text, extractor = json.dumps(r.json(), indent=2), "json" - # HTML + text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json" elif "text/html" in ctype or r.text[:256].lower().startswith((" max_chars - if truncated: - text = text[:max_chars] - + if truncated: text = text[:max_chars] + return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code, - "extractor": extractor, "truncated": truncated, "length": len(text), "text": text}) + "extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False) + except httpx.ProxyError as e: + logger.error("WebFetch proxy error for {}: {}", url, e) + return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False) except Exception as e: - return json.dumps({"error": str(e), "url": url}) - + logger.error("WebFetch error for {}: {}", url, e) + return json.dumps({"error": str(e), "url": url}, ensure_ascii=False) + def _to_markdown(self, html: str) -> str: """Convert HTML to markdown.""" # Convert links, headings, lists before stripping tags diff --git a/nanobot/bus/events.py b/nanobot/bus/events.py index a149e20..018c25b 100644 --- a/nanobot/bus/events.py +++ b/nanobot/bus/events.py @@ -8,7 +8,7 @@ from typing import Any @dataclass class InboundMessage: """Message received from a chat channel.""" - + channel: str # telegram, discord, slack, whatsapp sender_id: str # User identifier chat_id: str # Chat/channel identifier @@ -16,17 +16,18 @@ class InboundMessage: timestamp: datetime = field(default_factory=datetime.now) media: list[str] = field(default_factory=list) # Media URLs metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data - + session_key_override: str | None = None # Optional override for thread-scoped sessions + @property def session_key(self) -> str: """Unique key for session identification.""" - return f"{self.channel}:{self.chat_id}" + return self.session_key_override or f"{self.channel}:{self.chat_id}" @dataclass class OutboundMessage: """Message to send to a chat channel.""" - + channel: str chat_id: str content: str diff --git a/nanobot/bus/queue.py b/nanobot/bus/queue.py index 4123d06..7c0616f 100644 --- a/nanobot/bus/queue.py +++ b/nanobot/bus/queue.py @@ -1,9 +1,6 @@ """Async message queue for decoupled channel-agent communication.""" import asyncio -from typing import Callable, Awaitable - -from loguru import logger from nanobot.bus.events import InboundMessage, OutboundMessage @@ -11,70 +8,36 @@ from nanobot.bus.events import InboundMessage, OutboundMessage class MessageBus: """ Async message bus that decouples chat channels from the agent core. - + Channels push messages to the inbound queue, and the agent processes them and pushes responses to the outbound queue. """ - + def __init__(self): self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue() self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue() - self._outbound_subscribers: dict[str, list[Callable[[OutboundMessage], Awaitable[None]]]] = {} - self._running = False - + async def publish_inbound(self, msg: InboundMessage) -> None: """Publish a message from a channel to the agent.""" await self.inbound.put(msg) - + async def consume_inbound(self) -> InboundMessage: """Consume the next inbound message (blocks until available).""" return await self.inbound.get() - + async def publish_outbound(self, msg: OutboundMessage) -> None: """Publish a response from the agent to channels.""" await self.outbound.put(msg) - + async def consume_outbound(self) -> OutboundMessage: """Consume the next outbound message (blocks until available).""" return await self.outbound.get() - - def subscribe_outbound( - self, - channel: str, - callback: Callable[[OutboundMessage], Awaitable[None]] - ) -> None: - """Subscribe to outbound messages for a specific channel.""" - if channel not in self._outbound_subscribers: - self._outbound_subscribers[channel] = [] - self._outbound_subscribers[channel].append(callback) - - async def dispatch_outbound(self) -> None: - """ - Dispatch outbound messages to subscribed channels. - Run this as a background task. - """ - self._running = True - while self._running: - try: - msg = await asyncio.wait_for(self.outbound.get(), timeout=1.0) - subscribers = self._outbound_subscribers.get(msg.channel, []) - for callback in subscribers: - try: - await callback(msg) - except Exception as e: - logger.error(f"Error dispatching to {msg.channel}: {e}") - except asyncio.TimeoutError: - continue - - def stop(self) -> None: - """Stop the dispatcher loop.""" - self._running = False - + @property def inbound_size(self) -> int: """Number of pending inbound messages.""" return self.inbound.qsize() - + @property def outbound_size(self) -> int: """Number of pending outbound messages.""" diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 30fcd1a..b38fcaf 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -12,17 +12,17 @@ from nanobot.bus.queue import MessageBus class BaseChannel(ABC): """ Abstract base class for chat channel implementations. - + Each channel (Telegram, Discord, etc.) should implement this interface to integrate with the nanobot message bus. """ - + name: str = "base" - + def __init__(self, config: Any, bus: MessageBus): """ Initialize the channel. - + Args: config: Channel-specific configuration. bus: The message bus for communication. @@ -30,97 +30,89 @@ class BaseChannel(ABC): self.config = config self.bus = bus self._running = False - + @abstractmethod async def start(self) -> None: """ Start the channel and begin listening for messages. - + This should be a long-running async task that: 1. Connects to the chat platform 2. Listens for incoming messages 3. Forwards messages to the bus via _handle_message() """ pass - + @abstractmethod async def stop(self) -> None: """Stop the channel and clean up resources.""" pass - + @abstractmethod async def send(self, msg: OutboundMessage) -> None: """ Send a message through this channel. - + Args: msg: The message to send. """ pass - + def is_allowed(self, sender_id: str) -> bool: - """ - Check if a sender is allowed to use this bot. - - Args: - sender_id: The sender's identifier. - - Returns: - True if allowed, False otherwise. - """ + """Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all.""" allow_list = getattr(self.config, "allow_from", []) - - # If no allow list, allow everyone if not allow_list: + logger.warning("{}: allow_from is empty — all access denied", self.name) + return False + if "*" in allow_list: return True - sender_str = str(sender_id) - if sender_str in allow_list: - return True - if "|" in sender_str: - for part in sender_str.split("|"): - if part and part in allow_list: - return True - return False - + return sender_str in allow_list or any( + p in allow_list for p in sender_str.split("|") if p + ) + async def _handle_message( self, sender_id: str, chat_id: str, content: str, media: list[str] | None = None, - metadata: dict[str, Any] | None = None + metadata: dict[str, Any] | None = None, + session_key: str | None = None, ) -> None: """ Handle an incoming message from the chat platform. - + This method checks permissions and forwards to the bus. - + Args: sender_id: The sender's identifier. chat_id: The chat/channel identifier. content: Message text content. media: Optional list of media URLs. metadata: Optional channel-specific metadata. + session_key: Optional session key override (e.g. thread-scoped sessions). """ if not self.is_allowed(sender_id): logger.warning( - f"Access denied for sender {sender_id} on channel {self.name}. " - f"Add them to allowFrom list in config to grant access." + "Access denied for sender {} on channel {}. " + "Add them to allowFrom list in config to grant access.", + sender_id, self.name, ) return - + msg = InboundMessage( channel=self.name, sender_id=str(sender_id), chat_id=str(chat_id), content=content, media=media or [], - metadata=metadata or {} + metadata=metadata or {}, + session_key_override=session_key, ) - + await self.bus.publish_inbound(msg) - + @property def is_running(self) -> bool: """Check if the channel is running.""" diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 4a8cdd9..3c301a9 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -2,11 +2,15 @@ import asyncio import json +import mimetypes +import os import time +from pathlib import Path from typing import Any +from urllib.parse import unquote, urlparse -from loguru import logger import httpx +from loguru import logger from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus @@ -15,11 +19,11 @@ from nanobot.config.schema import DingTalkConfig try: from dingtalk_stream import ( - DingTalkStreamClient, - Credential, + AckMessage, CallbackHandler, CallbackMessage, - AckMessage, + Credential, + DingTalkStreamClient, ) from dingtalk_stream.chatbot import ChatbotMessage @@ -58,19 +62,32 @@ class NanobotDingTalkHandler(CallbackHandler): if not content: logger.warning( - f"Received empty or unsupported message type: {chatbot_msg.message_type}" + "Received empty or unsupported message type: {}", + chatbot_msg.message_type, ) return AckMessage.STATUS_OK, "OK" sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id sender_name = chatbot_msg.sender_nick or "Unknown" - logger.info(f"Received DingTalk message from {sender_name} ({sender_id}): {content}") + conversation_type = message.data.get("conversationType") + conversation_id = ( + message.data.get("conversationId") + or message.data.get("openConversationId") + ) + + logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content) # Forward to Nanobot via _on_message (non-blocking). # Store reference to prevent GC before task completes. task = asyncio.create_task( - self.channel._on_message(content, sender_id, sender_name) + self.channel._on_message( + content, + sender_id, + sender_name, + conversation_type, + conversation_id, + ) ) self.channel._background_tasks.add(task) task.add_done_callback(self.channel._background_tasks.discard) @@ -78,7 +95,7 @@ class NanobotDingTalkHandler(CallbackHandler): return AckMessage.STATUS_OK, "OK" except Exception as e: - logger.error(f"Error processing DingTalk message: {e}") + logger.error("Error processing DingTalk message: {}", e) # Return OK to avoid retry loop from DingTalk server return AckMessage.STATUS_OK, "Error" @@ -90,11 +107,14 @@ class DingTalkChannel(BaseChannel): Uses WebSocket to receive events via `dingtalk-stream` SDK. Uses direct HTTP API to send messages (SDK is mainly for receiving). - Note: Currently only supports private (1:1) chat. Group messages are - received but replies are sent back as private messages to the sender. + Supports both private (1:1) and group chats. + Group chat_id is stored with a "group:" prefix to route replies back. """ name = "dingtalk" + _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} + _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"} + _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"} def __init__(self, config: DingTalkConfig, bus: MessageBus): super().__init__(config, bus) @@ -126,7 +146,8 @@ class DingTalkChannel(BaseChannel): self._http = httpx.AsyncClient() logger.info( - f"Initializing DingTalk Stream Client with Client ID: {self.config.client_id}..." + "Initializing DingTalk Stream Client with Client ID: {}...", + self.config.client_id, ) credential = Credential(self.config.client_id, self.config.client_secret) self._client = DingTalkStreamClient(credential) @@ -142,13 +163,13 @@ class DingTalkChannel(BaseChannel): try: await self._client.start() except Exception as e: - logger.warning(f"DingTalk stream error: {e}") + logger.warning("DingTalk stream error: {}", e) if self._running: logger.info("Reconnecting DingTalk stream in 5 seconds...") await asyncio.sleep(5) except Exception as e: - logger.exception(f"Failed to start DingTalk channel: {e}") + logger.exception("Failed to start DingTalk channel: {}", e) async def stop(self) -> None: """Stop the DingTalk bot.""" @@ -186,60 +207,265 @@ class DingTalkChannel(BaseChannel): self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60 return self._access_token except Exception as e: - logger.error(f"Failed to get DingTalk access token: {e}") + logger.error("Failed to get DingTalk access token: {}", e) return None + @staticmethod + def _is_http_url(value: str) -> bool: + return urlparse(value).scheme in ("http", "https") + + def _guess_upload_type(self, media_ref: str) -> str: + ext = Path(urlparse(media_ref).path).suffix.lower() + if ext in self._IMAGE_EXTS: return "image" + if ext in self._AUDIO_EXTS: return "voice" + if ext in self._VIDEO_EXTS: return "video" + return "file" + + def _guess_filename(self, media_ref: str, upload_type: str) -> str: + name = os.path.basename(urlparse(media_ref).path) + return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin") + + async def _read_media_bytes( + self, + media_ref: str, + ) -> tuple[bytes | None, str | None, str | None]: + if not media_ref: + return None, None, None + + if self._is_http_url(media_ref): + if not self._http: + return None, None, None + try: + resp = await self._http.get(media_ref, follow_redirects=True) + if resp.status_code >= 400: + logger.warning( + "DingTalk media download failed status={} ref={}", + resp.status_code, + media_ref, + ) + return None, None, None + content_type = (resp.headers.get("content-type") or "").split(";")[0].strip() + filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref)) + return resp.content, filename, content_type or None + except Exception as e: + logger.error("DingTalk media download error ref={} err={}", media_ref, e) + return None, None, None + + try: + if media_ref.startswith("file://"): + parsed = urlparse(media_ref) + local_path = Path(unquote(parsed.path)) + else: + local_path = Path(os.path.expanduser(media_ref)) + if not local_path.is_file(): + logger.warning("DingTalk media file not found: {}", local_path) + return None, None, None + data = await asyncio.to_thread(local_path.read_bytes) + content_type = mimetypes.guess_type(local_path.name)[0] + return data, local_path.name, content_type + except Exception as e: + logger.error("DingTalk media read error ref={} err={}", media_ref, e) + return None, None, None + + async def _upload_media( + self, + token: str, + data: bytes, + media_type: str, + filename: str, + content_type: str | None, + ) -> str | None: + if not self._http: + return None + url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}" + mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" + files = {"media": (filename, data, mime)} + + try: + resp = await self._http.post(url, files=files) + text = resp.text + result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {} + if resp.status_code >= 400: + logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500]) + return None + errcode = result.get("errcode", 0) + if errcode != 0: + logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500]) + return None + sub = result.get("result") or {} + media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId") + if not media_id: + logger.error("DingTalk media upload missing media_id body={}", text[:500]) + return None + return str(media_id) + except Exception as e: + logger.error("DingTalk media upload error type={} err={}", media_type, e) + return None + + async def _send_batch_message( + self, + token: str, + chat_id: str, + msg_key: str, + msg_param: dict[str, Any], + ) -> bool: + if not self._http: + logger.warning("DingTalk HTTP client not initialized, cannot send") + return False + + headers = {"x-acs-dingtalk-access-token": token} + if chat_id.startswith("group:"): + # Group chat + url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send" + payload = { + "robotCode": self.config.client_id, + "openConversationId": chat_id[6:], # Remove "group:" prefix, + "msgKey": msg_key, + "msgParam": json.dumps(msg_param, ensure_ascii=False), + } + else: + # Private chat + url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend" + payload = { + "robotCode": self.config.client_id, + "userIds": [chat_id], + "msgKey": msg_key, + "msgParam": json.dumps(msg_param, ensure_ascii=False), + } + + try: + resp = await self._http.post(url, json=payload, headers=headers) + body = resp.text + if resp.status_code != 200: + logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500]) + return False + try: result = resp.json() + except Exception: result = {} + errcode = result.get("errcode") + if errcode not in (None, 0): + logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500]) + return False + logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key) + return True + except Exception as e: + logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e) + return False + + async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool: + return await self._send_batch_message( + token, + chat_id, + "sampleMarkdown", + {"text": content, "title": "Nanobot Reply"}, + ) + + async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool: + media_ref = (media_ref or "").strip() + if not media_ref: + return True + + upload_type = self._guess_upload_type(media_ref) + if upload_type == "image" and self._is_http_url(media_ref): + ok = await self._send_batch_message( + token, + chat_id, + "sampleImageMsg", + {"photoURL": media_ref}, + ) + if ok: + return True + logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref) + + data, filename, content_type = await self._read_media_bytes(media_ref) + if not data: + logger.error("DingTalk media read failed: {}", media_ref) + return False + + filename = filename or self._guess_filename(media_ref, upload_type) + file_type = Path(filename).suffix.lower().lstrip(".") + if not file_type: + guessed = mimetypes.guess_extension(content_type or "") + file_type = (guessed or ".bin").lstrip(".") + if file_type == "jpeg": + file_type = "jpg" + + media_id = await self._upload_media( + token=token, + data=data, + media_type=upload_type, + filename=filename, + content_type=content_type, + ) + if not media_id: + return False + + if upload_type == "image": + # Verified in production: sampleImageMsg accepts media_id in photoURL. + ok = await self._send_batch_message( + token, + chat_id, + "sampleImageMsg", + {"photoURL": media_id}, + ) + if ok: + return True + logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref) + + return await self._send_batch_message( + token, + chat_id, + "sampleFile", + {"mediaId": media_id, "fileName": filename, "fileType": file_type}, + ) + async def send(self, msg: OutboundMessage) -> None: """Send a message through DingTalk.""" token = await self._get_access_token() if not token: return - # oToMessages/batchSend: sends to individual users (private chat) - # https://open.dingtalk.com/document/orgapp/robot-batch-send-messages - url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend" + if msg.content and msg.content.strip(): + await self._send_markdown_text(token, msg.chat_id, msg.content.strip()) - headers = {"x-acs-dingtalk-access-token": token} + for media_ref in msg.media or []: + ok = await self._send_media_ref(token, msg.chat_id, media_ref) + if ok: + continue + logger.error("DingTalk media send failed for {}", media_ref) + # Send visible fallback so failures are observable by the user. + filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref)) + await self._send_markdown_text( + token, + msg.chat_id, + f"[Attachment send failed: {filename}]", + ) - data = { - "robotCode": self.config.client_id, - "userIds": [msg.chat_id], # chat_id is the user's staffId - "msgKey": "sampleMarkdown", - "msgParam": json.dumps({ - "text": msg.content, - "title": "Nanobot Reply", - }), - } - - if not self._http: - logger.warning("DingTalk HTTP client not initialized, cannot send") - return - - try: - resp = await self._http.post(url, json=data, headers=headers) - if resp.status_code != 200: - logger.error(f"DingTalk send failed: {resp.text}") - else: - logger.debug(f"DingTalk message sent to {msg.chat_id}") - except Exception as e: - logger.error(f"Error sending DingTalk message: {e}") - - async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None: + async def _on_message( + self, + content: str, + sender_id: str, + sender_name: str, + conversation_type: str | None = None, + conversation_id: str | None = None, + ) -> None: """Handle incoming message (called by NanobotDingTalkHandler). Delegates to BaseChannel._handle_message() which enforces allow_from permission checks before publishing to the bus. """ try: - logger.info(f"DingTalk inbound: {content} from {sender_name}") + logger.info("DingTalk inbound: {} from {}", content, sender_name) + is_group = conversation_type == "2" and conversation_id + chat_id = f"group:{conversation_id}" if is_group else sender_id await self._handle_message( sender_id=sender_id, - chat_id=sender_id, # For private chat, chat_id == sender_id + chat_id=chat_id, content=str(content), metadata={ "sender_name": sender_name, "platform": "dingtalk", + "conversation_type": conversation_type, }, ) except Exception as e: - logger.error(f"Error publishing DingTalk message: {e}") + logger.error("Error publishing DingTalk message: {}", e) diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index a76d6ac..0187c62 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -13,10 +13,11 @@ from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import DiscordConfig - +from nanobot.utils.helpers import split_message DISCORD_API_BASE = "https://discord.com/api/v10" MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB +MAX_MESSAGE_LEN = 2000 # Discord message character limit class DiscordChannel(BaseChannel): @@ -32,6 +33,7 @@ class DiscordChannel(BaseChannel): self._heartbeat_task: asyncio.Task | None = None self._typing_tasks: dict[str, asyncio.Task] = {} self._http: httpx.AsyncClient | None = None + self._bot_user_id: str | None = None async def start(self) -> None: """Start the Discord gateway connection.""" @@ -51,7 +53,7 @@ class DiscordChannel(BaseChannel): except asyncio.CancelledError: break except Exception as e: - logger.warning(f"Discord gateway error: {e}") + logger.warning("Discord gateway error: {}", e) if self._running: logger.info("Reconnecting to Discord gateway in 5 seconds...") await asyncio.sleep(5) @@ -73,40 +75,118 @@ class DiscordChannel(BaseChannel): self._http = None async def send(self, msg: OutboundMessage) -> None: - """Send a message through Discord REST API.""" + """Send a message through Discord REST API, including file attachments.""" if not self._http: logger.warning("Discord HTTP client not initialized") return url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages" - payload: dict[str, Any] = {"content": msg.content} - - if msg.reply_to: - payload["message_reference"] = {"message_id": msg.reply_to} - payload["allowed_mentions"] = {"replied_user": False} - headers = {"Authorization": f"Bot {self.config.token}"} try: - for attempt in range(3): - try: - response = await self._http.post(url, headers=headers, json=payload) - if response.status_code == 429: - data = response.json() - retry_after = float(data.get("retry_after", 1.0)) - logger.warning(f"Discord rate limited, retrying in {retry_after}s") - await asyncio.sleep(retry_after) - continue - response.raise_for_status() - return - except Exception as e: - if attempt == 2: - logger.error(f"Error sending Discord message: {e}") - else: - await asyncio.sleep(1) + sent_media = False + failed_media: list[str] = [] + + # Send file attachments first + for media_path in msg.media or []: + if await self._send_file(url, headers, media_path, reply_to=msg.reply_to): + sent_media = True + else: + failed_media.append(Path(media_path).name) + + # Send text content + chunks = split_message(msg.content or "", MAX_MESSAGE_LEN) + if not chunks and failed_media and not sent_media: + chunks = split_message( + "\n".join(f"[attachment: {name} - send failed]" for name in failed_media), + MAX_MESSAGE_LEN, + ) + if not chunks: + return + + for i, chunk in enumerate(chunks): + payload: dict[str, Any] = {"content": chunk} + + # Let the first successful attachment carry the reply if present. + if i == 0 and msg.reply_to and not sent_media: + payload["message_reference"] = {"message_id": msg.reply_to} + payload["allowed_mentions"] = {"replied_user": False} + + if not await self._send_payload(url, headers, payload): + break # Abort remaining chunks on failure finally: await self._stop_typing(msg.chat_id) + async def _send_payload( + self, url: str, headers: dict[str, str], payload: dict[str, Any] + ) -> bool: + """Send a single Discord API payload with retry on rate-limit. Returns True on success.""" + for attempt in range(3): + try: + response = await self._http.post(url, headers=headers, json=payload) + if response.status_code == 429: + data = response.json() + retry_after = float(data.get("retry_after", 1.0)) + logger.warning("Discord rate limited, retrying in {}s", retry_after) + await asyncio.sleep(retry_after) + continue + response.raise_for_status() + return True + except Exception as e: + if attempt == 2: + logger.error("Error sending Discord message: {}", e) + else: + await asyncio.sleep(1) + return False + + async def _send_file( + self, + url: str, + headers: dict[str, str], + file_path: str, + reply_to: str | None = None, + ) -> bool: + """Send a file attachment via Discord REST API using multipart/form-data.""" + path = Path(file_path) + if not path.is_file(): + logger.warning("Discord file not found, skipping: {}", file_path) + return False + + if path.stat().st_size > MAX_ATTACHMENT_BYTES: + logger.warning("Discord file too large (>20MB), skipping: {}", path.name) + return False + + payload_json: dict[str, Any] = {} + if reply_to: + payload_json["message_reference"] = {"message_id": reply_to} + payload_json["allowed_mentions"] = {"replied_user": False} + + for attempt in range(3): + try: + with open(path, "rb") as f: + files = {"files[0]": (path.name, f, "application/octet-stream")} + data: dict[str, Any] = {} + if payload_json: + data["payload_json"] = json.dumps(payload_json) + response = await self._http.post( + url, headers=headers, files=files, data=data + ) + if response.status_code == 429: + resp_data = response.json() + retry_after = float(resp_data.get("retry_after", 1.0)) + logger.warning("Discord rate limited, retrying in {}s", retry_after) + await asyncio.sleep(retry_after) + continue + response.raise_for_status() + logger.info("Discord file sent: {}", path.name) + return True + except Exception as e: + if attempt == 2: + logger.error("Error sending Discord file {}: {}", path.name, e) + else: + await asyncio.sleep(1) + return False + async def _gateway_loop(self) -> None: """Main gateway loop: identify, heartbeat, dispatch events.""" if not self._ws: @@ -116,7 +196,7 @@ class DiscordChannel(BaseChannel): try: data = json.loads(raw) except json.JSONDecodeError: - logger.warning(f"Invalid JSON from Discord gateway: {raw[:100]}") + logger.warning("Invalid JSON from Discord gateway: {}", raw[:100]) continue op = data.get("op") @@ -134,6 +214,10 @@ class DiscordChannel(BaseChannel): await self._identify() elif op == 0 and event_type == "READY": logger.info("Discord gateway READY") + # Capture bot user ID for mention detection + user_data = payload.get("user") or {} + self._bot_user_id = user_data.get("id") + logger.info("Discord bot connected as user {}", self._bot_user_id) elif op == 0 and event_type == "MESSAGE_CREATE": await self._handle_message_create(payload) elif op == 7: @@ -175,7 +259,7 @@ class DiscordChannel(BaseChannel): try: await self._ws.send(json.dumps(payload)) except Exception as e: - logger.warning(f"Discord heartbeat failed: {e}") + logger.warning("Discord heartbeat failed: {}", e) break await asyncio.sleep(interval_s) @@ -190,6 +274,7 @@ class DiscordChannel(BaseChannel): sender_id = str(author.get("id", "")) channel_id = str(payload.get("channel_id", "")) content = payload.get("content") or "" + guild_id = payload.get("guild_id") if not sender_id or not channel_id: return @@ -197,6 +282,11 @@ class DiscordChannel(BaseChannel): if not self.is_allowed(sender_id): return + # Check group channel policy (DMs always respond if is_allowed passes) + if guild_id is not None: + if not self._should_respond_in_group(payload, content): + return + content_parts = [content] if content else [] media_paths: list[str] = [] media_dir = Path.home() / ".nanobot" / "media" @@ -219,7 +309,7 @@ class DiscordChannel(BaseChannel): media_paths.append(str(file_path)) content_parts.append(f"[attachment: {file_path}]") except Exception as e: - logger.warning(f"Failed to download Discord attachment: {e}") + logger.warning("Failed to download Discord attachment: {}", e) content_parts.append(f"[attachment: {filename} - download failed]") reply_to = (payload.get("referenced_message") or {}).get("id") @@ -233,11 +323,32 @@ class DiscordChannel(BaseChannel): media=media_paths, metadata={ "message_id": str(payload.get("id", "")), - "guild_id": payload.get("guild_id"), + "guild_id": guild_id, "reply_to": reply_to, }, ) + def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool: + """Check if bot should respond in a group channel based on policy.""" + if self.config.group_policy == "open": + return True + + if self.config.group_policy == "mention": + # Check if bot was mentioned in the message + if self._bot_user_id: + # Check mentions array + mentions = payload.get("mentions") or [] + for mention in mentions: + if str(mention.get("id")) == self._bot_user_id: + return True + # Also check content for mention format <@USER_ID> + if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content: + return True + logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id")) + return False + + return True + async def _start_typing(self, channel_id: str) -> None: """Start periodic typing indicator for a channel.""" await self._stop_typing(channel_id) @@ -248,8 +359,11 @@ class DiscordChannel(BaseChannel): while self._running: try: await self._http.post(url, headers=headers) - except Exception: - pass + except asyncio.CancelledError: + return + except Exception as e: + logger.debug("Discord typing indicator failed for {}: {}", channel_id, e) + return await asyncio.sleep(8) self._typing_tasks[channel_id] = asyncio.create_task(typing_loop()) diff --git a/nanobot/channels/email.py b/nanobot/channels/email.py index 0e47067..16771fb 100644 --- a/nanobot/channels/email.py +++ b/nanobot/channels/email.py @@ -94,7 +94,7 @@ class EmailChannel(BaseChannel): metadata=item.get("metadata", {}), ) except Exception as e: - logger.error(f"Email polling error: {e}") + logger.error("Email polling error: {}", e) await asyncio.sleep(poll_seconds) @@ -108,11 +108,6 @@ class EmailChannel(BaseChannel): logger.warning("Skip email send: consent_granted is false") return - force_send = bool((msg.metadata or {}).get("force_send")) - if not self.config.auto_reply_enabled and not force_send: - logger.info("Skip automatic email reply: auto_reply_enabled is false") - return - if not self.config.smtp_host: logger.warning("Email channel SMTP host not configured") return @@ -122,6 +117,15 @@ class EmailChannel(BaseChannel): logger.warning("Email channel missing recipient address") return + # Determine if this is a reply (recipient has sent us an email before) + is_reply = to_addr in self._last_subject_by_chat + force_send = bool((msg.metadata or {}).get("force_send")) + + # autoReplyEnabled only controls automatic replies, not proactive sends + if is_reply and not self.config.auto_reply_enabled and not force_send: + logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr) + return + base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply") subject = self._reply_subject(base_subject) if msg.metadata and isinstance(msg.metadata.get("subject"), str): @@ -143,7 +147,7 @@ class EmailChannel(BaseChannel): try: await asyncio.to_thread(self._smtp_send, email_msg) except Exception as e: - logger.error(f"Error sending email to {to_addr}: {e}") + logger.error("Error sending email to {}: {}", to_addr, e) raise def _validate_config(self) -> bool: @@ -162,7 +166,7 @@ class EmailChannel(BaseChannel): missing.append("smtp_password") if missing: - logger.error(f"Email channel not configured, missing: {', '.join(missing)}") + logger.error("Email channel not configured, missing: {}", ', '.join(missing)) return False return True @@ -304,7 +308,8 @@ class EmailChannel(BaseChannel): self._processed_uids.add(uid) # mark_seen is the primary dedup; this set is a safety net if len(self._processed_uids) > self._MAX_PROCESSED_UIDS: - self._processed_uids.clear() + # Evict a random half to cap memory; mark_seen is the primary dedup + self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:]) if mark_seen: client.store(imap_id, "+FLAGS", "\\Seen") diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 23d1415..2dcf710 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -2,9 +2,11 @@ import asyncio import json +import os import re import threading from collections import OrderedDict +from pathlib import Path from typing import Any from loguru import logger @@ -14,21 +16,9 @@ from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import FeishuConfig -try: - import lark_oapi as lark - from lark_oapi.api.im.v1 import ( - CreateMessageRequest, - CreateMessageRequestBody, - CreateMessageReactionRequest, - CreateMessageReactionRequestBody, - Emoji, - P2ImMessageReceiveV1, - ) - FEISHU_AVAILABLE = True -except ImportError: - FEISHU_AVAILABLE = False - lark = None - Emoji = None +import importlib.util + +FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None # Message type display mapping MSG_TYPE_MAP = { @@ -39,57 +29,276 @@ MSG_TYPE_MAP = { } +def _extract_share_card_content(content_json: dict, msg_type: str) -> str: + """Extract text representation from share cards and interactive messages.""" + parts = [] + + if msg_type == "share_chat": + parts.append(f"[shared chat: {content_json.get('chat_id', '')}]") + elif msg_type == "share_user": + parts.append(f"[shared user: {content_json.get('user_id', '')}]") + elif msg_type == "interactive": + parts.extend(_extract_interactive_content(content_json)) + elif msg_type == "share_calendar_event": + parts.append(f"[shared calendar event: {content_json.get('event_key', '')}]") + elif msg_type == "system": + parts.append("[system message]") + elif msg_type == "merge_forward": + parts.append("[merged forward messages]") + + return "\n".join(parts) if parts else f"[{msg_type}]" + + +def _extract_interactive_content(content: dict) -> list[str]: + """Recursively extract text and links from interactive card content.""" + parts = [] + + if isinstance(content, str): + try: + content = json.loads(content) + except (json.JSONDecodeError, TypeError): + return [content] if content.strip() else [] + + if not isinstance(content, dict): + return parts + + if "title" in content: + title = content["title"] + if isinstance(title, dict): + title_content = title.get("content", "") or title.get("text", "") + if title_content: + parts.append(f"title: {title_content}") + elif isinstance(title, str): + parts.append(f"title: {title}") + + for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []: + for element in elements: + parts.extend(_extract_element_content(element)) + + card = content.get("card", {}) + if card: + parts.extend(_extract_interactive_content(card)) + + header = content.get("header", {}) + if header: + header_title = header.get("title", {}) + if isinstance(header_title, dict): + header_text = header_title.get("content", "") or header_title.get("text", "") + if header_text: + parts.append(f"title: {header_text}") + + return parts + + +def _extract_element_content(element: dict) -> list[str]: + """Extract content from a single card element.""" + parts = [] + + if not isinstance(element, dict): + return parts + + tag = element.get("tag", "") + + if tag in ("markdown", "lark_md"): + content = element.get("content", "") + if content: + parts.append(content) + + elif tag == "div": + text = element.get("text", {}) + if isinstance(text, dict): + text_content = text.get("content", "") or text.get("text", "") + if text_content: + parts.append(text_content) + elif isinstance(text, str): + parts.append(text) + for field in element.get("fields", []): + if isinstance(field, dict): + field_text = field.get("text", {}) + if isinstance(field_text, dict): + c = field_text.get("content", "") + if c: + parts.append(c) + + elif tag == "a": + href = element.get("href", "") + text = element.get("text", "") + if href: + parts.append(f"link: {href}") + if text: + parts.append(text) + + elif tag == "button": + text = element.get("text", {}) + if isinstance(text, dict): + c = text.get("content", "") + if c: + parts.append(c) + url = element.get("url", "") or element.get("multi_url", {}).get("url", "") + if url: + parts.append(f"link: {url}") + + elif tag == "img": + alt = element.get("alt", {}) + parts.append(alt.get("content", "[image]") if isinstance(alt, dict) else "[image]") + + elif tag == "note": + for ne in element.get("elements", []): + parts.extend(_extract_element_content(ne)) + + elif tag == "column_set": + for col in element.get("columns", []): + for ce in col.get("elements", []): + parts.extend(_extract_element_content(ce)) + + elif tag == "plain_text": + content = element.get("content", "") + if content: + parts.append(content) + + else: + for ne in element.get("elements", []): + parts.extend(_extract_element_content(ne)) + + return parts + + +def _extract_post_content(content_json: dict) -> tuple[str, list[str]]: + """Extract text and image keys from Feishu post (rich text) message. + + Handles three payload shapes: + - Direct: {"title": "...", "content": [[...]]} + - Localized: {"zh_cn": {"title": "...", "content": [...]}} + - Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}} + """ + + def _parse_block(block: dict) -> tuple[str | None, list[str]]: + if not isinstance(block, dict) or not isinstance(block.get("content"), list): + return None, [] + texts, images = [], [] + if title := block.get("title"): + texts.append(title) + for row in block["content"]: + if not isinstance(row, list): + continue + for el in row: + if not isinstance(el, dict): + continue + tag = el.get("tag") + if tag in ("text", "a"): + texts.append(el.get("text", "")) + elif tag == "at": + texts.append(f"@{el.get('user_name', 'user')}") + elif tag == "img" and (key := el.get("image_key")): + images.append(key) + return (" ".join(texts).strip() or None), images + + # Unwrap optional {"post": ...} envelope + root = content_json + if isinstance(root, dict) and isinstance(root.get("post"), dict): + root = root["post"] + if not isinstance(root, dict): + return "", [] + + # Direct format + if "content" in root: + text, imgs = _parse_block(root) + if text or imgs: + return text or "", imgs + + # Localized: prefer known locales, then fall back to any dict child + for key in ("zh_cn", "en_us", "ja_jp"): + if key in root: + text, imgs = _parse_block(root[key]) + if text or imgs: + return text or "", imgs + for val in root.values(): + if isinstance(val, dict): + text, imgs = _parse_block(val) + if text or imgs: + return text or "", imgs + + return "", [] + + +def _extract_post_text(content_json: dict) -> str: + """Extract plain text from Feishu post (rich text) message content. + + Legacy wrapper for _extract_post_content, returns only text. + """ + text, _ = _extract_post_content(content_json) + return text + + class FeishuChannel(BaseChannel): """ Feishu/Lark channel using WebSocket long connection. - + Uses WebSocket to receive events - no public IP or webhook required. - + Requires: - App ID and App Secret from Feishu Open Platform - Bot capability enabled - Event subscription enabled (im.message.receive_v1) """ - + name = "feishu" - - def __init__(self, config: FeishuConfig, bus: MessageBus): + + def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""): super().__init__(config, bus) self.config: FeishuConfig = config + self.groq_api_key = groq_api_key self._client: Any = None self._ws_client: Any = None self._ws_thread: threading.Thread | None = None self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache self._loop: asyncio.AbstractEventLoop | None = None - + + @staticmethod + def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any: + """Register an event handler only when the SDK supports it.""" + method = getattr(builder, method_name, None) + return method(handler) if callable(method) else builder + async def start(self) -> None: """Start the Feishu bot with WebSocket long connection.""" if not FEISHU_AVAILABLE: logger.error("Feishu SDK not installed. Run: pip install lark-oapi") return - + if not self.config.app_id or not self.config.app_secret: logger.error("Feishu app_id and app_secret not configured") return - + + import lark_oapi as lark self._running = True self._loop = asyncio.get_running_loop() - + # Create Lark client for sending messages self._client = lark.Client.builder() \ .app_id(self.config.app_id) \ .app_secret(self.config.app_secret) \ .log_level(lark.LogLevel.INFO) \ .build() - - # Create event handler (only register message receive, ignore other events) - event_handler = lark.EventDispatcherHandler.builder( + builder = lark.EventDispatcherHandler.builder( self.config.encrypt_key or "", self.config.verification_token or "", ).register_p2_im_message_receive_v1( self._on_message_sync - ).build() - + ) + builder = self._register_optional_event( + builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created + ) + builder = self._register_optional_event( + builder, "register_p2_im_message_message_read_v1", self._on_message_read + ) + builder = self._register_optional_event( + builder, + "register_p2_im_chat_access_event_bot_p2p_chat_entered_v1", + self._on_bot_p2p_chat_entered, + ) + event_handler = builder.build() + # Create WebSocket client for long connection self._ws_client = lark.ws.Client( self.config.app_id, @@ -97,39 +306,54 @@ class FeishuChannel(BaseChannel): event_handler=event_handler, log_level=lark.LogLevel.INFO ) - - # Start WebSocket client in a separate thread with reconnect loop + + # Start WebSocket client in a separate thread with reconnect loop. + # A dedicated event loop is created for this thread so that lark_oapi's + # module-level `loop = asyncio.get_event_loop()` picks up an idle loop + # instead of the already-running main asyncio loop, which would cause + # "This event loop is already running" errors. def run_ws(): - while self._running: - try: - self._ws_client.start() - except Exception as e: - logger.warning(f"Feishu WebSocket error: {e}") - if self._running: - import time; time.sleep(5) - + import time + import lark_oapi.ws.client as _lark_ws_client + ws_loop = asyncio.new_event_loop() + asyncio.set_event_loop(ws_loop) + # Patch the module-level loop used by lark's ws Client.start() + _lark_ws_client.loop = ws_loop + try: + while self._running: + try: + self._ws_client.start() + except Exception as e: + logger.warning("Feishu WebSocket error: {}", e) + if self._running: + time.sleep(5) + finally: + ws_loop.close() + self._ws_thread = threading.Thread(target=run_ws, daemon=True) self._ws_thread.start() - + logger.info("Feishu bot started with WebSocket long connection") logger.info("No public IP required - using WebSocket to receive events") - + # Keep running until stopped while self._running: await asyncio.sleep(1) - + async def stop(self) -> None: - """Stop the Feishu bot.""" + """ + Stop the Feishu bot. + + Notice: lark.ws.Client does not expose stop method, simply exiting the program will close the client. + + Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86 + """ self._running = False - if self._ws_client: - try: - self._ws_client.stop() - except Exception as e: - logger.warning(f"Error stopping WebSocket client: {e}") logger.info("Feishu bot stopped") - + def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: """Sync helper for adding reaction (runs in thread pool).""" + from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji try: request = CreateMessageReactionRequest.builder() \ .message_id(message_id) \ @@ -138,43 +362,48 @@ class FeishuChannel(BaseChannel): .reaction_type(Emoji.builder().emoji_type(emoji_type).build()) .build() ).build() - + response = self._client.im.v1.message_reaction.create(request) - + if not response.success(): - logger.warning(f"Failed to add reaction: code={response.code}, msg={response.msg}") + logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg) else: - logger.debug(f"Added {emoji_type} reaction to message {message_id}") + logger.debug("Added {} reaction to message {}", emoji_type, message_id) except Exception as e: - logger.warning(f"Error adding reaction: {e}") + logger.warning("Error adding reaction: {}", e) async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None: """ Add a reaction emoji to a message (non-blocking). - + Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART """ - if not self._client or not Emoji: + if not self._client: return - + loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type) - + # Regex to match markdown tables (header + separator + data rows) _TABLE_RE = re.compile( r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)", re.MULTILINE, ) + _HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) + + _CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE) + @staticmethod def _parse_md_table(table_text: str) -> dict | None: """Parse a markdown table into a Feishu table element.""" - lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()] + lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()] if len(lines) < 3: return None - split = lambda l: [c.strip() for c in l.strip("|").split("|")] + def split(_line: str) -> list[str]: + return [c.strip() for c in _line.strip("|").split("|")] headers = split(lines[0]) - rows = [split(l) for l in lines[2:]] + rows = [split(_line) for _line in lines[2:]] columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"} for i, h in enumerate(headers)] return { @@ -185,126 +414,572 @@ class FeishuChannel(BaseChannel): } def _build_card_elements(self, content: str) -> list[dict]: - """Split content into markdown + table elements for Feishu card.""" + """Split content into div/markdown + table elements for Feishu card.""" elements, last_end = [], 0 for m in self._TABLE_RE.finditer(content): - before = content[last_end:m.start()].strip() - if before: - elements.append({"tag": "markdown", "content": before}) + before = content[last_end:m.start()] + if before.strip(): + elements.extend(self._split_headings(before)) elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)}) last_end = m.end() - remaining = content[last_end:].strip() - if remaining: - elements.append({"tag": "markdown", "content": remaining}) + remaining = content[last_end:] + if remaining.strip(): + elements.extend(self._split_headings(remaining)) return elements or [{"tag": "markdown", "content": content}] - async def send(self, msg: OutboundMessage) -> None: - """Send a message through Feishu.""" - if not self._client: - logger.warning("Feishu client not initialized") - return - - try: - # Determine receive_id_type based on chat_id format - # open_id starts with "ou_", chat_id starts with "oc_" - if msg.chat_id.startswith("oc_"): - receive_id_type = "chat_id" + @staticmethod + def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]: + """Split card elements into groups with at most *max_tables* table elements each. + + Feishu cards have a hard limit of one table per card (API error 11310). + When the rendered content contains multiple markdown tables each table is + placed in a separate card message so every table reaches the user. + """ + if not elements: + return [[]] + groups: list[list[dict]] = [] + current: list[dict] = [] + table_count = 0 + for el in elements: + if el.get("tag") == "table": + if table_count >= max_tables: + if current: + groups.append(current) + current = [] + table_count = 0 + current.append(el) + table_count += 1 else: - receive_id_type = "open_id" - - # Build card with markdown + table support - elements = self._build_card_elements(msg.content) - card = { - "config": {"wide_screen_mode": True}, - "elements": elements, + current.append(el) + if current: + groups.append(current) + return groups or [[]] + + def _split_headings(self, content: str) -> list[dict]: + """Split content by headings, converting headings to div elements.""" + protected = content + code_blocks = [] + for m in self._CODE_BLOCK_RE.finditer(content): + code_blocks.append(m.group(1)) + protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1) + + elements = [] + last_end = 0 + for m in self._HEADING_RE.finditer(protected): + before = protected[last_end:m.start()].strip() + if before: + elements.append({"tag": "markdown", "content": before}) + text = m.group(2).strip() + elements.append({ + "tag": "div", + "text": { + "tag": "lark_md", + "content": f"**{text}**", + }, + }) + last_end = m.end() + remaining = protected[last_end:].strip() + if remaining: + elements.append({"tag": "markdown", "content": remaining}) + + for i, cb in enumerate(code_blocks): + for el in elements: + if el.get("tag") == "markdown": + el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb) + + return elements or [{"tag": "markdown", "content": content}] + + # ── Smart format detection ────────────────────────────────────────── + # Patterns that indicate "complex" markdown needing card rendering + _COMPLEX_MD_RE = re.compile( + r"```" # fenced code block + r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator) + r"|^#{1,6}\s+" # headings + , re.MULTILINE, + ) + + # Simple markdown patterns (bold, italic, strikethrough) + _SIMPLE_MD_RE = re.compile( + r"\*\*.+?\*\*" # **bold** + r"|__.+?__" # __bold__ + r"|(? str: + """Determine the optimal Feishu message format for *content*. + + Returns one of: + - ``"text"`` – plain text, short and no markdown + - ``"post"`` – rich text (links only, moderate length) + - ``"interactive"`` – card with full markdown rendering + """ + stripped = content.strip() + + # Complex markdown (code blocks, tables, headings) → always card + if cls._COMPLEX_MD_RE.search(stripped): + return "interactive" + + # Long content → card (better readability with card layout) + if len(stripped) > cls._POST_MAX_LEN: + return "interactive" + + # Has bold/italic/strikethrough → card (post format can't render these) + if cls._SIMPLE_MD_RE.search(stripped): + return "interactive" + + # Has list items → card (post format can't render list bullets well) + if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped): + return "interactive" + + # Has links → post format (supports tags) + if cls._MD_LINK_RE.search(stripped): + return "post" + + # Short plain text → text format + if len(stripped) <= cls._TEXT_MAX_LEN: + return "text" + + # Medium plain text without any formatting → post format + return "post" + + @classmethod + def _markdown_to_post(cls, content: str) -> str: + """Convert markdown content to Feishu post message JSON. + + Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags. + Each line becomes a paragraph (row) in the post body. + """ + lines = content.strip().split("\n") + paragraphs: list[list[dict]] = [] + + for line in lines: + elements: list[dict] = [] + last_end = 0 + + for m in cls._MD_LINK_RE.finditer(line): + # Text before this link + before = line[last_end:m.start()] + if before: + elements.append({"tag": "text", "text": before}) + elements.append({ + "tag": "a", + "text": m.group(1), + "href": m.group(2), + }) + last_end = m.end() + + # Remaining text after last link + remaining = line[last_end:] + if remaining: + elements.append({"tag": "text", "text": remaining}) + + # Empty line → empty paragraph for spacing + if not elements: + elements.append({"tag": "text", "text": ""}) + + paragraphs.append(elements) + + post_body = { + "zh_cn": { + "content": paragraphs, } - content = json.dumps(card, ensure_ascii=False) - + } + return json.dumps(post_body, ensure_ascii=False) + + _IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"} + _AUDIO_EXTS = {".opus"} + _VIDEO_EXTS = {".mp4", ".mov", ".avi"} + _FILE_TYPE_MAP = { + ".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc", + ".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt", + } + + def _upload_image_sync(self, file_path: str) -> str | None: + """Upload an image to Feishu and return the image_key.""" + from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody + try: + with open(file_path, "rb") as f: + request = CreateImageRequest.builder() \ + .request_body( + CreateImageRequestBody.builder() + .image_type("message") + .image(f) + .build() + ).build() + response = self._client.im.v1.image.create(request) + if response.success(): + image_key = response.data.image_key + logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key) + return image_key + else: + logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg) + return None + except Exception as e: + logger.error("Error uploading image {}: {}", file_path, e) + return None + + def _upload_file_sync(self, file_path: str) -> str | None: + """Upload a file to Feishu and return the file_key.""" + from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody + ext = os.path.splitext(file_path)[1].lower() + file_type = self._FILE_TYPE_MAP.get(ext, "stream") + file_name = os.path.basename(file_path) + try: + with open(file_path, "rb") as f: + request = CreateFileRequest.builder() \ + .request_body( + CreateFileRequestBody.builder() + .file_type(file_type) + .file_name(file_name) + .file(f) + .build() + ).build() + response = self._client.im.v1.file.create(request) + if response.success(): + file_key = response.data.file_key + logger.debug("Uploaded file {}: {}", file_name, file_key) + return file_key + else: + logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg) + return None + except Exception as e: + logger.error("Error uploading file {}: {}", file_path, e) + return None + + def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]: + """Download an image from Feishu message by message_id and image_key.""" + from lark_oapi.api.im.v1 import GetMessageResourceRequest + try: + request = GetMessageResourceRequest.builder() \ + .message_id(message_id) \ + .file_key(image_key) \ + .type("image") \ + .build() + response = self._client.im.v1.message_resource.get(request) + if response.success(): + file_data = response.file + # GetMessageResourceRequest returns BytesIO, need to read bytes + if hasattr(file_data, 'read'): + file_data = file_data.read() + return file_data, response.file_name + else: + logger.error("Failed to download image: code={}, msg={}", response.code, response.msg) + return None, None + except Exception as e: + logger.error("Error downloading image {}: {}", image_key, e) + return None, None + + def _download_file_sync( + self, message_id: str, file_key: str, resource_type: str = "file" + ) -> tuple[bytes | None, str | None]: + """Download a file/audio/media from a Feishu message by message_id and file_key.""" + from lark_oapi.api.im.v1 import GetMessageResourceRequest + + # Feishu API only accepts 'image' or 'file' as type parameter + # Convert 'audio' to 'file' for API compatibility + if resource_type == "audio": + resource_type = "file" + + try: + request = ( + GetMessageResourceRequest.builder() + .message_id(message_id) + .file_key(file_key) + .type(resource_type) + .build() + ) + response = self._client.im.v1.message_resource.get(request) + if response.success(): + file_data = response.file + if hasattr(file_data, "read"): + file_data = file_data.read() + return file_data, response.file_name + else: + logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg) + return None, None + except Exception: + logger.exception("Error downloading {} {}", resource_type, file_key) + return None, None + + async def _download_and_save_media( + self, + msg_type: str, + content_json: dict, + message_id: str | None = None + ) -> tuple[str | None, str]: + """ + Download media from Feishu and save to local disk. + + Returns: + (file_path, content_text) - file_path is None if download failed + """ + loop = asyncio.get_running_loop() + media_dir = Path.home() / ".nanobot" / "media" + media_dir.mkdir(parents=True, exist_ok=True) + + data, filename = None, None + + if msg_type == "image": + image_key = content_json.get("image_key") + if image_key and message_id: + data, filename = await loop.run_in_executor( + None, self._download_image_sync, message_id, image_key + ) + if not filename: + filename = f"{image_key[:16]}.jpg" + + elif msg_type in ("audio", "file", "media"): + file_key = content_json.get("file_key") + if file_key and message_id: + data, filename = await loop.run_in_executor( + None, self._download_file_sync, message_id, file_key, msg_type + ) + if not filename: + ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "") + filename = f"{file_key[:16]}{ext}" + + if data and filename: + file_path = media_dir / filename + file_path.write_bytes(data) + logger.debug("Downloaded {} to {}", msg_type, file_path) + return str(file_path), f"[{msg_type}: {filename}]" + + return None, f"[{msg_type}: download failed]" + + def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool: + """Send a single message (text/image/file/interactive) synchronously.""" + from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody + try: request = CreateMessageRequest.builder() \ .receive_id_type(receive_id_type) \ .request_body( CreateMessageRequestBody.builder() - .receive_id(msg.chat_id) - .msg_type("interactive") + .receive_id(receive_id) + .msg_type(msg_type) .content(content) .build() ).build() - response = self._client.im.v1.message.create(request) - if not response.success(): logger.error( - f"Failed to send Feishu message: code={response.code}, " - f"msg={response.msg}, log_id={response.get_log_id()}" + "Failed to send Feishu {} message: code={}, msg={}, log_id={}", + msg_type, response.code, response.msg, response.get_log_id() ) - else: - logger.debug(f"Feishu message sent to {msg.chat_id}") - + return False + logger.debug("Feishu {} message sent to {}", msg_type, receive_id) + return True except Exception as e: - logger.error(f"Error sending Feishu message: {e}") - - def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None: + logger.error("Error sending Feishu {} message: {}", msg_type, e) + return False + + async def send(self, msg: OutboundMessage) -> None: + """Send a message through Feishu, including media (images/files) if present.""" + if not self._client: + logger.warning("Feishu client not initialized") + return + + try: + receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id" + loop = asyncio.get_running_loop() + + for file_path in msg.media: + if not os.path.isfile(file_path): + logger.warning("Media file not found: {}", file_path) + continue + ext = os.path.splitext(file_path)[1].lower() + if ext in self._IMAGE_EXTS: + key = await loop.run_in_executor(None, self._upload_image_sync, file_path) + if key: + await loop.run_in_executor( + None, self._send_message_sync, + receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}, ensure_ascii=False), + ) + else: + key = await loop.run_in_executor(None, self._upload_file_sync, file_path) + if key: + # Use msg_type "media" for audio/video so users can play inline; + # "file" for everything else (documents, archives, etc.) + if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS: + media_type = "media" + else: + media_type = "file" + await loop.run_in_executor( + None, self._send_message_sync, + receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False), + ) + + if msg.content and msg.content.strip(): + fmt = self._detect_msg_format(msg.content) + + if fmt == "text": + # Short plain text – send as simple text message + text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False) + await loop.run_in_executor( + None, self._send_message_sync, + receive_id_type, msg.chat_id, "text", text_body, + ) + + elif fmt == "post": + # Medium content with links – send as rich-text post + post_body = self._markdown_to_post(msg.content) + await loop.run_in_executor( + None, self._send_message_sync, + receive_id_type, msg.chat_id, "post", post_body, + ) + + else: + # Complex / long content – send as interactive card + elements = self._build_card_elements(msg.content) + for chunk in self._split_elements_by_table_limit(elements): + card = {"config": {"wide_screen_mode": True}, "elements": chunk} + await loop.run_in_executor( + None, self._send_message_sync, + receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False), + ) + + except Exception as e: + logger.error("Error sending Feishu message: {}", e) + + def _on_message_sync(self, data: Any) -> None: """ Sync handler for incoming messages (called from WebSocket thread). Schedules async handling in the main event loop. """ if self._loop and self._loop.is_running(): asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop) - - async def _on_message(self, data: "P2ImMessageReceiveV1") -> None: + + async def _on_message(self, data: Any) -> None: """Handle incoming message from Feishu.""" try: event = data.event message = event.message sender = event.sender - + # Deduplication check message_id = message.message_id if message_id in self._processed_message_ids: return self._processed_message_ids[message_id] = None - - # Trim cache: keep most recent 500 when exceeds 1000 + + # Trim cache while len(self._processed_message_ids) > 1000: self._processed_message_ids.popitem(last=False) - + # Skip bot messages - sender_type = sender.sender_type - if sender_type == "bot": + if sender.sender_type == "bot": return - + sender_id = sender.sender_id.open_id if sender.sender_id else "unknown" chat_id = message.chat_id - chat_type = message.chat_type # "p2p" or "group" + chat_type = message.chat_type msg_type = message.message_type - - # Add reaction to indicate "seen" - await self._add_reaction(message_id, "THUMBSUP") - - # Parse message content + + # Add reaction + await self._add_reaction(message_id, self.config.react_emoji) + + # Parse content + content_parts = [] + media_paths = [] + + try: + content_json = json.loads(message.content) if message.content else {} + except json.JSONDecodeError: + content_json = {} + if msg_type == "text": - try: - content = json.loads(message.content).get("text", "") - except json.JSONDecodeError: - content = message.content or "" + text = content_json.get("text", "") + if text: + content_parts.append(text) + + elif msg_type == "post": + text, image_keys = _extract_post_content(content_json) + if text: + content_parts.append(text) + # Download images embedded in post + for img_key in image_keys: + file_path, content_text = await self._download_and_save_media( + "image", {"image_key": img_key}, message_id + ) + if file_path: + media_paths.append(file_path) + content_parts.append(content_text) + + elif msg_type in ("image", "audio", "file", "media"): + file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id) + if file_path: + media_paths.append(file_path) + + # Transcribe audio using Groq Whisper + if msg_type == "audio" and file_path and self.groq_api_key: + try: + from nanobot.providers.transcription import GroqTranscriptionProvider + transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key) + transcription = await transcriber.transcribe(file_path) + if transcription: + content_text = f"[transcription: {transcription}]" + except Exception as e: + logger.warning("Failed to transcribe audio: {}", e) + + content_parts.append(content_text) + + elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"): + # Handle share cards and interactive messages + text = _extract_share_card_content(content_json, msg_type) + if text: + content_parts.append(text) + else: - content = MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]") - - if not content: + content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]")) + + content = "\n".join(content_parts) if content_parts else "" + + if not content and not media_paths: return - + # Forward to message bus reply_to = chat_id if chat_type == "group" else sender_id await self._handle_message( sender_id=sender_id, chat_id=reply_to, content=content, + media=media_paths, metadata={ "message_id": message_id, "chat_type": chat_type, "msg_type": msg_type, } ) - + except Exception as e: - logger.error(f"Error processing Feishu message: {e}") + logger.error("Error processing Feishu message: {}", e) + + def _on_reaction_created(self, data: Any) -> None: + """Ignore reaction events so they do not generate SDK noise.""" + pass + + def _on_message_read(self, data: Any) -> None: + """Ignore read events so they do not generate SDK noise.""" + pass + + def _on_bot_p2p_chat_entered(self, data: Any) -> None: + """Ignore p2p-enter events when a user opens a bot chat.""" + logger.debug("Bot entered p2p chat (user opened chat window)") + pass diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 464fa97..51539dd 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Any, TYPE_CHECKING +from typing import Any from loguru import logger @@ -12,32 +12,28 @@ from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import Config -if TYPE_CHECKING: - from nanobot.session.manager import SessionManager - class ChannelManager: """ Manages chat channels and coordinates message routing. - + Responsibilities: - Initialize enabled channels (Telegram, WhatsApp, etc.) - Start/stop channels - Route outbound messages """ - - def __init__(self, config: Config, bus: MessageBus, session_manager: "SessionManager | None" = None): + + def __init__(self, config: Config, bus: MessageBus): self.config = config self.bus = bus - self.session_manager = session_manager self.channels: dict[str, BaseChannel] = {} self._dispatch_task: asyncio.Task | None = None - + self._init_channels() - + def _init_channels(self) -> None: """Initialize channels based on config.""" - + # Telegram channel if self.config.channels.telegram.enabled: try: @@ -46,12 +42,11 @@ class ChannelManager: self.config.channels.telegram, self.bus, groq_api_key=self.config.providers.groq.api_key, - session_manager=self.session_manager, ) logger.info("Telegram channel enabled") except ImportError as e: - logger.warning(f"Telegram channel not available: {e}") - + logger.warning("Telegram channel not available: {}", e) + # WhatsApp channel if self.config.channels.whatsapp.enabled: try: @@ -61,7 +56,7 @@ class ChannelManager: ) logger.info("WhatsApp channel enabled") except ImportError as e: - logger.warning(f"WhatsApp channel not available: {e}") + logger.warning("WhatsApp channel not available: {}", e) # Discord channel if self.config.channels.discord.enabled: @@ -72,18 +67,19 @@ class ChannelManager: ) logger.info("Discord channel enabled") except ImportError as e: - logger.warning(f"Discord channel not available: {e}") - + logger.warning("Discord channel not available: {}", e) + # Feishu channel if self.config.channels.feishu.enabled: try: from nanobot.channels.feishu import FeishuChannel self.channels["feishu"] = FeishuChannel( - self.config.channels.feishu, self.bus + self.config.channels.feishu, self.bus, + groq_api_key=self.config.providers.groq.api_key, ) logger.info("Feishu channel enabled") except ImportError as e: - logger.warning(f"Feishu channel not available: {e}") + logger.warning("Feishu channel not available: {}", e) # Mochat channel if self.config.channels.mochat.enabled: @@ -95,7 +91,7 @@ class ChannelManager: ) logger.info("Mochat channel enabled") except ImportError as e: - logger.warning(f"Mochat channel not available: {e}") + logger.warning("Mochat channel not available: {}", e) # DingTalk channel if self.config.channels.dingtalk.enabled: @@ -106,7 +102,7 @@ class ChannelManager: ) logger.info("DingTalk channel enabled") except ImportError as e: - logger.warning(f"DingTalk channel not available: {e}") + logger.warning("DingTalk channel not available: {}", e) # Email channel if self.config.channels.email.enabled: @@ -117,7 +113,7 @@ class ChannelManager: ) logger.info("Email channel enabled") except ImportError as e: - logger.warning(f"Email channel not available: {e}") + logger.warning("Email channel not available: {}", e) # Slack channel if self.config.channels.slack.enabled: @@ -128,7 +124,7 @@ class ChannelManager: ) logger.info("Slack channel enabled") except ImportError as e: - logger.warning(f"Slack channel not available: {e}") + logger.warning("Slack channel not available: {}", e) # QQ channel if self.config.channels.qq.enabled: @@ -140,37 +136,59 @@ class ChannelManager: ) logger.info("QQ channel enabled") except ImportError as e: - logger.warning(f"QQ channel not available: {e}") - + logger.warning("QQ channel not available: {}", e) + + # Matrix channel + if self.config.channels.matrix.enabled: + try: + from nanobot.channels.matrix import MatrixChannel + self.channels["matrix"] = MatrixChannel( + self.config.channels.matrix, + self.bus, + ) + logger.info("Matrix channel enabled") + except ImportError as e: + logger.warning("Matrix channel not available: {}", e) + + self._validate_allow_from() + + def _validate_allow_from(self) -> None: + for name, ch in self.channels.items(): + if getattr(ch.config, "allow_from", None) == []: + raise SystemExit( + f'Error: "{name}" has empty allowFrom (denies all). ' + f'Set ["*"] to allow everyone, or add specific user IDs.' + ) + async def _start_channel(self, name: str, channel: BaseChannel) -> None: """Start a channel and log any exceptions.""" try: await channel.start() except Exception as e: - logger.error(f"Failed to start channel {name}: {e}") + logger.error("Failed to start channel {}: {}", name, e) async def start_all(self) -> None: """Start all channels and the outbound dispatcher.""" if not self.channels: logger.warning("No channels enabled") return - + # Start outbound dispatcher self._dispatch_task = asyncio.create_task(self._dispatch_outbound()) - + # Start channels tasks = [] for name, channel in self.channels.items(): - logger.info(f"Starting {name} channel...") + logger.info("Starting {} channel...", name) tasks.append(asyncio.create_task(self._start_channel(name, channel))) - + # Wait for all to complete (they should run forever) await asyncio.gather(*tasks, return_exceptions=True) - + async def stop_all(self) -> None: """Stop all channels and the dispatcher.""" logger.info("Stopping all channels...") - + # Stop dispatcher if self._dispatch_task: self._dispatch_task.cancel() @@ -178,44 +196,50 @@ class ChannelManager: await self._dispatch_task except asyncio.CancelledError: pass - + # Stop all channels for name, channel in self.channels.items(): try: await channel.stop() - logger.info(f"Stopped {name} channel") + logger.info("Stopped {} channel", name) except Exception as e: - logger.error(f"Error stopping {name}: {e}") - + logger.error("Error stopping {}: {}", name, e) + async def _dispatch_outbound(self) -> None: """Dispatch outbound messages to the appropriate channel.""" logger.info("Outbound dispatcher started") - + while True: try: msg = await asyncio.wait_for( self.bus.consume_outbound(), timeout=1.0 ) - + + if msg.metadata.get("_progress"): + if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints: + continue + if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress: + continue + channel = self.channels.get(msg.channel) if channel: try: await channel.send(msg) except Exception as e: - logger.error(f"Error sending to {msg.channel}: {e}") + logger.error("Error sending to {}: {}", msg.channel, e) else: - logger.warning(f"Unknown channel: {msg.channel}") - + logger.warning("Unknown channel: {}", msg.channel) + except asyncio.TimeoutError: continue except asyncio.CancelledError: break - + def get_channel(self, name: str) -> BaseChannel | None: """Get a channel by name.""" return self.channels.get(name) - + def get_status(self) -> dict[str, Any]: """Get status of all channels.""" return { @@ -225,7 +249,7 @@ class ChannelManager: } for name, channel in self.channels.items() } - + @property def enabled_channels(self) -> list[str]: """Get list of enabled channel names.""" diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py new file mode 100644 index 0000000..4967ac1 --- /dev/null +++ b/nanobot/channels/matrix.py @@ -0,0 +1,699 @@ +"""Matrix (Element) channel — inbound sync + outbound message/media delivery.""" + +import asyncio +import logging +import mimetypes +from pathlib import Path +from typing import Any, TypeAlias + +from loguru import logger + +try: + import nh3 + from mistune import create_markdown + from nio import ( + AsyncClient, + AsyncClientConfig, + ContentRepositoryConfigError, + DownloadError, + InviteEvent, + JoinError, + MatrixRoom, + MemoryDownloadResponse, + RoomEncryptedMedia, + RoomMessage, + RoomMessageMedia, + RoomMessageText, + RoomSendError, + RoomTypingError, + SyncError, + UploadError, + ) + from nio.crypto.attachments import decrypt_attachment + from nio.exceptions import EncryptionError +except ImportError as e: + raise ImportError( + "Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]" + ) from e + +from nanobot.bus.events import OutboundMessage +from nanobot.channels.base import BaseChannel +from nanobot.config.loader import get_data_dir +from nanobot.utils.helpers import safe_filename + +TYPING_NOTICE_TIMEOUT_MS = 30_000 +# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing. +TYPING_KEEPALIVE_INTERVAL_MS = 20_000 +MATRIX_HTML_FORMAT = "org.matrix.custom.html" +_ATTACH_MARKER = "[attachment: {}]" +_ATTACH_TOO_LARGE = "[attachment: {} - too large]" +_ATTACH_FAILED = "[attachment: {} - download failed]" +_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]" +_DEFAULT_ATTACH_NAME = "attachment" +_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"} + +MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia) +MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia + +MATRIX_MARKDOWN = create_markdown( + escape=True, + plugins=["table", "strikethrough", "url", "superscript", "subscript"], +) + +MATRIX_ALLOWED_HTML_TAGS = { + "p", "a", "strong", "em", "del", "code", "pre", "blockquote", + "ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6", + "hr", "br", "table", "thead", "tbody", "tr", "th", "td", + "caption", "sup", "sub", "img", +} +MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = { + "a": {"href"}, "code": {"class"}, "ol": {"start"}, + "img": {"src", "alt", "title", "width", "height"}, +} +MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"} + + +def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None: + """Filter attribute values to a safe Matrix-compatible subset.""" + if tag == "a" and attr == "href": + return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None + if tag == "img" and attr == "src": + return value if value.lower().startswith("mxc://") else None + if tag == "code" and attr == "class": + classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")] + return " ".join(classes) if classes else None + return value + + +MATRIX_HTML_CLEANER = nh3.Cleaner( + tags=MATRIX_ALLOWED_HTML_TAGS, + attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES, + attribute_filter=_filter_matrix_html_attribute, + url_schemes=MATRIX_ALLOWED_URL_SCHEMES, + strip_comments=True, + link_rel="noopener noreferrer", +) + + +def _render_markdown_html(text: str) -> str | None: + """Render markdown to sanitized HTML; returns None for plain text.""" + try: + formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip() + except Exception: + return None + if not formatted: + return None + # Skip formatted_body for plain

text

to keep payload minimal. + if formatted.startswith("

") and formatted.endswith("

"): + inner = formatted[3:-4] + if "<" not in inner and ">" not in inner: + return None + return formatted + + +def _build_matrix_text_content(text: str) -> dict[str, object]: + """Build Matrix m.text payload with optional HTML formatted_body.""" + content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}} + if html := _render_markdown_html(text): + content["format"] = MATRIX_HTML_FORMAT + content["formatted_body"] = html + return content + + +class _NioLoguruHandler(logging.Handler): + """Route matrix-nio stdlib logs into Loguru.""" + + def emit(self, record: logging.LogRecord) -> None: + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + frame, depth = logging.currentframe(), 2 + while frame and frame.f_code.co_filename == logging.__file__: + frame, depth = frame.f_back, depth + 1 + logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + + +def _configure_nio_logging_bridge() -> None: + """Bridge matrix-nio logs to Loguru (idempotent).""" + nio_logger = logging.getLogger("nio") + if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers): + nio_logger.handlers = [_NioLoguruHandler()] + nio_logger.propagate = False + + +class MatrixChannel(BaseChannel): + """Matrix (Element) channel using long-polling sync.""" + + name = "matrix" + + def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False, + workspace: Path | None = None): + super().__init__(config, bus) + self.client: AsyncClient | None = None + self._sync_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} + self._restrict_to_workspace = restrict_to_workspace + self._workspace = workspace.expanduser().resolve() if workspace else None + self._server_upload_limit_bytes: int | None = None + self._server_upload_limit_checked = False + + async def start(self) -> None: + """Start Matrix client and begin sync loop.""" + self._running = True + _configure_nio_logging_bridge() + + store_path = get_data_dir() / "matrix-store" + store_path.mkdir(parents=True, exist_ok=True) + + self.client = AsyncClient( + homeserver=self.config.homeserver, user=self.config.user_id, + store_path=store_path, + config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled), + ) + self.client.user_id = self.config.user_id + self.client.access_token = self.config.access_token + self.client.device_id = self.config.device_id + + self._register_event_callbacks() + self._register_response_callbacks() + + if not self.config.e2ee_enabled: + logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.") + + if self.config.device_id: + try: + self.client.load_store() + except Exception: + logger.exception("Matrix store load failed; restart may replay recent messages.") + else: + logger.warning("Matrix device_id empty; restart may replay recent messages.") + + self._sync_task = asyncio.create_task(self._sync_loop()) + + async def stop(self) -> None: + """Stop the Matrix channel with graceful sync shutdown.""" + self._running = False + for room_id in list(self._typing_tasks): + await self._stop_typing_keepalive(room_id, clear_typing=False) + if self.client: + self.client.stop_sync_forever() + if self._sync_task: + try: + await asyncio.wait_for(asyncio.shield(self._sync_task), + timeout=self.config.sync_stop_grace_seconds) + except (asyncio.TimeoutError, asyncio.CancelledError): + self._sync_task.cancel() + try: + await self._sync_task + except asyncio.CancelledError: + pass + if self.client: + await self.client.close() + + def _is_workspace_path_allowed(self, path: Path) -> bool: + """Check path is inside workspace (when restriction enabled).""" + if not self._restrict_to_workspace or not self._workspace: + return True + try: + path.resolve(strict=False).relative_to(self._workspace) + return True + except ValueError: + return False + + def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]: + """Deduplicate and resolve outbound attachment paths.""" + seen: set[str] = set() + candidates: list[Path] = [] + for raw in media: + if not isinstance(raw, str) or not raw.strip(): + continue + path = Path(raw.strip()).expanduser() + try: + key = str(path.resolve(strict=False)) + except OSError: + key = str(path) + if key not in seen: + seen.add(key) + candidates.append(path) + return candidates + + @staticmethod + def _build_outbound_attachment_content( + *, filename: str, mime: str, size_bytes: int, + mxc_url: str, encryption_info: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Build Matrix content payload for an uploaded file/image/audio/video.""" + prefix = mime.split("/")[0] + msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file") + content: dict[str, Any] = { + "msgtype": msgtype, "body": filename, "filename": filename, + "info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {}, + } + if encryption_info: + content["file"] = {**encryption_info, "url": mxc_url} + else: + content["url"] = mxc_url + return content + + def _is_encrypted_room(self, room_id: str) -> bool: + if not self.client: + return False + room = getattr(self.client, "rooms", {}).get(room_id) + return bool(getattr(room, "encrypted", False)) + + async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None: + """Send m.room.message with E2EE options.""" + if not self.client: + return + kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content} + if self.config.e2ee_enabled: + kwargs["ignore_unverified_devices"] = True + await self.client.room_send(**kwargs) + + async def _resolve_server_upload_limit_bytes(self) -> int | None: + """Query homeserver upload limit once per channel lifecycle.""" + if self._server_upload_limit_checked: + return self._server_upload_limit_bytes + self._server_upload_limit_checked = True + if not self.client: + return None + try: + response = await self.client.content_repository_config() + except Exception: + return None + upload_size = getattr(response, "upload_size", None) + if isinstance(upload_size, int) and upload_size > 0: + self._server_upload_limit_bytes = upload_size + return upload_size + return None + + async def _effective_media_limit_bytes(self) -> int: + """min(local config, server advertised) — 0 blocks all uploads.""" + local_limit = max(int(self.config.max_media_bytes), 0) + server_limit = await self._resolve_server_upload_limit_bytes() + if server_limit is None: + return local_limit + return min(local_limit, server_limit) if local_limit else 0 + + async def _upload_and_send_attachment( + self, room_id: str, path: Path, limit_bytes: int, + relates_to: dict[str, Any] | None = None, + ) -> str | None: + """Upload one local file to Matrix and send it as a media message. Returns failure marker or None.""" + if not self.client: + return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME) + + resolved = path.expanduser().resolve(strict=False) + filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME + fail = _ATTACH_UPLOAD_FAILED.format(filename) + + if not resolved.is_file() or not self._is_workspace_path_allowed(resolved): + return fail + try: + size_bytes = resolved.stat().st_size + except OSError: + return fail + if limit_bytes <= 0 or size_bytes > limit_bytes: + return _ATTACH_TOO_LARGE.format(filename) + + mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream" + try: + with resolved.open("rb") as f: + upload_result = await self.client.upload( + f, content_type=mime, filename=filename, + encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id), + filesize=size_bytes, + ) + except Exception: + return fail + + upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result + encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None + if isinstance(upload_response, UploadError): + return fail + mxc_url = getattr(upload_response, "content_uri", None) + if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"): + return fail + + content = self._build_outbound_attachment_content( + filename=filename, mime=mime, size_bytes=size_bytes, + mxc_url=mxc_url, encryption_info=encryption_info, + ) + if relates_to: + content["m.relates_to"] = relates_to + try: + await self._send_room_content(room_id, content) + except Exception: + return fail + return None + + async def send(self, msg: OutboundMessage) -> None: + """Send outbound content; clear typing for non-progress messages.""" + if not self.client: + return + text = msg.content or "" + candidates = self._collect_outbound_media_candidates(msg.media) + relates_to = self._build_thread_relates_to(msg.metadata) + is_progress = bool((msg.metadata or {}).get("_progress")) + try: + failures: list[str] = [] + if candidates: + limit_bytes = await self._effective_media_limit_bytes() + for path in candidates: + if fail := await self._upload_and_send_attachment( + room_id=msg.chat_id, + path=path, + limit_bytes=limit_bytes, + relates_to=relates_to, + ): + failures.append(fail) + if failures: + text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures) + if text or not candidates: + content = _build_matrix_text_content(text) + if relates_to: + content["m.relates_to"] = relates_to + await self._send_room_content(msg.chat_id, content) + finally: + if not is_progress: + await self._stop_typing_keepalive(msg.chat_id, clear_typing=True) + + def _register_event_callbacks(self) -> None: + self.client.add_event_callback(self._on_message, RoomMessageText) + self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) + self.client.add_event_callback(self._on_room_invite, InviteEvent) + + def _register_response_callbacks(self) -> None: + self.client.add_response_callback(self._on_sync_error, SyncError) + self.client.add_response_callback(self._on_join_error, JoinError) + self.client.add_response_callback(self._on_send_error, RoomSendError) + + def _log_response_error(self, label: str, response: Any) -> None: + """Log Matrix response errors — auth errors at ERROR level, rest at WARNING.""" + code = getattr(response, "status_code", None) + is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"} + is_fatal = is_auth or getattr(response, "soft_logout", False) + (logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response) + + async def _on_sync_error(self, response: SyncError) -> None: + self._log_response_error("sync", response) + + async def _on_join_error(self, response: JoinError) -> None: + self._log_response_error("join", response) + + async def _on_send_error(self, response: RoomSendError) -> None: + self._log_response_error("send", response) + + async def _set_typing(self, room_id: str, typing: bool) -> None: + """Best-effort typing indicator update.""" + if not self.client: + return + try: + response = await self.client.room_typing(room_id=room_id, typing_state=typing, + timeout=TYPING_NOTICE_TIMEOUT_MS) + if isinstance(response, RoomTypingError): + logger.debug("Matrix typing failed for {}: {}", room_id, response) + except Exception: + pass + + async def _start_typing_keepalive(self, room_id: str) -> None: + """Start periodic typing refresh (spec-recommended keepalive).""" + await self._stop_typing_keepalive(room_id, clear_typing=False) + await self._set_typing(room_id, True) + if not self._running: + return + + async def loop() -> None: + try: + while self._running: + await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000) + await self._set_typing(room_id, True) + except asyncio.CancelledError: + pass + + self._typing_tasks[room_id] = asyncio.create_task(loop()) + + async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None: + if task := self._typing_tasks.pop(room_id, None): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + if clear_typing: + await self._set_typing(room_id, False) + + async def _sync_loop(self) -> None: + while self._running: + try: + await self.client.sync_forever(timeout=30000, full_state=True) + except asyncio.CancelledError: + break + except Exception: + await asyncio.sleep(2) + + async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None: + if self.is_allowed(event.sender): + await self.client.join(room.room_id) + + def _is_direct_room(self, room: MatrixRoom) -> bool: + count = getattr(room, "member_count", None) + return isinstance(count, int) and count <= 2 + + def _is_bot_mentioned(self, event: RoomMessage) -> bool: + """Check m.mentions payload for bot mention.""" + source = getattr(event, "source", None) + if not isinstance(source, dict): + return False + mentions = (source.get("content") or {}).get("m.mentions") + if not isinstance(mentions, dict): + return False + user_ids = mentions.get("user_ids") + if isinstance(user_ids, list) and self.config.user_id in user_ids: + return True + return bool(self.config.allow_room_mentions and mentions.get("room") is True) + + def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool: + """Apply sender and room policy checks.""" + if not self.is_allowed(event.sender): + return False + if self._is_direct_room(room): + return True + policy = self.config.group_policy + if policy == "open": + return True + if policy == "allowlist": + return room.room_id in (self.config.group_allow_from or []) + if policy == "mention": + return self._is_bot_mentioned(event) + return False + + def _media_dir(self) -> Path: + d = get_data_dir() / "media" / "matrix" + d.mkdir(parents=True, exist_ok=True) + return d + + @staticmethod + def _event_source_content(event: RoomMessage) -> dict[str, Any]: + source = getattr(event, "source", None) + if not isinstance(source, dict): + return {} + content = source.get("content") + return content if isinstance(content, dict) else {} + + def _event_thread_root_id(self, event: RoomMessage) -> str | None: + relates_to = self._event_source_content(event).get("m.relates_to") + if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread": + return None + root_id = relates_to.get("event_id") + return root_id if isinstance(root_id, str) and root_id else None + + def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None: + if not (root_id := self._event_thread_root_id(event)): + return None + meta: dict[str, str] = {"thread_root_event_id": root_id} + if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to: + meta["thread_reply_to_event_id"] = reply_to + return meta + + @staticmethod + def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None: + if not metadata: + return None + root_id = metadata.get("thread_root_event_id") + if not isinstance(root_id, str) or not root_id: + return None + reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id") + if not isinstance(reply_to, str) or not reply_to: + return None + return {"rel_type": "m.thread", "event_id": root_id, + "m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True} + + def _event_attachment_type(self, event: MatrixMediaEvent) -> str: + msgtype = self._event_source_content(event).get("msgtype") + return _MSGTYPE_MAP.get(msgtype, "file") + + @staticmethod + def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool: + return (isinstance(getattr(event, "key", None), dict) + and isinstance(getattr(event, "hashes", None), dict) + and isinstance(getattr(event, "iv", None), str)) + + def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None: + info = self._event_source_content(event).get("info") + size = info.get("size") if isinstance(info, dict) else None + return size if isinstance(size, int) and size >= 0 else None + + def _event_mime(self, event: MatrixMediaEvent) -> str | None: + info = self._event_source_content(event).get("info") + if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m: + return m + m = getattr(event, "mimetype", None) + return m if isinstance(m, str) and m else None + + def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str: + body = getattr(event, "body", None) + if isinstance(body, str) and body.strip(): + if candidate := safe_filename(Path(body).name): + return candidate + return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type + + def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str, + filename: str, mime: str | None) -> Path: + safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME + suffix = Path(safe_name).suffix + if not suffix and mime: + if guessed := mimetypes.guess_extension(mime, strict=False): + safe_name, suffix = f"{safe_name}{guessed}", guessed + stem = (Path(safe_name).stem or attachment_type)[:72] + suffix = suffix[:16] + event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$")) + event_prefix = (event_id[:24] or "evt").strip("_") + return self._media_dir() / f"{event_prefix}_{stem}{suffix}" + + async def _download_media_bytes(self, mxc_url: str) -> bytes | None: + if not self.client: + return None + response = await self.client.download(mxc=mxc_url) + if isinstance(response, DownloadError): + logger.warning("Matrix download failed for {}: {}", mxc_url, response) + return None + body = getattr(response, "body", None) + if isinstance(body, (bytes, bytearray)): + return bytes(body) + if isinstance(response, MemoryDownloadResponse): + return bytes(response.body) + if isinstance(body, (str, Path)): + path = Path(body) + if path.is_file(): + try: + return path.read_bytes() + except OSError: + return None + return None + + def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None: + key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None) + key = key_obj.get("k") if isinstance(key_obj, dict) else None + sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None + if not all(isinstance(v, str) for v in (key, sha256, iv)): + return None + try: + return decrypt_attachment(ciphertext, key, sha256, iv) + except (EncryptionError, ValueError, TypeError): + logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", "")) + return None + + async def _fetch_media_attachment( + self, room: MatrixRoom, event: MatrixMediaEvent, + ) -> tuple[dict[str, Any] | None, str]: + """Download, decrypt if needed, and persist a Matrix attachment.""" + atype = self._event_attachment_type(event) + mime = self._event_mime(event) + filename = self._event_filename(event, atype) + mxc_url = getattr(event, "url", None) + fail = _ATTACH_FAILED.format(filename) + + if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"): + return None, fail + + limit_bytes = await self._effective_media_limit_bytes() + declared = self._event_declared_size_bytes(event) + if declared is not None and declared > limit_bytes: + return None, _ATTACH_TOO_LARGE.format(filename) + + downloaded = await self._download_media_bytes(mxc_url) + if downloaded is None: + return None, fail + + encrypted = self._is_encrypted_media_event(event) + data = downloaded + if encrypted: + if (data := self._decrypt_media_bytes(event, downloaded)) is None: + return None, fail + + if len(data) > limit_bytes: + return None, _ATTACH_TOO_LARGE.format(filename) + + path = self._build_attachment_path(event, atype, filename, mime) + try: + path.write_bytes(data) + except OSError: + return None, fail + + attachment = { + "type": atype, "mime": mime, "filename": filename, + "event_id": str(getattr(event, "event_id", "") or ""), + "encrypted": encrypted, "size_bytes": len(data), + "path": str(path), "mxc_url": mxc_url, + } + return attachment, _ATTACH_MARKER.format(path) + + def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]: + """Build common metadata for text and media handlers.""" + meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)} + if isinstance(eid := getattr(event, "event_id", None), str) and eid: + meta["event_id"] = eid + if thread := self._thread_metadata(event): + meta.update(thread) + return meta + + async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None: + if event.sender == self.config.user_id or not self._should_process_message(room, event): + return + await self._start_typing_keepalive(room.room_id) + try: + await self._handle_message( + sender_id=event.sender, chat_id=room.room_id, + content=event.body, metadata=self._base_metadata(room, event), + ) + except Exception: + await self._stop_typing_keepalive(room.room_id, clear_typing=True) + raise + + async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None: + if event.sender == self.config.user_id or not self._should_process_message(room, event): + return + attachment, marker = await self._fetch_media_attachment(room, event) + parts: list[str] = [] + if isinstance(body := getattr(event, "body", None), str) and body.strip(): + parts.append(body.strip()) + if marker: + parts.append(marker) + + await self._start_typing_keepalive(room.room_id) + try: + meta = self._base_metadata(room, event) + meta["attachments"] = [] + if attachment: + meta["attachments"] = [attachment] + await self._handle_message( + sender_id=event.sender, chat_id=room.room_id, + content="\n".join(parts), + media=[attachment["path"]] if attachment else [], + metadata=meta, + ) + except Exception: + await self._stop_typing_keepalive(room.room_id, clear_typing=True) + raise diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py index 30c3dbf..e762dfd 100644 --- a/nanobot/channels/mochat.py +++ b/nanobot/channels/mochat.py @@ -322,7 +322,7 @@ class MochatChannel(BaseChannel): await self._api_send("/api/claw/sessions/send", "sessionId", target.id, content, msg.reply_to) except Exception as e: - logger.error(f"Failed to send Mochat message: {e}") + logger.error("Failed to send Mochat message: {}", e) # ---- config / init helpers --------------------------------------------- @@ -380,7 +380,7 @@ class MochatChannel(BaseChannel): @client.event async def connect_error(data: Any) -> None: - logger.error(f"Mochat websocket connect error: {data}") + logger.error("Mochat websocket connect error: {}", data) @client.on("claw.session.events") async def on_session_events(payload: dict[str, Any]) -> None: @@ -407,7 +407,7 @@ class MochatChannel(BaseChannel): ) return True except Exception as e: - logger.error(f"Failed to connect Mochat websocket: {e}") + logger.error("Failed to connect Mochat websocket: {}", e) try: await client.disconnect() except Exception: @@ -444,7 +444,7 @@ class MochatChannel(BaseChannel): "limit": self.config.watch_limit, }) if not ack.get("result"): - logger.error(f"Mochat subscribeSessions failed: {ack.get('message', 'unknown error')}") + logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error')) return False data = ack.get("data") @@ -466,7 +466,7 @@ class MochatChannel(BaseChannel): return True ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids}) if not ack.get("result"): - logger.error(f"Mochat subscribePanels failed: {ack.get('message', 'unknown error')}") + logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error')) return False return True @@ -488,7 +488,7 @@ class MochatChannel(BaseChannel): try: await self._refresh_targets(subscribe_new=self._ws_ready) except Exception as e: - logger.warning(f"Mochat refresh failed: {e}") + logger.warning("Mochat refresh failed: {}", e) if self._fallback_mode: await self._ensure_fallback_workers() @@ -502,7 +502,7 @@ class MochatChannel(BaseChannel): try: response = await self._post_json("/api/claw/sessions/list", {}) except Exception as e: - logger.warning(f"Mochat listSessions failed: {e}") + logger.warning("Mochat listSessions failed: {}", e) return sessions = response.get("sessions") @@ -536,7 +536,7 @@ class MochatChannel(BaseChannel): try: response = await self._post_json("/api/claw/groups/get", {}) except Exception as e: - logger.warning(f"Mochat getWorkspaceGroup failed: {e}") + logger.warning("Mochat getWorkspaceGroup failed: {}", e) return raw_panels = response.get("panels") @@ -598,7 +598,7 @@ class MochatChannel(BaseChannel): except asyncio.CancelledError: break except Exception as e: - logger.warning(f"Mochat watch fallback error ({session_id}): {e}") + logger.warning("Mochat watch fallback error ({}): {}", session_id, e) await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0)) async def _panel_poll_worker(self, panel_id: str) -> None: @@ -625,7 +625,7 @@ class MochatChannel(BaseChannel): except asyncio.CancelledError: break except Exception as e: - logger.warning(f"Mochat panel polling error ({panel_id}): {e}") + logger.warning("Mochat panel polling error ({}): {}", panel_id, e) await asyncio.sleep(sleep_s) # ---- inbound event processing ------------------------------------------ @@ -836,7 +836,7 @@ class MochatChannel(BaseChannel): try: data = json.loads(self._cursor_path.read_text("utf-8")) except Exception as e: - logger.warning(f"Failed to read Mochat cursor file: {e}") + logger.warning("Failed to read Mochat cursor file: {}", e) return cursors = data.get("cursors") if isinstance(data, dict) else None if isinstance(cursors, dict): @@ -852,7 +852,7 @@ class MochatChannel(BaseChannel): "cursors": self._session_cursor, }, ensure_ascii=False, indent=2) + "\n", "utf-8") except Exception as e: - logger.warning(f"Failed to save Mochat cursor file: {e}") + logger.warning("Failed to save Mochat cursor file: {}", e) # ---- HTTP helpers ------------------------------------------------------ diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index c0eb2b1..4809fd3 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -2,7 +2,7 @@ import asyncio from collections import deque -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from loguru import logger @@ -13,7 +13,7 @@ from nanobot.config.schema import QQConfig try: import botpy - from botpy.message import C2CMessage, GroupMessage # 1. Import GroupMessage + from botpy.message import C2CMessage, GroupMessage QQ_AVAILABLE = True except ImportError: @@ -28,27 +28,23 @@ if TYPE_CHECKING: def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": """Create a botpy Client subclass bound to the given channel.""" - # 2. Ensure intents enable public_messages (required for group messages) intents = botpy.Intents(public_messages=True, direct_message=True) class _Bot(botpy.Client): def __init__(self): - super().__init__(intents=intents) + # Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs + super().__init__(intents=intents, ext_handlers=False) async def on_ready(self): - logger.info(f"QQ bot ready: {self.robot.name}") + logger.info("QQ bot ready: {}", self.robot.name) async def on_c2c_message_create(self, message: "C2CMessage"): - # C2C (Private) message await channel._on_message(message, is_group=False) async def on_group_at_message_create(self, message: "GroupMessage"): - # 3. Added: Listen for group @messages - # Note: Official bots only receive messages @mentioning them unless privileged await channel._on_message(message, is_group=True) async def on_direct_message_create(self, message): - # Guild Direct Message await channel._on_message(message, is_group=False) return _Bot @@ -64,10 +60,8 @@ class QQChannel(BaseChannel): self.config: QQConfig = config self._client: "botpy.Client | None" = None self._processed_ids: deque = deque(maxlen=1000) - self._bot_task: asyncio.Task | None = None - # Cache to track if chat_id is a group or individual to select the correct reply API - # Format: {chat_id: "group" | "c2c"} - self._chat_type_cache: Dict[str, str] = {} + self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重 + self._chat_type_cache: dict[str, str] = {} async def start(self) -> None: """Start the QQ bot.""" @@ -82,9 +76,8 @@ class QQChannel(BaseChannel): self._running = True BotClass = _make_bot_class(self) self._client = BotClass() - - self._bot_task = asyncio.create_task(self._run_bot()) logger.info("QQ bot started (C2C & Group supported)") + await self._run_bot() async def _run_bot(self) -> None: """Run the bot connection with auto-reconnect.""" @@ -92,7 +85,7 @@ class QQChannel(BaseChannel): try: await self._client.start(appid=self.config.app_id, secret=self.config.secret) except Exception as e: - logger.warning(f"QQ bot error: {e}") + logger.warning("QQ bot error: {}", e) if self._running: logger.info("Reconnecting QQ bot in 5 seconds...") await asyncio.sleep(5) @@ -100,11 +93,10 @@ class QQChannel(BaseChannel): async def stop(self) -> None: """Stop the QQ bot.""" self._running = False - if self._bot_task: - self._bot_task.cancel() + if self._client: try: - await self._bot_task - except asyncio.CancelledError: + await self._client.close() + except Exception: pass logger.info("QQ bot stopped") @@ -113,29 +105,29 @@ class QQChannel(BaseChannel): if not self._client: logger.warning("QQ client not initialized") return - - # 4. Modified send logic: Check chat_id type to call the correct API - msg_type = self._chat_type_cache.get(msg.chat_id, "c2c") # Default to c2c try: + msg_id = msg.metadata.get("message_id") + self._msg_seq += 1 + msg_type = self._chat_type_cache.get(msg.chat_id, "c2c") if msg_type == "group": - # Send group message await self._client.api.post_group_message( group_openid=msg.chat_id, - msg_type=0, - msg_id=msg.metadata.get("message_id"), # Reply to specific message ID (optional but recommended) - content=msg.content + msg_type=0, + content=msg.content, + msg_id=msg_id, + msg_seq=self._msg_seq, ) else: - # Send C2C (private) message await self._client.api.post_c2c_message( openid=msg.chat_id, msg_type=0, - msg_id=msg.metadata.get("message_id"), content=msg.content, + msg_id=msg_id, + msg_seq=self._msg_seq, ) except Exception as e: - logger.error(f"Error sending QQ message ({msg_type}): {e}") + logger.error("Error sending QQ message: {}", e) async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None: """Handle incoming message from QQ.""" @@ -149,17 +141,11 @@ class QQChannel(BaseChannel): if not content: return - # 5. Extract ID and cache type if is_group: - # Group message: chat_id uses group_openid chat_id = data.group_openid - user_id = data.author.member_openid # Sender's ID + user_id = data.author.member_openid self._chat_type_cache[chat_id] = "group" - - # Remove @bot text (optional, prevents Nanobot from treating the name as prompt) - # content = content.replace("@BotName", "").strip() else: - # Private message: chat_id uses user_openid chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown')) user_id = chat_id self._chat_type_cache[chat_id] = "c2c" @@ -170,5 +156,5 @@ class QQChannel(BaseChannel): content=content, metadata={"message_id": data.id}, ) - except Exception as e: - logger.error(f"Error handling QQ message: {e}") \ No newline at end of file + except Exception: + logger.exception("Error handling QQ message") diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index be95dd2..afd1d2d 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -5,10 +5,11 @@ import re from typing import Any from loguru import logger -from slack_sdk.socket_mode.websockets import SocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse +from slack_sdk.socket_mode.websockets import SocketModeClient from slack_sdk.web.async_client import AsyncWebClient +from slackify_markdown import slackify_markdown from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus @@ -34,7 +35,7 @@ class SlackChannel(BaseChannel): logger.error("Slack bot/app token not configured") return if self.config.mode != "socket": - logger.error(f"Unsupported Slack mode: {self.config.mode}") + logger.error("Unsupported Slack mode: {}", self.config.mode) return self._running = True @@ -51,9 +52,9 @@ class SlackChannel(BaseChannel): try: auth = await self._web_client.auth_test() self._bot_user_id = auth.get("user_id") - logger.info(f"Slack bot connected as {self._bot_user_id}") + logger.info("Slack bot connected as {}", self._bot_user_id) except Exception as e: - logger.warning(f"Slack auth_test failed: {e}") + logger.warning("Slack auth_test failed: {}", e) logger.info("Starting Slack Socket Mode client...") await self._socket_client.connect() @@ -68,7 +69,7 @@ class SlackChannel(BaseChannel): try: await self._socket_client.close() except Exception as e: - logger.warning(f"Slack socket close failed: {e}") + logger.warning("Slack socket close failed: {}", e) self._socket_client = None async def send(self, msg: OutboundMessage) -> None: @@ -82,13 +83,26 @@ class SlackChannel(BaseChannel): channel_type = slack_meta.get("channel_type") # Only reply in thread for channel/group messages; DMs don't use threads use_thread = thread_ts and channel_type != "im" - await self._web_client.chat_postMessage( - channel=msg.chat_id, - text=msg.content or "", - thread_ts=thread_ts if use_thread else None, - ) + thread_ts_param = thread_ts if use_thread else None + + if msg.content: + await self._web_client.chat_postMessage( + channel=msg.chat_id, + text=self._to_mrkdwn(msg.content), + thread_ts=thread_ts_param, + ) + + for media_path in msg.media or []: + try: + await self._web_client.files_upload_v2( + channel=msg.chat_id, + file=media_path, + thread_ts=thread_ts_param, + ) + except Exception as e: + logger.error("Failed to upload file {}: {}", media_path, e) except Exception as e: - logger.error(f"Error sending Slack message: {e}") + logger.error("Error sending Slack message: {}", e) async def _on_socket_request( self, @@ -150,30 +164,39 @@ class SlackChannel(BaseChannel): text = self._strip_bot_mention(text) - thread_ts = event.get("thread_ts") or event.get("ts") + thread_ts = event.get("thread_ts") + if self.config.reply_in_thread and not thread_ts: + thread_ts = event.get("ts") # Add :eyes: reaction to the triggering message (best-effort) try: if self._web_client and event.get("ts"): await self._web_client.reactions_add( channel=chat_id, - name="eyes", + name=self.config.react_emoji, timestamp=event.get("ts"), ) except Exception as e: - logger.debug(f"Slack reactions_add failed: {e}") + logger.debug("Slack reactions_add failed: {}", e) - await self._handle_message( - sender_id=sender_id, - chat_id=chat_id, - content=text, - metadata={ - "slack": { - "event": event, - "thread_ts": thread_ts, - "channel_type": channel_type, - } - }, - ) + # Thread-scoped session key for channel/group messages + session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None + + try: + await self._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content=text, + metadata={ + "slack": { + "event": event, + "thread_ts": thread_ts, + "channel_type": channel_type, + }, + }, + session_key=session_key, + ) + except Exception: + logger.exception("Error handling Slack message from {}", sender_id) def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: if channel_type == "im": @@ -203,3 +226,55 @@ class SlackChannel(BaseChannel): if not text or not self._bot_user_id: return text return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip() + + _TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*") + _CODE_FENCE_RE = re.compile(r"```[\s\S]*?```") + _INLINE_CODE_RE = re.compile(r"`[^`]+`") + _LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*") + _LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE) + _BARE_URL_RE = re.compile(r"(? str: + """Convert Markdown to Slack mrkdwn, including tables.""" + if not text: + return "" + text = cls._TABLE_RE.sub(cls._convert_table, text) + return cls._fixup_mrkdwn(slackify_markdown(text)) + + @classmethod + def _fixup_mrkdwn(cls, text: str) -> str: + """Fix markdown artifacts that slackify_markdown misses.""" + code_blocks: list[str] = [] + + def _save_code(m: re.Match) -> str: + code_blocks.append(m.group(0)) + return f"\x00CB{len(code_blocks) - 1}\x00" + + text = cls._CODE_FENCE_RE.sub(_save_code, text) + text = cls._INLINE_CODE_RE.sub(_save_code, text) + text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text) + text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text) + text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text) + + for i, block in enumerate(code_blocks): + text = text.replace(f"\x00CB{i}\x00", block) + return text + + @staticmethod + def _convert_table(match: re.Match) -> str: + """Convert a Markdown table to a Slack-readable list.""" + lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()] + if len(lines) < 2: + return match.group(0) + headers = [h.strip() for h in lines[0].strip("|").split("|")] + start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1 + rows: list[str] = [] + for line in lines[start:]: + cells = [c.strip() for c in line.strip("|").split("|")] + cells = (cells + [""] * len(headers))[: len(headers)] + parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]] + if parts: + rows.append(" · ".join(parts)) + return "\n".join(rows) + diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 1abd600..81cf0ca 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -4,20 +4,62 @@ from __future__ import annotations import asyncio import re -from typing import TYPE_CHECKING +import time +import unicodedata from loguru import logger -from telegram import BotCommand, Update -from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes +from telegram import BotCommand, ReplyParameters, Update +from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters from telegram.request import HTTPXRequest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import TelegramConfig +from nanobot.utils.helpers import split_message -if TYPE_CHECKING: - from nanobot.session.manager import SessionManager +TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit + + +def _strip_md(s: str) -> str: + """Strip markdown inline formatting from text.""" + s = re.sub(r'\*\*(.+?)\*\*', r'\1', s) + s = re.sub(r'__(.+?)__', r'\1', s) + s = re.sub(r'~~(.+?)~~', r'\1', s) + s = re.sub(r'`([^`]+)`', r'\1', s) + return s.strip() + + +def _render_table_box(table_lines: list[str]) -> str: + """Convert markdown pipe-table to compact aligned text for
 display."""
+
+    def dw(s: str) -> int:
+        return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
+
+    rows: list[list[str]] = []
+    has_sep = False
+    for line in table_lines:
+        cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
+        if all(re.match(r'^:?-+:?$', c) for c in cells if c):
+            has_sep = True
+            continue
+        rows.append(cells)
+    if not rows or not has_sep:
+        return '\n'.join(table_lines)
+
+    ncols = max(len(r) for r in rows)
+    for r in rows:
+        r.extend([''] * (ncols - len(r)))
+    widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
+
+    def dr(cells: list[str]) -> str:
+        return '  '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
+
+    out = [dr(rows[0])]
+    out.append('  '.join('─' * w for w in widths))
+    for row in rows[1:]:
+        out.append(dr(row))
+    return '\n'.join(out)
 
 
 def _markdown_to_telegram_html(text: str) -> str:
@@ -26,275 +68,433 @@ def _markdown_to_telegram_html(text: str) -> str:
     """
     if not text:
         return ""
-    
+
     # 1. Extract and protect code blocks (preserve content from other processing)
     code_blocks: list[str] = []
     def save_code_block(m: re.Match) -> str:
         code_blocks.append(m.group(1))
         return f"\x00CB{len(code_blocks) - 1}\x00"
-    
+
     text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
-    
+
+    # 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
+    lines = text.split('\n')
+    rebuilt: list[str] = []
+    li = 0
+    while li < len(lines):
+        if re.match(r'^\s*\|.+\|', lines[li]):
+            tbl: list[str] = []
+            while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
+                tbl.append(lines[li])
+                li += 1
+            box = _render_table_box(tbl)
+            if box != '\n'.join(tbl):
+                code_blocks.append(box)
+                rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
+            else:
+                rebuilt.extend(tbl)
+        else:
+            rebuilt.append(lines[li])
+            li += 1
+    text = '\n'.join(rebuilt)
+
     # 2. Extract and protect inline code
     inline_codes: list[str] = []
     def save_inline_code(m: re.Match) -> str:
         inline_codes.append(m.group(1))
         return f"\x00IC{len(inline_codes) - 1}\x00"
-    
+
     text = re.sub(r'`([^`]+)`', save_inline_code, text)
-    
+
     # 3. Headers # Title -> just the title text
     text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
-    
+
     # 4. Blockquotes > text -> just the text (before HTML escaping)
     text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
-    
+
     # 5. Escape HTML special characters
     text = text.replace("&", "&").replace("<", "<").replace(">", ">")
-    
+
     # 6. Links [text](url) - must be before bold/italic to handle nested cases
     text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1', text)
-    
+
     # 7. Bold **text** or __text__
     text = re.sub(r'\*\*(.+?)\*\*', r'\1', text)
     text = re.sub(r'__(.+?)__', r'\1', text)
-    
+
     # 8. Italic _text_ (avoid matching inside words like some_var_name)
     text = re.sub(r'(?\1', text)
-    
+
     # 9. Strikethrough ~~text~~
     text = re.sub(r'~~(.+?)~~', r'\1', text)
-    
+
     # 10. Bullet lists - item -> • item
     text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
-    
+
     # 11. Restore inline code with HTML tags
     for i, code in enumerate(inline_codes):
         # Escape HTML in code content
         escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
         text = text.replace(f"\x00IC{i}\x00", f"{escaped}")
-    
+
     # 12. Restore code blocks with HTML tags
     for i, code in enumerate(code_blocks):
         # Escape HTML in code content
         escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
         text = text.replace(f"\x00CB{i}\x00", f"
{escaped}
") - + return text class TelegramChannel(BaseChannel): """ Telegram channel using long polling. - + Simple and reliable - no webhook/public IP needed. """ - + name = "telegram" - + # Commands registered with Telegram's command menu BOT_COMMANDS = [ BotCommand("start", "Start the bot"), - BotCommand("reset", "Reset conversation history"), + BotCommand("new", "Start a new conversation"), + BotCommand("stop", "Stop the current task"), BotCommand("help", "Show available commands"), ] - + def __init__( self, config: TelegramConfig, bus: MessageBus, groq_api_key: str = "", - session_manager: SessionManager | None = None, ): super().__init__(config, bus) self.config: TelegramConfig = config self.groq_api_key = groq_api_key - self.session_manager = session_manager self._app: Application | None = None self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task - + self._media_group_buffers: dict[str, dict] = {} + self._media_group_tasks: dict[str, asyncio.Task] = {} + self._message_threads: dict[tuple[str, int], int] = {} + async def start(self) -> None: """Start the Telegram bot with long polling.""" if not self.config.token: logger.error("Telegram bot token not configured") return - + self._running = True - + # Build the application with larger connection pool to avoid pool-timeout on long runs - req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) + req = HTTPXRequest( + connection_pool_size=16, + pool_timeout=5.0, + connect_timeout=30.0, + read_timeout=30.0, + proxy=self.config.proxy if self.config.proxy else None, + ) builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) - if self.config.proxy: - builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy) self._app = builder.build() self._app.add_error_handler(self._on_error) - + # Add command handlers self._app.add_handler(CommandHandler("start", self._on_start)) - self._app.add_handler(CommandHandler("reset", self._on_reset)) + self._app.add_handler(CommandHandler("new", self._forward_command)) + self._app.add_handler(CommandHandler("stop", self._forward_command)) self._app.add_handler(CommandHandler("help", self._on_help)) - + # Add message handler for text, photos, voice, documents self._app.add_handler( MessageHandler( - (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL) - & ~filters.COMMAND, + (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL) + & ~filters.COMMAND, self._on_message ) ) - + logger.info("Starting Telegram bot (polling mode)...") - + # Initialize and start polling await self._app.initialize() await self._app.start() - + # Get bot info and register command menu bot_info = await self._app.bot.get_me() - logger.info(f"Telegram bot @{bot_info.username} connected") - + logger.info("Telegram bot @{} connected", bot_info.username) + try: await self._app.bot.set_my_commands(self.BOT_COMMANDS) logger.debug("Telegram bot commands registered") except Exception as e: - logger.warning(f"Failed to register bot commands: {e}") - + logger.warning("Failed to register bot commands: {}", e) + # Start polling (this runs until stopped) await self._app.updater.start_polling( allowed_updates=["message"], drop_pending_updates=True # Ignore old messages on startup ) - + # Keep running until stopped while self._running: await asyncio.sleep(1) - + async def stop(self) -> None: """Stop the Telegram bot.""" self._running = False - + # Cancel all typing indicators for chat_id in list(self._typing_tasks): self._stop_typing(chat_id) - + + for task in self._media_group_tasks.values(): + task.cancel() + self._media_group_tasks.clear() + self._media_group_buffers.clear() + if self._app: logger.info("Stopping Telegram bot...") await self._app.updater.stop() await self._app.stop() await self._app.shutdown() self._app = None - + + @staticmethod + def _get_media_type(path: str) -> str: + """Guess media type from file extension.""" + ext = path.rsplit(".", 1)[-1].lower() if "." in path else "" + if ext in ("jpg", "jpeg", "png", "gif", "webp"): + return "photo" + if ext == "ogg": + return "voice" + if ext in ("mp3", "m4a", "wav", "aac"): + return "audio" + return "document" + async def send(self, msg: OutboundMessage) -> None: """Send a message through Telegram.""" if not self._app: logger.warning("Telegram bot not running") return - - # Stop typing indicator for this chat - self._stop_typing(msg.chat_id) - + + # Only stop typing indicator for final responses + if not msg.metadata.get("_progress", False): + self._stop_typing(msg.chat_id) + try: - # chat_id should be the Telegram chat ID (integer) chat_id = int(msg.chat_id) - # Convert markdown to Telegram HTML - html_content = _markdown_to_telegram_html(msg.content) - await self._app.bot.send_message( - chat_id=chat_id, - text=html_content, - parse_mode="HTML" - ) except ValueError: - logger.error(f"Invalid chat_id: {msg.chat_id}") + logger.error("Invalid chat_id: {}", msg.chat_id) + return + reply_to_message_id = msg.metadata.get("message_id") + message_thread_id = msg.metadata.get("message_thread_id") + if message_thread_id is None and reply_to_message_id is not None: + message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id)) + thread_kwargs = {} + if message_thread_id is not None: + thread_kwargs["message_thread_id"] = message_thread_id + + reply_params = None + if self.config.reply_to_message: + if reply_to_message_id: + reply_params = ReplyParameters( + message_id=reply_to_message_id, + allow_sending_without_reply=True + ) + + # Send media files + for media_path in (msg.media or []): + try: + media_type = self._get_media_type(media_path) + sender = { + "photo": self._app.bot.send_photo, + "voice": self._app.bot.send_voice, + "audio": self._app.bot.send_audio, + }.get(media_type, self._app.bot.send_document) + param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document" + with open(media_path, 'rb') as f: + await sender( + chat_id=chat_id, + **{param: f}, + reply_parameters=reply_params, + **thread_kwargs, + ) + except Exception as e: + filename = media_path.rsplit("/", 1)[-1] + logger.error("Failed to send media {}: {}", media_path, e) + await self._app.bot.send_message( + chat_id=chat_id, + text=f"[Failed to send: {filename}]", + reply_parameters=reply_params, + **thread_kwargs, + ) + + # Send text content + if msg.content and msg.content != "[empty message]": + is_progress = msg.metadata.get("_progress", False) + + for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN): + # Final response: simulate streaming via draft, then persist + if not is_progress: + await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs) + else: + await self._send_text(chat_id, chunk, reply_params, thread_kwargs) + + async def _send_text( + self, + chat_id: int, + text: str, + reply_params=None, + thread_kwargs: dict | None = None, + ) -> None: + """Send a plain text message with HTML fallback.""" + try: + html = _markdown_to_telegram_html(text) + await self._app.bot.send_message( + chat_id=chat_id, text=html, parse_mode="HTML", + reply_parameters=reply_params, + **(thread_kwargs or {}), + ) except Exception as e: - # Fallback to plain text if HTML parsing fails - logger.warning(f"HTML parse failed, falling back to plain text: {e}") + logger.warning("HTML parse failed, falling back to plain text: {}", e) try: await self._app.bot.send_message( - chat_id=int(msg.chat_id), - text=msg.content + chat_id=chat_id, + text=text, + reply_parameters=reply_params, + **(thread_kwargs or {}), ) except Exception as e2: - logger.error(f"Error sending Telegram message: {e2}") - + logger.error("Error sending Telegram message: {}", e2) + + async def _send_with_streaming( + self, + chat_id: int, + text: str, + reply_params=None, + thread_kwargs: dict | None = None, + ) -> None: + """Simulate streaming via send_message_draft, then persist with send_message.""" + draft_id = int(time.time() * 1000) % (2**31) + try: + step = max(len(text) // 8, 40) + for i in range(step, len(text), step): + await self._app.bot.send_message_draft( + chat_id=chat_id, draft_id=draft_id, text=text[:i], + ) + await asyncio.sleep(0.04) + await self._app.bot.send_message_draft( + chat_id=chat_id, draft_id=draft_id, text=text, + ) + await asyncio.sleep(0.15) + except Exception: + pass + await self._send_text(chat_id, text, reply_params, thread_kwargs) + async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /start command.""" if not update.message or not update.effective_user: return - + user = update.effective_user await update.message.reply_text( f"👋 Hi {user.first_name}! I'm nanobot.\n\n" "Send me a message and I'll respond!\n" "Type /help to see available commands." ) - - async def _on_reset(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Handle /reset command — clear conversation history.""" - if not update.message or not update.effective_user: - return - - chat_id = str(update.message.chat_id) - session_key = f"{self.name}:{chat_id}" - - if self.session_manager is None: - logger.warning("/reset called but session_manager is not available") - await update.message.reply_text("⚠️ Session management is not available.") - return - - session = self.session_manager.get_or_create(session_key) - msg_count = len(session.messages) - session.clear() - self.session_manager.save(session) - - logger.info(f"Session reset for {session_key} (cleared {msg_count} messages)") - await update.message.reply_text("🔄 Conversation history cleared. Let's start fresh!") - + async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - """Handle /help command — show available commands.""" + """Handle /help command, bypassing ACL so all users can access it.""" if not update.message: return - - help_text = ( - "🐈 nanobot commands\n\n" - "/start — Start the bot\n" - "/reset — Reset conversation history\n" - "/help — Show this help message\n\n" - "Just send me a text message to chat!" + await update.message.reply_text( + "🐈 nanobot commands:\n" + "/new — Start a new conversation\n" + "/stop — Stop the current task\n" + "/help — Show available commands" ) - await update.message.reply_text(help_text, parse_mode="HTML") - + + @staticmethod + def _sender_id(user) -> str: + """Build sender_id with username for allowlist matching.""" + sid = str(user.id) + return f"{sid}|{user.username}" if user.username else sid + + @staticmethod + def _derive_topic_session_key(message) -> str | None: + """Derive topic-scoped session key for non-private Telegram chats.""" + message_thread_id = getattr(message, "message_thread_id", None) + if message.chat.type == "private" or message_thread_id is None: + return None + return f"telegram:{message.chat_id}:topic:{message_thread_id}" + + @staticmethod + def _build_message_metadata(message, user) -> dict: + """Build common Telegram inbound metadata payload.""" + return { + "message_id": message.message_id, + "user_id": user.id, + "username": user.username, + "first_name": user.first_name, + "is_group": message.chat.type != "private", + "message_thread_id": getattr(message, "message_thread_id", None), + "is_forum": bool(getattr(message.chat, "is_forum", False)), + } + + def _remember_thread_context(self, message) -> None: + """Cache topic thread id by chat/message id for follow-up replies.""" + message_thread_id = getattr(message, "message_thread_id", None) + if message_thread_id is None: + return + key = (str(message.chat_id), message.message_id) + self._message_threads[key] = message_thread_id + if len(self._message_threads) > 1000: + self._message_threads.pop(next(iter(self._message_threads))) + + async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Forward slash commands to the bus for unified handling in AgentLoop.""" + if not update.message or not update.effective_user: + return + message = update.message + user = update.effective_user + self._remember_thread_context(message) + await self._handle_message( + sender_id=self._sender_id(user), + chat_id=str(message.chat_id), + content=message.text, + metadata=self._build_message_metadata(message, user), + session_key=self._derive_topic_session_key(message), + ) + async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle incoming messages (text, photos, voice, documents).""" if not update.message or not update.effective_user: return - + message = update.message user = update.effective_user chat_id = message.chat_id - - # Use stable numeric ID, but keep username for allowlist compatibility - sender_id = str(user.id) - if user.username: - sender_id = f"{sender_id}|{user.username}" - + sender_id = self._sender_id(user) + self._remember_thread_context(message) + # Store chat_id for replies self._chat_ids[sender_id] = chat_id - + # Build content from text and/or media content_parts = [] media_paths = [] - + # Text content if message.text: content_parts.append(message.text) if message.caption: content_parts.append(message.caption) - + # Handle media files media_file = None media_type = None - + if message.photo: media_file = message.photo[-1] # Largest photo media_type = "image" @@ -307,77 +507,112 @@ class TelegramChannel(BaseChannel): elif message.document: media_file = message.document media_type = "file" - + # Download media if present if media_file and self._app: try: file = await self._app.bot.get_file(media_file.file_id) - ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None)) - + ext = self._get_extension( + media_type, + getattr(media_file, 'mime_type', None), + getattr(media_file, 'file_name', None), + ) # Save to workspace/media/ from pathlib import Path media_dir = Path.home() / ".nanobot" / "media" media_dir.mkdir(parents=True, exist_ok=True) - + file_path = media_dir / f"{media_file.file_id[:16]}{ext}" await file.download_to_drive(str(file_path)) - + media_paths.append(str(file_path)) - + # Handle voice transcription if media_type == "voice" or media_type == "audio": from nanobot.providers.transcription import GroqTranscriptionProvider transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key) transcription = await transcriber.transcribe(file_path) if transcription: - logger.info(f"Transcribed {media_type}: {transcription[:50]}...") + logger.info("Transcribed {}: {}...", media_type, transcription[:50]) content_parts.append(f"[transcription: {transcription}]") else: content_parts.append(f"[{media_type}: {file_path}]") else: content_parts.append(f"[{media_type}: {file_path}]") - - logger.debug(f"Downloaded {media_type} to {file_path}") + + logger.debug("Downloaded {} to {}", media_type, file_path) except Exception as e: - logger.error(f"Failed to download media: {e}") + logger.error("Failed to download media: {}", e) content_parts.append(f"[{media_type}: download failed]") - + content = "\n".join(content_parts) if content_parts else "[empty message]" - - logger.debug(f"Telegram message from {sender_id}: {content[:50]}...") - + + logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) + str_chat_id = str(chat_id) - + metadata = self._build_message_metadata(message, user) + session_key = self._derive_topic_session_key(message) + + # Telegram media groups: buffer briefly, forward as one aggregated turn. + if media_group_id := getattr(message, "media_group_id", None): + key = f"{str_chat_id}:{media_group_id}" + if key not in self._media_group_buffers: + self._media_group_buffers[key] = { + "sender_id": sender_id, "chat_id": str_chat_id, + "contents": [], "media": [], + "metadata": metadata, + "session_key": session_key, + } + self._start_typing(str_chat_id) + buf = self._media_group_buffers[key] + if content and content != "[empty message]": + buf["contents"].append(content) + buf["media"].extend(media_paths) + if key not in self._media_group_tasks: + self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key)) + return + # Start typing indicator before processing self._start_typing(str_chat_id) - + # Forward to the message bus await self._handle_message( sender_id=sender_id, chat_id=str_chat_id, content=content, media=media_paths, - metadata={ - "message_id": message.message_id, - "user_id": user.id, - "username": user.username, - "first_name": user.first_name, - "is_group": message.chat.type != "private" - } + metadata=metadata, + session_key=session_key, ) - + + async def _flush_media_group(self, key: str) -> None: + """Wait briefly, then forward buffered media-group as one turn.""" + try: + await asyncio.sleep(0.6) + if not (buf := self._media_group_buffers.pop(key, None)): + return + content = "\n".join(buf["contents"]) or "[empty message]" + await self._handle_message( + sender_id=buf["sender_id"], chat_id=buf["chat_id"], + content=content, media=list(dict.fromkeys(buf["media"])), + metadata=buf["metadata"], + session_key=buf.get("session_key"), + ) + finally: + self._media_group_tasks.pop(key, None) + def _start_typing(self, chat_id: str) -> None: """Start sending 'typing...' indicator for a chat.""" # Cancel any existing typing task for this chat self._stop_typing(chat_id) self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) - + def _stop_typing(self, chat_id: str) -> None: """Stop the typing indicator for a chat.""" task = self._typing_tasks.pop(chat_id, None) if task and not task.done(): task.cancel() - + async def _typing_loop(self, chat_id: str) -> None: """Repeatedly send 'typing' action until cancelled.""" try: @@ -387,14 +622,19 @@ class TelegramChannel(BaseChannel): except asyncio.CancelledError: pass except Exception as e: - logger.debug(f"Typing indicator stopped for {chat_id}: {e}") - + logger.debug("Typing indicator stopped for {}: {}", chat_id, e) + async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: """Log polling / handler errors instead of silently swallowing them.""" - logger.error(f"Telegram error: {context.error}") + logger.error("Telegram error: {}", context.error) - def _get_extension(self, media_type: str, mime_type: str | None) -> str: - """Get file extension based on media type.""" + def _get_extension( + self, + media_type: str, + mime_type: str | None, + filename: str | None = None, + ) -> str: + """Get file extension based on media type or original filename.""" if mime_type: ext_map = { "image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", @@ -402,6 +642,14 @@ class TelegramChannel(BaseChannel): } if mime_type in ext_map: return ext_map[mime_type] - + type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""} - return type_map.get(media_type, "") + if ext := type_map.get(media_type, ""): + return ext + + if filename: + from pathlib import Path + + return "".join(Path(filename).suffixes) + + return "" diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index 6e00e9d..1307716 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -2,7 +2,8 @@ import asyncio import json -from typing import Any +import mimetypes +from collections import OrderedDict from loguru import logger @@ -15,131 +16,155 @@ from nanobot.config.schema import WhatsAppConfig class WhatsAppChannel(BaseChannel): """ WhatsApp channel that connects to a Node.js bridge. - + The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol. Communication between Python and Node.js is via WebSocket. """ - + name = "whatsapp" - + def __init__(self, config: WhatsAppConfig, bus: MessageBus): super().__init__(config, bus) self.config: WhatsAppConfig = config self._ws = None self._connected = False - + self._processed_message_ids: OrderedDict[str, None] = OrderedDict() + async def start(self) -> None: """Start the WhatsApp channel by connecting to the bridge.""" import websockets - + bridge_url = self.config.bridge_url - - logger.info(f"Connecting to WhatsApp bridge at {bridge_url}...") - + + logger.info("Connecting to WhatsApp bridge at {}...", bridge_url) + self._running = True - + while self._running: try: async with websockets.connect(bridge_url) as ws: self._ws = ws + # Send auth token if configured + if self.config.bridge_token: + await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token})) self._connected = True logger.info("Connected to WhatsApp bridge") - + # Listen for messages async for message in ws: try: await self._handle_bridge_message(message) except Exception as e: - logger.error(f"Error handling bridge message: {e}") - + logger.error("Error handling bridge message: {}", e) + except asyncio.CancelledError: break except Exception as e: self._connected = False self._ws = None - logger.warning(f"WhatsApp bridge connection error: {e}") - + logger.warning("WhatsApp bridge connection error: {}", e) + if self._running: logger.info("Reconnecting in 5 seconds...") await asyncio.sleep(5) - + async def stop(self) -> None: """Stop the WhatsApp channel.""" self._running = False self._connected = False - + if self._ws: await self._ws.close() self._ws = None - + async def send(self, msg: OutboundMessage) -> None: """Send a message through WhatsApp.""" if not self._ws or not self._connected: logger.warning("WhatsApp bridge not connected") return - + try: payload = { "type": "send", "to": msg.chat_id, "text": msg.content } - await self._ws.send(json.dumps(payload)) + await self._ws.send(json.dumps(payload, ensure_ascii=False)) except Exception as e: - logger.error(f"Error sending WhatsApp message: {e}") - + logger.error("Error sending WhatsApp message: {}", e) + async def _handle_bridge_message(self, raw: str) -> None: """Handle a message from the bridge.""" try: data = json.loads(raw) except json.JSONDecodeError: - logger.warning(f"Invalid JSON from bridge: {raw[:100]}") + logger.warning("Invalid JSON from bridge: {}", raw[:100]) return - + msg_type = data.get("type") - + if msg_type == "message": # Incoming message from WhatsApp # Deprecated by whatsapp: old phone number style typically: @s.whatspp.net pn = data.get("pn", "") - # New LID sytle typically: + # New LID sytle typically: sender = data.get("sender", "") content = data.get("content", "") - + message_id = data.get("id", "") + + if message_id: + if message_id in self._processed_message_ids: + return + self._processed_message_ids[message_id] = None + while len(self._processed_message_ids) > 1000: + self._processed_message_ids.popitem(last=False) + # Extract just the phone number or lid as chat_id user_id = pn if pn else sender sender_id = user_id.split("@")[0] if "@" in user_id else user_id - logger.info(f"Sender {sender}") - + logger.info("Sender {}", sender) + # Handle voice transcription if it's a voice message if content == "[Voice Message]": - logger.info(f"Voice message received from {sender_id}, but direct download from bridge is not yet supported.") + logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id) content = "[Voice Message: Transcription not available for WhatsApp yet]" - + + # Extract media paths (images/documents/videos downloaded by the bridge) + media_paths = data.get("media") or [] + + # Build content tags matching Telegram's pattern: [image: /path] or [file: /path] + if media_paths: + for p in media_paths: + mime, _ = mimetypes.guess_type(p) + media_type = "image" if mime and mime.startswith("image/") else "file" + media_tag = f"[{media_type}: {p}]" + content = f"{content}\n{media_tag}" if content else media_tag + await self._handle_message( sender_id=sender_id, chat_id=sender, # Use full LID for replies content=content, + media=media_paths, metadata={ - "message_id": data.get("id"), + "message_id": message_id, "timestamp": data.get("timestamp"), "is_group": data.get("isGroup", False) } ) - + elif msg_type == "status": # Connection status update status = data.get("status") - logger.info(f"WhatsApp status: {status}") - + logger.info("WhatsApp status: {}", status) + if status == "connected": self._connected = True elif status == "disconnected": self._connected = False - + elif msg_type == "qr": # QR code for authentication logger.info("Scan QR code in the bridge terminal to connect WhatsApp") - + elif msg_type == "error": - logger.error(f"WhatsApp bridge error: {data.get('error')}") + logger.error("WhatsApp bridge error: {}", data.get('error')) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index aa99d55..ca5d8d7 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -2,23 +2,36 @@ import asyncio import os -import signal -from pathlib import Path import select +import signal import sys +from pathlib import Path + +# Force UTF-8 encoding for Windows console +if sys.platform == "win32": + import locale + if sys.stdout.encoding != "utf-8": + os.environ["PYTHONIOENCODING"] = "utf-8" + # Re-open stdout/stderr with UTF-8 encoding + try: + sys.stdout.reconfigure(encoding="utf-8", errors="replace") + sys.stderr.reconfigure(encoding="utf-8", errors="replace") + except Exception: + pass import typer +from prompt_toolkit import PromptSession +from prompt_toolkit.formatted_text import HTML +from prompt_toolkit.history import FileHistory +from prompt_toolkit.patch_stdout import patch_stdout from rich.console import Console from rich.markdown import Markdown from rich.table import Table from rich.text import Text -from prompt_toolkit import PromptSession -from prompt_toolkit.formatted_text import HTML -from prompt_toolkit.history import FileHistory -from prompt_toolkit.patch_stdout import patch_stdout - -from nanobot import __version__, __logo__ +from nanobot import __logo__, __version__ +from nanobot.config.schema import Config +from nanobot.utils.helpers import sync_workspace_templates app = typer.Typer( name="nanobot", @@ -155,29 +168,37 @@ def main( @app.command() def onboard(): """Initialize nanobot configuration and workspace.""" - from nanobot.config.loader import get_config_path, save_config + from nanobot.config.loader import get_config_path, load_config, save_config from nanobot.config.schema import Config from nanobot.utils.helpers import get_workspace_path - + config_path = get_config_path() - + if config_path.exists(): console.print(f"[yellow]Config already exists at {config_path}[/yellow]") - if not typer.confirm("Overwrite?"): - raise typer.Exit() - - # Create default config - config = Config() - save_config(config) - console.print(f"[green]✓[/green] Created config at {config_path}") - + console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)") + console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields") + if typer.confirm("Overwrite?"): + config = Config() + save_config(config) + console.print(f"[green]✓[/green] Config reset to defaults at {config_path}") + else: + config = load_config() + save_config(config) + console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)") + else: + save_config(Config()) + console.print(f"[green]✓[/green] Created config at {config_path}") + # Create workspace workspace = get_workspace_path() - console.print(f"[green]✓[/green] Created workspace at {workspace}") - - # Create default bootstrap files - _create_workspace_templates(workspace) - + + if not workspace.exists(): + workspace.mkdir(parents=True, exist_ok=True) + console.print(f"[green]✓[/green] Created workspace at {workspace}") + + sync_workspace_templates(workspace) + console.print(f"\n{__logo__} nanobot is ready!") console.print("\nNext steps:") console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]") @@ -188,97 +209,57 @@ def onboard(): -def _create_workspace_templates(workspace: Path): - """Create default workspace template files.""" - templates = { - "AGENTS.md": """# Agent Instructions -You are a helpful AI assistant. Be concise, accurate, and friendly. +def _make_provider(config: Config): + """Create the appropriate LLM provider from config.""" + from nanobot.providers.openai_codex_provider import OpenAICodexProvider + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider -## Guidelines - -- Always explain what you're doing before taking actions -- Ask for clarification when the request is ambiguous -- Use tools to help accomplish tasks -- Remember important information in your memory files -""", - "SOUL.md": """# Soul - -I am nanobot, a lightweight AI assistant. - -## Personality - -- Helpful and friendly -- Concise and to the point -- Curious and eager to learn - -## Values - -- Accuracy over speed -- User privacy and safety -- Transparency in actions -""", - "USER.md": """# User - -Information about the user goes here. - -## Preferences - -- Communication style: (casual/formal) -- Timezone: (your timezone) -- Language: (your preferred language) -""", - } - - for filename, content in templates.items(): - file_path = workspace / filename - if not file_path.exists(): - file_path.write_text(content) - console.print(f" [dim]Created {filename}[/dim]") - - # Create memory directory and MEMORY.md - memory_dir = workspace / "memory" - memory_dir.mkdir(exist_ok=True) - memory_file = memory_dir / "MEMORY.md" - if not memory_file.exists(): - memory_file.write_text("""# Long-term Memory - -This file stores important information that should persist across sessions. - -## User Information - -(Important facts about the user) - -## Preferences - -(User preferences learned over time) - -## Important Notes - -(Things to remember) -""") - console.print(" [dim]Created memory/MEMORY.md[/dim]") - - # Create skills directory for custom user skills - skills_dir = workspace / "skills" - skills_dir.mkdir(exist_ok=True) - - -def _make_provider(config): - """Create LiteLLMProvider from config. Exits if no API key found.""" - from nanobot.providers.litellm_provider import LiteLLMProvider - p = config.get_provider() model = config.agents.defaults.model - if not (p and p.api_key) and not model.startswith("bedrock/"): + provider_name = config.get_provider_name(model) + p = config.get_provider(model) + + # OpenAI Codex (OAuth) + if provider_name == "openai_codex" or model.startswith("openai-codex/"): + return OpenAICodexProvider(default_model=model) + + # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM + from nanobot.providers.custom_provider import CustomProvider + if provider_name == "custom": + return CustomProvider( + api_key=p.api_key if p else "no-key", + api_base=config.get_api_base(model) or "http://localhost:8000/v1", + default_model=model, + ) + + # Azure OpenAI: direct Azure OpenAI endpoint with deployment name + if provider_name == "azure_openai": + if not p or not p.api_key or not p.api_base: + console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") + console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") + console.print("Use the model field to specify the deployment name.") + raise typer.Exit(1) + + return AzureOpenAIProvider( + api_key=p.api_key, + api_base=p.api_base, + default_model=model, + ) + + from nanobot.providers.litellm_provider import LiteLLMProvider + from nanobot.providers.registry import find_by_name + spec = find_by_name(provider_name) + if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth): console.print("[red]Error: No API key configured.[/red]") console.print("Set one in ~/.nanobot/config.json under providers section") raise typer.Exit(1) + return LiteLLMProvider( api_key=p.api_key if p else None, - api_base=config.get_api_base(), + api_base=config.get_api_base(model), default_model=model, extra_headers=p.extra_headers if p else None, - provider_name=config.get_provider_name(), + provider_name=provider_name, ) @@ -290,92 +271,167 @@ def _make_provider(config): @app.command() def gateway( port: int = typer.Option(18790, "--port", "-p", help="Gateway port"), + workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), + config: str | None = typer.Option(None, "--config", "-c", help="Config file path"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), ): """Start the nanobot gateway.""" - from nanobot.config.loader import load_config, get_data_dir - from nanobot.bus.queue import MessageBus from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus from nanobot.channels.manager import ChannelManager - from nanobot.session.manager import SessionManager + from nanobot.config.loader import load_config from nanobot.cron.service import CronService from nanobot.cron.types import CronJob from nanobot.heartbeat.service import HeartbeatService - + from nanobot.session.manager import SessionManager + if verbose: import logging logging.basicConfig(level=logging.DEBUG) - + + config_path = Path(config) if config else None + config = load_config(config_path) + if workspace: + config.agents.defaults.workspace = workspace + console.print(f"{__logo__} Starting nanobot gateway on port {port}...") - - config = load_config() + sync_workspace_templates(config.workspace_path) bus = MessageBus() provider = _make_provider(config) session_manager = SessionManager(config.workspace_path) - + # Create cron service first (callback set after agent creation) - cron_store_path = get_data_dir() / "cron" / "jobs.json" + # Use workspace path for per-instance cron store + cron_store_path = config.workspace_path / "cron" / "jobs.json" cron = CronService(cron_store_path) - + # Create agent with cron service agent = AgentLoop( bus=bus, provider=provider, workspace=config.workspace_path, model=config.agents.defaults.model, + temperature=config.agents.defaults.temperature, + max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, + memory_window=config.agents.defaults.memory_window, + reasoning_effort=config.agents.defaults.reasoning_effort, brave_api_key=config.tools.web.search.api_key or None, + web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, cron_service=cron, restrict_to_workspace=config.tools.restrict_to_workspace, session_manager=session_manager, + mcp_servers=config.tools.mcp_servers, + channels_config=config.channels, ) - + # Set cron callback (needs agent) async def on_cron_job(job: CronJob) -> str | None: """Execute a cron job through the agent.""" - response = await agent.process_direct( - job.payload.message, - session_key=f"cron:{job.id}", - channel=job.payload.channel or "cli", - chat_id=job.payload.to or "direct", + from nanobot.agent.tools.cron import CronTool + from nanobot.agent.tools.message import MessageTool + reminder_note = ( + "[Scheduled Task] Timer finished.\n\n" + f"Task '{job.name}' has been triggered.\n" + f"Scheduled instruction: {job.payload.message}" ) - if job.payload.deliver and job.payload.to: + + # Prevent the agent from scheduling new cron jobs during execution + cron_tool = agent.tools.get("cron") + cron_token = None + if isinstance(cron_tool, CronTool): + cron_token = cron_tool.set_cron_context(True) + try: + response = await agent.process_direct( + reminder_note, + session_key=f"cron:{job.id}", + channel=job.payload.channel or "cli", + chat_id=job.payload.to or "direct", + ) + finally: + if isinstance(cron_tool, CronTool) and cron_token is not None: + cron_tool.reset_cron_context(cron_token) + + message_tool = agent.tools.get("message") + if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: + return response + + if job.payload.deliver and job.payload.to and response: from nanobot.bus.events import OutboundMessage await bus.publish_outbound(OutboundMessage( channel=job.payload.channel or "cli", chat_id=job.payload.to, - content=response or "" + content=response )) return response cron.on_job = on_cron_job - + + # Create channel manager + channels = ChannelManager(config, bus) + + def _pick_heartbeat_target() -> tuple[str, str]: + """Pick a routable channel/chat target for heartbeat-triggered messages.""" + enabled = set(channels.enabled_channels) + # Prefer the most recently updated non-internal session on an enabled channel. + for item in session_manager.list_sessions(): + key = item.get("key") or "" + if ":" not in key: + continue + channel, chat_id = key.split(":", 1) + if channel in {"cli", "system"}: + continue + if channel in enabled and chat_id: + return channel, chat_id + # Fallback keeps prior behavior but remains explicit. + return "cli", "direct" + # Create heartbeat service - async def on_heartbeat(prompt: str) -> str: - """Execute heartbeat through the agent.""" - return await agent.process_direct(prompt, session_key="heartbeat") - + async def on_heartbeat_execute(tasks: str) -> str: + """Phase 2: execute heartbeat tasks through the full agent loop.""" + channel, chat_id = _pick_heartbeat_target() + + async def _silent(*_args, **_kwargs): + pass + + return await agent.process_direct( + tasks, + session_key="heartbeat", + channel=channel, + chat_id=chat_id, + on_progress=_silent, + ) + + async def on_heartbeat_notify(response: str) -> None: + """Deliver a heartbeat response to the user's channel.""" + from nanobot.bus.events import OutboundMessage + channel, chat_id = _pick_heartbeat_target() + if channel == "cli": + return # No external channel available to deliver to + await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response)) + + hb_cfg = config.gateway.heartbeat heartbeat = HeartbeatService( workspace=config.workspace_path, - on_heartbeat=on_heartbeat, - interval_s=30 * 60, # 30 minutes - enabled=True + provider=provider, + model=agent.model, + on_execute=on_heartbeat_execute, + on_notify=on_heartbeat_notify, + interval_s=hb_cfg.interval_s, + enabled=hb_cfg.enabled, ) - - # Create channel manager - channels = ChannelManager(config, bus, session_manager=session_manager) - + if channels.enabled_channels: console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}") else: console.print("[yellow]Warning: No channels enabled[/yellow]") - + cron_status = cron.status() if cron_status["jobs"] > 0: console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs") - - console.print(f"[green]✓[/green] Heartbeat: every 30m") - + + console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s") + async def run(): try: await cron.start() @@ -386,11 +442,13 @@ def gateway( ) except KeyboardInterrupt: console.print("\nShutting down...") + finally: + await agent.close_mcp() heartbeat.stop() cron.stop() agent.stop() await channels.stop_all() - + asyncio.run(run()) @@ -404,35 +462,52 @@ def gateway( @app.command() def agent( message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"), - session_id: str = typer.Option("cli:default", "--session", "-s", help="Session ID"), + session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"), markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"), logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"), ): """Interact with the agent directly.""" - from nanobot.config.loader import load_config - from nanobot.bus.queue import MessageBus - from nanobot.agent.loop import AgentLoop from loguru import logger - + + from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus + from nanobot.config.loader import get_data_dir, load_config + from nanobot.cron.service import CronService + config = load_config() - + sync_workspace_templates(config.workspace_path) + bus = MessageBus() provider = _make_provider(config) + # Create cron service for tool usage (no callback needed for CLI unless running) + cron_store_path = get_data_dir() / "cron" / "jobs.json" + cron = CronService(cron_store_path) + if logs: logger.enable("nanobot") else: logger.disable("nanobot") - + agent_loop = AgentLoop( bus=bus, provider=provider, workspace=config.workspace_path, + model=config.agents.defaults.model, + temperature=config.agents.defaults.temperature, + max_tokens=config.agents.defaults.max_tokens, + max_iterations=config.agents.defaults.max_tool_iterations, + memory_window=config.agents.defaults.memory_window, + reasoning_effort=config.agents.defaults.reasoning_effort, brave_api_key=config.tools.web.search.api_key or None, + web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, + cron_service=cron, restrict_to_workspace=config.tools.restrict_to_workspace, + mcp_servers=config.tools.mcp_servers, + channels_config=config.channels, ) - + # Show spinner when logs are off (no output to miss); skip when logs are on def _thinking_ctx(): if logs: @@ -441,52 +516,126 @@ def agent( # Animated spinner is safe to use with prompt_toolkit input handling return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots") + async def _cli_progress(content: str, *, tool_hint: bool = False) -> None: + ch = agent_loop.channels_config + if ch and tool_hint and not ch.send_tool_hints: + return + if ch and not tool_hint and not ch.send_progress: + return + console.print(f" [dim]↳ {content}[/dim]") + if message: - # Single message mode + # Single message mode — direct call, no bus needed async def run_once(): with _thinking_ctx(): - response = await agent_loop.process_direct(message, session_id) + response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress) _print_agent_response(response, render_markdown=markdown) - + await agent_loop.close_mcp() + asyncio.run(run_once()) else: - # Interactive mode + # Interactive mode — route through bus like other channels + from nanobot.bus.events import InboundMessage _init_prompt_session() console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n") - def _exit_on_sigint(signum, frame): + if ":" in session_id: + cli_channel, cli_chat_id = session_id.split(":", 1) + else: + cli_channel, cli_chat_id = "cli", session_id + + def _handle_signal(signum, frame): + sig_name = signal.Signals(signum).name _restore_terminal() - console.print("\nGoodbye!") - os._exit(0) + console.print(f"\nReceived {sig_name}, goodbye!") + sys.exit(0) + + signal.signal(signal.SIGINT, _handle_signal) + signal.signal(signal.SIGTERM, _handle_signal) + # SIGHUP is not available on Windows + if hasattr(signal, 'SIGHUP'): + signal.signal(signal.SIGHUP, _handle_signal) + # Ignore SIGPIPE to prevent silent process termination when writing to closed pipes + # SIGPIPE is not available on Windows + if hasattr(signal, 'SIGPIPE'): + signal.signal(signal.SIGPIPE, signal.SIG_IGN) - signal.signal(signal.SIGINT, _exit_on_sigint) - async def run_interactive(): - while True: - try: - _flush_pending_tty_input() - user_input = await _read_interactive_input_async() - command = user_input.strip() - if not command: - continue + bus_task = asyncio.create_task(agent_loop.run()) + turn_done = asyncio.Event() + turn_done.set() + turn_response: list[str] = [] - if _is_exit_command(command): + async def _consume_outbound(): + while True: + try: + msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + if msg.metadata.get("_progress"): + is_tool_hint = msg.metadata.get("_tool_hint", False) + ch = agent_loop.channels_config + if ch and is_tool_hint and not ch.send_tool_hints: + pass + elif ch and not is_tool_hint and not ch.send_progress: + pass + else: + console.print(f" [dim]↳ {msg.content}[/dim]") + elif not turn_done.is_set(): + if msg.content: + turn_response.append(msg.content) + turn_done.set() + elif msg.content: + console.print() + _print_agent_response(msg.content, render_markdown=markdown) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + outbound_task = asyncio.create_task(_consume_outbound()) + + try: + while True: + try: + _flush_pending_tty_input() + user_input = await _read_interactive_input_async() + command = user_input.strip() + if not command: + continue + + if _is_exit_command(command): + _restore_terminal() + console.print("\nGoodbye!") + break + + turn_done.clear() + turn_response.clear() + + await bus.publish_inbound(InboundMessage( + channel=cli_channel, + sender_id="user", + chat_id=cli_chat_id, + content=user_input, + )) + + with _thinking_ctx(): + await turn_done.wait() + + if turn_response: + _print_agent_response(turn_response[0], render_markdown=markdown) + except KeyboardInterrupt: _restore_terminal() console.print("\nGoodbye!") break - - with _thinking_ctx(): - response = await agent_loop.process_direct(user_input, session_id) - _print_agent_response(response, render_markdown=markdown) - except KeyboardInterrupt: - _restore_terminal() - console.print("\nGoodbye!") - break - except EOFError: - _restore_terminal() - console.print("\nGoodbye!") - break - + except EOFError: + _restore_terminal() + console.print("\nGoodbye!") + break + finally: + agent_loop.stop() + outbound_task.cancel() + await asyncio.gather(bus_task, outbound_task, return_exceptions=True) + await agent_loop.close_mcp() + asyncio.run(run_interactive()) @@ -543,7 +692,7 @@ def channels_status(): "✓" if mc.enabled else "✗", mc_base ) - + # Telegram tg = config.channels.telegram tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]" @@ -562,6 +711,33 @@ def channels_status(): slack_config ) + # DingTalk + dt = config.channels.dingtalk + dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]" + table.add_row( + "DingTalk", + "✓" if dt.enabled else "✗", + dt_config + ) + + # QQ + qq = config.channels.qq + qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]" + table.add_row( + "QQ", + "✓" if qq.enabled else "✗", + qq_config + ) + + # Email + em = config.channels.email + em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]" + table.add_row( + "Email", + "✓" if em.enabled else "✗", + em_config + ) + console.print(table) @@ -569,57 +745,57 @@ def _get_bridge_dir() -> Path: """Get the bridge directory, setting it up if needed.""" import shutil import subprocess - + # User's bridge location user_bridge = Path.home() / ".nanobot" / "bridge" - + # Check if already built if (user_bridge / "dist" / "index.js").exists(): return user_bridge - + # Check for npm if not shutil.which("npm"): console.print("[red]npm not found. Please install Node.js >= 18.[/red]") raise typer.Exit(1) - + # Find source bridge: first check package data, then source dir pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed) src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev) - + source = None if (pkg_bridge / "package.json").exists(): source = pkg_bridge elif (src_bridge / "package.json").exists(): source = src_bridge - + if not source: console.print("[red]Bridge source not found.[/red]") console.print("Try reinstalling: pip install --force-reinstall nanobot") raise typer.Exit(1) - + console.print(f"{__logo__} Setting up bridge...") - + # Copy to user directory user_bridge.parent.mkdir(parents=True, exist_ok=True) if user_bridge.exists(): shutil.rmtree(user_bridge) shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist")) - + # Install and build try: console.print(" Installing dependencies...") subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True) - + console.print(" Building...") subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True) - + console.print("[green]✓[/green] Bridge ready\n") except subprocess.CalledProcessError as e: console.print(f"[red]Build failed: {e}[/red]") if e.stderr: console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]") raise typer.Exit(1) - + return user_bridge @@ -627,177 +803,27 @@ def _get_bridge_dir() -> Path: def channels_login(): """Link device via QR code.""" import subprocess - + + from nanobot.config.loader import load_config + + config = load_config() bridge_dir = _get_bridge_dir() - + console.print(f"{__logo__} Starting bridge...") console.print("Scan the QR code to connect.\n") - + + env = {**os.environ} + if config.channels.whatsapp.bridge_token: + env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token + try: - subprocess.run(["npm", "start"], cwd=bridge_dir, check=True) + subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env) except subprocess.CalledProcessError as e: console.print(f"[red]Bridge failed: {e}[/red]") except FileNotFoundError: console.print("[red]npm not found. Please install Node.js.[/red]") -# ============================================================================ -# Cron Commands -# ============================================================================ - -cron_app = typer.Typer(help="Manage scheduled tasks") -app.add_typer(cron_app, name="cron") - - -@cron_app.command("list") -def cron_list( - all: bool = typer.Option(False, "--all", "-a", help="Include disabled jobs"), -): - """List scheduled jobs.""" - from nanobot.config.loader import get_data_dir - from nanobot.cron.service import CronService - - store_path = get_data_dir() / "cron" / "jobs.json" - service = CronService(store_path) - - jobs = service.list_jobs(include_disabled=all) - - if not jobs: - console.print("No scheduled jobs.") - return - - table = Table(title="Scheduled Jobs") - table.add_column("ID", style="cyan") - table.add_column("Name") - table.add_column("Schedule") - table.add_column("Status") - table.add_column("Next Run") - - import time - for job in jobs: - # Format schedule - if job.schedule.kind == "every": - sched = f"every {(job.schedule.every_ms or 0) // 1000}s" - elif job.schedule.kind == "cron": - sched = job.schedule.expr or "" - else: - sched = "one-time" - - # Format next run - next_run = "" - if job.state.next_run_at_ms: - next_time = time.strftime("%Y-%m-%d %H:%M", time.localtime(job.state.next_run_at_ms / 1000)) - next_run = next_time - - status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]" - - table.add_row(job.id, job.name, sched, status, next_run) - - console.print(table) - - -@cron_app.command("add") -def cron_add( - name: str = typer.Option(..., "--name", "-n", help="Job name"), - message: str = typer.Option(..., "--message", "-m", help="Message for agent"), - every: int = typer.Option(None, "--every", "-e", help="Run every N seconds"), - cron_expr: str = typer.Option(None, "--cron", "-c", help="Cron expression (e.g. '0 9 * * *')"), - at: str = typer.Option(None, "--at", help="Run once at time (ISO format)"), - deliver: bool = typer.Option(False, "--deliver", "-d", help="Deliver response to channel"), - to: str = typer.Option(None, "--to", help="Recipient for delivery"), - channel: str = typer.Option(None, "--channel", help="Channel for delivery (e.g. 'telegram', 'whatsapp')"), -): - """Add a scheduled job.""" - from nanobot.config.loader import get_data_dir - from nanobot.cron.service import CronService - from nanobot.cron.types import CronSchedule - - # Determine schedule type - if every: - schedule = CronSchedule(kind="every", every_ms=every * 1000) - elif cron_expr: - schedule = CronSchedule(kind="cron", expr=cron_expr) - elif at: - import datetime - dt = datetime.datetime.fromisoformat(at) - schedule = CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000)) - else: - console.print("[red]Error: Must specify --every, --cron, or --at[/red]") - raise typer.Exit(1) - - store_path = get_data_dir() / "cron" / "jobs.json" - service = CronService(store_path) - - job = service.add_job( - name=name, - schedule=schedule, - message=message, - deliver=deliver, - to=to, - channel=channel, - ) - - console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})") - - -@cron_app.command("remove") -def cron_remove( - job_id: str = typer.Argument(..., help="Job ID to remove"), -): - """Remove a scheduled job.""" - from nanobot.config.loader import get_data_dir - from nanobot.cron.service import CronService - - store_path = get_data_dir() / "cron" / "jobs.json" - service = CronService(store_path) - - if service.remove_job(job_id): - console.print(f"[green]✓[/green] Removed job {job_id}") - else: - console.print(f"[red]Job {job_id} not found[/red]") - - -@cron_app.command("enable") -def cron_enable( - job_id: str = typer.Argument(..., help="Job ID"), - disable: bool = typer.Option(False, "--disable", help="Disable instead of enable"), -): - """Enable or disable a job.""" - from nanobot.config.loader import get_data_dir - from nanobot.cron.service import CronService - - store_path = get_data_dir() / "cron" / "jobs.json" - service = CronService(store_path) - - job = service.enable_job(job_id, enabled=not disable) - if job: - status = "disabled" if disable else "enabled" - console.print(f"[green]✓[/green] Job '{job.name}' {status}") - else: - console.print(f"[red]Job {job_id} not found[/red]") - - -@cron_app.command("run") -def cron_run( - job_id: str = typer.Argument(..., help="Job ID to run"), - force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"), -): - """Manually run a job.""" - from nanobot.config.loader import get_data_dir - from nanobot.cron.service import CronService - - store_path = get_data_dir() / "cron" / "jobs.json" - service = CronService(store_path) - - async def run(): - return await service.run_job(job_id, force=force) - - if asyncio.run(run()): - console.print(f"[green]✓[/green] Job executed") - else: - console.print(f"[red]Failed to run job {job_id}[/red]") - - # ============================================================================ # Status Commands # ============================================================================ @@ -806,7 +832,7 @@ def cron_run( @app.command() def status(): """Show nanobot status.""" - from nanobot.config.loader import load_config, get_config_path + from nanobot.config.loader import get_config_path, load_config config_path = get_config_path() config = load_config() @@ -821,13 +847,15 @@ def status(): from nanobot.providers.registry import PROVIDERS console.print(f"Model: {config.agents.defaults.model}") - + # Check API keys from registry for spec in PROVIDERS: p = getattr(config.providers, spec.name, None) if p is None: continue - if spec.is_local: + if spec.is_oauth: + console.print(f"{spec.label}: [green]✓ (OAuth)[/green]") + elif spec.is_local: # Local deployments show api_base instead of api_key if p.api_base: console.print(f"{spec.label}: [green]✓ {p.api_base}[/green]") @@ -838,5 +866,88 @@ def status(): console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}") +# ============================================================================ +# OAuth Login +# ============================================================================ + +provider_app = typer.Typer(help="Manage providers") +app.add_typer(provider_app, name="provider") + + +_LOGIN_HANDLERS: dict[str, callable] = {} + + +def _register_login(name: str): + def decorator(fn): + _LOGIN_HANDLERS[name] = fn + return fn + return decorator + + +@provider_app.command("login") +def provider_login( + provider: str = typer.Argument(..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"), +): + """Authenticate with an OAuth provider.""" + from nanobot.providers.registry import PROVIDERS + + key = provider.replace("-", "_") + spec = next((s for s in PROVIDERS if s.name == key and s.is_oauth), None) + if not spec: + names = ", ".join(s.name.replace("_", "-") for s in PROVIDERS if s.is_oauth) + console.print(f"[red]Unknown OAuth provider: {provider}[/red] Supported: {names}") + raise typer.Exit(1) + + handler = _LOGIN_HANDLERS.get(spec.name) + if not handler: + console.print(f"[red]Login not implemented for {spec.label}[/red]") + raise typer.Exit(1) + + console.print(f"{__logo__} OAuth Login - {spec.label}\n") + handler() + + +@_register_login("openai_codex") +def _login_openai_codex() -> None: + try: + from oauth_cli_kit import get_token, login_oauth_interactive + token = None + try: + token = get_token() + except Exception: + pass + if not (token and token.access): + console.print("[cyan]Starting interactive OAuth login...[/cyan]\n") + token = login_oauth_interactive( + print_fn=lambda s: console.print(s), + prompt_fn=lambda s: typer.prompt(s), + ) + if not (token and token.access): + console.print("[red]✗ Authentication failed[/red]") + raise typer.Exit(1) + console.print(f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]") + except ImportError: + console.print("[red]oauth_cli_kit not installed. Run: pip install oauth-cli-kit[/red]") + raise typer.Exit(1) + + +@_register_login("github_copilot") +def _login_github_copilot() -> None: + import asyncio + + console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") + + async def _trigger(): + from litellm import acompletion + await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1) + + try: + asyncio.run(_trigger()) + console.print("[green]✓ Authenticated with GitHub Copilot[/green]") + except Exception as e: + console.print(f"[red]Authentication error: {e}[/red]") + raise typer.Exit(1) + + if __name__ == "__main__": app() diff --git a/nanobot/config/__init__.py b/nanobot/config/__init__.py index 88e8e9b..6c59668 100644 --- a/nanobot/config/__init__.py +++ b/nanobot/config/__init__.py @@ -1,6 +1,6 @@ """Configuration module for nanobot.""" -from nanobot.config.loader import load_config, get_config_path +from nanobot.config.loader import get_config_path, load_config from nanobot.config.schema import Config __all__ = ["Config", "load_config", "get_config_path"] diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py index fd7d1e8..c789efd 100644 --- a/nanobot/config/loader.py +++ b/nanobot/config/loader.py @@ -2,7 +2,6 @@ import json from pathlib import Path -from typing import Any from nanobot.config.schema import Config @@ -21,45 +20,43 @@ def get_data_dir() -> Path: def load_config(config_path: Path | None = None) -> Config: """ Load configuration from file or create default. - + Args: config_path: Optional path to config file. Uses default if not provided. - + Returns: Loaded configuration object. """ path = config_path or get_config_path() - + if path.exists(): try: - with open(path) as f: + with open(path, encoding="utf-8") as f: data = json.load(f) data = _migrate_config(data) - return Config.model_validate(convert_keys(data)) + return Config.model_validate(data) except (json.JSONDecodeError, ValueError) as e: print(f"Warning: Failed to load config from {path}: {e}") print("Using default configuration.") - + return Config() def save_config(config: Config, config_path: Path | None = None) -> None: """ Save configuration to file. - + Args: config: Configuration to save. config_path: Optional path to save to. Uses default if not provided. """ path = config_path or get_config_path() path.parent.mkdir(parents=True, exist_ok=True) - - # Convert to camelCase format - data = config.model_dump() - data = convert_to_camel(data) - - with open(path, "w") as f: - json.dump(data, f, indent=2) + + data = config.model_dump(by_alias=True) + + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) def _migrate_config(data: dict) -> dict: @@ -70,37 +67,3 @@ def _migrate_config(data: dict) -> dict: if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools: tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace") return data - - -def convert_keys(data: Any) -> Any: - """Convert camelCase keys to snake_case for Pydantic.""" - if isinstance(data, dict): - return {camel_to_snake(k): convert_keys(v) for k, v in data.items()} - if isinstance(data, list): - return [convert_keys(item) for item in data] - return data - - -def convert_to_camel(data: Any) -> Any: - """Convert snake_case keys to camelCase.""" - if isinstance(data, dict): - return {snake_to_camel(k): convert_to_camel(v) for k, v in data.items()} - if isinstance(data, list): - return [convert_to_camel(item) for item in data] - return data - - -def camel_to_snake(name: str) -> str: - """Convert camelCase to snake_case.""" - result = [] - for i, char in enumerate(name): - if char.isupper() and i > 0: - result.append("_") - result.append(char.lower()) - return "".join(result) - - -def snake_to_camel(name: str) -> str: - """Convert snake_case to camelCase.""" - components = name.split("_") - return components[0] + "".join(x.title() for x in components[1:]) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 19feba4..803cb61 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -1,53 +1,98 @@ """Configuration schema using Pydantic.""" from pathlib import Path -from pydantic import BaseModel, Field, ConfigDict +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field +from pydantic.alias_generators import to_camel from pydantic_settings import BaseSettings -class WhatsAppConfig(BaseModel): +class Base(BaseModel): + """Base model that accepts both camelCase and snake_case keys.""" + + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True) + + +class WhatsAppConfig(Base): """WhatsApp channel configuration.""" + enabled: bool = False bridge_url: str = "ws://localhost:3001" + bridge_token: str = "" # Shared token for bridge auth (optional, recommended) allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers -class TelegramConfig(BaseModel): +class TelegramConfig(Base): """Telegram channel configuration.""" + enabled: bool = False token: str = "" # Bot token from @BotFather allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames - proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" + proxy: str | None = ( + None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" + ) + reply_to_message: bool = False # If true, bot replies quote the original message -class FeishuConfig(BaseModel): +class FeishuConfig(Base): """Feishu/Lark channel configuration using WebSocket long connection.""" + enabled: bool = False app_id: str = "" # App ID from Feishu Open Platform app_secret: str = "" # App Secret from Feishu Open Platform encrypt_key: str = "" # Encrypt Key for event subscription (optional) verification_token: str = "" # Verification Token for event subscription (optional) allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids + react_emoji: str = ( + "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE) + ) -class DingTalkConfig(BaseModel): +class DingTalkConfig(Base): """DingTalk channel configuration using Stream mode.""" + enabled: bool = False client_id: str = "" # AppKey client_secret: str = "" # AppSecret allow_from: list[str] = Field(default_factory=list) # Allowed staff_ids -class DiscordConfig(BaseModel): +class DiscordConfig(Base): """Discord channel configuration.""" + enabled: bool = False token: str = "" # Bot token from Discord Developer Portal allow_from: list[str] = Field(default_factory=list) # Allowed user IDs gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json" intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT + group_policy: Literal["mention", "open"] = "mention" -class EmailConfig(BaseModel): + +class MatrixConfig(Base): + """Matrix (Element) channel configuration.""" + + enabled: bool = False + homeserver: str = "https://matrix.org" + access_token: str = "" + user_id: str = "" # @bot:matrix.org + device_id: str = "" + e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling). + sync_stop_grace_seconds: int = ( + 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback. + ) + max_media_bytes: int = ( + 20 * 1024 * 1024 + ) # Max attachment size accepted for Matrix media handling (inbound + outbound). + allow_from: list[str] = Field(default_factory=list) + group_policy: Literal["open", "mention", "allowlist"] = "open" + group_allow_from: list[str] = Field(default_factory=list) + allow_room_mentions: bool = False + + +class EmailConfig(Base): """Email channel configuration (IMAP inbound + SMTP outbound).""" + enabled: bool = False consent_granted: bool = False # Explicit owner permission to access mailbox data @@ -69,7 +114,9 @@ class EmailConfig(BaseModel): from_address: str = "" # Behavior - auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent + auto_reply_enabled: bool = ( + True # If false, inbound email is read but no automatic reply is sent + ) poll_interval_seconds: int = 30 mark_seen: bool = True max_body_chars: int = 12000 @@ -77,18 +124,21 @@ class EmailConfig(BaseModel): allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses -class MochatMentionConfig(BaseModel): +class MochatMentionConfig(Base): """Mochat mention behavior configuration.""" + require_in_groups: bool = False -class MochatGroupRule(BaseModel): +class MochatGroupRule(Base): """Mochat per-group mention requirement.""" + require_mention: bool = False -class MochatConfig(BaseModel): +class MochatConfig(Base): """Mochat channel configuration.""" + enabled: bool = False base_url: str = "https://mochat.io" socket_url: str = "" @@ -113,36 +163,49 @@ class MochatConfig(BaseModel): reply_delay_ms: int = 120000 -class SlackDMConfig(BaseModel): +class SlackDMConfig(Base): """Slack DM policy configuration.""" + enabled: bool = True policy: str = "open" # "open" or "allowlist" allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs -class SlackConfig(BaseModel): +class SlackConfig(Base): """Slack channel configuration.""" + enabled: bool = False mode: str = "socket" # "socket" supported webhook_path: str = "/slack/events" bot_token: str = "" # xoxb-... app_token: str = "" # xapp-... user_token_read_only: bool = True + reply_in_thread: bool = True + react_emoji: str = "eyes" + allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs (sender-level) group_policy: str = "mention" # "mention", "open", "allowlist" group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist dm: SlackDMConfig = Field(default_factory=SlackDMConfig) -class QQConfig(BaseModel): +class QQConfig(Base): """QQ channel configuration using botpy SDK.""" + enabled: bool = False app_id: str = "" # 机器人 ID (AppID) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com - allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access) + allow_from: list[str] = Field( + default_factory=list + ) # Allowed user openids (empty = public access) -class ChannelsConfig(BaseModel): + + +class ChannelsConfig(Base): """Configuration for chat channels.""" + + send_progress: bool = True # stream agent's text progress to the channel + send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig) telegram: TelegramConfig = Field(default_factory=TelegramConfig) discord: DiscordConfig = Field(default_factory=DiscordConfig) @@ -152,31 +215,43 @@ class ChannelsConfig(BaseModel): email: EmailConfig = Field(default_factory=EmailConfig) slack: SlackConfig = Field(default_factory=SlackConfig) qq: QQConfig = Field(default_factory=QQConfig) + matrix: MatrixConfig = Field(default_factory=MatrixConfig) -class AgentDefaults(BaseModel): +class AgentDefaults(Base): """Default agent configuration.""" + workspace: str = "~/.nanobot/workspace" model: str = "anthropic/claude-opus-4-5" + provider: str = ( + "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection + ) max_tokens: int = 8192 - temperature: float = 0.7 - max_tool_iterations: int = 20 + temperature: float = 0.1 + max_tool_iterations: int = 40 + memory_window: int = 100 + reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode -class AgentsConfig(BaseModel): +class AgentsConfig(Base): """Agent configuration.""" + defaults: AgentDefaults = Field(default_factory=AgentDefaults) -class ProviderConfig(BaseModel): +class ProviderConfig(Base): """LLM provider configuration.""" + api_key: str = "" api_base: str | None = None extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix) -class ProvidersConfig(BaseModel): +class ProvidersConfig(Base): """Configuration for LLM providers.""" + + custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint + azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name) anthropic: ProviderConfig = Field(default_factory=ProviderConfig) openai: ProviderConfig = Field(default_factory=ProviderConfig) openrouter: ProviderConfig = Field(default_factory=ProviderConfig) @@ -189,63 +264,124 @@ class ProvidersConfig(BaseModel): moonshot: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig) aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway + siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) + volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) + openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) + github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) -class GatewayConfig(BaseModel): +class HeartbeatConfig(Base): + """Heartbeat service configuration.""" + + enabled: bool = True + interval_s: int = 30 * 60 # 30 minutes + + +class GatewayConfig(Base): """Gateway/server configuration.""" + host: str = "0.0.0.0" port: int = 18790 + heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig) -class WebSearchConfig(BaseModel): +class WebSearchConfig(Base): """Web search tool configuration.""" + api_key: str = "" # Brave Search API key max_results: int = 5 -class WebToolsConfig(BaseModel): +class WebToolsConfig(Base): """Web tools configuration.""" + + proxy: str | None = ( + None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" + ) search: WebSearchConfig = Field(default_factory=WebSearchConfig) -class ExecToolConfig(BaseModel): +class ExecToolConfig(Base): """Shell exec tool configuration.""" + timeout: int = 60 + path_append: str = "" -class ToolsConfig(BaseModel): +class MCPServerConfig(Base): + """MCP server connection configuration (stdio or HTTP).""" + + type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted + command: str = "" # Stdio: command to run (e.g. "npx") + args: list[str] = Field(default_factory=list) # Stdio: command arguments + env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars + url: str = "" # HTTP/SSE: endpoint URL + headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers + tool_timeout: int = 30 # seconds before a tool call is cancelled + + +class ToolsConfig(Base): """Tools configuration.""" + web: WebToolsConfig = Field(default_factory=WebToolsConfig) exec: ExecToolConfig = Field(default_factory=ExecToolConfig) restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory + mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict) class Config(BaseSettings): """Root configuration for nanobot.""" + agents: AgentsConfig = Field(default_factory=AgentsConfig) channels: ChannelsConfig = Field(default_factory=ChannelsConfig) providers: ProvidersConfig = Field(default_factory=ProvidersConfig) gateway: GatewayConfig = Field(default_factory=GatewayConfig) tools: ToolsConfig = Field(default_factory=ToolsConfig) - + @property def workspace_path(self) -> Path: """Get expanded workspace path.""" return Path(self.agents.defaults.workspace).expanduser() - - def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]: + + def _match_provider( + self, model: str | None = None + ) -> tuple["ProviderConfig | None", str | None]: """Match provider config and its registry name. Returns (config, spec_name).""" from nanobot.providers.registry import PROVIDERS + + forced = self.agents.defaults.provider + if forced != "auto": + p = getattr(self.providers, forced, None) + return (p, forced) if p else (None, None) + model_lower = (model or self.agents.defaults.model).lower() + model_normalized = model_lower.replace("-", "_") + model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else "" + normalized_prefix = model_prefix.replace("-", "_") + + def _kw_matches(kw: str) -> bool: + kw = kw.lower() + return kw in model_lower or kw.replace("-", "_") in model_normalized + + # Explicit provider prefix wins — prevents `github-copilot/...codex` matching openai_codex. + for spec in PROVIDERS: + p = getattr(self.providers, spec.name, None) + if p and model_prefix and normalized_prefix == spec.name: + if spec.is_oauth or p.api_key: + return p, spec.name # Match by keyword (order follows PROVIDERS registry) for spec in PROVIDERS: p = getattr(self.providers, spec.name, None) - if p and any(kw in model_lower for kw in spec.keywords) and p.api_key: - return p, spec.name + if p and any(_kw_matches(kw) for kw in spec.keywords): + if spec.is_oauth or p.api_key: + return p, spec.name # Fallback: gateways first, then others (follows registry order) + # OAuth providers are NOT valid fallbacks — they require explicit model selection for spec in PROVIDERS: + if spec.is_oauth: + continue p = getattr(self.providers, spec.name, None) if p and p.api_key: return p, spec.name @@ -265,10 +401,11 @@ class Config(BaseSettings): """Get API key for the given model. Falls back to first available key.""" p = self.get_provider(model) return p.api_key if p else None - + def get_api_base(self, model: str | None = None) -> str | None: """Get API base URL for the given model. Applies default URLs for known gateways.""" from nanobot.providers.registry import find_by_name + p, name = self._match_provider(model) if p and p.api_base: return p.api_base @@ -280,8 +417,5 @@ class Config(BaseSettings): if spec and spec.is_gateway and spec.default_api_base: return spec.default_api_base return None - - model_config = ConfigDict( - env_prefix="NANOBOT_", - env_nested_delimiter="__" - ) + + model_config = ConfigDict(env_prefix="NANOBOT_", env_nested_delimiter="__") diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index d1965a9..1ed71f0 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -4,6 +4,7 @@ import asyncio import json import time import uuid +from datetime import datetime from pathlib import Path from typing import Any, Callable, Coroutine @@ -20,47 +21,73 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None: """Compute next run time in ms.""" if schedule.kind == "at": return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None - + if schedule.kind == "every": if not schedule.every_ms or schedule.every_ms <= 0: return None # Next interval from now return now_ms + schedule.every_ms - + if schedule.kind == "cron" and schedule.expr: try: + from zoneinfo import ZoneInfo + from croniter import croniter - cron = croniter(schedule.expr, time.time()) - next_time = cron.get_next() - return int(next_time * 1000) + # Use caller-provided reference time for deterministic scheduling + base_time = now_ms / 1000 + tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo + base_dt = datetime.fromtimestamp(base_time, tz=tz) + cron = croniter(schedule.expr, base_dt) + next_dt = cron.get_next(datetime) + return int(next_dt.timestamp() * 1000) except Exception: return None - + return None +def _validate_schedule_for_add(schedule: CronSchedule) -> None: + """Validate schedule fields that would otherwise create non-runnable jobs.""" + if schedule.tz and schedule.kind != "cron": + raise ValueError("tz can only be used with cron schedules") + + if schedule.kind == "cron" and schedule.tz: + try: + from zoneinfo import ZoneInfo + + ZoneInfo(schedule.tz) + except Exception: + raise ValueError(f"unknown timezone '{schedule.tz}'") from None + + class CronService: """Service for managing and executing scheduled jobs.""" - + def __init__( self, store_path: Path, on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None ): self.store_path = store_path - self.on_job = on_job # Callback to execute job, returns response text + self.on_job = on_job self._store: CronStore | None = None + self._last_mtime: float = 0.0 self._timer_task: asyncio.Task | None = None self._running = False - + def _load_store(self) -> CronStore: - """Load jobs from disk.""" + """Load jobs from disk. Reloads automatically if file was modified externally.""" + if self._store and self.store_path.exists(): + mtime = self.store_path.stat().st_mtime + if mtime != self._last_mtime: + logger.info("Cron: jobs.json modified externally, reloading") + self._store = None if self._store: return self._store - + if self.store_path.exists(): try: - data = json.loads(self.store_path.read_text()) + data = json.loads(self.store_path.read_text(encoding="utf-8")) jobs = [] for j in data.get("jobs", []): jobs.append(CronJob( @@ -93,20 +120,20 @@ class CronService: )) self._store = CronStore(jobs=jobs) except Exception as e: - logger.warning(f"Failed to load cron store: {e}") + logger.warning("Failed to load cron store: {}", e) self._store = CronStore() else: self._store = CronStore() - + return self._store - + def _save_store(self) -> None: """Save jobs to disk.""" if not self._store: return - + self.store_path.parent.mkdir(parents=True, exist_ok=True) - + data = { "version": self._store.version, "jobs": [ @@ -141,8 +168,9 @@ class CronService: for j in self._store.jobs ] } - - self.store_path.write_text(json.dumps(data, indent=2)) + + self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8") + self._last_mtime = self.store_path.stat().st_mtime async def start(self) -> None: """Start the cron service.""" @@ -151,15 +179,15 @@ class CronService: self._recompute_next_runs() self._save_store() self._arm_timer() - logger.info(f"Cron service started with {len(self._store.jobs if self._store else [])} jobs") - + logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else [])) + def stop(self) -> None: """Stop the cron service.""" self._running = False if self._timer_task: self._timer_task.cancel() self._timer_task = None - + def _recompute_next_runs(self) -> None: """Recompute next run times for all enabled jobs.""" if not self._store: @@ -168,73 +196,74 @@ class CronService: for job in self._store.jobs: if job.enabled: job.state.next_run_at_ms = _compute_next_run(job.schedule, now) - + def _get_next_wake_ms(self) -> int | None: """Get the earliest next run time across all jobs.""" if not self._store: return None - times = [j.state.next_run_at_ms for j in self._store.jobs + times = [j.state.next_run_at_ms for j in self._store.jobs if j.enabled and j.state.next_run_at_ms] return min(times) if times else None - + def _arm_timer(self) -> None: """Schedule the next timer tick.""" if self._timer_task: self._timer_task.cancel() - + next_wake = self._get_next_wake_ms() if not next_wake or not self._running: return - + delay_ms = max(0, next_wake - _now_ms()) delay_s = delay_ms / 1000 - + async def tick(): await asyncio.sleep(delay_s) if self._running: await self._on_timer() - + self._timer_task = asyncio.create_task(tick()) - + async def _on_timer(self) -> None: """Handle timer tick - run due jobs.""" + self._load_store() if not self._store: return - + now = _now_ms() due_jobs = [ j for j in self._store.jobs if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms ] - + for job in due_jobs: await self._execute_job(job) - + self._save_store() self._arm_timer() - + async def _execute_job(self, job: CronJob) -> None: """Execute a single job.""" start_ms = _now_ms() - logger.info(f"Cron: executing job '{job.name}' ({job.id})") - + logger.info("Cron: executing job '{}' ({})", job.name, job.id) + try: response = None if self.on_job: response = await self.on_job(job) - + job.state.last_status = "ok" job.state.last_error = None - logger.info(f"Cron: job '{job.name}' completed") - + logger.info("Cron: job '{}' completed", job.name) + except Exception as e: job.state.last_status = "error" job.state.last_error = str(e) - logger.error(f"Cron: job '{job.name}' failed: {e}") - + logger.error("Cron: job '{}' failed: {}", job.name, e) + job.state.last_run_at_ms = start_ms job.updated_at_ms = _now_ms() - + # Handle one-shot jobs if job.schedule.kind == "at": if job.delete_after_run: @@ -245,15 +274,15 @@ class CronService: else: # Compute next run job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms()) - + # ========== Public API ========== - + def list_jobs(self, include_disabled: bool = False) -> list[CronJob]: """List all jobs.""" store = self._load_store() jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled] return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf')) - + def add_job( self, name: str, @@ -266,8 +295,9 @@ class CronService: ) -> CronJob: """Add a new job.""" store = self._load_store() + _validate_schedule_for_add(schedule) now = _now_ms() - + job = CronJob( id=str(uuid.uuid4())[:8], name=name, @@ -285,28 +315,28 @@ class CronService: updated_at_ms=now, delete_after_run=delete_after_run, ) - + store.jobs.append(job) self._save_store() self._arm_timer() - - logger.info(f"Cron: added job '{name}' ({job.id})") + + logger.info("Cron: added job '{}' ({})", name, job.id) return job - + def remove_job(self, job_id: str) -> bool: """Remove a job by ID.""" store = self._load_store() before = len(store.jobs) store.jobs = [j for j in store.jobs if j.id != job_id] removed = len(store.jobs) < before - + if removed: self._save_store() self._arm_timer() - logger.info(f"Cron: removed job {job_id}") - + logger.info("Cron: removed job {}", job_id) + return removed - + def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None: """Enable or disable a job.""" store = self._load_store() @@ -322,7 +352,7 @@ class CronService: self._arm_timer() return job return None - + async def run_job(self, job_id: str, force: bool = False) -> bool: """Manually run a job.""" store = self._load_store() @@ -335,7 +365,7 @@ class CronService: self._arm_timer() return True return False - + def status(self) -> dict: """Get service status.""" store = self._load_store() diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py index 221ed27..e534017 100644 --- a/nanobot/heartbeat/service.py +++ b/nanobot/heartbeat/service.py @@ -1,92 +1,130 @@ """Heartbeat service - periodic agent wake-up to check for tasks.""" +from __future__ import annotations + import asyncio from pathlib import Path -from typing import Any, Callable, Coroutine +from typing import TYPE_CHECKING, Any, Callable, Coroutine from loguru import logger -# Default interval: 30 minutes -DEFAULT_HEARTBEAT_INTERVAL_S = 30 * 60 +if TYPE_CHECKING: + from nanobot.providers.base import LLMProvider -# The prompt sent to agent during heartbeat -HEARTBEAT_PROMPT = """Read HEARTBEAT.md in your workspace (if it exists). -Follow any instructions or tasks listed there. -If nothing needs attention, reply with just: HEARTBEAT_OK""" - -# Token that indicates "nothing to do" -HEARTBEAT_OK_TOKEN = "HEARTBEAT_OK" - - -def _is_heartbeat_empty(content: str | None) -> bool: - """Check if HEARTBEAT.md has no actionable content.""" - if not content: - return True - - # Lines to skip: empty, headers, HTML comments, empty checkboxes - skip_patterns = {"- [ ]", "* [ ]", "- [x]", "* [x]"} - - for line in content.split("\n"): - line = line.strip() - if not line or line.startswith("#") or line.startswith("