Merge remote-tracking branch 'origin/main' into pr-420
This commit is contained in:
123
README.md
123
README.md
@@ -16,10 +16,17 @@
|
|||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||||
|
|
||||||
📏 Real-time line count: **3,761 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: **3,966 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
||||||
|
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
||||||
|
- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements.
|
||||||
|
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
||||||
|
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
||||||
|
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
||||||
|
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
|
||||||
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
|
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
|
||||||
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
|
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
|
||||||
- **2026-02-15** 🔑 nanobot now supports OpenAI Codex provider with OAuth login support.
|
- **2026-02-15** 🔑 nanobot now supports OpenAI Codex provider with OAuth login support.
|
||||||
@@ -27,13 +34,13 @@
|
|||||||
- **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
|
- **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
|
||||||
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
|
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
|
||||||
- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
|
- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
|
||||||
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
|
||||||
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
|
||||||
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Earlier news</summary>
|
<summary>Earlier news</summary>
|
||||||
|
|
||||||
|
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
||||||
|
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
||||||
|
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
||||||
- **2026-02-07** 🚀 Released **v0.1.3.post5** with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details.
|
- **2026-02-07** 🚀 Released **v0.1.3.post5** with Qwen support & several key improvements! Check [here](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post5) for details.
|
||||||
- **2026-02-06** ✨ Added Moonshot/Kimi provider, Discord integration, and enhanced security hardening!
|
- **2026-02-06** ✨ Added Moonshot/Kimi provider, Discord integration, and enhanced security hardening!
|
||||||
- **2026-02-05** ✨ Added Feishu channel, DeepSeek provider, and enhanced scheduled tasks support!
|
- **2026-02-05** ✨ Added Feishu channel, DeepSeek provider, and enhanced scheduled tasks support!
|
||||||
@@ -131,12 +138,13 @@ Add or merge these **two parts** into your config (other options have defaults).
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
*Set your model*:
|
*Set your model* (optionally pin a provider — defaults to auto-detection):
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"agents": {
|
"agents": {
|
||||||
"defaults": {
|
"defaults": {
|
||||||
"model": "anthropic/claude-opus-4-5"
|
"model": "anthropic/claude-opus-4-5",
|
||||||
|
"provider": "openrouter"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -645,6 +653,7 @@ Config file: `~/.nanobot/config.json`
|
|||||||
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
||||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||||
|
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
||||||
|
|
||||||
| Provider | Purpose | Get API Key |
|
| Provider | Purpose | Get API Key |
|
||||||
|----------|---------|-------------|
|
|----------|---------|-------------|
|
||||||
@@ -655,9 +664,10 @@ Config file: `~/.nanobot/config.json`
|
|||||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||||
| `minimax` | LLM (MiniMax direct) | [platform.minimax.io](https://platform.minimax.io) |
|
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
||||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||||
| `siliconflow` | LLM (SiliconFlow/硅基流动, API gateway) | [siliconflow.cn](https://siliconflow.cn) |
|
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
||||||
|
| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) |
|
||||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||||
@@ -818,6 +828,12 @@ Add MCP servers to your `config.json`:
|
|||||||
"filesystem": {
|
"filesystem": {
|
||||||
"command": "npx",
|
"command": "npx",
|
||||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"]
|
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"]
|
||||||
|
},
|
||||||
|
"my-remote-mcp": {
|
||||||
|
"url": "https://example.com/mcp/",
|
||||||
|
"headers": {
|
||||||
|
"Authorization": "Bearer xxxxx"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -829,7 +845,22 @@ Two transport modes are supported:
|
|||||||
| Mode | Config | Example |
|
| Mode | Config | Example |
|
||||||
|------|--------|---------|
|
|------|--------|---------|
|
||||||
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
|
| **Stdio** | `command` + `args` | Local process via `npx` / `uvx` |
|
||||||
| **HTTP** | `url` | Remote endpoint (`https://mcp.example.com/sse`) |
|
| **HTTP** | `url` + `headers` (optional) | Remote endpoint (`https://mcp.example.com/sse`) |
|
||||||
|
|
||||||
|
Use `toolTimeout` to override the default 30s per-call timeout for slow servers:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": {
|
||||||
|
"mcpServers": {
|
||||||
|
"my-slow-server": {
|
||||||
|
"url": "https://example.com/mcp/",
|
||||||
|
"toolTimeout": 120
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
MCP tools are automatically discovered and registered on startup. The LLM can use them alongside built-in tools — no extra configuration needed.
|
||||||
|
|
||||||
@@ -844,6 +875,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
|||||||
| Option | Default | Description |
|
| Option | Default | Description |
|
||||||
|--------|---------|-------------|
|
|--------|---------|-------------|
|
||||||
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
|
||||||
|
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
|
||||||
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
|
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
|
||||||
|
|
||||||
|
|
||||||
@@ -881,6 +913,26 @@ nanobot cron remove <job_id>
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Heartbeat (Periodic Tasks)</b></summary>
|
||||||
|
|
||||||
|
The gateway wakes up every 30 minutes and checks `HEARTBEAT.md` in your workspace (`~/.nanobot/workspace/HEARTBEAT.md`). If the file has tasks, the agent executes them and delivers results to your most recently active chat channel.
|
||||||
|
|
||||||
|
**Setup:** edit `~/.nanobot/workspace/HEARTBEAT.md` (created automatically by `nanobot onboard`):
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Periodic Tasks
|
||||||
|
|
||||||
|
- [ ] Check weather forecast and send a summary
|
||||||
|
- [ ] Scan inbox for urgent emails
|
||||||
|
```
|
||||||
|
|
||||||
|
The agent can also manage this file itself — ask it to "add a periodic task" and it will update `HEARTBEAT.md` for you.
|
||||||
|
|
||||||
|
> **Note:** The gateway must be running (`nanobot gateway`) and you must have chatted with the bot at least once so it knows which channel to deliver to.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 🐳 Docker
|
## 🐳 Docker
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
@@ -920,6 +972,59 @@ docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!"
|
|||||||
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
|
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🐧 Linux Service
|
||||||
|
|
||||||
|
Run the gateway as a systemd user service so it starts automatically and restarts on failure.
|
||||||
|
|
||||||
|
**1. Find the nanobot binary path:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
which nanobot # e.g. /home/user/.local/bin/nanobot
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Create the service file** at `~/.config/systemd/user/nanobot-gateway.service` (replace `ExecStart` path if needed):
|
||||||
|
|
||||||
|
```ini
|
||||||
|
[Unit]
|
||||||
|
Description=Nanobot Gateway
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
ExecStart=%h/.local/bin/nanobot gateway
|
||||||
|
Restart=always
|
||||||
|
RestartSec=10
|
||||||
|
NoNewPrivileges=yes
|
||||||
|
ProtectSystem=strict
|
||||||
|
ReadWritePaths=%h
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=default.target
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Enable and start:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl --user daemon-reload
|
||||||
|
systemctl --user enable --now nanobot-gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
**Common operations:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
systemctl --user status nanobot-gateway # check status
|
||||||
|
systemctl --user restart nanobot-gateway # restart after config changes
|
||||||
|
journalctl --user -u nanobot-gateway -f # follow logs
|
||||||
|
```
|
||||||
|
|
||||||
|
If you edit the `.service` file itself, run `systemctl --user daemon-reload` before restarting.
|
||||||
|
|
||||||
|
> **Note:** User services only run while you are logged in. To keep the gateway running after logout, enable lingering:
|
||||||
|
>
|
||||||
|
> ```bash
|
||||||
|
> loginctl enable-linger $USER
|
||||||
|
> ```
|
||||||
|
|
||||||
## 📁 Project Structure
|
## 📁 Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
nanobot - A lightweight AI agent framework
|
nanobot - A lightweight AI agent framework
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.4"
|
__version__ = "0.1.4.post2"
|
||||||
__logo__ = "🐈"
|
__logo__ = "🐈"
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import platform
|
import platform
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -11,14 +13,10 @@ from nanobot.agent.skills import SkillsLoader
|
|||||||
|
|
||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
"""
|
"""Builds the context (system prompt + messages) for the agent."""
|
||||||
Builds the context (system prompt + messages) for the agent.
|
|
||||||
|
|
||||||
Assembles bootstrap files, memory, skills, and conversation history
|
|
||||||
into a coherent prompt for the LLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
||||||
|
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
@@ -26,39 +24,23 @@ class ContextBuilder:
|
|||||||
self.skills = SkillsLoader(workspace)
|
self.skills = SkillsLoader(workspace)
|
||||||
|
|
||||||
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
||||||
"""
|
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
||||||
Build the system prompt from bootstrap files, memory, and skills.
|
parts = [self._get_identity()]
|
||||||
|
|
||||||
Args:
|
|
||||||
skill_names: Optional list of skills to include.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Complete system prompt.
|
|
||||||
"""
|
|
||||||
parts = []
|
|
||||||
|
|
||||||
# Core identity
|
|
||||||
parts.append(self._get_identity())
|
|
||||||
|
|
||||||
# Bootstrap files
|
|
||||||
bootstrap = self._load_bootstrap_files()
|
bootstrap = self._load_bootstrap_files()
|
||||||
if bootstrap:
|
if bootstrap:
|
||||||
parts.append(bootstrap)
|
parts.append(bootstrap)
|
||||||
|
|
||||||
# Memory context
|
|
||||||
memory = self.memory.get_memory_context()
|
memory = self.memory.get_memory_context()
|
||||||
if memory:
|
if memory:
|
||||||
parts.append(f"# Memory\n\n{memory}")
|
parts.append(f"# Memory\n\n{memory}")
|
||||||
|
|
||||||
# Skills - progressive loading
|
|
||||||
# 1. Always-loaded skills: include full content
|
|
||||||
always_skills = self.skills.get_always_skills()
|
always_skills = self.skills.get_always_skills()
|
||||||
if always_skills:
|
if always_skills:
|
||||||
always_content = self.skills.load_skills_for_context(always_skills)
|
always_content = self.skills.load_skills_for_context(always_skills)
|
||||||
if always_content:
|
if always_content:
|
||||||
parts.append(f"# Active Skills\n\n{always_content}")
|
parts.append(f"# Active Skills\n\n{always_content}")
|
||||||
|
|
||||||
# 2. Available skills: only show summary (agent uses read_file to load)
|
|
||||||
skills_summary = self.skills.build_skills_summary()
|
skills_summary = self.skills.build_skills_summary()
|
||||||
if skills_summary:
|
if skills_summary:
|
||||||
parts.append(f"""# Skills
|
parts.append(f"""# Skills
|
||||||
@@ -67,47 +49,46 @@ The following skills extend your capabilities. To use a skill, read its SKILL.md
|
|||||||
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew.
|
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew.
|
||||||
|
|
||||||
{skills_summary}""")
|
{skills_summary}""")
|
||||||
|
|
||||||
return "\n\n---\n\n".join(parts)
|
return "\n\n---\n\n".join(parts)
|
||||||
|
|
||||||
def _get_identity(self) -> str:
|
def _get_identity(self) -> str:
|
||||||
"""Get the core identity section."""
|
"""Get the core identity section."""
|
||||||
from datetime import datetime
|
|
||||||
import time as _time
|
|
||||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
|
||||||
tz = _time.strftime("%Z") or "UTC"
|
|
||||||
workspace_path = str(self.workspace.expanduser().resolve())
|
workspace_path = str(self.workspace.expanduser().resolve())
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||||
|
|
||||||
return f"""# nanobot 🐈
|
return f"""# nanobot 🐈
|
||||||
|
|
||||||
You are nanobot, a helpful AI assistant. You have access to tools that allow you to:
|
You are nanobot, a helpful AI assistant.
|
||||||
- Read, write, and edit files
|
|
||||||
- Execute shell commands
|
|
||||||
- Search the web and fetch web pages
|
|
||||||
- Send messages to users on chat channels
|
|
||||||
- Spawn subagents for complex background tasks
|
|
||||||
|
|
||||||
## Current Time
|
|
||||||
{now} ({tz})
|
|
||||||
|
|
||||||
## Runtime
|
## Runtime
|
||||||
{runtime}
|
{runtime}
|
||||||
|
|
||||||
## Workspace
|
## Workspace
|
||||||
Your workspace is at: {workspace_path}
|
Your workspace is at: {workspace_path}
|
||||||
- Long-term memory: {workspace_path}/memory/MEMORY.md
|
- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
|
||||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable)
|
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable)
|
||||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||||
|
|
||||||
IMPORTANT: When responding to direct questions or conversations, reply directly with your text response.
|
## nanobot Guidelines
|
||||||
Only use the 'message' tool when you need to send a message to a specific chat channel (like WhatsApp).
|
- State intent before tool calls, but NEVER predict or claim results before receiving them.
|
||||||
For normal conversation, just respond with text - do not call the message tool.
|
- Before modifying a file, read it first. Do not assume files or directories exist.
|
||||||
|
- After writing or editing a file, re-read it if accuracy matters.
|
||||||
|
- If a tool call fails, analyze the error before retrying with a different approach.
|
||||||
|
- Ask for clarification when the request is ambiguous.
|
||||||
|
|
||||||
Always be helpful, accurate, and concise. Before calling tools, briefly tell the user what you're about to do (one short sentence in the user's language).
|
Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel."""
|
||||||
When remembering something important, write to {workspace_path}/memory/MEMORY.md
|
|
||||||
To recall past events, grep {workspace_path}/memory/HISTORY.md"""
|
@staticmethod
|
||||||
|
def _build_runtime_context(channel: str | None, chat_id: str | None) -> str:
|
||||||
|
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||||
|
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||||
|
tz = time.strftime("%Z") or "UTC"
|
||||||
|
lines = [f"Current Time: {now} ({tz})"]
|
||||||
|
if channel and chat_id:
|
||||||
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
|
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||||
|
|
||||||
def _load_bootstrap_files(self) -> str:
|
def _load_bootstrap_files(self) -> str:
|
||||||
"""Load all bootstrap files from workspace."""
|
"""Load all bootstrap files from workspace."""
|
||||||
@@ -130,36 +111,13 @@ To recall past events, grep {workspace_path}/memory/HISTORY.md"""
|
|||||||
channel: str | None = None,
|
channel: str | None = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Build the complete message list for an LLM call."""
|
||||||
Build the complete message list for an LLM call.
|
return [
|
||||||
|
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||||
Args:
|
*history,
|
||||||
history: Previous conversation messages.
|
{"role": "user", "content": self._build_runtime_context(channel, chat_id)},
|
||||||
current_message: The new user message.
|
{"role": "user", "content": self._build_user_content(current_message, media)},
|
||||||
skill_names: Optional skills to include.
|
]
|
||||||
media: Optional list of local file paths for images/media.
|
|
||||||
channel: Current channel (telegram, feishu, etc.).
|
|
||||||
chat_id: Current chat/user ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages including system prompt.
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# System prompt
|
|
||||||
system_prompt = self.build_system_prompt(skill_names)
|
|
||||||
if channel and chat_id:
|
|
||||||
system_prompt += f"\n\n## Current Session\nChannel: {channel}\nChat ID: {chat_id}"
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
|
||||||
|
|
||||||
# History
|
|
||||||
messages.extend(history)
|
|
||||||
|
|
||||||
# Current message (with optional image attachments)
|
|
||||||
user_content = self._build_user_content(current_message, media)
|
|
||||||
messages.append({"role": "user", "content": user_content})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||||
"""Build user message content with optional base64-encoded images."""
|
"""Build user message content with optional base64-encoded images."""
|
||||||
@@ -180,63 +138,24 @@ To recall past events, grep {workspace_path}/memory/HISTORY.md"""
|
|||||||
return images + [{"type": "text", "text": text}]
|
return images + [{"type": "text", "text": text}]
|
||||||
|
|
||||||
def add_tool_result(
|
def add_tool_result(
|
||||||
self,
|
self, messages: list[dict[str, Any]],
|
||||||
messages: list[dict[str, Any]],
|
tool_call_id: str, tool_name: str, result: str,
|
||||||
tool_call_id: str,
|
|
||||||
tool_name: str,
|
|
||||||
result: str
|
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Add a tool result to the message list."""
|
||||||
Add a tool result to the message list.
|
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Current message list.
|
|
||||||
tool_call_id: ID of the tool call.
|
|
||||||
tool_name: Name of the tool.
|
|
||||||
result: Tool execution result.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated message list.
|
|
||||||
"""
|
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"name": tool_name,
|
|
||||||
"content": result
|
|
||||||
})
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_assistant_message(
|
def add_assistant_message(
|
||||||
self,
|
self, messages: list[dict[str, Any]],
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
content: str | None,
|
content: str | None,
|
||||||
tool_calls: list[dict[str, Any]] | None = None,
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
reasoning_content: str | None = None,
|
reasoning_content: str | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Add an assistant message to the message list."""
|
||||||
Add an assistant message to the message list.
|
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Current message list.
|
|
||||||
content: Message content.
|
|
||||||
tool_calls: Optional tool calls.
|
|
||||||
reasoning_content: Thinking output (Kimi, DeepSeek-R1, etc.).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated message list.
|
|
||||||
"""
|
|
||||||
msg: dict[str, Any] = {"role": "assistant"}
|
|
||||||
|
|
||||||
# Omit empty content — some backends reject empty text blocks
|
|
||||||
if content:
|
|
||||||
msg["content"] = content
|
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
msg["tool_calls"] = tool_calls
|
msg["tool_calls"] = tool_calls
|
||||||
|
if reasoning_content is not None:
|
||||||
# Include reasoning content when provided (required by some thinking models)
|
|
||||||
if reasoning_content:
|
|
||||||
msg["reasoning_content"] = reasoning_content
|
msg["reasoning_content"] = reasoning_content
|
||||||
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@@ -1,30 +1,35 @@
|
|||||||
"""Agent loop: the core processing engine."""
|
"""Agent loop: the core processing engine."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
|
||||||
import json
|
import json
|
||||||
import json_repair
|
|
||||||
from pathlib import Path
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Awaitable, Callable
|
from contextlib import AsyncExitStack
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
from nanobot.agent.memory import MemoryStore
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.agent.tools.cron import CronTool
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
from nanobot.agent.tools.spawn import SpawnTool
|
||||||
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.agent.context import ContextBuilder
|
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
|
||||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
|
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
|
||||||
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
|
||||||
from nanobot.agent.tools.message import MessageTool
|
|
||||||
from nanobot.agent.tools.spawn import SpawnTool
|
|
||||||
from nanobot.agent.tools.cron import CronTool
|
|
||||||
from nanobot.agent.memory import MemoryStore
|
|
||||||
from nanobot.agent.subagent import SubagentManager
|
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.config.schema import ChannelsConfig, ExecToolConfig
|
||||||
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
|
|
||||||
class AgentLoop:
|
class AgentLoop:
|
||||||
"""
|
"""
|
||||||
@@ -44,20 +49,21 @@ class AgentLoop:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_iterations: int = 20,
|
max_iterations: int = 40,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.1,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
memory_window: int = 50,
|
memory_window: int = 100,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
cron_service: "CronService | None" = None,
|
cron_service: CronService | None = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
session_manager: SessionManager | None = None,
|
session_manager: SessionManager | None = None,
|
||||||
mcp_servers: dict | None = None,
|
mcp_servers: dict | None = None,
|
||||||
|
channels_config: ChannelsConfig | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
|
self.channels_config = channels_config
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
@@ -84,60 +90,64 @@ class AgentLoop:
|
|||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
restrict_to_workspace=restrict_to_workspace,
|
restrict_to_workspace=restrict_to_workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._mcp_servers = mcp_servers or {}
|
self._mcp_servers = mcp_servers or {}
|
||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stack: AsyncExitStack | None = None
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
|
self._mcp_connecting = False
|
||||||
|
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||||
|
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||||
|
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
|
self._processing_lock = asyncio.Lock()
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
"""Register the default set of tools."""
|
"""Register the default set of tools."""
|
||||||
# File tools (restrict to workspace if configured)
|
|
||||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||||
self.tools.register(ReadFileTool(allowed_dir=allowed_dir))
|
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||||
self.tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
self.tools.register(EditFileTool(allowed_dir=allowed_dir))
|
|
||||||
self.tools.register(ListDirTool(allowed_dir=allowed_dir))
|
|
||||||
|
|
||||||
# Shell tool
|
|
||||||
self.tools.register(ExecTool(
|
self.tools.register(ExecTool(
|
||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
timeout=self.exec_config.timeout,
|
timeout=self.exec_config.timeout,
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
|
path_append=self.exec_config.path_append,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Web tools
|
|
||||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
self.tools.register(WebFetchTool())
|
self.tools.register(WebFetchTool())
|
||||||
|
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||||
# Message tool
|
self.tools.register(SpawnTool(manager=self.subagents))
|
||||||
message_tool = MessageTool(send_callback=self.bus.publish_outbound)
|
|
||||||
self.tools.register(message_tool)
|
|
||||||
|
|
||||||
# Spawn tool (for subagents)
|
|
||||||
spawn_tool = SpawnTool(manager=self.subagents)
|
|
||||||
self.tools.register(spawn_tool)
|
|
||||||
|
|
||||||
# Cron tool (for scheduling)
|
|
||||||
if self.cron_service:
|
if self.cron_service:
|
||||||
self.tools.register(CronTool(self.cron_service))
|
self.tools.register(CronTool(self.cron_service))
|
||||||
|
|
||||||
async def _connect_mcp(self) -> None:
|
async def _connect_mcp(self) -> None:
|
||||||
"""Connect to configured MCP servers (one-time, lazy)."""
|
"""Connect to configured MCP servers (one-time, lazy)."""
|
||||||
if self._mcp_connected or not self._mcp_servers:
|
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
|
||||||
return
|
return
|
||||||
self._mcp_connected = True
|
self._mcp_connecting = True
|
||||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||||
self._mcp_stack = AsyncExitStack()
|
try:
|
||||||
await self._mcp_stack.__aenter__()
|
self._mcp_stack = AsyncExitStack()
|
||||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
await self._mcp_stack.__aenter__()
|
||||||
|
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||||
|
self._mcp_connected = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||||
|
if self._mcp_stack:
|
||||||
|
try:
|
||||||
|
await self._mcp_stack.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._mcp_stack = None
|
||||||
|
finally:
|
||||||
|
self._mcp_connecting = False
|
||||||
|
|
||||||
def _set_tool_context(self, channel: str, chat_id: str) -> None:
|
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||||
"""Update context for all tools that need routing info."""
|
"""Update context for all tools that need routing info."""
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
if isinstance(message_tool, MessageTool):
|
if isinstance(message_tool, MessageTool):
|
||||||
message_tool.set_context(channel, chat_id)
|
message_tool.set_context(channel, chat_id, message_id)
|
||||||
|
|
||||||
if spawn_tool := self.tools.get("spawn"):
|
if spawn_tool := self.tools.get("spawn"):
|
||||||
if isinstance(spawn_tool, SpawnTool):
|
if isinstance(spawn_tool, SpawnTool):
|
||||||
@@ -167,18 +177,9 @@ class AgentLoop:
|
|||||||
async def _run_agent_loop(
|
async def _run_agent_loop(
|
||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> tuple[str | None, list[str]]:
|
) -> tuple[str | None, list[str], list[dict]]:
|
||||||
"""
|
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
||||||
Run the agent iteration loop.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
initial_messages: Starting messages for the LLM conversation.
|
|
||||||
on_progress: Optional callback to push intermediate content to the user.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (final_content, list_of_tools_used).
|
|
||||||
"""
|
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_content = None
|
final_content = None
|
||||||
@@ -198,7 +199,9 @@ class AgentLoop:
|
|||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
if on_progress:
|
if on_progress:
|
||||||
clean = self._strip_think(response.content)
|
clean = self._strip_think(response.content)
|
||||||
await on_progress(clean or self._tool_hint(response.tool_calls))
|
if clean:
|
||||||
|
await on_progress(clean)
|
||||||
|
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||||
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
{
|
{
|
||||||
@@ -206,7 +209,7 @@ class AgentLoop:
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc.name,
|
"name": tc.name,
|
||||||
"arguments": json.dumps(tc.arguments)
|
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
@@ -219,43 +222,87 @@ class AgentLoop:
|
|||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
tools_used.append(tool_call.name)
|
tools_used.append(tool_call.name)
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||||
messages = self.context.add_tool_result(
|
messages = self.context.add_tool_result(
|
||||||
messages, tool_call.id, tool_call.name, result
|
messages, tool_call.id, tool_call.name, result
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
final_content = self._strip_think(response.content)
|
clean = self._strip_think(response.content)
|
||||||
|
if on_progress and clean:
|
||||||
|
await on_progress(clean)
|
||||||
|
messages = self.context.add_assistant_message(
|
||||||
|
messages, clean, reasoning_content=response.reasoning_content,
|
||||||
|
)
|
||||||
|
final_content = clean
|
||||||
break
|
break
|
||||||
|
|
||||||
return final_content, tools_used
|
if final_content is None and iteration >= self.max_iterations:
|
||||||
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||||
|
final_content = (
|
||||||
|
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
||||||
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_content, tools_used, messages
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, processing messages from the bus."""
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||||
self._running = True
|
self._running = True
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
logger.info("Agent loop started")
|
logger.info("Agent loop started")
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(
|
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||||
self.bus.consume_inbound(),
|
|
||||||
timeout=1.0
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await self._process_message(msg)
|
|
||||||
if response:
|
|
||||||
await self.bus.publish_outbound(response)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing message: {e}")
|
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
|
||||||
channel=msg.channel,
|
|
||||||
chat_id=msg.chat_id,
|
|
||||||
content=f"Sorry, I encountered an error: {str(e)}"
|
|
||||||
))
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if msg.content.strip().lower() == "/stop":
|
||||||
|
await self._handle_stop(msg)
|
||||||
|
else:
|
||||||
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
|
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||||
|
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||||
|
|
||||||
|
async def _handle_stop(self, msg: InboundMessage) -> None:
|
||||||
|
"""Cancel all active tasks and subagents for the session."""
|
||||||
|
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||||
|
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||||
|
for t in tasks:
|
||||||
|
try:
|
||||||
|
await t
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||||
|
total = cancelled + sub_cancelled
|
||||||
|
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
|
"""Process a message under the global lock."""
|
||||||
|
async with self._processing_lock:
|
||||||
|
try:
|
||||||
|
response = await self._process_message(msg)
|
||||||
|
if response is not None:
|
||||||
|
await self.bus.publish_outbound(response)
|
||||||
|
elif msg.channel == "cli":
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="", metadata=msg.metadata or {},
|
||||||
|
))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Task cancelled for session {}", msg.session_key)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing message for session {}", msg.session_key)
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="Sorry, I encountered an error.",
|
||||||
|
))
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Close MCP connections."""
|
"""Close MCP connections."""
|
||||||
if self._mcp_stack:
|
if self._mcp_stack:
|
||||||
@@ -269,222 +316,177 @@ class AgentLoop:
|
|||||||
"""Stop the agent loop."""
|
"""Stop the agent loop."""
|
||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Agent loop stopping")
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
|
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
|
||||||
|
lock = self._consolidation_locks.get(session_key)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
self._consolidation_locks[session_key] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
|
||||||
|
"""Drop lock entry if no longer in use."""
|
||||||
|
if not lock.locked():
|
||||||
|
self._consolidation_locks.pop(session_key, None)
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
session_key: str | None = None,
|
session_key: str | None = None,
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""
|
"""Process a single inbound message and return the response."""
|
||||||
Process a single inbound message.
|
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||||
|
|
||||||
Args:
|
|
||||||
msg: The inbound message to process.
|
|
||||||
session_key: Override session key (used by process_direct).
|
|
||||||
on_progress: Optional callback for intermediate output (defaults to bus publish).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The response message, or None if no response needed.
|
|
||||||
"""
|
|
||||||
# System messages route back via chat_id ("channel:chat_id")
|
|
||||||
if msg.channel == "system":
|
if msg.channel == "system":
|
||||||
return await self._process_system_message(msg)
|
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||||
|
else ("cli", msg.chat_id))
|
||||||
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
|
key = f"{channel}:{chat_id}"
|
||||||
|
session = self.sessions.get_or_create(key)
|
||||||
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
|
history = session.get_history(max_messages=self.memory_window)
|
||||||
|
messages = self.context.build_messages(
|
||||||
|
history=history,
|
||||||
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
|
)
|
||||||
|
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||||
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
|
self.sessions.save(session)
|
||||||
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
|
content=final_content or "Background task completed.")
|
||||||
|
|
||||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||||
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
|
|
||||||
key = session_key or msg.session_key
|
key = session_key or msg.session_key
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
|
||||||
# Handle slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
# Capture messages before clearing (avoid race condition with background task)
|
lock = self._get_consolidation_lock(session.key)
|
||||||
messages_to_archive = session.messages.copy()
|
self._consolidating.add(session.key)
|
||||||
|
try:
|
||||||
|
async with lock:
|
||||||
|
snapshot = session.messages[session.last_consolidated:]
|
||||||
|
if snapshot:
|
||||||
|
temp = Session(key=session.key)
|
||||||
|
temp.messages = list(snapshot)
|
||||||
|
if not await self._consolidate_memory(temp, archive_all=True):
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("/new archival failed for {}", session.key)
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._consolidating.discard(session.key)
|
||||||
|
self._prune_consolidation_lock(session.key, lock)
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self.sessions.invalidate(session.key)
|
self.sessions.invalidate(session.key)
|
||||||
|
|
||||||
async def _consolidate_and_cleanup():
|
|
||||||
temp_session = Session(key=session.key)
|
|
||||||
temp_session.messages = messages_to_archive
|
|
||||||
await self._consolidate_memory(temp_session, archive_all=True)
|
|
||||||
|
|
||||||
asyncio.create_task(_consolidate_and_cleanup())
|
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="New session started. Memory consolidation in progress.")
|
content="New session started.")
|
||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||||
|
|
||||||
if len(session.messages) > self.memory_window:
|
|
||||||
asyncio.create_task(self._consolidate_memory(session))
|
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id)
|
unconsolidated = len(session.messages) - session.last_consolidated
|
||||||
|
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||||
|
self._consolidating.add(session.key)
|
||||||
|
lock = self._get_consolidation_lock(session.key)
|
||||||
|
|
||||||
|
async def _consolidate_and_unlock():
|
||||||
|
try:
|
||||||
|
async with lock:
|
||||||
|
await self._consolidate_memory(session)
|
||||||
|
finally:
|
||||||
|
self._consolidating.discard(session.key)
|
||||||
|
self._prune_consolidation_lock(session.key, lock)
|
||||||
|
_task = asyncio.current_task()
|
||||||
|
if _task is not None:
|
||||||
|
self._consolidation_tasks.discard(_task)
|
||||||
|
|
||||||
|
_task = asyncio.create_task(_consolidate_and_unlock())
|
||||||
|
self._consolidation_tasks.add(_task)
|
||||||
|
|
||||||
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
|
if message_tool := self.tools.get("message"):
|
||||||
|
if isinstance(message_tool, MessageTool):
|
||||||
|
message_tool.start_turn()
|
||||||
|
|
||||||
|
history = session.get_history(max_messages=self.memory_window)
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=session.get_history(max_messages=self.memory_window),
|
history=history,
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
chat_id=msg.chat_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _bus_progress(content: str) -> None:
|
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
|
meta = dict(msg.metadata or {})
|
||||||
|
meta["_progress"] = True
|
||||||
|
meta["_tool_hint"] = tool_hint
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||||
metadata=msg.metadata or {},
|
|
||||||
))
|
))
|
||||||
|
|
||||||
final_content, tools_used = await self._run_agent_loop(
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
initial_messages, on_progress=on_progress or _bus_progress,
|
initial_messages, on_progress=on_progress or _bus_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
|
|
||||||
session.add_message("user", msg.content)
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
session.add_message("assistant", final_content,
|
|
||||||
tools_used=tools_used if tools_used else None)
|
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
if message_tool := self.tools.get("message"):
|
||||||
|
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||||
|
return None
|
||||||
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||||
chat_id=msg.chat_id,
|
metadata=msg.metadata or {},
|
||||||
content=final_content,
|
|
||||||
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
_TOOL_RESULT_MAX_CHARS = 500
|
||||||
"""
|
|
||||||
Process a system message (e.g., subagent announce).
|
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||||
|
"""Save new-turn messages into session, truncating large tool results."""
|
||||||
The chat_id field contains "original_channel:original_chat_id" to route
|
from datetime import datetime
|
||||||
the response back to the correct destination.
|
for m in messages[skip:]:
|
||||||
"""
|
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
|
||||||
logger.info(f"Processing system message from {msg.sender_id}")
|
if entry.get("role") == "tool" and isinstance(entry.get("content"), str):
|
||||||
|
content = entry["content"]
|
||||||
# Parse origin from chat_id (format: "channel:chat_id")
|
if len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||||
if ":" in msg.chat_id:
|
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||||
parts = msg.chat_id.split(":", 1)
|
if entry.get("role") == "user" and isinstance(entry.get("content"), list):
|
||||||
origin_channel = parts[0]
|
entry["content"] = [
|
||||||
origin_chat_id = parts[1]
|
{"type": "text", "text": "[image]"} if (
|
||||||
else:
|
c.get("type") == "image_url"
|
||||||
# Fallback
|
and c.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||||
origin_channel = "cli"
|
) else c
|
||||||
origin_chat_id = msg.chat_id
|
for c in entry["content"]
|
||||||
|
]
|
||||||
session_key = f"{origin_channel}:{origin_chat_id}"
|
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||||
session = self.sessions.get_or_create(session_key)
|
session.messages.append(entry)
|
||||||
self._set_tool_context(origin_channel, origin_chat_id)
|
session.updated_at = datetime.now()
|
||||||
initial_messages = self.context.build_messages(
|
|
||||||
history=session.get_history(max_messages=self.memory_window),
|
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
||||||
current_message=msg.content,
|
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
||||||
channel=origin_channel,
|
return await MemoryStore(self.workspace).consolidate(
|
||||||
chat_id=origin_chat_id,
|
session, self.provider, self.model,
|
||||||
|
archive_all=archive_all, memory_window=self.memory_window,
|
||||||
)
|
)
|
||||||
final_content, _ = await self._run_agent_loop(initial_messages)
|
|
||||||
|
|
||||||
if final_content is None:
|
|
||||||
final_content = "Background task completed."
|
|
||||||
|
|
||||||
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
|
|
||||||
session.add_message("assistant", final_content)
|
|
||||||
self.sessions.save(session)
|
|
||||||
|
|
||||||
return OutboundMessage(
|
|
||||||
channel=origin_channel,
|
|
||||||
chat_id=origin_chat_id,
|
|
||||||
content=final_content
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
|
||||||
"""Consolidate old messages into MEMORY.md + HISTORY.md.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
archive_all: If True, clear all messages and reset session (for /new command).
|
|
||||||
If False, only write to files without modifying session.
|
|
||||||
"""
|
|
||||||
memory = MemoryStore(self.workspace)
|
|
||||||
|
|
||||||
if archive_all:
|
|
||||||
old_messages = session.messages
|
|
||||||
keep_count = 0
|
|
||||||
logger.info(f"Memory consolidation (archive_all): {len(session.messages)} total messages archived")
|
|
||||||
else:
|
|
||||||
keep_count = self.memory_window // 2
|
|
||||||
if len(session.messages) <= keep_count:
|
|
||||||
logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})")
|
|
||||||
return
|
|
||||||
|
|
||||||
messages_to_process = len(session.messages) - session.last_consolidated
|
|
||||||
if messages_to_process <= 0:
|
|
||||||
logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})")
|
|
||||||
return
|
|
||||||
|
|
||||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
|
||||||
if not old_messages:
|
|
||||||
return
|
|
||||||
logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep")
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for m in old_messages:
|
|
||||||
if not m.get("content"):
|
|
||||||
continue
|
|
||||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
|
||||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
|
||||||
conversation = "\n".join(lines)
|
|
||||||
current_memory = memory.read_long_term()
|
|
||||||
|
|
||||||
prompt = f"""You are a memory consolidation agent. Process this conversation and return a JSON object with exactly two keys:
|
|
||||||
|
|
||||||
1. "history_entry": A paragraph (2-5 sentences) summarizing the key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later.
|
|
||||||
|
|
||||||
2. "memory_update": The updated long-term memory content. Add any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.
|
|
||||||
|
|
||||||
## Current Long-term Memory
|
|
||||||
{current_memory or "(empty)"}
|
|
||||||
|
|
||||||
## Conversation to Process
|
|
||||||
{conversation}
|
|
||||||
|
|
||||||
Respond with ONLY valid JSON, no markdown fences."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.provider.chat(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
model=self.model,
|
|
||||||
)
|
|
||||||
text = (response.content or "").strip()
|
|
||||||
if not text:
|
|
||||||
logger.warning("Memory consolidation: LLM returned empty response, skipping")
|
|
||||||
return
|
|
||||||
if text.startswith("```"):
|
|
||||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
|
||||||
result = json_repair.loads(text)
|
|
||||||
if not isinstance(result, dict):
|
|
||||||
logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if entry := result.get("history_entry"):
|
|
||||||
memory.append_history(entry)
|
|
||||||
if update := result.get("memory_update"):
|
|
||||||
if update != current_memory:
|
|
||||||
memory.write_long_term(update)
|
|
||||||
|
|
||||||
if archive_all:
|
|
||||||
session.last_consolidated = 0
|
|
||||||
else:
|
|
||||||
session.last_consolidated = len(session.messages) - keep_count
|
|
||||||
logger.info(f"Memory consolidation done: {len(session.messages)} messages, last_consolidated={session.last_consolidated}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Memory consolidation failed: {e}")
|
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
@@ -494,26 +496,8 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""Process a message directly (for CLI or cron usage)."""
|
||||||
Process a message directly (for CLI or cron usage).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The message content.
|
|
||||||
session_key: Session identifier (overrides channel:chat_id for session lookup).
|
|
||||||
channel: Source channel (for tool context routing).
|
|
||||||
chat_id: Source chat ID (for tool context routing).
|
|
||||||
on_progress: Optional callback for intermediate output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The agent's response.
|
|
||||||
"""
|
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||||
channel=channel,
|
|
||||||
sender_id="user",
|
|
||||||
chat_id=chat_id,
|
|
||||||
content=content
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||||
return response.content if response else ""
|
return response.content if response else ""
|
||||||
|
|||||||
@@ -1,9 +1,46 @@
|
|||||||
"""Memory system for persistent agent memory."""
|
"""Memory system for persistent agent memory."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.utils.helpers import ensure_dir
|
from nanobot.utils.helpers import ensure_dir
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
from nanobot.session.manager import Session
|
||||||
|
|
||||||
|
|
||||||
|
_SAVE_MEMORY_TOOL = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "save_memory",
|
||||||
|
"description": "Save the memory consolidation result to persistent storage.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"history_entry": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
|
||||||
|
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
||||||
|
},
|
||||||
|
"memory_update": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Full updated long-term memory as markdown. Include all existing "
|
||||||
|
"facts plus new ones. Return unchanged if nothing new.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["history_entry", "memory_update"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class MemoryStore:
|
class MemoryStore:
|
||||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||||
@@ -28,3 +65,86 @@ class MemoryStore:
|
|||||||
def get_memory_context(self) -> str:
|
def get_memory_context(self) -> str:
|
||||||
long_term = self.read_long_term()
|
long_term = self.read_long_term()
|
||||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||||
|
|
||||||
|
async def consolidate(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
provider: LLMProvider,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
archive_all: bool = False,
|
||||||
|
memory_window: int = 50,
|
||||||
|
) -> bool:
|
||||||
|
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||||
|
|
||||||
|
Returns True on success (including no-op), False on failure.
|
||||||
|
"""
|
||||||
|
if archive_all:
|
||||||
|
old_messages = session.messages
|
||||||
|
keep_count = 0
|
||||||
|
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
||||||
|
else:
|
||||||
|
keep_count = memory_window // 2
|
||||||
|
if len(session.messages) <= keep_count:
|
||||||
|
return True
|
||||||
|
if len(session.messages) - session.last_consolidated <= 0:
|
||||||
|
return True
|
||||||
|
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||||
|
if not old_messages:
|
||||||
|
return True
|
||||||
|
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for m in old_messages:
|
||||||
|
if not m.get("content"):
|
||||||
|
continue
|
||||||
|
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||||
|
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
||||||
|
|
||||||
|
current_memory = self.read_long_term()
|
||||||
|
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||||
|
|
||||||
|
## Current Long-term Memory
|
||||||
|
{current_memory or "(empty)"}
|
||||||
|
|
||||||
|
## Conversation to Process
|
||||||
|
{chr(10).join(lines)}"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await provider.chat(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
tools=_SAVE_MEMORY_TOOL,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.has_tool_calls:
|
||||||
|
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||||
|
return False
|
||||||
|
|
||||||
|
args = response.tool_calls[0].arguments
|
||||||
|
# Some providers return arguments as a JSON string instead of dict
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args)
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if entry := args.get("history_entry"):
|
||||||
|
if not isinstance(entry, str):
|
||||||
|
entry = json.dumps(entry, ensure_ascii=False)
|
||||||
|
self.append_history(entry)
|
||||||
|
if update := args.get("memory_update"):
|
||||||
|
if not isinstance(update, str):
|
||||||
|
update = json.dumps(update, ensure_ascii=False)
|
||||||
|
if update != current_memory:
|
||||||
|
self.write_long_term(update)
|
||||||
|
|
||||||
|
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||||
|
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Memory consolidation failed")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -18,13 +18,7 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
|||||||
|
|
||||||
|
|
||||||
class SubagentManager:
|
class SubagentManager:
|
||||||
"""
|
"""Manages background subagent execution."""
|
||||||
Manages background subagent execution.
|
|
||||||
|
|
||||||
Subagents are lightweight agent instances that run in the background
|
|
||||||
to handle specific tasks. They share the same LLM provider but have
|
|
||||||
isolated context and a focused system prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -49,6 +43,7 @@ class SubagentManager:
|
|||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
|
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
self,
|
self,
|
||||||
@@ -56,37 +51,30 @@ class SubagentManager:
|
|||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
origin_channel: str = "cli",
|
origin_channel: str = "cli",
|
||||||
origin_chat_id: str = "direct",
|
origin_chat_id: str = "direct",
|
||||||
|
session_key: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""Spawn a subagent to execute a task in the background."""
|
||||||
Spawn a subagent to execute a task in the background.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The task description for the subagent.
|
|
||||||
label: Optional human-readable label for the task.
|
|
||||||
origin_channel: The channel to announce results to.
|
|
||||||
origin_chat_id: The chat ID to announce results to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Status message indicating the subagent was started.
|
|
||||||
"""
|
|
||||||
task_id = str(uuid.uuid4())[:8]
|
task_id = str(uuid.uuid4())[:8]
|
||||||
display_label = label or task[:30] + ("..." if len(task) > 30 else "")
|
display_label = label or task[:30] + ("..." if len(task) > 30 else "")
|
||||||
|
origin = {"channel": origin_channel, "chat_id": origin_chat_id}
|
||||||
origin = {
|
|
||||||
"channel": origin_channel,
|
|
||||||
"chat_id": origin_chat_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create background task
|
|
||||||
bg_task = asyncio.create_task(
|
bg_task = asyncio.create_task(
|
||||||
self._run_subagent(task_id, task, display_label, origin)
|
self._run_subagent(task_id, task, display_label, origin)
|
||||||
)
|
)
|
||||||
self._running_tasks[task_id] = bg_task
|
self._running_tasks[task_id] = bg_task
|
||||||
|
if session_key:
|
||||||
|
self._session_tasks.setdefault(session_key, set()).add(task_id)
|
||||||
|
|
||||||
|
def _cleanup(_: asyncio.Task) -> None:
|
||||||
|
self._running_tasks.pop(task_id, None)
|
||||||
|
if session_key and (ids := self._session_tasks.get(session_key)):
|
||||||
|
ids.discard(task_id)
|
||||||
|
if not ids:
|
||||||
|
del self._session_tasks[session_key]
|
||||||
|
|
||||||
|
bg_task.add_done_callback(_cleanup)
|
||||||
|
|
||||||
# Cleanup when done
|
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
||||||
bg_task.add_done_callback(lambda _: self._running_tasks.pop(task_id, None))
|
|
||||||
|
|
||||||
logger.info(f"Spawned subagent [{task_id}]: {display_label}")
|
|
||||||
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
||||||
|
|
||||||
async def _run_subagent(
|
async def _run_subagent(
|
||||||
@@ -97,20 +85,21 @@ class SubagentManager:
|
|||||||
origin: dict[str, str],
|
origin: dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute the subagent task and announce the result."""
|
"""Execute the subagent task and announce the result."""
|
||||||
logger.info(f"Subagent [{task_id}] starting task: {label}")
|
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build subagent tools (no message tool, no spawn tool)
|
# Build subagent tools (no message tool, no spawn tool)
|
||||||
tools = ToolRegistry()
|
tools = ToolRegistry()
|
||||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||||
tools.register(ReadFileTool(allowed_dir=allowed_dir))
|
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
tools.register(EditFileTool(allowed_dir=allowed_dir))
|
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
tools.register(ListDirTool(allowed_dir=allowed_dir))
|
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
tools.register(ExecTool(
|
tools.register(ExecTool(
|
||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
timeout=self.exec_config.timeout,
|
timeout=self.exec_config.timeout,
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
|
path_append=self.exec_config.path_append,
|
||||||
))
|
))
|
||||||
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
tools.register(WebFetchTool())
|
tools.register(WebFetchTool())
|
||||||
@@ -146,7 +135,7 @@ class SubagentManager:
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc.name,
|
"name": tc.name,
|
||||||
"arguments": json.dumps(tc.arguments),
|
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
@@ -159,8 +148,8 @@ class SubagentManager:
|
|||||||
|
|
||||||
# Execute tools
|
# Execute tools
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
args_str = json.dumps(tool_call.arguments)
|
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||||
logger.debug(f"Subagent [{task_id}] executing: {tool_call.name} with arguments: {args_str}")
|
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||||
result = await tools.execute(tool_call.name, tool_call.arguments)
|
result = await tools.execute(tool_call.name, tool_call.arguments)
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
@@ -175,12 +164,12 @@ class SubagentManager:
|
|||||||
if final_result is None:
|
if final_result is None:
|
||||||
final_result = "Task completed but no final response was generated."
|
final_result = "Task completed but no final response was generated."
|
||||||
|
|
||||||
logger.info(f"Subagent [{task_id}] completed successfully")
|
logger.info("Subagent [{}] completed successfully", task_id)
|
||||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error: {str(e)}"
|
error_msg = f"Error: {str(e)}"
|
||||||
logger.error(f"Subagent [{task_id}] failed: {e}")
|
logger.error("Subagent [{}] failed: {}", task_id, e)
|
||||||
await self._announce_result(task_id, label, task, error_msg, origin, "error")
|
await self._announce_result(task_id, label, task, error_msg, origin, "error")
|
||||||
|
|
||||||
async def _announce_result(
|
async def _announce_result(
|
||||||
@@ -213,7 +202,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
logger.debug(f"Subagent [{task_id}] announced result to {origin['channel']}:{origin['chat_id']}")
|
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||||
|
|
||||||
def _build_subagent_prompt(self, task: str) -> str:
|
def _build_subagent_prompt(self, task: str) -> str:
|
||||||
"""Build a focused system prompt for the subagent."""
|
"""Build a focused system prompt for the subagent."""
|
||||||
@@ -252,6 +241,16 @@ Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed
|
|||||||
|
|
||||||
When you have completed the task, provide a clear summary of your findings or actions."""
|
When you have completed the task, provide a clear summary of your findings or actions."""
|
||||||
|
|
||||||
|
async def cancel_by_session(self, session_key: str) -> int:
|
||||||
|
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||||
|
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
||||||
|
if tid in self._running_tasks and not self._running_tasks[tid].done()]
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
return len(tasks)
|
||||||
|
|
||||||
def get_running_count(self) -> int:
|
def get_running_count(self) -> int:
|
||||||
"""Return the number of currently running subagents."""
|
"""Return the number of currently running subagents."""
|
||||||
return len(self._running_tasks)
|
return len(self._running_tasks)
|
||||||
|
|||||||
@@ -1,23 +1,31 @@
|
|||||||
"""File system tools: read, write, edit."""
|
"""File system tools: read, write, edit."""
|
||||||
|
|
||||||
|
import difflib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(path: str, allowed_dir: Path | None = None) -> Path:
|
def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path | None = None) -> Path:
|
||||||
"""Resolve path and optionally enforce directory restriction."""
|
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||||
resolved = Path(path).expanduser().resolve()
|
p = Path(path).expanduser()
|
||||||
if allowed_dir and not str(resolved).startswith(str(allowed_dir.resolve())):
|
if not p.is_absolute() and workspace:
|
||||||
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
p = workspace / p
|
||||||
|
resolved = p.resolve()
|
||||||
|
if allowed_dir:
|
||||||
|
try:
|
||||||
|
resolved.relative_to(allowed_dir.resolve())
|
||||||
|
except ValueError:
|
||||||
|
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
class ReadFileTool(Tool):
|
class ReadFileTool(Tool):
|
||||||
"""Tool to read file contents."""
|
"""Tool to read file contents."""
|
||||||
|
|
||||||
def __init__(self, allowed_dir: Path | None = None):
|
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||||
|
self._workspace = workspace
|
||||||
self._allowed_dir = allowed_dir
|
self._allowed_dir = allowed_dir
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -43,12 +51,12 @@ class ReadFileTool(Tool):
|
|||||||
|
|
||||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
file_path = _resolve_path(path, self._allowed_dir)
|
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
return f"Error: File not found: {path}"
|
return f"Error: File not found: {path}"
|
||||||
if not file_path.is_file():
|
if not file_path.is_file():
|
||||||
return f"Error: Not a file: {path}"
|
return f"Error: Not a file: {path}"
|
||||||
|
|
||||||
content = file_path.read_text(encoding="utf-8")
|
content = file_path.read_text(encoding="utf-8")
|
||||||
return content
|
return content
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
@@ -59,8 +67,9 @@ class ReadFileTool(Tool):
|
|||||||
|
|
||||||
class WriteFileTool(Tool):
|
class WriteFileTool(Tool):
|
||||||
"""Tool to write content to a file."""
|
"""Tool to write content to a file."""
|
||||||
|
|
||||||
def __init__(self, allowed_dir: Path | None = None):
|
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||||
|
self._workspace = workspace
|
||||||
self._allowed_dir = allowed_dir
|
self._allowed_dir = allowed_dir
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -90,10 +99,10 @@ class WriteFileTool(Tool):
|
|||||||
|
|
||||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
file_path = _resolve_path(path, self._allowed_dir)
|
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
file_path.write_text(content, encoding="utf-8")
|
file_path.write_text(content, encoding="utf-8")
|
||||||
return f"Successfully wrote {len(content)} bytes to {path}"
|
return f"Successfully wrote {len(content)} bytes to {file_path}"
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -102,8 +111,9 @@ class WriteFileTool(Tool):
|
|||||||
|
|
||||||
class EditFileTool(Tool):
|
class EditFileTool(Tool):
|
||||||
"""Tool to edit a file by replacing text."""
|
"""Tool to edit a file by replacing text."""
|
||||||
|
|
||||||
def __init__(self, allowed_dir: Path | None = None):
|
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||||
|
self._workspace = workspace
|
||||||
self._allowed_dir = allowed_dir
|
self._allowed_dir = allowed_dir
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -137,34 +147,57 @@ class EditFileTool(Tool):
|
|||||||
|
|
||||||
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
file_path = _resolve_path(path, self._allowed_dir)
|
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
return f"Error: File not found: {path}"
|
return f"Error: File not found: {path}"
|
||||||
|
|
||||||
content = file_path.read_text(encoding="utf-8")
|
content = file_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
if old_text not in content:
|
if old_text not in content:
|
||||||
return f"Error: old_text not found in file. Make sure it matches exactly."
|
return self._not_found_message(old_text, content, path)
|
||||||
|
|
||||||
# Count occurrences
|
# Count occurrences
|
||||||
count = content.count(old_text)
|
count = content.count(old_text)
|
||||||
if count > 1:
|
if count > 1:
|
||||||
return f"Warning: old_text appears {count} times. Please provide more context to make it unique."
|
return f"Warning: old_text appears {count} times. Please provide more context to make it unique."
|
||||||
|
|
||||||
new_content = content.replace(old_text, new_text, 1)
|
new_content = content.replace(old_text, new_text, 1)
|
||||||
file_path.write_text(new_content, encoding="utf-8")
|
file_path.write_text(new_content, encoding="utf-8")
|
||||||
|
|
||||||
return f"Successfully edited {path}"
|
return f"Successfully edited {file_path}"
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error editing file: {str(e)}"
|
return f"Error editing file: {str(e)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _not_found_message(old_text: str, content: str, path: str) -> str:
|
||||||
|
"""Build a helpful error when old_text is not found."""
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
old_lines = old_text.splitlines(keepends=True)
|
||||||
|
window = len(old_lines)
|
||||||
|
|
||||||
|
best_ratio, best_start = 0.0, 0
|
||||||
|
for i in range(max(1, len(lines) - window + 1)):
|
||||||
|
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||||
|
if ratio > best_ratio:
|
||||||
|
best_ratio, best_start = ratio, i
|
||||||
|
|
||||||
|
if best_ratio > 0.5:
|
||||||
|
diff = "\n".join(difflib.unified_diff(
|
||||||
|
old_lines, lines[best_start : best_start + window],
|
||||||
|
fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})",
|
||||||
|
lineterm="",
|
||||||
|
))
|
||||||
|
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||||
|
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||||
|
|
||||||
|
|
||||||
class ListDirTool(Tool):
|
class ListDirTool(Tool):
|
||||||
"""Tool to list directory contents."""
|
"""Tool to list directory contents."""
|
||||||
|
|
||||||
def __init__(self, allowed_dir: Path | None = None):
|
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||||
|
self._workspace = workspace
|
||||||
self._allowed_dir = allowed_dir
|
self._allowed_dir = allowed_dir
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -190,20 +223,20 @@ class ListDirTool(Tool):
|
|||||||
|
|
||||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
dir_path = _resolve_path(path, self._allowed_dir)
|
dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||||
if not dir_path.exists():
|
if not dir_path.exists():
|
||||||
return f"Error: Directory not found: {path}"
|
return f"Error: Directory not found: {path}"
|
||||||
if not dir_path.is_dir():
|
if not dir_path.is_dir():
|
||||||
return f"Error: Not a directory: {path}"
|
return f"Error: Not a directory: {path}"
|
||||||
|
|
||||||
items = []
|
items = []
|
||||||
for item in sorted(dir_path.iterdir()):
|
for item in sorted(dir_path.iterdir()):
|
||||||
prefix = "📁 " if item.is_dir() else "📄 "
|
prefix = "📁 " if item.is_dir() else "📄 "
|
||||||
items.append(f"{prefix}{item.name}")
|
items.append(f"{prefix}{item.name}")
|
||||||
|
|
||||||
if not items:
|
if not items:
|
||||||
return f"Directory {path} is empty"
|
return f"Directory {path} is empty"
|
||||||
|
|
||||||
return "\n".join(items)
|
return "\n".join(items)
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
@@ -12,12 +14,13 @@ from nanobot.agent.tools.registry import ToolRegistry
|
|||||||
class MCPToolWrapper(Tool):
|
class MCPToolWrapper(Tool):
|
||||||
"""Wraps a single MCP server tool as a nanobot Tool."""
|
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||||
|
|
||||||
def __init__(self, session, server_name: str, tool_def):
|
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
||||||
self._session = session
|
self._session = session
|
||||||
self._original_name = tool_def.name
|
self._original_name = tool_def.name
|
||||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||||
self._description = tool_def.description or tool_def.name
|
self._description = tool_def.description or tool_def.name
|
||||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||||
|
self._tool_timeout = tool_timeout
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -33,7 +36,14 @@ class MCPToolWrapper(Tool):
|
|||||||
|
|
||||||
async def execute(self, **kwargs: Any) -> str:
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
from mcp import types
|
from mcp import types
|
||||||
result = await self._session.call_tool(self._original_name, arguments=kwargs)
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._session.call_tool(self._original_name, arguments=kwargs),
|
||||||
|
timeout=self._tool_timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||||
|
return f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||||
parts = []
|
parts = []
|
||||||
for block in result.content:
|
for block in result.content:
|
||||||
if isinstance(block, types.TextContent):
|
if isinstance(block, types.TextContent):
|
||||||
@@ -59,11 +69,20 @@ async def connect_mcp_servers(
|
|||||||
read, write = await stack.enter_async_context(stdio_client(params))
|
read, write = await stack.enter_async_context(stdio_client(params))
|
||||||
elif cfg.url:
|
elif cfg.url:
|
||||||
from mcp.client.streamable_http import streamable_http_client
|
from mcp.client.streamable_http import streamable_http_client
|
||||||
|
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||||
|
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||||
|
http_client = await stack.enter_async_context(
|
||||||
|
httpx.AsyncClient(
|
||||||
|
headers=cfg.headers or None,
|
||||||
|
follow_redirects=True,
|
||||||
|
timeout=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
read, write, _ = await stack.enter_async_context(
|
read, write, _ = await stack.enter_async_context(
|
||||||
streamable_http_client(cfg.url)
|
streamable_http_client(cfg.url, http_client=http_client)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"MCP server '{name}': no command or url configured, skipping")
|
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
session = await stack.enter_async_context(ClientSession(read, write))
|
session = await stack.enter_async_context(ClientSession(read, write))
|
||||||
@@ -71,10 +90,10 @@ async def connect_mcp_servers(
|
|||||||
|
|
||||||
tools = await session.list_tools()
|
tools = await session.list_tools()
|
||||||
for tool_def in tools.tools:
|
for tool_def in tools.tools:
|
||||||
wrapper = MCPToolWrapper(session, name, tool_def)
|
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
|
||||||
registry.register(wrapper)
|
registry.register(wrapper)
|
||||||
logger.debug(f"MCP: registered tool '{wrapper.name}' from server '{name}'")
|
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||||
|
|
||||||
logger.info(f"MCP server '{name}': connected, {len(tools.tools)} tools registered")
|
logger.info("MCP server '{}': connected, {} tools registered", name, len(tools.tools))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"MCP server '{name}': failed to connect: {e}")
|
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Message tool for sending messages to users."""
|
"""Message tool for sending messages to users."""
|
||||||
|
|
||||||
from typing import Any, Callable, Awaitable
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@@ -8,34 +8,42 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
|
|
||||||
class MessageTool(Tool):
|
class MessageTool(Tool):
|
||||||
"""Tool to send messages to users on chat channels."""
|
"""Tool to send messages to users on chat channels."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
||||||
default_channel: str = "",
|
default_channel: str = "",
|
||||||
default_chat_id: str = ""
|
default_chat_id: str = "",
|
||||||
|
default_message_id: str | None = None,
|
||||||
):
|
):
|
||||||
self._send_callback = send_callback
|
self._send_callback = send_callback
|
||||||
self._default_channel = default_channel
|
self._default_channel = default_channel
|
||||||
self._default_chat_id = default_chat_id
|
self._default_chat_id = default_chat_id
|
||||||
|
self._default_message_id = default_message_id
|
||||||
def set_context(self, channel: str, chat_id: str) -> None:
|
self._sent_in_turn: bool = False
|
||||||
|
|
||||||
|
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||||
"""Set the current message context."""
|
"""Set the current message context."""
|
||||||
self._default_channel = channel
|
self._default_channel = channel
|
||||||
self._default_chat_id = chat_id
|
self._default_chat_id = chat_id
|
||||||
|
self._default_message_id = message_id
|
||||||
|
|
||||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||||
"""Set the callback for sending messages."""
|
"""Set the callback for sending messages."""
|
||||||
self._send_callback = callback
|
self._send_callback = callback
|
||||||
|
|
||||||
|
def start_turn(self) -> None:
|
||||||
|
"""Reset per-turn send tracking."""
|
||||||
|
self._sent_in_turn = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "message"
|
return "message"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Send a message to the user. Use this when you want to communicate something."
|
return "Send a message to the user. Use this when you want to communicate something."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@@ -61,33 +69,39 @@ class MessageTool(Tool):
|
|||||||
},
|
},
|
||||||
"required": ["content"]
|
"required": ["content"]
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
channel: str | None = None,
|
channel: str | None = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
media: list[str] | None = None,
|
media: list[str] | None = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> str:
|
) -> str:
|
||||||
channel = channel or self._default_channel
|
channel = channel or self._default_channel
|
||||||
chat_id = chat_id or self._default_chat_id
|
chat_id = chat_id or self._default_chat_id
|
||||||
|
message_id = message_id or self._default_message_id
|
||||||
|
|
||||||
if not channel or not chat_id:
|
if not channel or not chat_id:
|
||||||
return "Error: No target channel/chat specified"
|
return "Error: No target channel/chat specified"
|
||||||
|
|
||||||
if not self._send_callback:
|
if not self._send_callback:
|
||||||
return "Error: Message sending not configured"
|
return "Error: Message sending not configured"
|
||||||
|
|
||||||
msg = OutboundMessage(
|
msg = OutboundMessage(
|
||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
content=content,
|
content=content,
|
||||||
media=media or []
|
media=media or [],
|
||||||
|
metadata={
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_callback(msg)
|
await self._send_callback(msg)
|
||||||
|
self._sent_in_turn = True
|
||||||
media_info = f" with {len(media)} attachments" if media else ""
|
media_info = f" with {len(media)} attachments" if media else ""
|
||||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -36,30 +36,23 @@ class ToolRegistry:
|
|||||||
return [tool.to_schema() for tool in self._tools.values()]
|
return [tool.to_schema() for tool in self._tools.values()]
|
||||||
|
|
||||||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||||||
"""
|
"""Execute a tool by name with given parameters."""
|
||||||
Execute a tool by name with given parameters.
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Tool name.
|
|
||||||
params: Tool parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tool execution result as string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If tool not found.
|
|
||||||
"""
|
|
||||||
tool = self._tools.get(name)
|
tool = self._tools.get(name)
|
||||||
if not tool:
|
if not tool:
|
||||||
return f"Error: Tool '{name}' not found"
|
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
errors = tool.validate_params(params)
|
errors = tool.validate_params(params)
|
||||||
if errors:
|
if errors:
|
||||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||||
return await tool.execute(**params)
|
result = await tool.execute(**params)
|
||||||
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
|
return result + _HINT
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error executing {name}: {str(e)}"
|
return f"Error executing {name}: {str(e)}" + _HINT
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_names(self) -> list[str]:
|
def tool_names(self) -> list[str]:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class ExecTool(Tool):
|
|||||||
deny_patterns: list[str] | None = None,
|
deny_patterns: list[str] | None = None,
|
||||||
allow_patterns: list[str] | None = None,
|
allow_patterns: list[str] | None = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
|
path_append: str = "",
|
||||||
):
|
):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.working_dir = working_dir
|
self.working_dir = working_dir
|
||||||
@@ -35,6 +36,7 @@ class ExecTool(Tool):
|
|||||||
]
|
]
|
||||||
self.allow_patterns = allow_patterns or []
|
self.allow_patterns = allow_patterns or []
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
|
self.path_append = path_append
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -67,12 +69,17 @@ class ExecTool(Tool):
|
|||||||
if guard_error:
|
if guard_error:
|
||||||
return guard_error
|
return guard_error
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
if self.path_append:
|
||||||
|
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = await asyncio.create_subprocess_shell(
|
process = await asyncio.create_subprocess_shell(
|
||||||
command,
|
command,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -9,22 +9,19 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class SpawnTool(Tool):
|
class SpawnTool(Tool):
|
||||||
"""
|
"""Tool to spawn a subagent for background task execution."""
|
||||||
Tool to spawn a subagent for background task execution.
|
|
||||||
|
|
||||||
The subagent runs asynchronously and announces its result back
|
|
||||||
to the main agent when complete.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, manager: "SubagentManager"):
|
def __init__(self, manager: "SubagentManager"):
|
||||||
self._manager = manager
|
self._manager = manager
|
||||||
self._origin_channel = "cli"
|
self._origin_channel = "cli"
|
||||||
self._origin_chat_id = "direct"
|
self._origin_chat_id = "direct"
|
||||||
|
self._session_key = "cli:direct"
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str) -> None:
|
def set_context(self, channel: str, chat_id: str) -> None:
|
||||||
"""Set the origin context for subagent announcements."""
|
"""Set the origin context for subagent announcements."""
|
||||||
self._origin_channel = channel
|
self._origin_channel = channel
|
||||||
self._origin_chat_id = chat_id
|
self._origin_chat_id = chat_id
|
||||||
|
self._session_key = f"{channel}:{chat_id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -62,4 +59,5 @@ class SpawnTool(Tool):
|
|||||||
label=label,
|
label=label,
|
||||||
origin_channel=self._origin_channel,
|
origin_channel=self._origin_channel,
|
||||||
origin_chat_id=self._origin_chat_id,
|
origin_chat_id=self._origin_chat_id,
|
||||||
|
session_key=self._session_key,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -58,12 +58,21 @@ class WebSearchTool(Tool):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
||||||
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "")
|
self._init_api_key = api_key
|
||||||
self.max_results = max_results
|
self.max_results = max_results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_key(self) -> str:
|
||||||
|
"""Resolve API key at call time so env/config changes are picked up."""
|
||||||
|
return self._init_api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||||
|
|
||||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
return "Error: BRAVE_API_KEY not configured"
|
return (
|
||||||
|
"Error: Brave Search API key not configured. "
|
||||||
|
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey "
|
||||||
|
"(or export BRAVE_API_KEY), then restart the gateway."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
n = min(max(count or self.max_results, 1), 10)
|
n = min(max(count or self.max_results, 1), 10)
|
||||||
@@ -71,7 +80,7 @@ class WebSearchTool(Tool):
|
|||||||
r = await client.get(
|
r = await client.get(
|
||||||
"https://api.search.brave.com/res/v1/web/search",
|
"https://api.search.brave.com/res/v1/web/search",
|
||||||
params={"q": query, "count": n},
|
params={"q": query, "count": n},
|
||||||
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
headers={"Accept": "application/json", "X-Subscription-Token": api_key},
|
||||||
timeout=10.0
|
timeout=10.0
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
@@ -116,7 +125,7 @@ class WebFetchTool(Tool):
|
|||||||
# Validate URL before fetching
|
# Validate URL before fetching
|
||||||
is_valid, error_msg = _validate_url(url)
|
is_valid, error_msg = _validate_url(url)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url})
|
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
@@ -131,7 +140,7 @@ class WebFetchTool(Tool):
|
|||||||
|
|
||||||
# JSON
|
# JSON
|
||||||
if "application/json" in ctype:
|
if "application/json" in ctype:
|
||||||
text, extractor = json.dumps(r.json(), indent=2), "json"
|
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
||||||
# HTML
|
# HTML
|
||||||
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
||||||
doc = Document(r.text)
|
doc = Document(r.text)
|
||||||
@@ -146,9 +155,9 @@ class WebFetchTool(Tool):
|
|||||||
text = text[:max_chars]
|
text = text[:max_chars]
|
||||||
|
|
||||||
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
||||||
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text})
|
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return json.dumps({"error": str(e), "url": url})
|
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||||
|
|
||||||
def _to_markdown(self, html: str) -> str:
|
def _to_markdown(self, html: str) -> str:
|
||||||
"""Convert HTML to markdown."""
|
"""Convert HTML to markdown."""
|
||||||
|
|||||||
@@ -16,11 +16,12 @@ class InboundMessage:
|
|||||||
timestamp: datetime = field(default_factory=datetime.now)
|
timestamp: datetime = field(default_factory=datetime.now)
|
||||||
media: list[str] = field(default_factory=list) # Media URLs
|
media: list[str] = field(default_factory=list) # Media URLs
|
||||||
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
||||||
|
session_key_override: str | None = None # Optional override for thread-scoped sessions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session_key(self) -> str:
|
def session_key(self) -> str:
|
||||||
"""Unique key for session identification."""
|
"""Unique key for session identification."""
|
||||||
return f"{self.channel}:{self.chat_id}"
|
return self.session_key_override or f"{self.channel}:{self.chat_id}"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
"""Async message queue for decoupled channel-agent communication."""
|
"""Async message queue for decoupled channel-agent communication."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Callable, Awaitable
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
@@ -11,70 +8,36 @@ from nanobot.bus.events import InboundMessage, OutboundMessage
|
|||||||
class MessageBus:
|
class MessageBus:
|
||||||
"""
|
"""
|
||||||
Async message bus that decouples chat channels from the agent core.
|
Async message bus that decouples chat channels from the agent core.
|
||||||
|
|
||||||
Channels push messages to the inbound queue, and the agent processes
|
Channels push messages to the inbound queue, and the agent processes
|
||||||
them and pushes responses to the outbound queue.
|
them and pushes responses to the outbound queue.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||||
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
||||||
self._outbound_subscribers: dict[str, list[Callable[[OutboundMessage], Awaitable[None]]]] = {}
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
async def publish_inbound(self, msg: InboundMessage) -> None:
|
async def publish_inbound(self, msg: InboundMessage) -> None:
|
||||||
"""Publish a message from a channel to the agent."""
|
"""Publish a message from a channel to the agent."""
|
||||||
await self.inbound.put(msg)
|
await self.inbound.put(msg)
|
||||||
|
|
||||||
async def consume_inbound(self) -> InboundMessage:
|
async def consume_inbound(self) -> InboundMessage:
|
||||||
"""Consume the next inbound message (blocks until available)."""
|
"""Consume the next inbound message (blocks until available)."""
|
||||||
return await self.inbound.get()
|
return await self.inbound.get()
|
||||||
|
|
||||||
async def publish_outbound(self, msg: OutboundMessage) -> None:
|
async def publish_outbound(self, msg: OutboundMessage) -> None:
|
||||||
"""Publish a response from the agent to channels."""
|
"""Publish a response from the agent to channels."""
|
||||||
await self.outbound.put(msg)
|
await self.outbound.put(msg)
|
||||||
|
|
||||||
async def consume_outbound(self) -> OutboundMessage:
|
async def consume_outbound(self) -> OutboundMessage:
|
||||||
"""Consume the next outbound message (blocks until available)."""
|
"""Consume the next outbound message (blocks until available)."""
|
||||||
return await self.outbound.get()
|
return await self.outbound.get()
|
||||||
|
|
||||||
def subscribe_outbound(
|
|
||||||
self,
|
|
||||||
channel: str,
|
|
||||||
callback: Callable[[OutboundMessage], Awaitable[None]]
|
|
||||||
) -> None:
|
|
||||||
"""Subscribe to outbound messages for a specific channel."""
|
|
||||||
if channel not in self._outbound_subscribers:
|
|
||||||
self._outbound_subscribers[channel] = []
|
|
||||||
self._outbound_subscribers[channel].append(callback)
|
|
||||||
|
|
||||||
async def dispatch_outbound(self) -> None:
|
|
||||||
"""
|
|
||||||
Dispatch outbound messages to subscribed channels.
|
|
||||||
Run this as a background task.
|
|
||||||
"""
|
|
||||||
self._running = True
|
|
||||||
while self._running:
|
|
||||||
try:
|
|
||||||
msg = await asyncio.wait_for(self.outbound.get(), timeout=1.0)
|
|
||||||
subscribers = self._outbound_subscribers.get(msg.channel, [])
|
|
||||||
for callback in subscribers:
|
|
||||||
try:
|
|
||||||
await callback(msg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error dispatching to {msg.channel}: {e}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the dispatcher loop."""
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inbound_size(self) -> int:
|
def inbound_size(self) -> int:
|
||||||
"""Number of pending inbound messages."""
|
"""Number of pending inbound messages."""
|
||||||
return self.inbound.qsize()
|
return self.inbound.qsize()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outbound_size(self) -> int:
|
def outbound_size(self) -> int:
|
||||||
"""Number of pending outbound messages."""
|
"""Number of pending outbound messages."""
|
||||||
|
|||||||
@@ -89,7 +89,8 @@ class BaseChannel(ABC):
|
|||||||
chat_id: str,
|
chat_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
media: list[str] | None = None,
|
media: list[str] | None = None,
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None,
|
||||||
|
session_key: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Handle an incoming message from the chat platform.
|
Handle an incoming message from the chat platform.
|
||||||
@@ -102,11 +103,13 @@ class BaseChannel(ABC):
|
|||||||
content: Message text content.
|
content: Message text content.
|
||||||
media: Optional list of media URLs.
|
media: Optional list of media URLs.
|
||||||
metadata: Optional channel-specific metadata.
|
metadata: Optional channel-specific metadata.
|
||||||
|
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||||
"""
|
"""
|
||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Access denied for sender {sender_id} on channel {self.name}. "
|
"Access denied for sender {} on channel {}. "
|
||||||
f"Add them to allowFrom list in config to grant access."
|
"Add them to allowFrom list in config to grant access.",
|
||||||
|
sender_id, self.name,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -116,7 +119,8 @@ class BaseChannel(ABC):
|
|||||||
chat_id=str(chat_id),
|
chat_id=str(chat_id),
|
||||||
content=content,
|
content=content,
|
||||||
media=media or [],
|
media=media or [],
|
||||||
metadata=metadata or {}
|
metadata=metadata or {},
|
||||||
|
session_key_override=session_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
|
|||||||
@@ -58,14 +58,15 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Received empty or unsupported message type: {chatbot_msg.message_type}"
|
"Received empty or unsupported message type: {}",
|
||||||
|
chatbot_msg.message_type,
|
||||||
)
|
)
|
||||||
return AckMessage.STATUS_OK, "OK"
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||||
|
|
||||||
logger.info(f"Received DingTalk message from {sender_name} ({sender_id}): {content}")
|
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||||
|
|
||||||
# Forward to Nanobot via _on_message (non-blocking).
|
# Forward to Nanobot via _on_message (non-blocking).
|
||||||
# Store reference to prevent GC before task completes.
|
# Store reference to prevent GC before task completes.
|
||||||
@@ -78,7 +79,7 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
return AckMessage.STATUS_OK, "OK"
|
return AckMessage.STATUS_OK, "OK"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing DingTalk message: {e}")
|
logger.error("Error processing DingTalk message: {}", e)
|
||||||
# Return OK to avoid retry loop from DingTalk server
|
# Return OK to avoid retry loop from DingTalk server
|
||||||
return AckMessage.STATUS_OK, "Error"
|
return AckMessage.STATUS_OK, "Error"
|
||||||
|
|
||||||
@@ -126,7 +127,8 @@ class DingTalkChannel(BaseChannel):
|
|||||||
self._http = httpx.AsyncClient()
|
self._http = httpx.AsyncClient()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initializing DingTalk Stream Client with Client ID: {self.config.client_id}..."
|
"Initializing DingTalk Stream Client with Client ID: {}...",
|
||||||
|
self.config.client_id,
|
||||||
)
|
)
|
||||||
credential = Credential(self.config.client_id, self.config.client_secret)
|
credential = Credential(self.config.client_id, self.config.client_secret)
|
||||||
self._client = DingTalkStreamClient(credential)
|
self._client = DingTalkStreamClient(credential)
|
||||||
@@ -142,13 +144,13 @@ class DingTalkChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._client.start()
|
await self._client.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"DingTalk stream error: {e}")
|
logger.warning("DingTalk stream error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to start DingTalk channel: {e}")
|
logger.exception("Failed to start DingTalk channel: {}", e)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the DingTalk bot."""
|
"""Stop the DingTalk bot."""
|
||||||
@@ -186,7 +188,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
||||||
return self._access_token
|
return self._access_token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get DingTalk access token: {e}")
|
logger.error("Failed to get DingTalk access token: {}", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
@@ -208,7 +210,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
"msgParam": json.dumps({
|
"msgParam": json.dumps({
|
||||||
"text": msg.content,
|
"text": msg.content,
|
||||||
"title": "Nanobot Reply",
|
"title": "Nanobot Reply",
|
||||||
}),
|
}, ensure_ascii=False),
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self._http:
|
if not self._http:
|
||||||
@@ -218,11 +220,11 @@ class DingTalkChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
resp = await self._http.post(url, json=data, headers=headers)
|
resp = await self._http.post(url, json=data, headers=headers)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
logger.error(f"DingTalk send failed: {resp.text}")
|
logger.error("DingTalk send failed: {}", resp.text)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"DingTalk message sent to {msg.chat_id}")
|
logger.debug("DingTalk message sent to {}", msg.chat_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending DingTalk message: {e}")
|
logger.error("Error sending DingTalk message: {}", e)
|
||||||
|
|
||||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
||||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||||
@@ -231,7 +233,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
permission checks before publishing to the bus.
|
permission checks before publishing to the bus.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"DingTalk inbound: {content} from {sender_name}")
|
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
chat_id=sender_id, # For private chat, chat_id == sender_id
|
chat_id=sender_id, # For private chat, chat_id == sender_id
|
||||||
@@ -242,4 +244,4 @@ class DingTalkChannel(BaseChannel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error publishing DingTalk message: {e}")
|
logger.error("Error publishing DingTalk message: {}", e)
|
||||||
|
|||||||
@@ -17,6 +17,29 @@ from nanobot.config.schema import DiscordConfig
|
|||||||
|
|
||||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||||
|
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||||
|
|
||||||
|
|
||||||
|
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
|
||||||
|
"""Split content into chunks within max_len, preferring line breaks."""
|
||||||
|
if not content:
|
||||||
|
return []
|
||||||
|
if len(content) <= max_len:
|
||||||
|
return [content]
|
||||||
|
chunks: list[str] = []
|
||||||
|
while content:
|
||||||
|
if len(content) <= max_len:
|
||||||
|
chunks.append(content)
|
||||||
|
break
|
||||||
|
cut = content[:max_len]
|
||||||
|
pos = cut.rfind('\n')
|
||||||
|
if pos <= 0:
|
||||||
|
pos = cut.rfind(' ')
|
||||||
|
if pos <= 0:
|
||||||
|
pos = max_len
|
||||||
|
chunks.append(content[:pos])
|
||||||
|
content = content[pos:].lstrip()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(BaseChannel):
|
class DiscordChannel(BaseChannel):
|
||||||
@@ -51,7 +74,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Discord gateway error: {e}")
|
logger.warning("Discord gateway error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
@@ -79,34 +102,48 @@ class DiscordChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||||
payload: dict[str, Any] = {"content": msg.content}
|
|
||||||
|
|
||||||
if msg.reply_to:
|
|
||||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
|
||||||
payload["allowed_mentions"] = {"replied_user": False}
|
|
||||||
|
|
||||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for attempt in range(3):
|
chunks = _split_message(msg.content or "")
|
||||||
try:
|
if not chunks:
|
||||||
response = await self._http.post(url, headers=headers, json=payload)
|
return
|
||||||
if response.status_code == 429:
|
|
||||||
data = response.json()
|
for i, chunk in enumerate(chunks):
|
||||||
retry_after = float(data.get("retry_after", 1.0))
|
payload: dict[str, Any] = {"content": chunk}
|
||||||
logger.warning(f"Discord rate limited, retrying in {retry_after}s")
|
|
||||||
await asyncio.sleep(retry_after)
|
# Only set reply reference on the first chunk
|
||||||
continue
|
if i == 0 and msg.reply_to:
|
||||||
response.raise_for_status()
|
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||||
return
|
payload["allowed_mentions"] = {"replied_user": False}
|
||||||
except Exception as e:
|
|
||||||
if attempt == 2:
|
if not await self._send_payload(url, headers, payload):
|
||||||
logger.error(f"Error sending Discord message: {e}")
|
break # Abort remaining chunks on failure
|
||||||
else:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
finally:
|
finally:
|
||||||
await self._stop_typing(msg.chat_id)
|
await self._stop_typing(msg.chat_id)
|
||||||
|
|
||||||
|
async def _send_payload(
|
||||||
|
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||||
|
for attempt in range(3):
|
||||||
|
try:
|
||||||
|
response = await self._http.post(url, headers=headers, json=payload)
|
||||||
|
if response.status_code == 429:
|
||||||
|
data = response.json()
|
||||||
|
retry_after = float(data.get("retry_after", 1.0))
|
||||||
|
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||||
|
await asyncio.sleep(retry_after)
|
||||||
|
continue
|
||||||
|
response.raise_for_status()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
if attempt == 2:
|
||||||
|
logger.error("Error sending Discord message: {}", e)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
return False
|
||||||
|
|
||||||
async def _gateway_loop(self) -> None:
|
async def _gateway_loop(self) -> None:
|
||||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||||
if not self._ws:
|
if not self._ws:
|
||||||
@@ -116,7 +153,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
data = json.loads(raw)
|
data = json.loads(raw)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Invalid JSON from Discord gateway: {raw[:100]}")
|
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
op = data.get("op")
|
op = data.get("op")
|
||||||
@@ -175,7 +212,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._ws.send(json.dumps(payload))
|
await self._ws.send(json.dumps(payload))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Discord heartbeat failed: {e}")
|
logger.warning("Discord heartbeat failed: {}", e)
|
||||||
break
|
break
|
||||||
await asyncio.sleep(interval_s)
|
await asyncio.sleep(interval_s)
|
||||||
|
|
||||||
@@ -219,7 +256,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
media_paths.append(str(file_path))
|
media_paths.append(str(file_path))
|
||||||
content_parts.append(f"[attachment: {file_path}]")
|
content_parts.append(f"[attachment: {file_path}]")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to download Discord attachment: {e}")
|
logger.warning("Failed to download Discord attachment: {}", e)
|
||||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||||
|
|
||||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||||
@@ -248,8 +285,11 @@ class DiscordChannel(BaseChannel):
|
|||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
await self._http.post(url, headers=headers)
|
await self._http.post(url, headers=headers)
|
||||||
except Exception:
|
except asyncio.CancelledError:
|
||||||
pass
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||||
|
return
|
||||||
await asyncio.sleep(8)
|
await asyncio.sleep(8)
|
||||||
|
|
||||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class EmailChannel(BaseChannel):
|
|||||||
metadata=item.get("metadata", {}),
|
metadata=item.get("metadata", {}),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Email polling error: {e}")
|
logger.error("Email polling error: {}", e)
|
||||||
|
|
||||||
await asyncio.sleep(poll_seconds)
|
await asyncio.sleep(poll_seconds)
|
||||||
|
|
||||||
@@ -108,11 +108,6 @@ class EmailChannel(BaseChannel):
|
|||||||
logger.warning("Skip email send: consent_granted is false")
|
logger.warning("Skip email send: consent_granted is false")
|
||||||
return
|
return
|
||||||
|
|
||||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
|
||||||
if not self.config.auto_reply_enabled and not force_send:
|
|
||||||
logger.info("Skip automatic email reply: auto_reply_enabled is false")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.config.smtp_host:
|
if not self.config.smtp_host:
|
||||||
logger.warning("Email channel SMTP host not configured")
|
logger.warning("Email channel SMTP host not configured")
|
||||||
return
|
return
|
||||||
@@ -122,6 +117,15 @@ class EmailChannel(BaseChannel):
|
|||||||
logger.warning("Email channel missing recipient address")
|
logger.warning("Email channel missing recipient address")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Determine if this is a reply (recipient has sent us an email before)
|
||||||
|
is_reply = to_addr in self._last_subject_by_chat
|
||||||
|
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||||
|
|
||||||
|
# autoReplyEnabled only controls automatic replies, not proactive sends
|
||||||
|
if is_reply and not self.config.auto_reply_enabled and not force_send:
|
||||||
|
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
|
||||||
|
return
|
||||||
|
|
||||||
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
||||||
subject = self._reply_subject(base_subject)
|
subject = self._reply_subject(base_subject)
|
||||||
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
||||||
@@ -143,7 +147,7 @@ class EmailChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await asyncio.to_thread(self._smtp_send, email_msg)
|
await asyncio.to_thread(self._smtp_send, email_msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending email to {to_addr}: {e}")
|
logger.error("Error sending email to {}: {}", to_addr, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _validate_config(self) -> bool:
|
def _validate_config(self) -> bool:
|
||||||
@@ -162,7 +166,7 @@ class EmailChannel(BaseChannel):
|
|||||||
missing.append("smtp_password")
|
missing.append("smtp_password")
|
||||||
|
|
||||||
if missing:
|
if missing:
|
||||||
logger.error(f"Email channel not configured, missing: {', '.join(missing)}")
|
logger.error("Email channel not configured, missing: {}", ', '.join(missing))
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -304,7 +308,8 @@ class EmailChannel(BaseChannel):
|
|||||||
self._processed_uids.add(uid)
|
self._processed_uids.add(uid)
|
||||||
# mark_seen is the primary dedup; this set is a safety net
|
# mark_seen is the primary dedup; this set is a safety net
|
||||||
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
||||||
self._processed_uids.clear()
|
# Evict a random half to cap memory; mark_seen is the primary dedup
|
||||||
|
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
|
||||||
|
|
||||||
if mark_seen:
|
if mark_seen:
|
||||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -27,6 +28,8 @@ try:
|
|||||||
CreateMessageReactionRequest,
|
CreateMessageReactionRequest,
|
||||||
CreateMessageReactionRequestBody,
|
CreateMessageReactionRequestBody,
|
||||||
Emoji,
|
Emoji,
|
||||||
|
GetFileRequest,
|
||||||
|
GetMessageResourceRequest,
|
||||||
P2ImMessageReceiveV1,
|
P2ImMessageReceiveV1,
|
||||||
)
|
)
|
||||||
FEISHU_AVAILABLE = True
|
FEISHU_AVAILABLE = True
|
||||||
@@ -44,21 +47,158 @@ MSG_TYPE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _extract_post_text(content_json: dict) -> str:
|
def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
|
||||||
"""Extract plain text from Feishu post (rich text) message content.
|
"""Extract text representation from share cards and interactive messages."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if msg_type == "share_chat":
|
||||||
|
parts.append(f"[shared chat: {content_json.get('chat_id', '')}]")
|
||||||
|
elif msg_type == "share_user":
|
||||||
|
parts.append(f"[shared user: {content_json.get('user_id', '')}]")
|
||||||
|
elif msg_type == "interactive":
|
||||||
|
parts.extend(_extract_interactive_content(content_json))
|
||||||
|
elif msg_type == "share_calendar_event":
|
||||||
|
parts.append(f"[shared calendar event: {content_json.get('event_key', '')}]")
|
||||||
|
elif msg_type == "system":
|
||||||
|
parts.append("[system message]")
|
||||||
|
elif msg_type == "merge_forward":
|
||||||
|
parts.append("[merged forward messages]")
|
||||||
|
|
||||||
|
return "\n".join(parts) if parts else f"[{msg_type}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_interactive_content(content: dict) -> list[str]:
|
||||||
|
"""Recursively extract text and links from interactive card content."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
content = json.loads(content)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return [content] if content.strip() else []
|
||||||
|
|
||||||
|
if not isinstance(content, dict):
|
||||||
|
return parts
|
||||||
|
|
||||||
|
if "title" in content:
|
||||||
|
title = content["title"]
|
||||||
|
if isinstance(title, dict):
|
||||||
|
title_content = title.get("content", "") or title.get("text", "")
|
||||||
|
if title_content:
|
||||||
|
parts.append(f"title: {title_content}")
|
||||||
|
elif isinstance(title, str):
|
||||||
|
parts.append(f"title: {title}")
|
||||||
|
|
||||||
|
for element in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
||||||
|
parts.extend(_extract_element_content(element))
|
||||||
|
|
||||||
|
card = content.get("card", {})
|
||||||
|
if card:
|
||||||
|
parts.extend(_extract_interactive_content(card))
|
||||||
|
|
||||||
|
header = content.get("header", {})
|
||||||
|
if header:
|
||||||
|
header_title = header.get("title", {})
|
||||||
|
if isinstance(header_title, dict):
|
||||||
|
header_text = header_title.get("content", "") or header_title.get("text", "")
|
||||||
|
if header_text:
|
||||||
|
parts.append(f"title: {header_text}")
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_element_content(element: dict) -> list[str]:
|
||||||
|
"""Extract content from a single card element."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if not isinstance(element, dict):
|
||||||
|
return parts
|
||||||
|
|
||||||
|
tag = element.get("tag", "")
|
||||||
|
|
||||||
|
if tag in ("markdown", "lark_md"):
|
||||||
|
content = element.get("content", "")
|
||||||
|
if content:
|
||||||
|
parts.append(content)
|
||||||
|
|
||||||
|
elif tag == "div":
|
||||||
|
text = element.get("text", {})
|
||||||
|
if isinstance(text, dict):
|
||||||
|
text_content = text.get("content", "") or text.get("text", "")
|
||||||
|
if text_content:
|
||||||
|
parts.append(text_content)
|
||||||
|
elif isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
for field in element.get("fields", []):
|
||||||
|
if isinstance(field, dict):
|
||||||
|
field_text = field.get("text", {})
|
||||||
|
if isinstance(field_text, dict):
|
||||||
|
c = field_text.get("content", "")
|
||||||
|
if c:
|
||||||
|
parts.append(c)
|
||||||
|
|
||||||
|
elif tag == "a":
|
||||||
|
href = element.get("href", "")
|
||||||
|
text = element.get("text", "")
|
||||||
|
if href:
|
||||||
|
parts.append(f"link: {href}")
|
||||||
|
if text:
|
||||||
|
parts.append(text)
|
||||||
|
|
||||||
|
elif tag == "button":
|
||||||
|
text = element.get("text", {})
|
||||||
|
if isinstance(text, dict):
|
||||||
|
c = text.get("content", "")
|
||||||
|
if c:
|
||||||
|
parts.append(c)
|
||||||
|
url = element.get("url", "") or element.get("multi_url", {}).get("url", "")
|
||||||
|
if url:
|
||||||
|
parts.append(f"link: {url}")
|
||||||
|
|
||||||
|
elif tag == "img":
|
||||||
|
alt = element.get("alt", {})
|
||||||
|
parts.append(alt.get("content", "[image]") if isinstance(alt, dict) else "[image]")
|
||||||
|
|
||||||
|
elif tag == "note":
|
||||||
|
for ne in element.get("elements", []):
|
||||||
|
parts.extend(_extract_element_content(ne))
|
||||||
|
|
||||||
|
elif tag == "column_set":
|
||||||
|
for col in element.get("columns", []):
|
||||||
|
for ce in col.get("elements", []):
|
||||||
|
parts.extend(_extract_element_content(ce))
|
||||||
|
|
||||||
|
elif tag == "plain_text":
|
||||||
|
content = element.get("content", "")
|
||||||
|
if content:
|
||||||
|
parts.append(content)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for ne in element.get("elements", []):
|
||||||
|
parts.extend(_extract_element_content(ne))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||||
|
"""Extract text and image keys from Feishu post (rich text) message content.
|
||||||
|
|
||||||
Supports two formats:
|
Supports two formats:
|
||||||
1. Direct format: {"title": "...", "content": [...]}
|
1. Direct format: {"title": "...", "content": [...]}
|
||||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(text, image_keys) - extracted text and list of image keys
|
||||||
"""
|
"""
|
||||||
def extract_from_lang(lang_content: dict) -> str | None:
|
def extract_from_lang(lang_content: dict) -> tuple[str | None, list[str]]:
|
||||||
if not isinstance(lang_content, dict):
|
if not isinstance(lang_content, dict):
|
||||||
return None
|
return None, []
|
||||||
title = lang_content.get("title", "")
|
title = lang_content.get("title", "")
|
||||||
content_blocks = lang_content.get("content", [])
|
content_blocks = lang_content.get("content", [])
|
||||||
if not isinstance(content_blocks, list):
|
if not isinstance(content_blocks, list):
|
||||||
return None
|
return None, []
|
||||||
text_parts = []
|
text_parts = []
|
||||||
|
image_keys = []
|
||||||
if title:
|
if title:
|
||||||
text_parts.append(title)
|
text_parts.append(title)
|
||||||
for block in content_blocks:
|
for block in content_blocks:
|
||||||
@@ -73,22 +213,36 @@ def _extract_post_text(content_json: dict) -> str:
|
|||||||
text_parts.append(element.get("text", ""))
|
text_parts.append(element.get("text", ""))
|
||||||
elif tag == "at":
|
elif tag == "at":
|
||||||
text_parts.append(f"@{element.get('user_name', 'user')}")
|
text_parts.append(f"@{element.get('user_name', 'user')}")
|
||||||
return " ".join(text_parts).strip() if text_parts else None
|
elif tag == "img":
|
||||||
|
img_key = element.get("image_key")
|
||||||
|
if img_key:
|
||||||
|
image_keys.append(img_key)
|
||||||
|
text = " ".join(text_parts).strip() if text_parts else None
|
||||||
|
return text, image_keys
|
||||||
|
|
||||||
# Try direct format first
|
# Try direct format first
|
||||||
if "content" in content_json:
|
if "content" in content_json:
|
||||||
result = extract_from_lang(content_json)
|
text, images = extract_from_lang(content_json)
|
||||||
if result:
|
if text or images:
|
||||||
return result
|
return text or "", images
|
||||||
|
|
||||||
# Try localized format
|
# Try localized format
|
||||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||||
lang_content = content_json.get(lang_key)
|
lang_content = content_json.get(lang_key)
|
||||||
result = extract_from_lang(lang_content)
|
text, images = extract_from_lang(lang_content)
|
||||||
if result:
|
if text or images:
|
||||||
return result
|
return text or "", images
|
||||||
|
|
||||||
return ""
|
return "", []
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_post_text(content_json: dict) -> str:
|
||||||
|
"""Extract plain text from Feishu post (rich text) message content.
|
||||||
|
|
||||||
|
Legacy wrapper for _extract_post_content, returns only text.
|
||||||
|
"""
|
||||||
|
text, _ = _extract_post_content(content_json)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class FeishuChannel(BaseChannel):
|
class FeishuChannel(BaseChannel):
|
||||||
@@ -156,7 +310,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
self._ws_client.start()
|
self._ws_client.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Feishu WebSocket error: {e}")
|
logger.warning("Feishu WebSocket error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
import time; time.sleep(5)
|
import time; time.sleep(5)
|
||||||
|
|
||||||
@@ -177,7 +331,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
self._ws_client.stop()
|
self._ws_client.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error stopping WebSocket client: {e}")
|
logger.warning("Error stopping WebSocket client: {}", e)
|
||||||
logger.info("Feishu bot stopped")
|
logger.info("Feishu bot stopped")
|
||||||
|
|
||||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||||
@@ -194,11 +348,11 @@ class FeishuChannel(BaseChannel):
|
|||||||
response = self._client.im.v1.message_reaction.create(request)
|
response = self._client.im.v1.message_reaction.create(request)
|
||||||
|
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.warning(f"Failed to add reaction: code={response.code}, msg={response.msg}")
|
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Added {emoji_type} reaction to message {message_id}")
|
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error adding reaction: {e}")
|
logger.warning("Error adding reaction: {}", e)
|
||||||
|
|
||||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
||||||
"""
|
"""
|
||||||
@@ -309,13 +463,13 @@ class FeishuChannel(BaseChannel):
|
|||||||
response = self._client.im.v1.image.create(request)
|
response = self._client.im.v1.image.create(request)
|
||||||
if response.success():
|
if response.success():
|
||||||
image_key = response.data.image_key
|
image_key = response.data.image_key
|
||||||
logger.debug(f"Uploaded image {os.path.basename(file_path)}: {image_key}")
|
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
||||||
return image_key
|
return image_key
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to upload image: code={response.code}, msg={response.msg}")
|
logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error uploading image {file_path}: {e}")
|
logger.error("Error uploading image {}: {}", file_path, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||||
@@ -336,15 +490,107 @@ class FeishuChannel(BaseChannel):
|
|||||||
response = self._client.im.v1.file.create(request)
|
response = self._client.im.v1.file.create(request)
|
||||||
if response.success():
|
if response.success():
|
||||||
file_key = response.data.file_key
|
file_key = response.data.file_key
|
||||||
logger.debug(f"Uploaded file {file_name}: {file_key}")
|
logger.debug("Uploaded file {}: {}", file_name, file_key)
|
||||||
return file_key
|
return file_key
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to upload file: code={response.code}, msg={response.msg}")
|
logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error uploading file {file_path}: {e}")
|
logger.error("Error uploading file {}: {}", file_path, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
||||||
|
"""Download an image from Feishu message by message_id and image_key."""
|
||||||
|
try:
|
||||||
|
request = GetMessageResourceRequest.builder() \
|
||||||
|
.message_id(message_id) \
|
||||||
|
.file_key(image_key) \
|
||||||
|
.type("image") \
|
||||||
|
.build()
|
||||||
|
response = self._client.im.v1.message_resource.get(request)
|
||||||
|
if response.success():
|
||||||
|
file_data = response.file
|
||||||
|
# GetMessageResourceRequest returns BytesIO, need to read bytes
|
||||||
|
if hasattr(file_data, 'read'):
|
||||||
|
file_data = file_data.read()
|
||||||
|
return file_data, response.file_name
|
||||||
|
else:
|
||||||
|
logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error downloading image {}: {}", image_key, e)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _download_file_sync(
|
||||||
|
self, message_id: str, file_key: str, resource_type: str = "file"
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
||||||
|
try:
|
||||||
|
request = (
|
||||||
|
GetMessageResourceRequest.builder()
|
||||||
|
.message_id(message_id)
|
||||||
|
.file_key(file_key)
|
||||||
|
.type(resource_type)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
response = self._client.im.v1.message_resource.get(request)
|
||||||
|
if response.success():
|
||||||
|
file_data = response.file
|
||||||
|
if hasattr(file_data, "read"):
|
||||||
|
file_data = file_data.read()
|
||||||
|
return file_data, response.file_name
|
||||||
|
else:
|
||||||
|
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
|
||||||
|
return None, None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _download_and_save_media(
|
||||||
|
self,
|
||||||
|
msg_type: str,
|
||||||
|
content_json: dict,
|
||||||
|
message_id: str | None = None
|
||||||
|
) -> tuple[str | None, str]:
|
||||||
|
"""
|
||||||
|
Download media from Feishu and save to local disk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(file_path, content_text) - file_path is None if download failed
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
media_dir = Path.home() / ".nanobot" / "media"
|
||||||
|
media_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
data, filename = None, None
|
||||||
|
|
||||||
|
if msg_type == "image":
|
||||||
|
image_key = content_json.get("image_key")
|
||||||
|
if image_key and message_id:
|
||||||
|
data, filename = await loop.run_in_executor(
|
||||||
|
None, self._download_image_sync, message_id, image_key
|
||||||
|
)
|
||||||
|
if not filename:
|
||||||
|
filename = f"{image_key[:16]}.jpg"
|
||||||
|
|
||||||
|
elif msg_type in ("audio", "file", "media"):
|
||||||
|
file_key = content_json.get("file_key")
|
||||||
|
if file_key and message_id:
|
||||||
|
data, filename = await loop.run_in_executor(
|
||||||
|
None, self._download_file_sync, message_id, file_key, msg_type
|
||||||
|
)
|
||||||
|
if not filename:
|
||||||
|
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
|
||||||
|
filename = f"{file_key[:16]}{ext}"
|
||||||
|
|
||||||
|
if data and filename:
|
||||||
|
file_path = media_dir / filename
|
||||||
|
file_path.write_bytes(data)
|
||||||
|
logger.debug("Downloaded {} to {}", msg_type, file_path)
|
||||||
|
return str(file_path), f"[{msg_type}: {filename}]"
|
||||||
|
|
||||||
|
return None, f"[{msg_type}: download failed]"
|
||||||
|
|
||||||
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
||||||
"""Send a single message (text/image/file/interactive) synchronously."""
|
"""Send a single message (text/image/file/interactive) synchronously."""
|
||||||
try:
|
try:
|
||||||
@@ -360,14 +606,14 @@ class FeishuChannel(BaseChannel):
|
|||||||
response = self._client.im.v1.message.create(request)
|
response = self._client.im.v1.message.create(request)
|
||||||
if not response.success():
|
if not response.success():
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to send Feishu {msg_type} message: code={response.code}, "
|
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
|
||||||
f"msg={response.msg}, log_id={response.get_log_id()}"
|
msg_type, response.code, response.msg, response.get_log_id()
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
logger.debug(f"Feishu {msg_type} message sent to {receive_id}")
|
logger.debug("Feishu {} message sent to {}", msg_type, receive_id)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending Feishu {msg_type} message: {e}")
|
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
@@ -382,7 +628,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
for file_path in msg.media:
|
for file_path in msg.media:
|
||||||
if not os.path.isfile(file_path):
|
if not os.path.isfile(file_path):
|
||||||
logger.warning(f"Media file not found: {file_path}")
|
logger.warning("Media file not found: {}", file_path)
|
||||||
continue
|
continue
|
||||||
ext = os.path.splitext(file_path)[1].lower()
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
if ext in self._IMAGE_EXTS:
|
if ext in self._IMAGE_EXTS:
|
||||||
@@ -390,7 +636,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
if key:
|
if key:
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None, self._send_message_sync,
|
None, self._send_message_sync,
|
||||||
receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}),
|
receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}, ensure_ascii=False),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||||
@@ -398,7 +644,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
media_type = "audio" if ext in self._AUDIO_EXTS else "file"
|
media_type = "audio" if ext in self._AUDIO_EXTS else "file"
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None, self._send_message_sync,
|
None, self._send_message_sync,
|
||||||
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}),
|
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
if msg.content and msg.content.strip():
|
if msg.content and msg.content.strip():
|
||||||
@@ -409,7 +655,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending Feishu message: {e}")
|
logger.error("Error sending Feishu message: {}", e)
|
||||||
|
|
||||||
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
||||||
"""
|
"""
|
||||||
@@ -425,60 +671,89 @@ class FeishuChannel(BaseChannel):
|
|||||||
event = data.event
|
event = data.event
|
||||||
message = event.message
|
message = event.message
|
||||||
sender = event.sender
|
sender = event.sender
|
||||||
|
|
||||||
# Deduplication check
|
# Deduplication check
|
||||||
message_id = message.message_id
|
message_id = message.message_id
|
||||||
if message_id in self._processed_message_ids:
|
if message_id in self._processed_message_ids:
|
||||||
return
|
return
|
||||||
self._processed_message_ids[message_id] = None
|
self._processed_message_ids[message_id] = None
|
||||||
|
|
||||||
# Trim cache: keep most recent 500 when exceeds 1000
|
# Trim cache
|
||||||
while len(self._processed_message_ids) > 1000:
|
while len(self._processed_message_ids) > 1000:
|
||||||
self._processed_message_ids.popitem(last=False)
|
self._processed_message_ids.popitem(last=False)
|
||||||
|
|
||||||
# Skip bot messages
|
# Skip bot messages
|
||||||
sender_type = sender.sender_type
|
if sender.sender_type == "bot":
|
||||||
if sender_type == "bot":
|
|
||||||
return
|
return
|
||||||
|
|
||||||
sender_id = sender.sender_id.open_id if sender.sender_id else "unknown"
|
sender_id = sender.sender_id.open_id if sender.sender_id else "unknown"
|
||||||
chat_id = message.chat_id
|
chat_id = message.chat_id
|
||||||
chat_type = message.chat_type # "p2p" or "group"
|
chat_type = message.chat_type
|
||||||
msg_type = message.message_type
|
msg_type = message.message_type
|
||||||
|
|
||||||
# Add reaction to indicate "seen"
|
# Add reaction
|
||||||
await self._add_reaction(message_id, "THUMBSUP")
|
await self._add_reaction(message_id, "THUMBSUP")
|
||||||
|
|
||||||
# Parse message content
|
# Parse content
|
||||||
|
content_parts = []
|
||||||
|
media_paths = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_json = json.loads(message.content) if message.content else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
content_json = {}
|
||||||
|
|
||||||
if msg_type == "text":
|
if msg_type == "text":
|
||||||
try:
|
text = content_json.get("text", "")
|
||||||
content = json.loads(message.content).get("text", "")
|
if text:
|
||||||
except json.JSONDecodeError:
|
content_parts.append(text)
|
||||||
content = message.content or ""
|
|
||||||
elif msg_type == "post":
|
elif msg_type == "post":
|
||||||
try:
|
text, image_keys = _extract_post_content(content_json)
|
||||||
content_json = json.loads(message.content)
|
if text:
|
||||||
content = _extract_post_text(content_json)
|
content_parts.append(text)
|
||||||
except (json.JSONDecodeError, TypeError):
|
# Download images embedded in post
|
||||||
content = message.content or ""
|
for img_key in image_keys:
|
||||||
|
file_path, content_text = await self._download_and_save_media(
|
||||||
|
"image", {"image_key": img_key}, message_id
|
||||||
|
)
|
||||||
|
if file_path:
|
||||||
|
media_paths.append(file_path)
|
||||||
|
content_parts.append(content_text)
|
||||||
|
|
||||||
|
elif msg_type in ("image", "audio", "file", "media"):
|
||||||
|
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||||
|
if file_path:
|
||||||
|
media_paths.append(file_path)
|
||||||
|
content_parts.append(content_text)
|
||||||
|
|
||||||
|
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||||
|
# Handle share cards and interactive messages
|
||||||
|
text = _extract_share_card_content(content_json, msg_type)
|
||||||
|
if text:
|
||||||
|
content_parts.append(text)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
content = MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]")
|
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||||
|
|
||||||
if not content:
|
content = "\n".join(content_parts) if content_parts else ""
|
||||||
|
|
||||||
|
if not content and not media_paths:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Forward to message bus
|
# Forward to message bus
|
||||||
reply_to = chat_id if chat_type == "group" else sender_id
|
reply_to = chat_id if chat_type == "group" else sender_id
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
chat_id=reply_to,
|
chat_id=reply_to,
|
||||||
content=content,
|
content=content,
|
||||||
|
media=media_paths,
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"chat_type": chat_type,
|
"chat_type": chat_type,
|
||||||
"msg_type": msg_type,
|
"msg_type": msg_type,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing Feishu message: {e}")
|
logger.error("Error processing Feishu message: {}", e)
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("Telegram channel enabled")
|
logger.info("Telegram channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Telegram channel not available: {e}")
|
logger.warning("Telegram channel not available: {}", e)
|
||||||
|
|
||||||
# WhatsApp channel
|
# WhatsApp channel
|
||||||
if self.config.channels.whatsapp.enabled:
|
if self.config.channels.whatsapp.enabled:
|
||||||
@@ -56,7 +56,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("WhatsApp channel enabled")
|
logger.info("WhatsApp channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"WhatsApp channel not available: {e}")
|
logger.warning("WhatsApp channel not available: {}", e)
|
||||||
|
|
||||||
# Discord channel
|
# Discord channel
|
||||||
if self.config.channels.discord.enabled:
|
if self.config.channels.discord.enabled:
|
||||||
@@ -67,7 +67,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("Discord channel enabled")
|
logger.info("Discord channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Discord channel not available: {e}")
|
logger.warning("Discord channel not available: {}", e)
|
||||||
|
|
||||||
# Feishu channel
|
# Feishu channel
|
||||||
if self.config.channels.feishu.enabled:
|
if self.config.channels.feishu.enabled:
|
||||||
@@ -78,7 +78,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("Feishu channel enabled")
|
logger.info("Feishu channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Feishu channel not available: {e}")
|
logger.warning("Feishu channel not available: {}", e)
|
||||||
|
|
||||||
# Mochat channel
|
# Mochat channel
|
||||||
if self.config.channels.mochat.enabled:
|
if self.config.channels.mochat.enabled:
|
||||||
@@ -90,7 +90,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("Mochat channel enabled")
|
logger.info("Mochat channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Mochat channel not available: {e}")
|
logger.warning("Mochat channel not available: {}", e)
|
||||||
|
|
||||||
# DingTalk channel
|
# DingTalk channel
|
||||||
if self.config.channels.dingtalk.enabled:
|
if self.config.channels.dingtalk.enabled:
|
||||||
@@ -101,7 +101,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("DingTalk channel enabled")
|
logger.info("DingTalk channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"DingTalk channel not available: {e}")
|
logger.warning("DingTalk channel not available: {}", e)
|
||||||
|
|
||||||
# Email channel
|
# Email channel
|
||||||
if self.config.channels.email.enabled:
|
if self.config.channels.email.enabled:
|
||||||
@@ -112,7 +112,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("Email channel enabled")
|
logger.info("Email channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Email channel not available: {e}")
|
logger.warning("Email channel not available: {}", e)
|
||||||
|
|
||||||
# Slack channel
|
# Slack channel
|
||||||
if self.config.channels.slack.enabled:
|
if self.config.channels.slack.enabled:
|
||||||
@@ -123,7 +123,7 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("Slack channel enabled")
|
logger.info("Slack channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Slack channel not available: {e}")
|
logger.warning("Slack channel not available: {}", e)
|
||||||
|
|
||||||
# QQ channel
|
# QQ channel
|
||||||
if self.config.channels.qq.enabled:
|
if self.config.channels.qq.enabled:
|
||||||
@@ -135,14 +135,14 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
logger.info("QQ channel enabled")
|
logger.info("QQ channel enabled")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"QQ channel not available: {e}")
|
logger.warning("QQ channel not available: {}", e)
|
||||||
|
|
||||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||||
"""Start a channel and log any exceptions."""
|
"""Start a channel and log any exceptions."""
|
||||||
try:
|
try:
|
||||||
await channel.start()
|
await channel.start()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start channel {name}: {e}")
|
logger.error("Failed to start channel {}: {}", name, e)
|
||||||
|
|
||||||
async def start_all(self) -> None:
|
async def start_all(self) -> None:
|
||||||
"""Start all channels and the outbound dispatcher."""
|
"""Start all channels and the outbound dispatcher."""
|
||||||
@@ -156,7 +156,7 @@ class ChannelManager:
|
|||||||
# Start channels
|
# Start channels
|
||||||
tasks = []
|
tasks = []
|
||||||
for name, channel in self.channels.items():
|
for name, channel in self.channels.items():
|
||||||
logger.info(f"Starting {name} channel...")
|
logger.info("Starting {} channel...", name)
|
||||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||||
|
|
||||||
# Wait for all to complete (they should run forever)
|
# Wait for all to complete (they should run forever)
|
||||||
@@ -178,9 +178,9 @@ class ChannelManager:
|
|||||||
for name, channel in self.channels.items():
|
for name, channel in self.channels.items():
|
||||||
try:
|
try:
|
||||||
await channel.stop()
|
await channel.stop()
|
||||||
logger.info(f"Stopped {name} channel")
|
logger.info("Stopped {} channel", name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error stopping {name}: {e}")
|
logger.error("Error stopping {}: {}", name, e)
|
||||||
|
|
||||||
async def _dispatch_outbound(self) -> None:
|
async def _dispatch_outbound(self) -> None:
|
||||||
"""Dispatch outbound messages to the appropriate channel."""
|
"""Dispatch outbound messages to the appropriate channel."""
|
||||||
@@ -193,14 +193,20 @@ class ChannelManager:
|
|||||||
timeout=1.0
|
timeout=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if msg.metadata.get("_progress"):
|
||||||
|
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||||
|
continue
|
||||||
|
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||||
|
continue
|
||||||
|
|
||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
try:
|
try:
|
||||||
await channel.send(msg)
|
await channel.send(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending to {msg.channel}: {e}")
|
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unknown channel: {msg.channel}")
|
logger.warning("Unknown channel: {}", msg.channel)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ class MochatChannel(BaseChannel):
|
|||||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||||
content, msg.reply_to)
|
content, msg.reply_to)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send Mochat message: {e}")
|
logger.error("Failed to send Mochat message: {}", e)
|
||||||
|
|
||||||
# ---- config / init helpers ---------------------------------------------
|
# ---- config / init helpers ---------------------------------------------
|
||||||
|
|
||||||
@@ -380,7 +380,7 @@ class MochatChannel(BaseChannel):
|
|||||||
|
|
||||||
@client.event
|
@client.event
|
||||||
async def connect_error(data: Any) -> None:
|
async def connect_error(data: Any) -> None:
|
||||||
logger.error(f"Mochat websocket connect error: {data}")
|
logger.error("Mochat websocket connect error: {}", data)
|
||||||
|
|
||||||
@client.on("claw.session.events")
|
@client.on("claw.session.events")
|
||||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||||
@@ -407,7 +407,7 @@ class MochatChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect Mochat websocket: {e}")
|
logger.error("Failed to connect Mochat websocket: {}", e)
|
||||||
try:
|
try:
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -444,7 +444,7 @@ class MochatChannel(BaseChannel):
|
|||||||
"limit": self.config.watch_limit,
|
"limit": self.config.watch_limit,
|
||||||
})
|
})
|
||||||
if not ack.get("result"):
|
if not ack.get("result"):
|
||||||
logger.error(f"Mochat subscribeSessions failed: {ack.get('message', 'unknown error')}")
|
logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
data = ack.get("data")
|
data = ack.get("data")
|
||||||
@@ -466,7 +466,7 @@ class MochatChannel(BaseChannel):
|
|||||||
return True
|
return True
|
||||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||||
if not ack.get("result"):
|
if not ack.get("result"):
|
||||||
logger.error(f"Mochat subscribePanels failed: {ack.get('message', 'unknown error')}")
|
logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -488,7 +488,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Mochat refresh failed: {e}")
|
logger.warning("Mochat refresh failed: {}", e)
|
||||||
if self._fallback_mode:
|
if self._fallback_mode:
|
||||||
await self._ensure_fallback_workers()
|
await self._ensure_fallback_workers()
|
||||||
|
|
||||||
@@ -502,7 +502,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
response = await self._post_json("/api/claw/sessions/list", {})
|
response = await self._post_json("/api/claw/sessions/list", {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Mochat listSessions failed: {e}")
|
logger.warning("Mochat listSessions failed: {}", e)
|
||||||
return
|
return
|
||||||
|
|
||||||
sessions = response.get("sessions")
|
sessions = response.get("sessions")
|
||||||
@@ -536,7 +536,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
response = await self._post_json("/api/claw/groups/get", {})
|
response = await self._post_json("/api/claw/groups/get", {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Mochat getWorkspaceGroup failed: {e}")
|
logger.warning("Mochat getWorkspaceGroup failed: {}", e)
|
||||||
return
|
return
|
||||||
|
|
||||||
raw_panels = response.get("panels")
|
raw_panels = response.get("panels")
|
||||||
@@ -598,7 +598,7 @@ class MochatChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Mochat watch fallback error ({session_id}): {e}")
|
logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
|
||||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||||
|
|
||||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||||
@@ -625,7 +625,7 @@ class MochatChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Mochat panel polling error ({panel_id}): {e}")
|
logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
|
||||||
await asyncio.sleep(sleep_s)
|
await asyncio.sleep(sleep_s)
|
||||||
|
|
||||||
# ---- inbound event processing ------------------------------------------
|
# ---- inbound event processing ------------------------------------------
|
||||||
@@ -836,7 +836,7 @@ class MochatChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to read Mochat cursor file: {e}")
|
logger.warning("Failed to read Mochat cursor file: {}", e)
|
||||||
return
|
return
|
||||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||||
if isinstance(cursors, dict):
|
if isinstance(cursors, dict):
|
||||||
@@ -852,7 +852,7 @@ class MochatChannel(BaseChannel):
|
|||||||
"cursors": self._session_cursor,
|
"cursors": self._session_cursor,
|
||||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to save Mochat cursor file: {e}")
|
logger.warning("Failed to save Mochat cursor file: {}", e)
|
||||||
|
|
||||||
# ---- HTTP helpers ------------------------------------------------------
|
# ---- HTTP helpers ------------------------------------------------------
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
|||||||
super().__init__(intents=intents)
|
super().__init__(intents=intents)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
logger.info(f"QQ bot ready: {self.robot.name}")
|
logger.info("QQ bot ready: {}", self.robot.name)
|
||||||
|
|
||||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||||
await channel._on_message(message)
|
await channel._on_message(message)
|
||||||
@@ -55,7 +55,6 @@ class QQChannel(BaseChannel):
|
|||||||
self.config: QQConfig = config
|
self.config: QQConfig = config
|
||||||
self._client: "botpy.Client | None" = None
|
self._client: "botpy.Client | None" = None
|
||||||
self._processed_ids: deque = deque(maxlen=1000)
|
self._processed_ids: deque = deque(maxlen=1000)
|
||||||
self._bot_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the QQ bot."""
|
"""Start the QQ bot."""
|
||||||
@@ -71,8 +70,8 @@ class QQChannel(BaseChannel):
|
|||||||
BotClass = _make_bot_class(self)
|
BotClass = _make_bot_class(self)
|
||||||
self._client = BotClass()
|
self._client = BotClass()
|
||||||
|
|
||||||
self._bot_task = asyncio.create_task(self._run_bot())
|
|
||||||
logger.info("QQ bot started (C2C private message)")
|
logger.info("QQ bot started (C2C private message)")
|
||||||
|
await self._run_bot()
|
||||||
|
|
||||||
async def _run_bot(self) -> None:
|
async def _run_bot(self) -> None:
|
||||||
"""Run the bot connection with auto-reconnect."""
|
"""Run the bot connection with auto-reconnect."""
|
||||||
@@ -80,7 +79,7 @@ class QQChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"QQ bot error: {e}")
|
logger.warning("QQ bot error: {}", e)
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
@@ -88,11 +87,10 @@ class QQChannel(BaseChannel):
|
|||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the QQ bot."""
|
"""Stop the QQ bot."""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._bot_task:
|
if self._client:
|
||||||
self._bot_task.cancel()
|
|
||||||
try:
|
try:
|
||||||
await self._bot_task
|
await self._client.close()
|
||||||
except asyncio.CancelledError:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info("QQ bot stopped")
|
logger.info("QQ bot stopped")
|
||||||
|
|
||||||
@@ -108,7 +106,7 @@ class QQChannel(BaseChannel):
|
|||||||
content=msg.content,
|
content=msg.content,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending QQ message: {e}")
|
logger.error("Error sending QQ message: {}", e)
|
||||||
|
|
||||||
async def _on_message(self, data: "C2CMessage") -> None:
|
async def _on_message(self, data: "C2CMessage") -> None:
|
||||||
"""Handle incoming message from QQ."""
|
"""Handle incoming message from QQ."""
|
||||||
@@ -130,5 +128,5 @@ class QQChannel(BaseChannel):
|
|||||||
content=content,
|
content=content,
|
||||||
metadata={"message_id": data.id},
|
metadata={"message_id": data.id},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Error handling QQ message: {e}")
|
logger.exception("Error handling QQ message")
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class SlackChannel(BaseChannel):
|
|||||||
logger.error("Slack bot/app token not configured")
|
logger.error("Slack bot/app token not configured")
|
||||||
return
|
return
|
||||||
if self.config.mode != "socket":
|
if self.config.mode != "socket":
|
||||||
logger.error(f"Unsupported Slack mode: {self.config.mode}")
|
logger.error("Unsupported Slack mode: {}", self.config.mode)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
@@ -53,9 +53,9 @@ class SlackChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
auth = await self._web_client.auth_test()
|
auth = await self._web_client.auth_test()
|
||||||
self._bot_user_id = auth.get("user_id")
|
self._bot_user_id = auth.get("user_id")
|
||||||
logger.info(f"Slack bot connected as {self._bot_user_id}")
|
logger.info("Slack bot connected as {}", self._bot_user_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Slack auth_test failed: {e}")
|
logger.warning("Slack auth_test failed: {}", e)
|
||||||
|
|
||||||
logger.info("Starting Slack Socket Mode client...")
|
logger.info("Starting Slack Socket Mode client...")
|
||||||
await self._socket_client.connect()
|
await self._socket_client.connect()
|
||||||
@@ -70,7 +70,7 @@ class SlackChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._socket_client.close()
|
await self._socket_client.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Slack socket close failed: {e}")
|
logger.warning("Slack socket close failed: {}", e)
|
||||||
self._socket_client = None
|
self._socket_client = None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
@@ -84,13 +84,26 @@ class SlackChannel(BaseChannel):
|
|||||||
channel_type = slack_meta.get("channel_type")
|
channel_type = slack_meta.get("channel_type")
|
||||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
# Only reply in thread for channel/group messages; DMs don't use threads
|
||||||
use_thread = thread_ts and channel_type != "im"
|
use_thread = thread_ts and channel_type != "im"
|
||||||
await self._web_client.chat_postMessage(
|
thread_ts_param = thread_ts if use_thread else None
|
||||||
channel=msg.chat_id,
|
|
||||||
text=self._to_mrkdwn(msg.content),
|
if msg.content:
|
||||||
thread_ts=thread_ts if use_thread else None,
|
await self._web_client.chat_postMessage(
|
||||||
)
|
channel=msg.chat_id,
|
||||||
|
text=self._to_mrkdwn(msg.content),
|
||||||
|
thread_ts=thread_ts_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
for media_path in msg.media or []:
|
||||||
|
try:
|
||||||
|
await self._web_client.files_upload_v2(
|
||||||
|
channel=msg.chat_id,
|
||||||
|
file=media_path,
|
||||||
|
thread_ts=thread_ts_param,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to upload file {}: {}", media_path, e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending Slack message: {e}")
|
logger.error("Error sending Slack message: {}", e)
|
||||||
|
|
||||||
async def _on_socket_request(
|
async def _on_socket_request(
|
||||||
self,
|
self,
|
||||||
@@ -164,20 +177,27 @@ class SlackChannel(BaseChannel):
|
|||||||
timestamp=event.get("ts"),
|
timestamp=event.get("ts"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Slack reactions_add failed: {e}")
|
logger.debug("Slack reactions_add failed: {}", e)
|
||||||
|
|
||||||
await self._handle_message(
|
# Thread-scoped session key for channel/group messages
|
||||||
sender_id=sender_id,
|
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||||
chat_id=chat_id,
|
|
||||||
content=text,
|
try:
|
||||||
metadata={
|
await self._handle_message(
|
||||||
"slack": {
|
sender_id=sender_id,
|
||||||
"event": event,
|
chat_id=chat_id,
|
||||||
"thread_ts": thread_ts,
|
content=text,
|
||||||
"channel_type": channel_type,
|
metadata={
|
||||||
}
|
"slack": {
|
||||||
},
|
"event": event,
|
||||||
)
|
"thread_ts": thread_ts,
|
||||||
|
"channel_type": channel_type,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
session_key=session_key,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error handling Slack message from {}", sender_id)
|
||||||
|
|
||||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||||
if channel_type == "im":
|
if channel_type == "im":
|
||||||
@@ -209,6 +229,11 @@ class SlackChannel(BaseChannel):
|
|||||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||||
|
|
||||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||||
|
_CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
|
||||||
|
_INLINE_CODE_RE = re.compile(r"`[^`]+`")
|
||||||
|
_LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
|
||||||
|
_LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||||
|
_BARE_URL_RE = re.compile(r"(?<![|<])(https?://\S+)")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _to_mrkdwn(cls, text: str) -> str:
|
def _to_mrkdwn(cls, text: str) -> str:
|
||||||
@@ -216,7 +241,26 @@ class SlackChannel(BaseChannel):
|
|||||||
if not text:
|
if not text:
|
||||||
return ""
|
return ""
|
||||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||||
return slackify_markdown(text)
|
return cls._fixup_mrkdwn(slackify_markdown(text))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fixup_mrkdwn(cls, text: str) -> str:
|
||||||
|
"""Fix markdown artifacts that slackify_markdown misses."""
|
||||||
|
code_blocks: list[str] = []
|
||||||
|
|
||||||
|
def _save_code(m: re.Match) -> str:
|
||||||
|
code_blocks.append(m.group(0))
|
||||||
|
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||||
|
|
||||||
|
text = cls._CODE_FENCE_RE.sub(_save_code, text)
|
||||||
|
text = cls._INLINE_CODE_RE.sub(_save_code, text)
|
||||||
|
text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
|
||||||
|
text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
|
||||||
|
text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
|
||||||
|
|
||||||
|
for i, block in enumerate(code_blocks):
|
||||||
|
text = text.replace(f"\x00CB{i}\x00", block)
|
||||||
|
return text
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_table(match: re.Match) -> str:
|
def _convert_table(match: re.Match) -> str:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from telegram import BotCommand, Update
|
from telegram import BotCommand, Update, ReplyParameters
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||||
from telegram.request import HTTPXRequest
|
from telegram.request import HTTPXRequest
|
||||||
|
|
||||||
@@ -111,6 +111,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
BOT_COMMANDS = [
|
BOT_COMMANDS = [
|
||||||
BotCommand("start", "Start the bot"),
|
BotCommand("start", "Start the bot"),
|
||||||
BotCommand("new", "Start a new conversation"),
|
BotCommand("new", "Start a new conversation"),
|
||||||
|
BotCommand("stop", "Stop the current task"),
|
||||||
BotCommand("help", "Show available commands"),
|
BotCommand("help", "Show available commands"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -146,7 +147,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
# Add command handlers
|
# Add command handlers
|
||||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||||
self._app.add_handler(CommandHandler("help", self._forward_command))
|
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||||
|
|
||||||
# Add message handler for text, photos, voice, documents
|
# Add message handler for text, photos, voice, documents
|
||||||
self._app.add_handler(
|
self._app.add_handler(
|
||||||
@@ -165,13 +166,13 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
# Get bot info and register command menu
|
# Get bot info and register command menu
|
||||||
bot_info = await self._app.bot.get_me()
|
bot_info = await self._app.bot.get_me()
|
||||||
logger.info(f"Telegram bot @{bot_info.username} connected")
|
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||||
logger.debug("Telegram bot commands registered")
|
logger.debug("Telegram bot commands registered")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to register bot commands: {e}")
|
logger.warning("Failed to register bot commands: {}", e)
|
||||||
|
|
||||||
# Start polling (this runs until stopped)
|
# Start polling (this runs until stopped)
|
||||||
await self._app.updater.start_polling(
|
await self._app.updater.start_polling(
|
||||||
@@ -221,9 +222,18 @@ class TelegramChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
chat_id = int(msg.chat_id)
|
chat_id = int(msg.chat_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error(f"Invalid chat_id: {msg.chat_id}")
|
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
reply_params = None
|
||||||
|
if self.config.reply_to_message:
|
||||||
|
reply_to_message_id = msg.metadata.get("message_id")
|
||||||
|
if reply_to_message_id:
|
||||||
|
reply_params = ReplyParameters(
|
||||||
|
message_id=reply_to_message_id,
|
||||||
|
allow_sending_without_reply=True
|
||||||
|
)
|
||||||
|
|
||||||
# Send media files
|
# Send media files
|
||||||
for media_path in (msg.media or []):
|
for media_path in (msg.media or []):
|
||||||
try:
|
try:
|
||||||
@@ -235,37 +245,65 @@ class TelegramChannel(BaseChannel):
|
|||||||
}.get(media_type, self._app.bot.send_document)
|
}.get(media_type, self._app.bot.send_document)
|
||||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||||
with open(media_path, 'rb') as f:
|
with open(media_path, 'rb') as f:
|
||||||
await sender(chat_id=chat_id, **{param: f})
|
await sender(
|
||||||
|
chat_id=chat_id,
|
||||||
|
**{param: f},
|
||||||
|
reply_parameters=reply_params
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
filename = media_path.rsplit("/", 1)[-1]
|
filename = media_path.rsplit("/", 1)[-1]
|
||||||
logger.error(f"Failed to send media {media_path}: {e}")
|
logger.error("Failed to send media {}: {}", media_path, e)
|
||||||
await self._app.bot.send_message(chat_id=chat_id, text=f"[Failed to send: {filename}]")
|
await self._app.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=f"[Failed to send: {filename}]",
|
||||||
|
reply_parameters=reply_params
|
||||||
|
)
|
||||||
|
|
||||||
# Send text content
|
# Send text content
|
||||||
if msg.content and msg.content != "[empty message]":
|
if msg.content and msg.content != "[empty message]":
|
||||||
for chunk in _split_message(msg.content):
|
for chunk in _split_message(msg.content):
|
||||||
try:
|
try:
|
||||||
html = _markdown_to_telegram_html(chunk)
|
html = _markdown_to_telegram_html(chunk)
|
||||||
await self._app.bot.send_message(chat_id=chat_id, text=html, parse_mode="HTML")
|
await self._app.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=html,
|
||||||
|
parse_mode="HTML",
|
||||||
|
reply_parameters=reply_params
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"HTML parse failed, falling back to plain text: {e}")
|
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||||
try:
|
try:
|
||||||
await self._app.bot.send_message(chat_id=chat_id, text=chunk)
|
await self._app.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=chunk,
|
||||||
|
reply_parameters=reply_params
|
||||||
|
)
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
logger.error(f"Error sending Telegram message: {e2}")
|
logger.error("Error sending Telegram message: {}", e2)
|
||||||
|
|
||||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
if not update.message or not update.effective_user:
|
if not update.message or not update.effective_user:
|
||||||
return
|
return
|
||||||
|
|
||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
await update.message.reply_text(
|
await update.message.reply_text(
|
||||||
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
||||||
"Send me a message and I'll respond!\n"
|
"Send me a message and I'll respond!\n"
|
||||||
"Type /help to see available commands."
|
"Type /help to see available commands."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
"""Handle /help command, bypassing ACL so all users can access it."""
|
||||||
|
if not update.message:
|
||||||
|
return
|
||||||
|
await update.message.reply_text(
|
||||||
|
"🐈 nanobot commands:\n"
|
||||||
|
"/new — Start a new conversation\n"
|
||||||
|
"/stop — Stop the current task\n"
|
||||||
|
"/help — Show available commands"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sender_id(user) -> str:
|
def _sender_id(user) -> str:
|
||||||
"""Build sender_id with username for allowlist matching."""
|
"""Build sender_id with username for allowlist matching."""
|
||||||
@@ -344,21 +382,21 @@ class TelegramChannel(BaseChannel):
|
|||||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||||
transcription = await transcriber.transcribe(file_path)
|
transcription = await transcriber.transcribe(file_path)
|
||||||
if transcription:
|
if transcription:
|
||||||
logger.info(f"Transcribed {media_type}: {transcription[:50]}...")
|
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||||
content_parts.append(f"[transcription: {transcription}]")
|
content_parts.append(f"[transcription: {transcription}]")
|
||||||
else:
|
else:
|
||||||
content_parts.append(f"[{media_type}: {file_path}]")
|
content_parts.append(f"[{media_type}: {file_path}]")
|
||||||
else:
|
else:
|
||||||
content_parts.append(f"[{media_type}: {file_path}]")
|
content_parts.append(f"[{media_type}: {file_path}]")
|
||||||
|
|
||||||
logger.debug(f"Downloaded {media_type} to {file_path}")
|
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download media: {e}")
|
logger.error("Failed to download media: {}", e)
|
||||||
content_parts.append(f"[{media_type}: download failed]")
|
content_parts.append(f"[{media_type}: download failed]")
|
||||||
|
|
||||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||||
|
|
||||||
logger.debug(f"Telegram message from {sender_id}: {content[:50]}...")
|
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||||
|
|
||||||
str_chat_id = str(chat_id)
|
str_chat_id = str(chat_id)
|
||||||
|
|
||||||
@@ -401,11 +439,11 @@ class TelegramChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Typing indicator stopped for {chat_id}: {e}")
|
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||||
|
|
||||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
"""Log polling / handler errors instead of silently swallowing them."""
|
"""Log polling / handler errors instead of silently swallowing them."""
|
||||||
logger.error(f"Telegram error: {context.error}")
|
logger.error("Telegram error: {}", context.error)
|
||||||
|
|
||||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
||||||
"""Get file extension based on media type."""
|
"""Get file extension based on media type."""
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
|
|
||||||
bridge_url = self.config.bridge_url
|
bridge_url = self.config.bridge_url
|
||||||
|
|
||||||
logger.info(f"Connecting to WhatsApp bridge at {bridge_url}...")
|
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
@@ -53,14 +53,14 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
await self._handle_bridge_message(message)
|
await self._handle_bridge_message(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling bridge message: {e}")
|
logger.error("Error handling bridge message: {}", e)
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self._ws = None
|
self._ws = None
|
||||||
logger.warning(f"WhatsApp bridge connection error: {e}")
|
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||||
|
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.info("Reconnecting in 5 seconds...")
|
logger.info("Reconnecting in 5 seconds...")
|
||||||
@@ -87,16 +87,16 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
"to": msg.chat_id,
|
"to": msg.chat_id,
|
||||||
"text": msg.content
|
"text": msg.content
|
||||||
}
|
}
|
||||||
await self._ws.send(json.dumps(payload))
|
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending WhatsApp message: {e}")
|
logger.error("Error sending WhatsApp message: {}", e)
|
||||||
|
|
||||||
async def _handle_bridge_message(self, raw: str) -> None:
|
async def _handle_bridge_message(self, raw: str) -> None:
|
||||||
"""Handle a message from the bridge."""
|
"""Handle a message from the bridge."""
|
||||||
try:
|
try:
|
||||||
data = json.loads(raw)
|
data = json.loads(raw)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"Invalid JSON from bridge: {raw[:100]}")
|
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||||
return
|
return
|
||||||
|
|
||||||
msg_type = data.get("type")
|
msg_type = data.get("type")
|
||||||
@@ -112,11 +112,11 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
# Extract just the phone number or lid as chat_id
|
# Extract just the phone number or lid as chat_id
|
||||||
user_id = pn if pn else sender
|
user_id = pn if pn else sender
|
||||||
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
||||||
logger.info(f"Sender {sender}")
|
logger.info("Sender {}", sender)
|
||||||
|
|
||||||
# Handle voice transcription if it's a voice message
|
# Handle voice transcription if it's a voice message
|
||||||
if content == "[Voice Message]":
|
if content == "[Voice Message]":
|
||||||
logger.info(f"Voice message received from {sender_id}, but direct download from bridge is not yet supported.")
|
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
@@ -133,7 +133,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
elif msg_type == "status":
|
elif msg_type == "status":
|
||||||
# Connection status update
|
# Connection status update
|
||||||
status = data.get("status")
|
status = data.get("status")
|
||||||
logger.info(f"WhatsApp status: {status}")
|
logger.info("WhatsApp status: {}", status)
|
||||||
|
|
||||||
if status == "connected":
|
if status == "connected":
|
||||||
self._connected = True
|
self._connected = True
|
||||||
@@ -145,4 +145,4 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||||
|
|
||||||
elif msg_type == "error":
|
elif msg_type == "error":
|
||||||
logger.error(f"WhatsApp bridge error: {data.get('error')}")
|
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||||
|
|||||||
@@ -199,84 +199,34 @@ def onboard():
|
|||||||
|
|
||||||
|
|
||||||
def _create_workspace_templates(workspace: Path):
|
def _create_workspace_templates(workspace: Path):
|
||||||
"""Create default workspace template files."""
|
"""Create default workspace template files from bundled templates."""
|
||||||
templates = {
|
from importlib.resources import files as pkg_files
|
||||||
"AGENTS.md": """# Agent Instructions
|
|
||||||
|
|
||||||
You are a helpful AI assistant. Be concise, accurate, and friendly.
|
templates_dir = pkg_files("nanobot") / "templates"
|
||||||
|
|
||||||
## Guidelines
|
for item in templates_dir.iterdir():
|
||||||
|
if not item.name.endswith(".md"):
|
||||||
|
continue
|
||||||
|
dest = workspace / item.name
|
||||||
|
if not dest.exists():
|
||||||
|
dest.write_text(item.read_text(encoding="utf-8"), encoding="utf-8")
|
||||||
|
console.print(f" [dim]Created {item.name}[/dim]")
|
||||||
|
|
||||||
- Always explain what you're doing before taking actions
|
|
||||||
- Ask for clarification when the request is ambiguous
|
|
||||||
- Use tools to help accomplish tasks
|
|
||||||
- Remember important information in memory/MEMORY.md; past events are logged in memory/HISTORY.md
|
|
||||||
""",
|
|
||||||
"SOUL.md": """# Soul
|
|
||||||
|
|
||||||
I am nanobot, a lightweight AI assistant.
|
|
||||||
|
|
||||||
## Personality
|
|
||||||
|
|
||||||
- Helpful and friendly
|
|
||||||
- Concise and to the point
|
|
||||||
- Curious and eager to learn
|
|
||||||
|
|
||||||
## Values
|
|
||||||
|
|
||||||
- Accuracy over speed
|
|
||||||
- User privacy and safety
|
|
||||||
- Transparency in actions
|
|
||||||
""",
|
|
||||||
"USER.md": """# User
|
|
||||||
|
|
||||||
Information about the user goes here.
|
|
||||||
|
|
||||||
## Preferences
|
|
||||||
|
|
||||||
- Communication style: (casual/formal)
|
|
||||||
- Timezone: (your timezone)
|
|
||||||
- Language: (your preferred language)
|
|
||||||
""",
|
|
||||||
}
|
|
||||||
|
|
||||||
for filename, content in templates.items():
|
|
||||||
file_path = workspace / filename
|
|
||||||
if not file_path.exists():
|
|
||||||
file_path.write_text(content)
|
|
||||||
console.print(f" [dim]Created {filename}[/dim]")
|
|
||||||
|
|
||||||
# Create memory directory and MEMORY.md
|
|
||||||
memory_dir = workspace / "memory"
|
memory_dir = workspace / "memory"
|
||||||
memory_dir.mkdir(exist_ok=True)
|
memory_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
memory_template = templates_dir / "memory" / "MEMORY.md"
|
||||||
memory_file = memory_dir / "MEMORY.md"
|
memory_file = memory_dir / "MEMORY.md"
|
||||||
if not memory_file.exists():
|
if not memory_file.exists():
|
||||||
memory_file.write_text("""# Long-term Memory
|
memory_file.write_text(memory_template.read_text(encoding="utf-8"), encoding="utf-8")
|
||||||
|
|
||||||
This file stores important information that should persist across sessions.
|
|
||||||
|
|
||||||
## User Information
|
|
||||||
|
|
||||||
(Important facts about the user)
|
|
||||||
|
|
||||||
## Preferences
|
|
||||||
|
|
||||||
(User preferences learned over time)
|
|
||||||
|
|
||||||
## Important Notes
|
|
||||||
|
|
||||||
(Things to remember)
|
|
||||||
""")
|
|
||||||
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
||||||
|
|
||||||
history_file = memory_dir / "HISTORY.md"
|
history_file = memory_dir / "HISTORY.md"
|
||||||
if not history_file.exists():
|
if not history_file.exists():
|
||||||
history_file.write_text("")
|
history_file.write_text("", encoding="utf-8")
|
||||||
console.print(" [dim]Created memory/HISTORY.md[/dim]")
|
console.print(" [dim]Created memory/HISTORY.md[/dim]")
|
||||||
|
|
||||||
# Create skills directory for custom user skills
|
(workspace / "skills").mkdir(exist_ok=True)
|
||||||
skills_dir = workspace / "skills"
|
|
||||||
skills_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
@@ -368,6 +318,7 @@ def gateway(
|
|||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
mcp_servers=config.tools.mcp_servers,
|
mcp_servers=config.tools.mcp_servers,
|
||||||
|
channels_config=config.channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set cron callback (needs agent)
|
# Set cron callback (needs agent)
|
||||||
@@ -389,20 +340,59 @@ def gateway(
|
|||||||
return response
|
return response
|
||||||
cron.on_job = on_cron_job
|
cron.on_job = on_cron_job
|
||||||
|
|
||||||
# Create heartbeat service
|
|
||||||
async def on_heartbeat(prompt: str) -> str:
|
|
||||||
"""Execute heartbeat through the agent."""
|
|
||||||
return await agent.process_direct(prompt, session_key="heartbeat")
|
|
||||||
|
|
||||||
heartbeat = HeartbeatService(
|
|
||||||
workspace=config.workspace_path,
|
|
||||||
on_heartbeat=on_heartbeat,
|
|
||||||
interval_s=30 * 60, # 30 minutes
|
|
||||||
enabled=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create channel manager
|
# Create channel manager
|
||||||
channels = ChannelManager(config, bus)
|
channels = ChannelManager(config, bus)
|
||||||
|
|
||||||
|
def _pick_heartbeat_target() -> tuple[str, str]:
|
||||||
|
"""Pick a routable channel/chat target for heartbeat-triggered messages."""
|
||||||
|
enabled = set(channels.enabled_channels)
|
||||||
|
# Prefer the most recently updated non-internal session on an enabled channel.
|
||||||
|
for item in session_manager.list_sessions():
|
||||||
|
key = item.get("key") or ""
|
||||||
|
if ":" not in key:
|
||||||
|
continue
|
||||||
|
channel, chat_id = key.split(":", 1)
|
||||||
|
if channel in {"cli", "system"}:
|
||||||
|
continue
|
||||||
|
if channel in enabled and chat_id:
|
||||||
|
return channel, chat_id
|
||||||
|
# Fallback keeps prior behavior but remains explicit.
|
||||||
|
return "cli", "direct"
|
||||||
|
|
||||||
|
# Create heartbeat service
|
||||||
|
async def on_heartbeat_execute(tasks: str) -> str:
|
||||||
|
"""Phase 2: execute heartbeat tasks through the full agent loop."""
|
||||||
|
channel, chat_id = _pick_heartbeat_target()
|
||||||
|
|
||||||
|
async def _silent(*_args, **_kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return await agent.process_direct(
|
||||||
|
tasks,
|
||||||
|
session_key="heartbeat",
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
on_progress=_silent,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_heartbeat_notify(response: str) -> None:
|
||||||
|
"""Deliver a heartbeat response to the user's channel."""
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
channel, chat_id = _pick_heartbeat_target()
|
||||||
|
if channel == "cli":
|
||||||
|
return # No external channel available to deliver to
|
||||||
|
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
||||||
|
|
||||||
|
hb_cfg = config.gateway.heartbeat
|
||||||
|
heartbeat = HeartbeatService(
|
||||||
|
workspace=config.workspace_path,
|
||||||
|
provider=provider,
|
||||||
|
model=agent.model,
|
||||||
|
on_execute=on_heartbeat_execute,
|
||||||
|
on_notify=on_heartbeat_notify,
|
||||||
|
interval_s=hb_cfg.interval_s,
|
||||||
|
enabled=hb_cfg.enabled,
|
||||||
|
)
|
||||||
|
|
||||||
if channels.enabled_channels:
|
if channels.enabled_channels:
|
||||||
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
||||||
@@ -413,7 +403,7 @@ def gateway(
|
|||||||
if cron_status["jobs"] > 0:
|
if cron_status["jobs"] > 0:
|
||||||
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Heartbeat: every 30m")
|
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||||
|
|
||||||
async def run():
|
async def run():
|
||||||
try:
|
try:
|
||||||
@@ -484,6 +474,7 @@ def agent(
|
|||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
mcp_servers=config.tools.mcp_servers,
|
mcp_servers=config.tools.mcp_servers,
|
||||||
|
channels_config=config.channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show spinner when logs are off (no output to miss); skip when logs are on
|
# Show spinner when logs are off (no output to miss); skip when logs are on
|
||||||
@@ -494,31 +485,74 @@ def agent(
|
|||||||
# Animated spinner is safe to use with prompt_toolkit input handling
|
# Animated spinner is safe to use with prompt_toolkit input handling
|
||||||
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
||||||
|
|
||||||
async def _cli_progress(content: str) -> None:
|
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
|
ch = agent_loop.channels_config
|
||||||
|
if ch and tool_hint and not ch.send_tool_hints:
|
||||||
|
return
|
||||||
|
if ch and not tool_hint and not ch.send_progress:
|
||||||
|
return
|
||||||
console.print(f" [dim]↳ {content}[/dim]")
|
console.print(f" [dim]↳ {content}[/dim]")
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
# Single message mode
|
# Single message mode — direct call, no bus needed
|
||||||
async def run_once():
|
async def run_once():
|
||||||
with _thinking_ctx():
|
with _thinking_ctx():
|
||||||
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
|
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
|
||||||
_print_agent_response(response, render_markdown=markdown)
|
_print_agent_response(response, render_markdown=markdown)
|
||||||
await agent_loop.close_mcp()
|
await agent_loop.close_mcp()
|
||||||
|
|
||||||
asyncio.run(run_once())
|
asyncio.run(run_once())
|
||||||
else:
|
else:
|
||||||
# Interactive mode
|
# Interactive mode — route through bus like other channels
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
_init_prompt_session()
|
_init_prompt_session()
|
||||||
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
||||||
|
|
||||||
|
if ":" in session_id:
|
||||||
|
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||||
|
else:
|
||||||
|
cli_channel, cli_chat_id = "cli", session_id
|
||||||
|
|
||||||
def _exit_on_sigint(signum, frame):
|
def _exit_on_sigint(signum, frame):
|
||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, _exit_on_sigint)
|
signal.signal(signal.SIGINT, _exit_on_sigint)
|
||||||
|
|
||||||
async def run_interactive():
|
async def run_interactive():
|
||||||
|
bus_task = asyncio.create_task(agent_loop.run())
|
||||||
|
turn_done = asyncio.Event()
|
||||||
|
turn_done.set()
|
||||||
|
turn_response: list[str] = []
|
||||||
|
|
||||||
|
async def _consume_outbound():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
if msg.metadata.get("_progress"):
|
||||||
|
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||||
|
ch = agent_loop.channels_config
|
||||||
|
if ch and is_tool_hint and not ch.send_tool_hints:
|
||||||
|
pass
|
||||||
|
elif ch and not is_tool_hint and not ch.send_progress:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
console.print(f" [dim]↳ {msg.content}[/dim]")
|
||||||
|
elif not turn_done.is_set():
|
||||||
|
if msg.content:
|
||||||
|
turn_response.append(msg.content)
|
||||||
|
turn_done.set()
|
||||||
|
elif msg.content:
|
||||||
|
console.print()
|
||||||
|
_print_agent_response(msg.content, render_markdown=markdown)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
outbound_task = asyncio.create_task(_consume_outbound())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -532,10 +566,22 @@ def agent(
|
|||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
turn_done.clear()
|
||||||
|
turn_response.clear()
|
||||||
|
|
||||||
|
await bus.publish_inbound(InboundMessage(
|
||||||
|
channel=cli_channel,
|
||||||
|
sender_id="user",
|
||||||
|
chat_id=cli_chat_id,
|
||||||
|
content=user_input,
|
||||||
|
))
|
||||||
|
|
||||||
with _thinking_ctx():
|
with _thinking_ctx():
|
||||||
response = await agent_loop.process_direct(user_input, session_id, on_progress=_cli_progress)
|
await turn_done.wait()
|
||||||
_print_agent_response(response, render_markdown=markdown)
|
|
||||||
|
if turn_response:
|
||||||
|
_print_agent_response(turn_response[0], render_markdown=markdown)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
@@ -545,8 +591,11 @@ def agent(
|
|||||||
console.print("\nGoodbye!")
|
console.print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
|
agent_loop.stop()
|
||||||
|
outbound_task.cancel()
|
||||||
|
await asyncio.gather(bus_task, outbound_task, return_exceptions=True)
|
||||||
await agent_loop.close_mcp()
|
await agent_loop.close_mcp()
|
||||||
|
|
||||||
asyncio.run(run_interactive())
|
asyncio.run(run_interactive())
|
||||||
|
|
||||||
|
|
||||||
@@ -622,6 +671,33 @@ def channels_status():
|
|||||||
slack_config
|
slack_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DingTalk
|
||||||
|
dt = config.channels.dingtalk
|
||||||
|
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"DingTalk",
|
||||||
|
"✓" if dt.enabled else "✗",
|
||||||
|
dt_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# QQ
|
||||||
|
qq = config.channels.qq
|
||||||
|
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"QQ",
|
||||||
|
"✓" if qq.enabled else "✗",
|
||||||
|
qq_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Email
|
||||||
|
em = config.channels.email
|
||||||
|
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
|
||||||
|
table.add_row(
|
||||||
|
"Email",
|
||||||
|
"✓" if em.enabled else "✗",
|
||||||
|
em_config
|
||||||
|
)
|
||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
@@ -805,15 +881,19 @@ def cron_add(
|
|||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
store_path = get_data_dir() / "cron" / "jobs.json"
|
||||||
service = CronService(store_path)
|
service = CronService(store_path)
|
||||||
|
|
||||||
job = service.add_job(
|
try:
|
||||||
name=name,
|
job = service.add_job(
|
||||||
schedule=schedule,
|
name=name,
|
||||||
message=message,
|
schedule=schedule,
|
||||||
deliver=deliver,
|
message=message,
|
||||||
to=to,
|
deliver=deliver,
|
||||||
channel=channel,
|
to=to,
|
||||||
)
|
channel=channel,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
console.print(f"[red]Error: {e}[/red]")
|
||||||
|
raise typer.Exit(1) from e
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
|
console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
|
||||||
|
|
||||||
|
|
||||||
@@ -860,17 +940,57 @@ def cron_run(
|
|||||||
force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
|
force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
|
||||||
):
|
):
|
||||||
"""Manually run a job."""
|
"""Manually run a job."""
|
||||||
from nanobot.config.loader import get_data_dir
|
from loguru import logger
|
||||||
|
from nanobot.config.loader import load_config, get_data_dir
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
from nanobot.cron.types import CronJob
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
logger.disable("nanobot")
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
provider = _make_provider(config)
|
||||||
|
bus = MessageBus()
|
||||||
|
agent_loop = AgentLoop(
|
||||||
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=config.workspace_path,
|
||||||
|
model=config.agents.defaults.model,
|
||||||
|
temperature=config.agents.defaults.temperature,
|
||||||
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
|
memory_window=config.agents.defaults.memory_window,
|
||||||
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
|
exec_config=config.tools.exec,
|
||||||
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
|
mcp_servers=config.tools.mcp_servers,
|
||||||
|
channels_config=config.channels,
|
||||||
|
)
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
store_path = get_data_dir() / "cron" / "jobs.json"
|
||||||
service = CronService(store_path)
|
service = CronService(store_path)
|
||||||
|
|
||||||
|
result_holder = []
|
||||||
|
|
||||||
|
async def on_job(job: CronJob) -> str | None:
|
||||||
|
response = await agent_loop.process_direct(
|
||||||
|
job.payload.message,
|
||||||
|
session_key=f"cron:{job.id}",
|
||||||
|
channel=job.payload.channel or "cli",
|
||||||
|
chat_id=job.payload.to or "direct",
|
||||||
|
)
|
||||||
|
result_holder.append(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
service.on_job = on_job
|
||||||
|
|
||||||
async def run():
|
async def run():
|
||||||
return await service.run_job(job_id, force=force)
|
return await service.run_job(job_id, force=force)
|
||||||
|
|
||||||
if asyncio.run(run()):
|
if asyncio.run(run()):
|
||||||
console.print(f"[green]✓[/green] Job executed")
|
console.print("[green]✓[/green] Job executed")
|
||||||
|
if result_holder:
|
||||||
|
_print_agent_response(result_holder[0], render_markdown=True)
|
||||||
else:
|
else:
|
||||||
console.print(f"[red]Failed to run job {job_id}[/red]")
|
console.print(f"[red]Failed to run job {job_id}[/red]")
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ def load_config(config_path: Path | None = None) -> Config:
|
|||||||
|
|
||||||
if path.exists():
|
if path.exists():
|
||||||
try:
|
try:
|
||||||
with open(path) as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
data = _migrate_config(data)
|
data = _migrate_config(data)
|
||||||
return Config.model_validate(data)
|
return Config.model_validate(data)
|
||||||
@@ -55,8 +55,8 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
|
|||||||
|
|
||||||
data = config.model_dump(by_alias=True)
|
data = config.model_dump(by_alias=True)
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def _migrate_config(data: dict) -> dict:
|
def _migrate_config(data: dict) -> dict:
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class TelegramConfig(Base):
|
|||||||
token: str = "" # Bot token from @BotFather
|
token: str = "" # Bot token from @BotFather
|
||||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
|
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
|
||||||
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||||
|
reply_to_message: bool = False # If true, bot replies quote the original message
|
||||||
|
|
||||||
|
|
||||||
class FeishuConfig(Base):
|
class FeishuConfig(Base):
|
||||||
@@ -189,6 +190,8 @@ class QQConfig(Base):
|
|||||||
class ChannelsConfig(Base):
|
class ChannelsConfig(Base):
|
||||||
"""Configuration for chat channels."""
|
"""Configuration for chat channels."""
|
||||||
|
|
||||||
|
send_progress: bool = True # stream agent's text progress to the channel
|
||||||
|
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||||
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
|
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
|
||||||
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
||||||
discord: DiscordConfig = Field(default_factory=DiscordConfig)
|
discord: DiscordConfig = Field(default_factory=DiscordConfig)
|
||||||
@@ -206,10 +209,11 @@ class AgentDefaults(Base):
|
|||||||
|
|
||||||
workspace: str = "~/.nanobot/workspace"
|
workspace: str = "~/.nanobot/workspace"
|
||||||
model: str = "anthropic/claude-opus-4-5"
|
model: str = "anthropic/claude-opus-4-5"
|
||||||
|
provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
temperature: float = 0.7
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 20
|
max_tool_iterations: int = 40
|
||||||
memory_window: int = 50
|
memory_window: int = 100
|
||||||
|
|
||||||
|
|
||||||
class AgentsConfig(Base):
|
class AgentsConfig(Base):
|
||||||
@@ -243,15 +247,24 @@ class ProvidersConfig(Base):
|
|||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) API gateway
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) API gateway
|
||||||
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) API gateway
|
||||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||||
|
|
||||||
|
|
||||||
|
class HeartbeatConfig(Base):
|
||||||
|
"""Heartbeat service configuration."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
interval_s: int = 30 * 60 # 30 minutes
|
||||||
|
|
||||||
|
|
||||||
class GatewayConfig(Base):
|
class GatewayConfig(Base):
|
||||||
"""Gateway/server configuration."""
|
"""Gateway/server configuration."""
|
||||||
|
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
port: int = 18790
|
port: int = 18790
|
||||||
|
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
|
||||||
|
|
||||||
|
|
||||||
class WebSearchConfig(Base):
|
class WebSearchConfig(Base):
|
||||||
@@ -271,6 +284,7 @@ class ExecToolConfig(Base):
|
|||||||
"""Shell exec tool configuration."""
|
"""Shell exec tool configuration."""
|
||||||
|
|
||||||
timeout: int = 60
|
timeout: int = 60
|
||||||
|
path_append: str = ""
|
||||||
|
|
||||||
|
|
||||||
class MCPServerConfig(Base):
|
class MCPServerConfig(Base):
|
||||||
@@ -280,6 +294,8 @@ class MCPServerConfig(Base):
|
|||||||
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
||||||
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
||||||
url: str = "" # HTTP: streamable HTTP endpoint URL
|
url: str = "" # HTTP: streamable HTTP endpoint URL
|
||||||
|
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
|
||||||
|
tool_timeout: int = 30 # Seconds before a tool call is cancelled
|
||||||
|
|
||||||
|
|
||||||
class ToolsConfig(Base):
|
class ToolsConfig(Base):
|
||||||
@@ -309,6 +325,11 @@ class Config(BaseSettings):
|
|||||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import PROVIDERS
|
||||||
|
|
||||||
|
forced = self.agents.defaults.provider
|
||||||
|
if forced != "auto":
|
||||||
|
p = getattr(self.providers, forced, None)
|
||||||
|
return (p, forced) if p else (None, None)
|
||||||
|
|
||||||
model_lower = (model or self.agents.defaults.model).lower()
|
model_lower = (model or self.agents.defaults.model).lower()
|
||||||
model_normalized = model_lower.replace("-", "_")
|
model_normalized = model_lower.replace("-", "_")
|
||||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||||
|
|||||||
@@ -45,6 +45,20 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_schedule_for_add(schedule: CronSchedule) -> None:
|
||||||
|
"""Validate schedule fields that would otherwise create non-runnable jobs."""
|
||||||
|
if schedule.tz and schedule.kind != "cron":
|
||||||
|
raise ValueError("tz can only be used with cron schedules")
|
||||||
|
|
||||||
|
if schedule.kind == "cron" and schedule.tz:
|
||||||
|
try:
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
ZoneInfo(schedule.tz)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(f"unknown timezone '{schedule.tz}'") from None
|
||||||
|
|
||||||
|
|
||||||
class CronService:
|
class CronService:
|
||||||
"""Service for managing and executing scheduled jobs."""
|
"""Service for managing and executing scheduled jobs."""
|
||||||
|
|
||||||
@@ -66,7 +80,7 @@ class CronService:
|
|||||||
|
|
||||||
if self.store_path.exists():
|
if self.store_path.exists():
|
||||||
try:
|
try:
|
||||||
data = json.loads(self.store_path.read_text())
|
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
||||||
jobs = []
|
jobs = []
|
||||||
for j in data.get("jobs", []):
|
for j in data.get("jobs", []):
|
||||||
jobs.append(CronJob(
|
jobs.append(CronJob(
|
||||||
@@ -99,7 +113,7 @@ class CronService:
|
|||||||
))
|
))
|
||||||
self._store = CronStore(jobs=jobs)
|
self._store = CronStore(jobs=jobs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load cron store: {e}")
|
logger.warning("Failed to load cron store: {}", e)
|
||||||
self._store = CronStore()
|
self._store = CronStore()
|
||||||
else:
|
else:
|
||||||
self._store = CronStore()
|
self._store = CronStore()
|
||||||
@@ -148,7 +162,7 @@ class CronService:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
self.store_path.write_text(json.dumps(data, indent=2))
|
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the cron service."""
|
"""Start the cron service."""
|
||||||
@@ -157,7 +171,7 @@ class CronService:
|
|||||||
self._recompute_next_runs()
|
self._recompute_next_runs()
|
||||||
self._save_store()
|
self._save_store()
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
logger.info(f"Cron service started with {len(self._store.jobs if self._store else [])} jobs")
|
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the cron service."""
|
"""Stop the cron service."""
|
||||||
@@ -222,7 +236,7 @@ class CronService:
|
|||||||
async def _execute_job(self, job: CronJob) -> None:
|
async def _execute_job(self, job: CronJob) -> None:
|
||||||
"""Execute a single job."""
|
"""Execute a single job."""
|
||||||
start_ms = _now_ms()
|
start_ms = _now_ms()
|
||||||
logger.info(f"Cron: executing job '{job.name}' ({job.id})")
|
logger.info("Cron: executing job '{}' ({})", job.name, job.id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = None
|
response = None
|
||||||
@@ -231,12 +245,12 @@ class CronService:
|
|||||||
|
|
||||||
job.state.last_status = "ok"
|
job.state.last_status = "ok"
|
||||||
job.state.last_error = None
|
job.state.last_error = None
|
||||||
logger.info(f"Cron: job '{job.name}' completed")
|
logger.info("Cron: job '{}' completed", job.name)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
job.state.last_status = "error"
|
job.state.last_status = "error"
|
||||||
job.state.last_error = str(e)
|
job.state.last_error = str(e)
|
||||||
logger.error(f"Cron: job '{job.name}' failed: {e}")
|
logger.error("Cron: job '{}' failed: {}", job.name, e)
|
||||||
|
|
||||||
job.state.last_run_at_ms = start_ms
|
job.state.last_run_at_ms = start_ms
|
||||||
job.updated_at_ms = _now_ms()
|
job.updated_at_ms = _now_ms()
|
||||||
@@ -272,6 +286,7 @@ class CronService:
|
|||||||
) -> CronJob:
|
) -> CronJob:
|
||||||
"""Add a new job."""
|
"""Add a new job."""
|
||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
|
_validate_schedule_for_add(schedule)
|
||||||
now = _now_ms()
|
now = _now_ms()
|
||||||
|
|
||||||
job = CronJob(
|
job = CronJob(
|
||||||
@@ -296,7 +311,7 @@ class CronService:
|
|||||||
self._save_store()
|
self._save_store()
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
|
|
||||||
logger.info(f"Cron: added job '{name}' ({job.id})")
|
logger.info("Cron: added job '{}' ({})", name, job.id)
|
||||||
return job
|
return job
|
||||||
|
|
||||||
def remove_job(self, job_id: str) -> bool:
|
def remove_job(self, job_id: str) -> bool:
|
||||||
@@ -309,7 +324,7 @@ class CronService:
|
|||||||
if removed:
|
if removed:
|
||||||
self._save_store()
|
self._save_store()
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
logger.info(f"Cron: removed job {job_id}")
|
logger.info("Cron: removed job {}", job_id)
|
||||||
|
|
||||||
return removed
|
return removed
|
||||||
|
|
||||||
|
|||||||
@@ -1,92 +1,130 @@
|
|||||||
"""Heartbeat service - periodic agent wake-up to check for tasks."""
|
"""Heartbeat service - periodic agent wake-up to check for tasks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# Default interval: 30 minutes
|
if TYPE_CHECKING:
|
||||||
DEFAULT_HEARTBEAT_INTERVAL_S = 30 * 60
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
# The prompt sent to agent during heartbeat
|
_HEARTBEAT_TOOL = [
|
||||||
HEARTBEAT_PROMPT = """Read HEARTBEAT.md in your workspace (if it exists).
|
{
|
||||||
Follow any instructions or tasks listed there.
|
"type": "function",
|
||||||
If nothing needs attention, reply with just: HEARTBEAT_OK"""
|
"function": {
|
||||||
|
"name": "heartbeat",
|
||||||
# Token that indicates "nothing to do"
|
"description": "Report heartbeat decision after reviewing tasks.",
|
||||||
HEARTBEAT_OK_TOKEN = "HEARTBEAT_OK"
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
def _is_heartbeat_empty(content: str | None) -> bool:
|
"action": {
|
||||||
"""Check if HEARTBEAT.md has no actionable content."""
|
"type": "string",
|
||||||
if not content:
|
"enum": ["skip", "run"],
|
||||||
return True
|
"description": "skip = nothing to do, run = has active tasks",
|
||||||
|
},
|
||||||
# Lines to skip: empty, headers, HTML comments, empty checkboxes
|
"tasks": {
|
||||||
skip_patterns = {"- [ ]", "* [ ]", "- [x]", "* [x]"}
|
"type": "string",
|
||||||
|
"description": "Natural-language summary of active tasks (required for run)",
|
||||||
for line in content.split("\n"):
|
},
|
||||||
line = line.strip()
|
},
|
||||||
if not line or line.startswith("#") or line.startswith("<!--") or line in skip_patterns:
|
"required": ["action"],
|
||||||
continue
|
},
|
||||||
return False # Found actionable content
|
},
|
||||||
|
}
|
||||||
return True
|
]
|
||||||
|
|
||||||
|
|
||||||
class HeartbeatService:
|
class HeartbeatService:
|
||||||
"""
|
"""
|
||||||
Periodic heartbeat service that wakes the agent to check for tasks.
|
Periodic heartbeat service that wakes the agent to check for tasks.
|
||||||
|
|
||||||
The agent reads HEARTBEAT.md from the workspace and executes any
|
Phase 1 (decision): reads HEARTBEAT.md and asks the LLM — via a virtual
|
||||||
tasks listed there. If nothing needs attention, it replies HEARTBEAT_OK.
|
tool call — whether there are active tasks. This avoids free-text parsing
|
||||||
|
and the unreliable HEARTBEAT_OK token.
|
||||||
|
|
||||||
|
Phase 2 (execution): only triggered when Phase 1 returns ``run``. The
|
||||||
|
``on_execute`` callback runs the task through the full agent loop and
|
||||||
|
returns the result to deliver.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
on_heartbeat: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
provider: LLMProvider,
|
||||||
interval_s: int = DEFAULT_HEARTBEAT_INTERVAL_S,
|
model: str,
|
||||||
|
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
||||||
|
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
||||||
|
interval_s: int = 30 * 60,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
):
|
):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.on_heartbeat = on_heartbeat
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.on_execute = on_execute
|
||||||
|
self.on_notify = on_notify
|
||||||
self.interval_s = interval_s
|
self.interval_s = interval_s
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task: asyncio.Task | None = None
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def heartbeat_file(self) -> Path:
|
def heartbeat_file(self) -> Path:
|
||||||
return self.workspace / "HEARTBEAT.md"
|
return self.workspace / "HEARTBEAT.md"
|
||||||
|
|
||||||
def _read_heartbeat_file(self) -> str | None:
|
def _read_heartbeat_file(self) -> str | None:
|
||||||
"""Read HEARTBEAT.md content."""
|
|
||||||
if self.heartbeat_file.exists():
|
if self.heartbeat_file.exists():
|
||||||
try:
|
try:
|
||||||
return self.heartbeat_file.read_text()
|
return self.heartbeat_file.read_text(encoding="utf-8")
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _decide(self, content: str) -> tuple[str, str]:
|
||||||
|
"""Phase 1: ask LLM to decide skip/run via virtual tool call.
|
||||||
|
|
||||||
|
Returns (action, tasks) where action is 'skip' or 'run'.
|
||||||
|
"""
|
||||||
|
response = await self.provider.chat(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
||||||
|
{"role": "user", "content": (
|
||||||
|
"Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n"
|
||||||
|
f"{content}"
|
||||||
|
)},
|
||||||
|
],
|
||||||
|
tools=_HEARTBEAT_TOOL,
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.has_tool_calls:
|
||||||
|
return "skip", ""
|
||||||
|
|
||||||
|
args = response.tool_calls[0].arguments
|
||||||
|
return args.get("action", "skip"), args.get("tasks", "")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the heartbeat service."""
|
"""Start the heartbeat service."""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
logger.info("Heartbeat disabled")
|
logger.info("Heartbeat disabled")
|
||||||
return
|
return
|
||||||
|
if self._running:
|
||||||
|
logger.warning("Heartbeat already running")
|
||||||
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
self._task = asyncio.create_task(self._run_loop())
|
||||||
logger.info(f"Heartbeat started (every {self.interval_s}s)")
|
logger.info("Heartbeat started (every {}s)", self.interval_s)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the heartbeat service."""
|
"""Stop the heartbeat service."""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._task:
|
if self._task:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
self._task = None
|
self._task = None
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
async def _run_loop(self) -> None:
|
||||||
"""Main heartbeat loop."""
|
"""Main heartbeat loop."""
|
||||||
while self._running:
|
while self._running:
|
||||||
@@ -97,34 +135,39 @@ class HeartbeatService:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Heartbeat error: {e}")
|
logger.error("Heartbeat error: {}", e)
|
||||||
|
|
||||||
async def _tick(self) -> None:
|
async def _tick(self) -> None:
|
||||||
"""Execute a single heartbeat tick."""
|
"""Execute a single heartbeat tick."""
|
||||||
content = self._read_heartbeat_file()
|
content = self._read_heartbeat_file()
|
||||||
|
if not content:
|
||||||
# Skip if HEARTBEAT.md is empty or doesn't exist
|
logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
|
||||||
if _is_heartbeat_empty(content):
|
|
||||||
logger.debug("Heartbeat: no tasks (HEARTBEAT.md empty)")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Heartbeat: checking for tasks...")
|
logger.info("Heartbeat: checking for tasks...")
|
||||||
|
|
||||||
if self.on_heartbeat:
|
try:
|
||||||
try:
|
action, tasks = await self._decide(content)
|
||||||
response = await self.on_heartbeat(HEARTBEAT_PROMPT)
|
|
||||||
|
if action != "run":
|
||||||
# Check if agent said "nothing to do"
|
logger.info("Heartbeat: OK (nothing to report)")
|
||||||
if HEARTBEAT_OK_TOKEN.replace("_", "") in response.upper().replace("_", ""):
|
return
|
||||||
logger.info("Heartbeat: OK (no action needed)")
|
|
||||||
else:
|
logger.info("Heartbeat: tasks found, executing...")
|
||||||
logger.info(f"Heartbeat: completed task")
|
if self.on_execute:
|
||||||
|
response = await self.on_execute(tasks)
|
||||||
except Exception as e:
|
if response and self.on_notify:
|
||||||
logger.error(f"Heartbeat execution failed: {e}")
|
logger.info("Heartbeat: completed, delivering response")
|
||||||
|
await self.on_notify(response)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Heartbeat execution failed")
|
||||||
|
|
||||||
async def trigger_now(self) -> str | None:
|
async def trigger_now(self) -> str | None:
|
||||||
"""Manually trigger a heartbeat."""
|
"""Manually trigger a heartbeat."""
|
||||||
if self.on_heartbeat:
|
content = self._read_heartbeat_file()
|
||||||
return await self.on_heartbeat(HEARTBEAT_PROMPT)
|
if not content:
|
||||||
return None
|
return None
|
||||||
|
action, tasks = await self._decide(content)
|
||||||
|
if action != "run" or not self.on_execute:
|
||||||
|
return None
|
||||||
|
return await self.on_execute(tasks)
|
||||||
|
|||||||
@@ -39,6 +39,46 @@ class LLMProvider(ABC):
|
|||||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Replace empty text content that causes provider 400 errors.
|
||||||
|
|
||||||
|
Empty content can appear when MCP tools return nothing. Most providers
|
||||||
|
reject empty-string content or empty text blocks in list content.
|
||||||
|
"""
|
||||||
|
result: list[dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
if isinstance(content, str) and not content:
|
||||||
|
clean = dict(msg)
|
||||||
|
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
|
||||||
|
result.append(clean)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
filtered = [
|
||||||
|
item for item in content
|
||||||
|
if not (
|
||||||
|
isinstance(item, dict)
|
||||||
|
and item.get("type") in ("text", "input_text", "output_text")
|
||||||
|
and not item.get("text")
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if len(filtered) != len(content):
|
||||||
|
clean = dict(msg)
|
||||||
|
if filtered:
|
||||||
|
clean["content"] = filtered
|
||||||
|
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
clean["content"] = None
|
||||||
|
else:
|
||||||
|
clean["content"] = "(empty)"
|
||||||
|
result.append(clean)
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
|
|||||||
@@ -19,8 +19,12 @@ class CustomProvider(LLMProvider):
|
|||||||
|
|
||||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
|
||||||
kwargs: dict[str, Any] = {"model": model or self.default_model, "messages": messages,
|
kwargs: dict[str, Any] = {
|
||||||
"max_tokens": max(1, max_tokens), "temperature": temperature}
|
"model": model or self.default_model,
|
||||||
|
"messages": self._sanitize_empty_content(messages),
|
||||||
|
"max_tokens": max(1, max_tokens),
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
if tools:
|
if tools:
|
||||||
kwargs.update(tools=tools, tool_choice="auto")
|
kwargs.update(tools=tools, tool_choice="auto")
|
||||||
try:
|
try:
|
||||||
@@ -40,8 +44,9 @@ class CustomProvider(LLMProvider):
|
|||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
||||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||||
reasoning_content=getattr(msg, "reasoning_content", None),
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,11 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
from nanobot.providers.registry import find_by_model, find_gateway
|
||||||
|
|
||||||
|
|
||||||
|
# Standard OpenAI chat-completion message keys plus reasoning_content for
|
||||||
|
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
|
||||||
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(LLMProvider):
|
class LiteLLMProvider(LLMProvider):
|
||||||
"""
|
"""
|
||||||
LLM provider using LiteLLM for multi-provider support.
|
LLM provider using LiteLLM for multi-provider support.
|
||||||
@@ -104,6 +109,39 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
return model
|
return model
|
||||||
return f"{canonical_prefix}/{remainder}"
|
return f"{canonical_prefix}/{remainder}"
|
||||||
|
|
||||||
|
def _supports_cache_control(self, model: str) -> bool:
|
||||||
|
"""Return True when the provider supports cache_control on content blocks."""
|
||||||
|
if self._gateway is not None:
|
||||||
|
return self._gateway.supports_prompt_caching
|
||||||
|
spec = find_by_model(model)
|
||||||
|
return spec is not None and spec.supports_prompt_caching
|
||||||
|
|
||||||
|
def _apply_cache_control(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||||
|
"""Return copies of messages and tools with cache_control injected."""
|
||||||
|
new_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
content = msg["content"]
|
||||||
|
if isinstance(content, str):
|
||||||
|
new_content = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||||
|
else:
|
||||||
|
new_content = list(content)
|
||||||
|
new_content[-1] = {**new_content[-1], "cache_control": {"type": "ephemeral"}}
|
||||||
|
new_messages.append({**msg, "content": new_content})
|
||||||
|
else:
|
||||||
|
new_messages.append(msg)
|
||||||
|
|
||||||
|
new_tools = tools
|
||||||
|
if tools:
|
||||||
|
new_tools = list(tools)
|
||||||
|
new_tools[-1] = {**new_tools[-1], "cache_control": {"type": "ephemeral"}}
|
||||||
|
|
||||||
|
return new_messages, new_tools
|
||||||
|
|
||||||
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
||||||
"""Apply model-specific parameter overrides from the registry."""
|
"""Apply model-specific parameter overrides from the registry."""
|
||||||
model_lower = model.lower()
|
model_lower = model.lower()
|
||||||
@@ -114,6 +152,18 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
kwargs.update(overrides)
|
kwargs.update(overrides)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||||
|
sanitized = []
|
||||||
|
for msg in messages:
|
||||||
|
clean = {k: v for k, v in msg.items() if k in _ALLOWED_MSG_KEYS}
|
||||||
|
# Strict providers require "content" even when assistant only has tool_calls
|
||||||
|
if clean.get("role") == "assistant" and "content" not in clean:
|
||||||
|
clean["content"] = None
|
||||||
|
sanitized.append(clean)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
@@ -135,15 +185,19 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
Returns:
|
Returns:
|
||||||
LLMResponse with content and/or tool calls.
|
LLMResponse with content and/or tool calls.
|
||||||
"""
|
"""
|
||||||
model = self._resolve_model(model or self.default_model)
|
original_model = model or self.default_model
|
||||||
|
model = self._resolve_model(original_model)
|
||||||
|
|
||||||
|
if self._supports_cache_control(original_model):
|
||||||
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
|
|
||||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||||
max_tokens = max(1, max_tokens)
|
max_tokens = max(1, max_tokens)
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
@@ -204,7 +258,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"total_tokens": response.usage.total_tokens,
|
"total_tokens": response.usage.total_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
reasoning_content = getattr(message, "reasoning_content", None)
|
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=message.content,
|
content=message.content,
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st
|
|||||||
|
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||||
output_text = content if isinstance(content, str) else json.dumps(content)
|
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||||
input_items.append(
|
input_items.append(
|
||||||
{
|
{
|
||||||
"type": "function_call_output",
|
"type": "function_call_output",
|
||||||
|
|||||||
@@ -57,6 +57,9 @@ class ProviderSpec:
|
|||||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
||||||
is_direct: bool = False
|
is_direct: bool = False
|
||||||
|
|
||||||
|
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
||||||
|
supports_prompt_caching: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def label(self) -> str:
|
def label(self) -> str:
|
||||||
return self.display_name or self.name.title()
|
return self.display_name or self.name.title()
|
||||||
@@ -97,6 +100,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
default_api_base="https://openrouter.ai/api/v1",
|
default_api_base="https://openrouter.ai/api/v1",
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||||
@@ -137,6 +141,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
||||||
|
ProviderSpec(
|
||||||
|
name="volcengine",
|
||||||
|
keywords=("volcengine", "volces", "ark"),
|
||||||
|
env_key="OPENAI_API_KEY",
|
||||||
|
display_name="VolcEngine",
|
||||||
|
litellm_prefix="volcengine",
|
||||||
|
skip_prefixes=(),
|
||||||
|
env_extras=(),
|
||||||
|
is_gateway=True,
|
||||||
|
is_local=False,
|
||||||
|
detect_by_key_prefix="",
|
||||||
|
detect_by_base_keyword="volces",
|
||||||
|
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||||
|
strip_model_prefix=False,
|
||||||
|
model_overrides=(),
|
||||||
|
),
|
||||||
|
|
||||||
# === Standard providers (matched by model-name keywords) ===============
|
# === Standard providers (matched by model-name keywords) ===============
|
||||||
|
|
||||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||||
@@ -155,6 +177,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
default_api_base="",
|
default_api_base="",
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class GroqTranscriptionProvider:
|
|||||||
|
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
logger.error(f"Audio file not found: {file_path}")
|
logger.error("Audio file not found: {}", file_path)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -61,5 +61,5 @@ class GroqTranscriptionProvider:
|
|||||||
return data.get("text", "")
|
return data.get("text", "")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Groq transcription error: {e}")
|
logger.error("Groq transcription error: {}", e)
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Session management for conversation history."""
|
"""Session management for conversation history."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -42,9 +43,18 @@ class Session:
|
|||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||||
"""Get recent messages in LLM format, preserving tool metadata."""
|
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
||||||
|
unconsolidated = self.messages[self.last_consolidated:]
|
||||||
|
sliced = unconsolidated[-max_messages:]
|
||||||
|
|
||||||
|
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
||||||
|
for i, m in enumerate(sliced):
|
||||||
|
if m.get("role") == "user":
|
||||||
|
sliced = sliced[i:]
|
||||||
|
break
|
||||||
|
|
||||||
out: list[dict[str, Any]] = []
|
out: list[dict[str, Any]] = []
|
||||||
for m in self.messages[-max_messages:]:
|
for m in sliced:
|
||||||
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
|
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
|
||||||
for k in ("tool_calls", "tool_call_id", "name"):
|
for k in ("tool_calls", "tool_call_id", "name"):
|
||||||
if k in m:
|
if k in m:
|
||||||
@@ -108,9 +118,11 @@ class SessionManager:
|
|||||||
if not path.exists():
|
if not path.exists():
|
||||||
legacy_path = self._get_legacy_session_path(key)
|
legacy_path = self._get_legacy_session_path(key)
|
||||||
if legacy_path.exists():
|
if legacy_path.exists():
|
||||||
import shutil
|
try:
|
||||||
shutil.move(str(legacy_path), str(path))
|
shutil.move(str(legacy_path), str(path))
|
||||||
logger.info(f"Migrated session {key} from legacy path")
|
logger.info("Migrated session {} from legacy path", key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to migrate session {}", key)
|
||||||
|
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return None
|
return None
|
||||||
@@ -121,7 +133,7 @@ class SessionManager:
|
|||||||
created_at = None
|
created_at = None
|
||||||
last_consolidated = 0
|
last_consolidated = 0
|
||||||
|
|
||||||
with open(path) as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
@@ -144,24 +156,25 @@ class SessionManager:
|
|||||||
last_consolidated=last_consolidated
|
last_consolidated=last_consolidated
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load session {key}: {e}")
|
logger.warning("Failed to load session {}: {}", key, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def save(self, session: Session) -> None:
|
def save(self, session: Session) -> None:
|
||||||
"""Save a session to disk."""
|
"""Save a session to disk."""
|
||||||
path = self._get_session_path(session.key)
|
path = self._get_session_path(session.key)
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
metadata_line = {
|
metadata_line = {
|
||||||
"_type": "metadata",
|
"_type": "metadata",
|
||||||
|
"key": session.key,
|
||||||
"created_at": session.created_at.isoformat(),
|
"created_at": session.created_at.isoformat(),
|
||||||
"updated_at": session.updated_at.isoformat(),
|
"updated_at": session.updated_at.isoformat(),
|
||||||
"metadata": session.metadata,
|
"metadata": session.metadata,
|
||||||
"last_consolidated": session.last_consolidated
|
"last_consolidated": session.last_consolidated
|
||||||
}
|
}
|
||||||
f.write(json.dumps(metadata_line) + "\n")
|
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
||||||
for msg in session.messages:
|
for msg in session.messages:
|
||||||
f.write(json.dumps(msg) + "\n")
|
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
self._cache[session.key] = session
|
self._cache[session.key] = session
|
||||||
|
|
||||||
@@ -181,13 +194,14 @@ class SessionManager:
|
|||||||
for path in self.sessions_dir.glob("*.jsonl"):
|
for path in self.sessions_dir.glob("*.jsonl"):
|
||||||
try:
|
try:
|
||||||
# Read just the metadata line
|
# Read just the metadata line
|
||||||
with open(path) as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
first_line = f.readline().strip()
|
first_line = f.readline().strip()
|
||||||
if first_line:
|
if first_line:
|
||||||
data = json.loads(first_line)
|
data = json.loads(first_line)
|
||||||
if data.get("_type") == "metadata":
|
if data.get("_type") == "metadata":
|
||||||
|
key = data.get("key") or path.stem.replace("_", ":", 1)
|
||||||
sessions.append({
|
sessions.append({
|
||||||
"key": path.stem.replace("_", ":"),
|
"key": key,
|
||||||
"created_at": data.get("created_at"),
|
"created_at": data.get("created_at"),
|
||||||
"updated_at": data.get("updated_at"),
|
"updated_at": data.get("updated_at"),
|
||||||
"path": str(path)
|
"path": str(path)
|
||||||
|
|||||||
23
nanobot/templates/AGENTS.md
Normal file
23
nanobot/templates/AGENTS.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Agent Instructions
|
||||||
|
|
||||||
|
You are a helpful AI assistant. Be concise, accurate, and friendly.
|
||||||
|
|
||||||
|
## Scheduled Reminders
|
||||||
|
|
||||||
|
When user asks for a reminder at a specific time, use `exec` to run:
|
||||||
|
```
|
||||||
|
nanobot cron add --name "reminder" --message "Your message" --at "YYYY-MM-DDTHH:MM:SS" --deliver --to "USER_ID" --channel "CHANNEL"
|
||||||
|
```
|
||||||
|
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
|
||||||
|
|
||||||
|
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
|
||||||
|
|
||||||
|
## Heartbeat Tasks
|
||||||
|
|
||||||
|
`HEARTBEAT.md` is checked every 30 minutes. Use file tools to manage periodic tasks:
|
||||||
|
|
||||||
|
- **Add**: `edit_file` to append new tasks
|
||||||
|
- **Remove**: `edit_file` to delete completed tasks
|
||||||
|
- **Rewrite**: `write_file` to replace all tasks
|
||||||
|
|
||||||
|
When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder.
|
||||||
15
nanobot/templates/TOOLS.md
Normal file
15
nanobot/templates/TOOLS.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Tool Usage Notes
|
||||||
|
|
||||||
|
Tool signatures are provided automatically via function calling.
|
||||||
|
This file documents non-obvious constraints and usage patterns.
|
||||||
|
|
||||||
|
## exec — Safety Limits
|
||||||
|
|
||||||
|
- Commands have a configurable timeout (default 60s)
|
||||||
|
- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
|
||||||
|
- Output is truncated at 10,000 characters
|
||||||
|
- `restrictToWorkspace` config can limit file access to the workspace
|
||||||
|
|
||||||
|
## cron — Scheduled Reminders
|
||||||
|
|
||||||
|
- Please refer to cron skill for usage.
|
||||||
0
nanobot/templates/__init__.py
Normal file
0
nanobot/templates/__init__.py
Normal file
0
nanobot/templates/memory/__init__.py
Normal file
0
nanobot/templates/memory/__init__.py
Normal file
@@ -3,7 +3,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
def ensure_dir(path: Path) -> Path:
|
def ensure_dir(path: Path) -> Path:
|
||||||
"""Ensure a directory exists, creating it if necessary."""
|
"""Ensure a directory exists, creating it if necessary."""
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -77,4 +76,4 @@ def parse_session_key(key: str) -> tuple[str, str]:
|
|||||||
parts = key.split(":", 1)
|
parts = key.split(":", 1)
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
raise ValueError(f"Invalid session key: {key}")
|
raise ValueError(f"Invalid session key: {key}")
|
||||||
return parts[0], parts[1]
|
return parts[0], parts[1]
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "nanobot-ai"
|
name = "nanobot-ai"
|
||||||
version = "0.1.4"
|
version = "0.1.4.post2"
|
||||||
description = "A lightweight personal AI assistant framework"
|
description = "A lightweight personal AI assistant framework"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
@@ -67,10 +67,11 @@ packages = ["nanobot"]
|
|||||||
[tool.hatch.build.targets.wheel.sources]
|
[tool.hatch.build.targets.wheel.sources]
|
||||||
"nanobot" = "nanobot"
|
"nanobot" = "nanobot"
|
||||||
|
|
||||||
# Include non-Python files in skills
|
# Include non-Python files in skills and templates
|
||||||
[tool.hatch.build]
|
[tool.hatch.build]
|
||||||
include = [
|
include = [
|
||||||
"nanobot/**/*.py",
|
"nanobot/**/*.py",
|
||||||
|
"nanobot/templates/**/*.md",
|
||||||
"nanobot/skills/**/*.md",
|
"nanobot/skills/**/*.md",
|
||||||
"nanobot/skills/**/*.sh",
|
"nanobot/skills/**/*.sh",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""Test session management with cache-friendly message handling."""
|
"""Test session management with cache-friendly message handling."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
@@ -475,3 +478,351 @@ class TestEmptyAndBoundarySessions:
|
|||||||
expected_count = 60 - KEEP_COUNT - 10
|
expected_count = 60 - KEEP_COUNT - 10
|
||||||
assert len(old_messages) == expected_count
|
assert len(old_messages) == expected_count
|
||||||
assert_messages_content(old_messages, 10, 34)
|
assert_messages_content(old_messages, 10, 34)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsolidationDeduplicationGuard:
|
||||||
|
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||||
|
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
consolidation_calls = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
nonlocal consolidation_calls
|
||||||
|
consolidation_calls += 1
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert consolidation_calls == 1, (
|
||||||
|
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
consolidation_calls = 0
|
||||||
|
active = 0
|
||||||
|
max_active = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
nonlocal consolidation_calls, active, max_active
|
||||||
|
consolidation_calls += 1
|
||||||
|
active += 1
|
||||||
|
max_active = max(max_active, active)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
active -= 1
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
await loop._process_message(new_msg)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert consolidation_calls == 2, (
|
||||||
|
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||||
|
)
|
||||||
|
assert max_active == 1, (
|
||||||
|
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||||
|
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
started.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
|
||||||
|
await started.wait()
|
||||||
|
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||||
|
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
assert len(loop._consolidation_tasks) == 0, (
|
||||||
|
"Task reference must be removed after completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new waits for in-flight consolidation and archives before clear."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
archived_count = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
nonlocal archived_count
|
||||||
|
if archive_all:
|
||||||
|
archived_count = len(sess.messages)
|
||||||
|
return True
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||||
|
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
response = await pending_new
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||||
|
|
||||||
|
session_after = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||||
|
"""/new must keep session data if archive step reports failure."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(5):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
before_count = len(session.messages)
|
||||||
|
|
||||||
|
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
if archive_all:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "failed" in response.content.lower()
|
||||||
|
session_after = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert len(session_after.messages) == before_count, (
|
||||||
|
"Session must remain intact when /new archival fails"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new should archive only messages not yet consolidated by prior task."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
archived_count = -1
|
||||||
|
|
||||||
|
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
nonlocal archived_count
|
||||||
|
if archive_all:
|
||||||
|
archived_count = len(sess.messages)
|
||||||
|
return True
|
||||||
|
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
sess.last_consolidated = len(sess.messages) - 3
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
assert not pending_new.done()
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
response = await pending_new
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert archived_count == 3, (
|
||||||
|
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new should remove lock entry for fully invalidated session key."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(3):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
# Ensure lock exists before /new.
|
||||||
|
_ = loop._get_consolidation_lock(session.key)
|
||||||
|
assert session.key in loop._consolidation_locks
|
||||||
|
|
||||||
|
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert session.key not in loop._consolidation_locks
|
||||||
|
|||||||
66
tests/test_context_prompt_cache.py
Normal file
66
tests/test_context_prompt_cache.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""Tests for cache-friendly prompt construction."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime as real_datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import datetime as datetime_module
|
||||||
|
|
||||||
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeDatetime(real_datetime):
|
||||||
|
current = real_datetime(2026, 2, 24, 13, 59)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def now(cls, tz=None): # type: ignore[override]
|
||||||
|
return cls.current
|
||||||
|
|
||||||
|
|
||||||
|
def _make_workspace(tmp_path: Path) -> Path:
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir(parents=True)
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
|
||||||
|
"""System prompt should not change just because wall clock minute changes."""
|
||||||
|
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
|
||||||
|
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
|
||||||
|
prompt1 = builder.build_system_prompt()
|
||||||
|
|
||||||
|
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
|
||||||
|
prompt2 = builder.build_system_prompt()
|
||||||
|
|
||||||
|
assert prompt1 == prompt2
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||||
|
"""Runtime metadata should be a separate user message before the actual user message."""
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
messages = builder.build_messages(
|
||||||
|
history=[],
|
||||||
|
current_message="Return exactly: OK",
|
||||||
|
channel="cli",
|
||||||
|
chat_id="direct",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert messages[0]["role"] == "system"
|
||||||
|
assert "## Current Session" not in messages[0]["content"]
|
||||||
|
|
||||||
|
assert messages[-2]["role"] == "user"
|
||||||
|
runtime_content = messages[-2]["content"]
|
||||||
|
assert isinstance(runtime_content, str)
|
||||||
|
assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
|
||||||
|
assert "Current Time:" in runtime_content
|
||||||
|
assert "Channel: cli" in runtime_content
|
||||||
|
assert "Chat ID: direct" in runtime_content
|
||||||
|
|
||||||
|
assert messages[-1]["role"] == "user"
|
||||||
|
assert messages[-1]["content"] == "Return exactly: OK"
|
||||||
29
tests/test_cron_commands.py
Normal file
29
tests/test_cron_commands.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from nanobot.cli.commands import app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
app,
|
||||||
|
[
|
||||||
|
"cron",
|
||||||
|
"add",
|
||||||
|
"--name",
|
||||||
|
"demo",
|
||||||
|
"--message",
|
||||||
|
"hello",
|
||||||
|
"--cron",
|
||||||
|
"0 9 * * *",
|
||||||
|
"--tz",
|
||||||
|
"America/Vancovuer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
|
||||||
|
assert not (tmp_path / "cron" / "jobs.json").exists()
|
||||||
30
tests/test_cron_service.py
Normal file
30
tests/test_cron_service.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.cron.service import CronService
|
||||||
|
from nanobot.cron.types import CronSchedule
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_job_rejects_unknown_timezone(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="unknown timezone 'America/Vancovuer'"):
|
||||||
|
service.add_job(
|
||||||
|
name="tz typo",
|
||||||
|
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancovuer"),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert service.list_jobs(include_disabled=True) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
|
||||||
|
job = service.add_job(
|
||||||
|
name="tz ok",
|
||||||
|
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancouver"),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert job.schedule.tz == "America/Vancouver"
|
||||||
|
assert job.state.next_run_at_ms is not None
|
||||||
@@ -169,7 +169,8 @@ async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
|
||||||
|
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
|
||||||
class FakeSMTP:
|
class FakeSMTP:
|
||||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||||
self.sent_messages: list[EmailMessage] = []
|
self.sent_messages: list[EmailMessage] = []
|
||||||
@@ -201,6 +202,11 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
|||||||
cfg = _make_config()
|
cfg = _make_config()
|
||||||
cfg.auto_reply_enabled = False
|
cfg.auto_reply_enabled = False
|
||||||
channel = EmailChannel(cfg, MessageBus())
|
channel = EmailChannel(cfg, MessageBus())
|
||||||
|
|
||||||
|
# Mark alice as someone who sent us an email (making this a "reply")
|
||||||
|
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
|
||||||
|
|
||||||
|
# Reply should be skipped (auto_reply_enabled=False)
|
||||||
await channel.send(
|
await channel.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
channel="email",
|
channel="email",
|
||||||
@@ -210,6 +216,7 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
|||||||
)
|
)
|
||||||
assert fake_instances == []
|
assert fake_instances == []
|
||||||
|
|
||||||
|
# Reply with force_send=True should be sent
|
||||||
await channel.send(
|
await channel.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
channel="email",
|
channel="email",
|
||||||
@@ -222,6 +229,56 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
|||||||
assert len(fake_instances[0].sent_messages) == 1
|
assert len(fake_instances[0].sent_messages) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
|
||||||
|
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
|
||||||
|
class FakeSMTP:
|
||||||
|
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||||
|
self.sent_messages: list[EmailMessage] = []
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def starttls(self, context=None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def login(self, _user: str, _pw: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send_message(self, msg: EmailMessage):
|
||||||
|
self.sent_messages.append(msg)
|
||||||
|
|
||||||
|
fake_instances: list[FakeSMTP] = []
|
||||||
|
|
||||||
|
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||||
|
instance = FakeSMTP(host, port, timeout=timeout)
|
||||||
|
fake_instances.append(instance)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||||
|
|
||||||
|
cfg = _make_config()
|
||||||
|
cfg.auto_reply_enabled = False
|
||||||
|
channel = EmailChannel(cfg, MessageBus())
|
||||||
|
|
||||||
|
# bob@example.com has never sent us an email (proactive send)
|
||||||
|
# This should be sent even with auto_reply_enabled=False
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="email",
|
||||||
|
chat_id="bob@example.com",
|
||||||
|
content="Hello, this is a proactive email.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(fake_instances) == 1
|
||||||
|
assert len(fake_instances[0].sent_messages) == 1
|
||||||
|
sent = fake_instances[0].sent_messages[0]
|
||||||
|
assert sent["To"] == "bob@example.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
||||||
class FakeSMTP:
|
class FakeSMTP:
|
||||||
|
|||||||
44
tests/test_heartbeat_service.py
Normal file
44
tests/test_heartbeat_service.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.heartbeat.service import (
|
||||||
|
HEARTBEAT_OK_TOKEN,
|
||||||
|
HeartbeatService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_heartbeat_ok_detection() -> None:
|
||||||
|
def is_ok(response: str) -> bool:
|
||||||
|
return HEARTBEAT_OK_TOKEN in response.upper()
|
||||||
|
|
||||||
|
assert is_ok("HEARTBEAT_OK")
|
||||||
|
assert is_ok("`HEARTBEAT_OK`")
|
||||||
|
assert is_ok("**HEARTBEAT_OK**")
|
||||||
|
assert is_ok("heartbeat_ok")
|
||||||
|
assert is_ok("HEARTBEAT_OK.")
|
||||||
|
|
||||||
|
assert not is_ok("HEARTBEAT_NOT_OK")
|
||||||
|
assert not is_ok("all good")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_is_idempotent(tmp_path) -> None:
|
||||||
|
async def _on_heartbeat(_: str) -> str:
|
||||||
|
return "HEARTBEAT_OK"
|
||||||
|
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
on_heartbeat=_on_heartbeat,
|
||||||
|
interval_s=9999,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await service.start()
|
||||||
|
first_task = service._task
|
||||||
|
await service.start()
|
||||||
|
|
||||||
|
assert service._task is first_task
|
||||||
|
|
||||||
|
service.stop()
|
||||||
|
await asyncio.sleep(0)
|
||||||
147
tests/test_memory_consolidation_types.py
Normal file
147
tests/test_memory_consolidation_types.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
|
||||||
|
|
||||||
|
Regression test for https://github.com/HKUDS/nanobot/issues/1042
|
||||||
|
When memory consolidation receives dict values instead of strings from the LLM
|
||||||
|
tool call response, it should serialize them to JSON instead of raising TypeError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.memory import MemoryStore
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session(message_count: int = 30, memory_window: int = 50):
|
||||||
|
"""Create a mock session with messages."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||||
|
for i in range(message_count)
|
||||||
|
]
|
||||||
|
session.last_consolidated = 0
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool_response(history_entry, memory_update):
|
||||||
|
"""Create an LLMResponse with a save_memory tool call."""
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={
|
||||||
|
"history_entry": history_entry,
|
||||||
|
"memory_update": memory_update,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryConsolidationTypeHandling:
|
||||||
|
"""Test that consolidation handles various argument types correctly."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_string_arguments_work(self, tmp_path: Path) -> None:
|
||||||
|
"""Normal case: LLM returns string arguments."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry="[2026-01-01] User discussed testing.",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert store.history_file.exists()
|
||||||
|
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
|
||||||
|
assert "User likes testing." in store.memory_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
|
||||||
|
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
|
||||||
|
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert store.history_file.exists()
|
||||||
|
history_content = store.history_file.read_text()
|
||||||
|
parsed = json.loads(history_content.strip())
|
||||||
|
assert parsed["summary"] == "User discussed testing."
|
||||||
|
|
||||||
|
memory_content = store.memory_file.read_text()
|
||||||
|
parsed_mem = json.loads(memory_content)
|
||||||
|
assert "User likes testing" in parsed_mem["facts"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
|
||||||
|
"""Some providers return arguments as a JSON string instead of parsed dict."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
# Simulate arguments being a JSON string (not yet parsed)
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"history_entry": "[2026-01-01] User discussed testing.",
|
||||||
|
"memory_update": "# Memory\nUser likes testing.",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
|
||||||
|
"""When LLM doesn't use the save_memory tool, return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
||||||
|
"""Consolidation should be a no-op when messages < keep_count."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
session = _make_session(message_count=10)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
provider.chat.assert_not_called()
|
||||||
167
tests/test_task_cancel.py
Normal file
167
tests/test_task_cancel.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""Tests for /stop task cancellation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop():
|
||||||
|
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
workspace = MagicMock()
|
||||||
|
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||||
|
return loop, bus
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleStop:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_no_active_task(self):
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||||
|
await loop._handle_stop(msg)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "No active task" in out.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_active_task(self):
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow_task():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow_task())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
loop._active_tasks["test:c1"] = [task]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||||
|
await loop._handle_stop(msg)
|
||||||
|
|
||||||
|
assert cancelled.is_set()
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "stopped" in out.content.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_multiple_tasks(self):
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
events = [asyncio.Event(), asyncio.Event()]
|
||||||
|
|
||||||
|
async def slow(idx):
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
events[idx].set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
loop._active_tasks["test:c1"] = tasks
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||||
|
await loop._handle_stop(msg)
|
||||||
|
|
||||||
|
assert all(e.is_set() for e in events)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "2 task" in out.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestDispatch:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_processes_and_publishes(self):
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
|
||||||
|
loop._process_message = AsyncMock(
|
||||||
|
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
|
||||||
|
)
|
||||||
|
await loop._dispatch(msg)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert out.content == "hi"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_processing_lock_serializes(self):
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
order = []
|
||||||
|
|
||||||
|
async def mock_process(m, **kwargs):
|
||||||
|
order.append(f"start-{m.content}")
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
order.append(f"end-{m.content}")
|
||||||
|
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
|
||||||
|
|
||||||
|
loop._process_message = mock_process
|
||||||
|
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
|
||||||
|
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
|
||||||
|
|
||||||
|
t1 = asyncio.create_task(loop._dispatch(msg1))
|
||||||
|
t2 = asyncio.create_task(loop._dispatch(msg2))
|
||||||
|
await asyncio.gather(t1, t2)
|
||||||
|
assert order == ["start-a", "end-a", "start-b", "end-b"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubagentCancellation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_by_session(self):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
|
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
mgr._running_tasks["sub-1"] = task
|
||||||
|
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||||
|
|
||||||
|
count = await mgr.cancel_by_session("test:c1")
|
||||||
|
assert count == 1
|
||||||
|
assert cancelled.is_set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_by_session_no_tasks(self):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
|
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
# Agent Instructions
|
|
||||||
|
|
||||||
You are a helpful AI assistant. Be concise, accurate, and friendly.
|
|
||||||
|
|
||||||
## Guidelines
|
|
||||||
|
|
||||||
- Always explain what you're doing before taking actions
|
|
||||||
- Ask for clarification when the request is ambiguous
|
|
||||||
- Use tools to help accomplish tasks
|
|
||||||
- Remember important information in your memory files
|
|
||||||
|
|
||||||
## Tools Available
|
|
||||||
|
|
||||||
You have access to:
|
|
||||||
- File operations (read, write, edit, list)
|
|
||||||
- Shell commands (exec)
|
|
||||||
- Web access (search, fetch)
|
|
||||||
- Messaging (message)
|
|
||||||
- Background tasks (spawn)
|
|
||||||
|
|
||||||
## Memory
|
|
||||||
|
|
||||||
- `memory/MEMORY.md` — long-term facts (preferences, context, relationships)
|
|
||||||
- `memory/HISTORY.md` — append-only event log, search with grep to recall past events
|
|
||||||
|
|
||||||
## Scheduled Reminders
|
|
||||||
|
|
||||||
When user asks for a reminder at a specific time, use `exec` to run:
|
|
||||||
```
|
|
||||||
nanobot cron add --name "reminder" --message "Your message" --at "YYYY-MM-DDTHH:MM:SS" --deliver --to "USER_ID" --channel "CHANNEL"
|
|
||||||
```
|
|
||||||
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
|
|
||||||
|
|
||||||
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
|
|
||||||
|
|
||||||
## Heartbeat Tasks
|
|
||||||
|
|
||||||
`HEARTBEAT.md` is checked every 30 minutes. You can manage periodic tasks by editing this file:
|
|
||||||
|
|
||||||
- **Add a task**: Use `edit_file` to append new tasks to `HEARTBEAT.md`
|
|
||||||
- **Remove a task**: Use `edit_file` to remove completed or obsolete tasks
|
|
||||||
- **Rewrite tasks**: Use `write_file` to completely rewrite the task list
|
|
||||||
|
|
||||||
Task format examples:
|
|
||||||
```
|
|
||||||
- [ ] Check calendar and remind of upcoming events
|
|
||||||
- [ ] Scan inbox for urgent emails
|
|
||||||
- [ ] Check weather forecast for today
|
|
||||||
```
|
|
||||||
|
|
||||||
When the user asks you to add a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time reminder. Keep the file small to minimize token usage.
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
# Available Tools
|
|
||||||
|
|
||||||
This document describes the tools available to nanobot.
|
|
||||||
|
|
||||||
## File Operations
|
|
||||||
|
|
||||||
### read_file
|
|
||||||
Read the contents of a file.
|
|
||||||
```
|
|
||||||
read_file(path: str) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
### write_file
|
|
||||||
Write content to a file (creates parent directories if needed).
|
|
||||||
```
|
|
||||||
write_file(path: str, content: str) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
### edit_file
|
|
||||||
Edit a file by replacing specific text.
|
|
||||||
```
|
|
||||||
edit_file(path: str, old_text: str, new_text: str) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
### list_dir
|
|
||||||
List contents of a directory.
|
|
||||||
```
|
|
||||||
list_dir(path: str) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
## Shell Execution
|
|
||||||
|
|
||||||
### exec
|
|
||||||
Execute a shell command and return output.
|
|
||||||
```
|
|
||||||
exec(command: str, working_dir: str = None) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
**Safety Notes:**
|
|
||||||
- Commands have a configurable timeout (default 60s)
|
|
||||||
- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
|
|
||||||
- Output is truncated at 10,000 characters
|
|
||||||
- Optional `restrictToWorkspace` config to limit paths
|
|
||||||
|
|
||||||
## Web Access
|
|
||||||
|
|
||||||
### web_search
|
|
||||||
Search the web using Brave Search API.
|
|
||||||
```
|
|
||||||
web_search(query: str, count: int = 5) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
Returns search results with titles, URLs, and snippets. Requires `tools.web.search.apiKey` in config.
|
|
||||||
|
|
||||||
### web_fetch
|
|
||||||
Fetch and extract main content from a URL.
|
|
||||||
```
|
|
||||||
web_fetch(url: str, extractMode: str = "markdown", maxChars: int = 50000) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
**Notes:**
|
|
||||||
- Content is extracted using readability
|
|
||||||
- Supports markdown or plain text extraction
|
|
||||||
- Output is truncated at 50,000 characters by default
|
|
||||||
|
|
||||||
## Communication
|
|
||||||
|
|
||||||
### message
|
|
||||||
Send a message to the user (used internally).
|
|
||||||
```
|
|
||||||
message(content: str, channel: str = None, chat_id: str = None) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
## Background Tasks
|
|
||||||
|
|
||||||
### spawn
|
|
||||||
Spawn a subagent to handle a task in the background.
|
|
||||||
```
|
|
||||||
spawn(task: str, label: str = None) -> str
|
|
||||||
```
|
|
||||||
|
|
||||||
Use for complex or time-consuming tasks that can run independently. The subagent will complete the task and report back when done.
|
|
||||||
|
|
||||||
## Scheduled Reminders (Cron)
|
|
||||||
|
|
||||||
Use the `exec` tool to create scheduled reminders with `nanobot cron add`:
|
|
||||||
|
|
||||||
### Set a recurring reminder
|
|
||||||
```bash
|
|
||||||
# Every day at 9am
|
|
||||||
nanobot cron add --name "morning" --message "Good morning! ☀️" --cron "0 9 * * *"
|
|
||||||
|
|
||||||
# Every 2 hours
|
|
||||||
nanobot cron add --name "water" --message "Drink water! 💧" --every 7200
|
|
||||||
```
|
|
||||||
|
|
||||||
### Set a one-time reminder
|
|
||||||
```bash
|
|
||||||
# At a specific time (ISO format)
|
|
||||||
nanobot cron add --name "meeting" --message "Meeting starts now!" --at "2025-01-31T15:00:00"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Manage reminders
|
|
||||||
```bash
|
|
||||||
nanobot cron list # List all jobs
|
|
||||||
nanobot cron remove <job_id> # Remove a job
|
|
||||||
```
|
|
||||||
|
|
||||||
## Heartbeat Task Management
|
|
||||||
|
|
||||||
The `HEARTBEAT.md` file in the workspace is checked every 30 minutes.
|
|
||||||
Use file operations to manage periodic tasks:
|
|
||||||
|
|
||||||
### Add a heartbeat task
|
|
||||||
```python
|
|
||||||
# Append a new task
|
|
||||||
edit_file(
|
|
||||||
path="HEARTBEAT.md",
|
|
||||||
old_text="## Example Tasks",
|
|
||||||
new_text="- [ ] New periodic task here\n\n## Example Tasks"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Remove a heartbeat task
|
|
||||||
```python
|
|
||||||
# Remove a specific task
|
|
||||||
edit_file(
|
|
||||||
path="HEARTBEAT.md",
|
|
||||||
old_text="- [ ] Task to remove\n",
|
|
||||||
new_text=""
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Rewrite all tasks
|
|
||||||
```python
|
|
||||||
# Replace the entire file
|
|
||||||
write_file(
|
|
||||||
path="HEARTBEAT.md",
|
|
||||||
content="# Heartbeat Tasks\n\n- [ ] Task 1\n- [ ] Task 2\n"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Adding Custom Tools
|
|
||||||
|
|
||||||
To add custom tools:
|
|
||||||
1. Create a class that extends `Tool` in `nanobot/agent/tools/`
|
|
||||||
2. Implement `name`, `description`, `parameters`, and `execute`
|
|
||||||
3. Register it in `AgentLoop._register_default_tools()`
|
|
||||||
Reference in New Issue
Block a user