Merge branch 'HKUDS:main' into feat-volcengine-tuning

This commit is contained in:
gaoyiman
2026-03-05 14:14:33 +08:00
committed by GitHub
55 changed files with 1640 additions and 1217 deletions

2
.gitignore vendored
View File

@@ -19,4 +19,4 @@ __pycache__/
poetry.lock poetry.lock
.pytest_cache/ .pytest_cache/
botpy.log botpy.log
tests/

View File

@@ -16,16 +16,24 @@
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines. ⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
📏 Real-time line count: **3,966 lines** (run `bash core_agent_lines.sh` to verify anytime) 📏 Real-time line count: **3,935 lines** (run `bash core_agent_lines.sh` to verify anytime)
## 📢 News ## 📢 News
- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details. - **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes. - **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements. - **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements.
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details. - **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood. - **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode. - **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
<details>
<summary>Earlier news</summary>
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching. - **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details. - **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills. - **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
@@ -34,10 +42,6 @@
- **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details. - **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it! - **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support! - **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
<details>
<summary>Earlier news</summary>
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431). - **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms! - **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers). - **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
@@ -343,7 +347,7 @@ pip install nanobot-ai[matrix]
"accessToken": "syt_xxx", "accessToken": "syt_xxx",
"deviceId": "NANOBOT01", "deviceId": "NANOBOT01",
"e2eeEnabled": true, "e2eeEnabled": true,
"allowFrom": [], "allowFrom": ["@your_user:matrix.org"],
"groupPolicy": "open", "groupPolicy": "open",
"groupAllowFrom": [], "groupAllowFrom": [],
"allowRoomMentions": false, "allowRoomMentions": false,
@@ -420,7 +424,7 @@ Uses **WebSocket** long connection — no public IP required.
**1. Create a Feishu bot** **1. Create a Feishu bot**
- Visit [Feishu Open Platform](https://open.feishu.cn/app) - Visit [Feishu Open Platform](https://open.feishu.cn/app)
- Create a new app → Enable **Bot** capability - Create a new app → Enable **Bot** capability
- **Permissions**: Add `im:message` (send messages) - **Permissions**: Add `im:message` (send messages) and `im:message.p2p_msg:readonly` (receive messages)
- **Events**: Add `im.message.receive_v1` (receive messages) - **Events**: Add `im.message.receive_v1` (receive messages)
- Select **Long Connection** mode (requires running nanobot first to establish connection) - Select **Long Connection** mode (requires running nanobot first to establish connection)
- Get **App ID** and **App Secret** from "Credentials & Basic Info" - Get **App ID** and **App Secret** from "Credentials & Basic Info"
@@ -437,14 +441,14 @@ Uses **WebSocket** long connection — no public IP required.
"appSecret": "xxx", "appSecret": "xxx",
"encryptKey": "", "encryptKey": "",
"verificationToken": "", "verificationToken": "",
"allowFrom": [] "allowFrom": ["ou_YOUR_OPEN_ID"]
} }
} }
} }
``` ```
> `encryptKey` and `verificationToken` are optional for Long Connection mode. > `encryptKey` and `verificationToken` are optional for Long Connection mode.
> `allowFrom`: Leave empty to allow all users, or add `["ou_xxx"]` to restrict access. > `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
**3. Run** **3. Run**
@@ -474,7 +478,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports
**3. Configure** **3. Configure**
> - `allowFrom`: Leave empty for public access, or add user openids to restrict. You can find openids in the nanobot logs when a user messages the bot. > - `allowFrom`: Add your openid (find it in nanobot logs when you message the bot). Use `["*"]` for public access.
> - For production: submit a review in the bot console and publish. See [QQ Bot Docs](https://bot.q.qq.com/wiki/) for the full publishing flow. > - For production: submit a review in the bot console and publish. See [QQ Bot Docs](https://bot.q.qq.com/wiki/) for the full publishing flow.
```json ```json
@@ -484,7 +488,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports
"enabled": true, "enabled": true,
"appId": "YOUR_APP_ID", "appId": "YOUR_APP_ID",
"secret": "YOUR_APP_SECRET", "secret": "YOUR_APP_SECRET",
"allowFrom": [] "allowFrom": ["YOUR_OPENID"]
} }
} }
} }
@@ -523,13 +527,13 @@ Uses **Stream Mode** — no public IP required.
"enabled": true, "enabled": true,
"clientId": "YOUR_APP_KEY", "clientId": "YOUR_APP_KEY",
"clientSecret": "YOUR_APP_SECRET", "clientSecret": "YOUR_APP_SECRET",
"allowFrom": [] "allowFrom": ["YOUR_STAFF_ID"]
} }
} }
} }
``` ```
> `allowFrom`: Leave empty to allow all users, or add `["staffId"]` to restrict access. > `allowFrom`: Add your staff ID. Use `["*"]` to allow all users.
**3. Run** **3. Run**
@@ -564,6 +568,7 @@ Uses **Socket Mode** — no public URL required.
"enabled": true, "enabled": true,
"botToken": "xoxb-...", "botToken": "xoxb-...",
"appToken": "xapp-...", "appToken": "xapp-...",
"allowFrom": ["YOUR_SLACK_USER_ID"],
"groupPolicy": "mention" "groupPolicy": "mention"
} }
} }
@@ -597,7 +602,7 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl
**2. Configure** **2. Configure**
> - `consentGranted` must be `true` to allow mailbox access. This is a safety gate — set `false` to fully disable. > - `consentGranted` must be `true` to allow mailbox access. This is a safety gate — set `false` to fully disable.
> - `allowFrom`: Leave empty to accept emails from anyone, or restrict to specific senders. > - `allowFrom`: Add your email address. Use `["*"]` to accept emails from anyone.
> - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly. > - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly.
> - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies. > - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies.
@@ -870,6 +875,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
> [!TIP] > [!TIP]
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent. > For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
> **Change in source / post-`v0.1.4.post3`:** In `v0.1.4.post3` and earlier, an empty `allowFrom` means "allow all senders". In newer versions (including building from source), **empty `allowFrom` denies all access by default**. To allow all senders, set `"allowFrom": ["*"]`.
| Option | Default | Description | | Option | Default | Description |
|--------|---------|-------------| |--------|---------|-------------|
@@ -895,23 +901,6 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`. Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`.
<details>
<summary><b>Scheduled Tasks (Cron)</b></summary>
```bash
# Add a job
nanobot cron add --name "daily" --message "Good morning!" --cron "0 9 * * *"
nanobot cron add --name "hourly" --message "Check status" --every 3600
# List jobs
nanobot cron list
# Remove a job
nanobot cron remove <job_id>
```
</details>
<details> <details>
<summary><b>Heartbeat (Periodic Tasks)</b></summary> <summary><b>Heartbeat (Periodic Tasks)</b></summary>

View File

@@ -55,7 +55,7 @@ chmod 600 ~/.nanobot/config.json
``` ```
**Security Notes:** **Security Notes:**
- Empty `allowFrom` list will **ALLOW ALL** users (open by default for personal use) - In `v0.1.4.post3` and earlier, an empty `allowFrom` allows all users. In newer versions (including source builds), **empty `allowFrom` denies all access** — set `["*"]` to explicitly allow everyone.
- Get your Telegram user ID from `@userinfobot` - Get your Telegram user ID from `@userinfobot`
- Use full phone numbers with country code for WhatsApp - Use full phone numbers with country code for WhatsApp
- Review access logs regularly for unauthorized access attempts - Review access logs regularly for unauthorized access attempts
@@ -212,9 +212,8 @@ If you suspect a security breach:
- Input length limits on HTTP requests - Input length limits on HTTP requests
✅ **Authentication** ✅ **Authentication**
- Allow-list based access control - Allow-list based access control — in `v0.1.4.post3` and earlier empty means allow all; in newer versions empty means deny all (`["*"]` to explicitly allow all)
- Failed authentication attempt logging - Failed authentication attempt logging
- Open by default (configure allowFrom for production use)
✅ **Resource Protection** ✅ **Resource Protection**
- Command execution timeouts (60s default) - Command execution timeouts (60s default)

View File

@@ -2,5 +2,5 @@
nanobot - A lightweight AI agent framework nanobot - A lightweight AI agent framework
""" """
__version__ = "0.1.4.post2" __version__ = "0.1.4.post3"
__logo__ = "🐈" __logo__ = "🐈"

View File

@@ -1,7 +1,7 @@
"""Agent core module.""" """Agent core module."""
from nanobot.agent.loop import AgentLoop
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader

View File

@@ -14,15 +14,15 @@ from nanobot.agent.skills import SkillsLoader
class ContextBuilder: class ContextBuilder:
"""Builds the context (system prompt + messages) for the agent.""" """Builds the context (system prompt + messages) for the agent."""
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"] BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.workspace = workspace self.workspace = workspace
self.memory = MemoryStore(workspace) self.memory = MemoryStore(workspace)
self.skills = SkillsLoader(workspace) self.skills = SkillsLoader(workspace)
def build_system_prompt(self, skill_names: list[str] | None = None) -> str: def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
"""Build the system prompt from identity, bootstrap files, memory, and skills.""" """Build the system prompt from identity, bootstrap files, memory, and skills."""
parts = [self._get_identity()] parts = [self._get_identity()]
@@ -51,13 +51,13 @@ Skills with available="false" need dependencies installed first - you can try in
{skills_summary}""") {skills_summary}""")
return "\n\n---\n\n".join(parts) return "\n\n---\n\n".join(parts)
def _get_identity(self) -> str: def _get_identity(self) -> str:
"""Get the core identity section.""" """Get the core identity section."""
workspace_path = str(self.workspace.expanduser().resolve()) workspace_path = str(self.workspace.expanduser().resolve())
system = platform.system() system = platform.system()
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
return f"""# nanobot 🐈 return f"""# nanobot 🐈
You are nanobot, a helpful AI assistant. You are nanobot, a helpful AI assistant.
@@ -68,7 +68,7 @@ You are nanobot, a helpful AI assistant.
## Workspace ## Workspace
Your workspace is at: {workspace_path} Your workspace is at: {workspace_path}
- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here) - Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable) - History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
## nanobot Guidelines ## nanobot Guidelines
@@ -89,19 +89,19 @@ Reply directly with text for conversations. Only use the 'message' tool to send
if channel and chat_id: if channel and chat_id:
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
def _load_bootstrap_files(self) -> str: def _load_bootstrap_files(self) -> str:
"""Load all bootstrap files from workspace.""" """Load all bootstrap files from workspace."""
parts = [] parts = []
for filename in self.BOOTSTRAP_FILES: for filename in self.BOOTSTRAP_FILES:
file_path = self.workspace / filename file_path = self.workspace / filename
if file_path.exists(): if file_path.exists():
content = file_path.read_text(encoding="utf-8") content = file_path.read_text(encoding="utf-8")
parts.append(f"## {filename}\n\n{content}") parts.append(f"## {filename}\n\n{content}")
return "\n\n".join(parts) if parts else "" return "\n\n".join(parts) if parts else ""
def build_messages( def build_messages(
self, self,
history: list[dict[str, Any]], history: list[dict[str, Any]],
@@ -112,18 +112,27 @@ Reply directly with text for conversations. Only use the 'message' tool to send
chat_id: str | None = None, chat_id: str | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Build the complete message list for an LLM call.""" """Build the complete message list for an LLM call."""
runtime_ctx = self._build_runtime_context(channel, chat_id)
user_content = self._build_user_content(current_message, media)
# Merge runtime context and user content into a single user message
# to avoid consecutive same-role messages that some providers reject.
if isinstance(user_content, str):
merged = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
return [ return [
{"role": "system", "content": self.build_system_prompt(skill_names)}, {"role": "system", "content": self.build_system_prompt(skill_names)},
*history, *history,
{"role": "user", "content": self._build_runtime_context(channel, chat_id)}, {"role": "user", "content": merged},
{"role": "user", "content": self._build_user_content(current_message, media)},
] ]
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
"""Build user message content with optional base64-encoded images.""" """Build user message content with optional base64-encoded images."""
if not media: if not media:
return text return text
images = [] images = []
for path in media: for path in media:
p = Path(path) p = Path(path)
@@ -132,11 +141,11 @@ Reply directly with text for conversations. Only use the 'message' tool to send
continue continue
b64 = base64.b64encode(p.read_bytes()).decode() b64 = base64.b64encode(p.read_bytes()).decode()
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}}) images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
if not images: if not images:
return text return text
return images + [{"type": "text", "text": text}] return images + [{"type": "text", "text": text}]
def add_tool_result( def add_tool_result(
self, messages: list[dict[str, Any]], self, messages: list[dict[str, Any]],
tool_call_id: str, tool_name: str, result: str, tool_call_id: str, tool_name: str, result: str,
@@ -144,12 +153,13 @@ Reply directly with text for conversations. Only use the 'message' tool to send
"""Add a tool result to the message list.""" """Add a tool result to the message list."""
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result}) messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
return messages return messages
def add_assistant_message( def add_assistant_message(
self, messages: list[dict[str, Any]], self, messages: list[dict[str, Any]],
content: str | None, content: str | None,
tool_calls: list[dict[str, Any]] | None = None, tool_calls: list[dict[str, Any]] | None = None,
reasoning_content: str | None = None, reasoning_content: str | None = None,
thinking_blocks: list[dict] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Add an assistant message to the message list.""" """Add an assistant message to the message list."""
msg: dict[str, Any] = {"role": "assistant", "content": content} msg: dict[str, Any] = {"role": "assistant", "content": content}
@@ -157,5 +167,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
msg["tool_calls"] = tool_calls msg["tool_calls"] = tool_calls
if reasoning_content is not None: if reasoning_content is not None:
msg["reasoning_content"] = reasoning_content msg["reasoning_content"] = reasoning_content
if thinking_blocks:
msg["thinking_blocks"] = thinking_blocks
messages.append(msg) messages.append(msg)
return messages return messages

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import re import re
import weakref
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -43,6 +44,8 @@ class AgentLoop:
5. Sends responses back 5. Sends responses back
""" """
_TOOL_RESULT_MAX_CHARS = 500
def __init__( def __init__(
self, self,
bus: MessageBus, bus: MessageBus,
@@ -53,7 +56,9 @@ class AgentLoop:
temperature: float = 0.1, temperature: float = 0.1,
max_tokens: int = 4096, max_tokens: int = 4096,
memory_window: int = 100, memory_window: int = 100,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None, exec_config: ExecToolConfig | None = None,
cron_service: CronService | None = None, cron_service: CronService | None = None,
restrict_to_workspace: bool = False, restrict_to_workspace: bool = False,
@@ -71,7 +76,9 @@ class AgentLoop:
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.memory_window = memory_window self.memory_window = memory_window
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
self.cron_service = cron_service self.cron_service = cron_service
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
@@ -86,7 +93,9 @@ class AgentLoop:
model=self.model, model=self.model,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
reasoning_effort=reasoning_effort,
brave_api_key=brave_api_key, brave_api_key=brave_api_key,
web_proxy=web_proxy,
exec_config=self.exec_config, exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace, restrict_to_workspace=restrict_to_workspace,
) )
@@ -98,7 +107,7 @@ class AgentLoop:
self._mcp_connecting = False self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: dict[str, asyncio.Lock] = {} self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._processing_lock = asyncio.Lock() self._processing_lock = asyncio.Lock()
self._register_default_tools() self._register_default_tools()
@@ -114,8 +123,8 @@ class AgentLoop:
restrict_to_workspace=self.restrict_to_workspace, restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append, path_append=self.exec_config.path_append,
)) ))
self.tools.register(WebSearchTool(api_key=self.brave_api_key)) self.tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
self.tools.register(WebFetchTool()) self.tools.register(WebFetchTool(proxy=self.web_proxy))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound)) self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(SpawnTool(manager=self.subagents)) self.tools.register(SpawnTool(manager=self.subagents))
if self.cron_service: if self.cron_service:
@@ -145,17 +154,10 @@ class AgentLoop:
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Update context for all tools that need routing info.""" """Update context for all tools that need routing info."""
if message_tool := self.tools.get("message"): for name in ("message", "spawn", "cron"):
if isinstance(message_tool, MessageTool): if tool := self.tools.get(name):
message_tool.set_context(channel, chat_id, message_id) if hasattr(tool, "set_context"):
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
if spawn_tool := self.tools.get("spawn"):
if isinstance(spawn_tool, SpawnTool):
spawn_tool.set_context(channel, chat_id)
if cron_tool := self.tools.get("cron"):
if isinstance(cron_tool, CronTool):
cron_tool.set_context(channel, chat_id)
@staticmethod @staticmethod
def _strip_think(text: str | None) -> str | None: def _strip_think(text: str | None) -> str | None:
@@ -168,7 +170,8 @@ class AgentLoop:
def _tool_hint(tool_calls: list) -> str: def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'.""" """Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc): def _fmt(tc):
val = next(iter(tc.arguments.values()), None) if tc.arguments else None args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
val = next(iter(args.values()), None) if isinstance(args, dict) else None
if not isinstance(val, str): if not isinstance(val, str):
return tc.name return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")' return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
@@ -194,6 +197,7 @@ class AgentLoop:
model=self.model, model=self.model,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
@@ -217,6 +221,7 @@ class AgentLoop:
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
messages, response.content, tool_call_dicts, messages, response.content, tool_call_dicts,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
) )
for tool_call in response.tool_calls: for tool_call in response.tool_calls:
@@ -229,8 +234,15 @@ class AgentLoop:
) )
else: else:
clean = self._strip_think(response.content) clean = self._strip_think(response.content)
# Don't persist error responses to session history — they can
# poison the context and cause permanent 400 loops (#1303).
if response.finish_reason == "error":
logger.error("LLM returned error: {}", (clean or "")[:200])
final_content = clean or "Sorry, I encountered an error calling the AI model."
break
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
messages, clean, reasoning_content=response.reasoning_content, messages, clean, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
) )
final_content = clean final_content = clean
break break
@@ -315,18 +327,6 @@ class AgentLoop:
self._running = False self._running = False
logger.info("Agent loop stopping") logger.info("Agent loop stopping")
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
lock = self._consolidation_locks.get(session_key)
if lock is None:
lock = asyncio.Lock()
self._consolidation_locks[session_key] = lock
return lock
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
"""Drop lock entry if no longer in use."""
if not lock.locked():
self._consolidation_locks.pop(session_key, None)
async def _process_message( async def _process_message(
self, self,
msg: InboundMessage, msg: InboundMessage,
@@ -362,7 +362,7 @@ class AgentLoop:
# Slash commands # Slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
lock = self._get_consolidation_lock(session.key) lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
self._consolidating.add(session.key) self._consolidating.add(session.key)
try: try:
async with lock: async with lock:
@@ -383,7 +383,6 @@ class AgentLoop:
) )
finally: finally:
self._consolidating.discard(session.key) self._consolidating.discard(session.key)
self._prune_consolidation_lock(session.key, lock)
session.clear() session.clear()
self.sessions.save(session) self.sessions.save(session)
@@ -397,7 +396,7 @@ class AgentLoop:
unconsolidated = len(session.messages) - session.last_consolidated unconsolidated = len(session.messages) - session.last_consolidated
if (unconsolidated >= self.memory_window and session.key not in self._consolidating): if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
self._consolidating.add(session.key) self._consolidating.add(session.key)
lock = self._get_consolidation_lock(session.key) lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
async def _consolidate_and_unlock(): async def _consolidate_and_unlock():
try: try:
@@ -405,7 +404,6 @@ class AgentLoop:
await self._consolidate_memory(session) await self._consolidate_memory(session)
finally: finally:
self._consolidating.discard(session.key) self._consolidating.discard(session.key)
self._prune_consolidation_lock(session.key, lock)
_task = asyncio.current_task() _task = asyncio.current_task()
if _task is not None: if _task is not None:
self._consolidation_tasks.discard(_task) self._consolidation_tasks.discard(_task)
@@ -441,40 +439,50 @@ class AgentLoop:
if final_content is None: if final_content is None:
final_content = "I've completed processing but have no response to give." final_content = "I've completed processing but have no response to give."
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
if message_tool := self.tools.get("message"): if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: return None
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content, channel=msg.channel, chat_id=msg.chat_id, content=final_content,
metadata=msg.metadata or {}, metadata=msg.metadata or {},
) )
_TOOL_RESULT_MAX_CHARS = 500
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results.""" """Save new-turn messages into session, truncating large tool results."""
from datetime import datetime from datetime import datetime
for m in messages[skip:]: for m in messages[skip:]:
entry = {k: v for k, v in m.items() if k != "reasoning_content"} entry = dict(m)
if entry.get("role") == "tool" and isinstance(entry.get("content"), str): role, content = entry.get("role"), entry.get("content")
content = entry["content"] if role == "assistant" and not content and not entry.get("tool_calls"):
if len(content) > self._TOOL_RESULT_MAX_CHARS: continue # skip empty assistant messages — they poison session context
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
if entry.get("role") == "user" and isinstance(entry.get("content"), list): entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
entry["content"] = [ elif role == "user":
{"type": "text", "text": "[image]"} if ( if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
c.get("type") == "image_url" # Strip the runtime-context prefix, keep only the user text.
and c.get("image_url", {}).get("url", "").startswith("data:image/") parts = content.split("\n\n", 1)
) else c if len(parts) > 1 and parts[1].strip():
for c in entry["content"] entry["content"] = parts[1]
] else:
continue
if isinstance(content, list):
filtered = []
for c in content:
if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
continue # Strip runtime context from multimodal messages
if (c.get("type") == "image_url"
and c.get("image_url", {}).get("url", "").startswith("data:image/")):
filtered.append({"type": "text", "text": "[image]"})
else:
filtered.append(c)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat()) entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry) session.messages.append(entry)
session.updated_at = datetime.now() session.updated_at = datetime.now()

View File

@@ -13,28 +13,28 @@ BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
class SkillsLoader: class SkillsLoader:
""" """
Loader for agent skills. Loader for agent skills.
Skills are markdown files (SKILL.md) that teach the agent how to use Skills are markdown files (SKILL.md) that teach the agent how to use
specific tools or perform certain tasks. specific tools or perform certain tasks.
""" """
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None): def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
self.workspace = workspace self.workspace = workspace
self.workspace_skills = workspace / "skills" self.workspace_skills = workspace / "skills"
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]: def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
""" """
List all available skills. List all available skills.
Args: Args:
filter_unavailable: If True, filter out skills with unmet requirements. filter_unavailable: If True, filter out skills with unmet requirements.
Returns: Returns:
List of skill info dicts with 'name', 'path', 'source'. List of skill info dicts with 'name', 'path', 'source'.
""" """
skills = [] skills = []
# Workspace skills (highest priority) # Workspace skills (highest priority)
if self.workspace_skills.exists(): if self.workspace_skills.exists():
for skill_dir in self.workspace_skills.iterdir(): for skill_dir in self.workspace_skills.iterdir():
@@ -42,7 +42,7 @@ class SkillsLoader:
skill_file = skill_dir / "SKILL.md" skill_file = skill_dir / "SKILL.md"
if skill_file.exists(): if skill_file.exists():
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"}) skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
# Built-in skills # Built-in skills
if self.builtin_skills and self.builtin_skills.exists(): if self.builtin_skills and self.builtin_skills.exists():
for skill_dir in self.builtin_skills.iterdir(): for skill_dir in self.builtin_skills.iterdir():
@@ -50,19 +50,19 @@ class SkillsLoader:
skill_file = skill_dir / "SKILL.md" skill_file = skill_dir / "SKILL.md"
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills): if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"}) skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
# Filter by requirements # Filter by requirements
if filter_unavailable: if filter_unavailable:
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))] return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
return skills return skills
def load_skill(self, name: str) -> str | None: def load_skill(self, name: str) -> str | None:
""" """
Load a skill by name. Load a skill by name.
Args: Args:
name: Skill name (directory name). name: Skill name (directory name).
Returns: Returns:
Skill content or None if not found. Skill content or None if not found.
""" """
@@ -70,22 +70,22 @@ class SkillsLoader:
workspace_skill = self.workspace_skills / name / "SKILL.md" workspace_skill = self.workspace_skills / name / "SKILL.md"
if workspace_skill.exists(): if workspace_skill.exists():
return workspace_skill.read_text(encoding="utf-8") return workspace_skill.read_text(encoding="utf-8")
# Check built-in # Check built-in
if self.builtin_skills: if self.builtin_skills:
builtin_skill = self.builtin_skills / name / "SKILL.md" builtin_skill = self.builtin_skills / name / "SKILL.md"
if builtin_skill.exists(): if builtin_skill.exists():
return builtin_skill.read_text(encoding="utf-8") return builtin_skill.read_text(encoding="utf-8")
return None return None
def load_skills_for_context(self, skill_names: list[str]) -> str: def load_skills_for_context(self, skill_names: list[str]) -> str:
""" """
Load specific skills for inclusion in agent context. Load specific skills for inclusion in agent context.
Args: Args:
skill_names: List of skill names to load. skill_names: List of skill names to load.
Returns: Returns:
Formatted skills content. Formatted skills content.
""" """
@@ -95,26 +95,26 @@ class SkillsLoader:
if content: if content:
content = self._strip_frontmatter(content) content = self._strip_frontmatter(content)
parts.append(f"### Skill: {name}\n\n{content}") parts.append(f"### Skill: {name}\n\n{content}")
return "\n\n---\n\n".join(parts) if parts else "" return "\n\n---\n\n".join(parts) if parts else ""
def build_skills_summary(self) -> str: def build_skills_summary(self) -> str:
""" """
Build a summary of all skills (name, description, path, availability). Build a summary of all skills (name, description, path, availability).
This is used for progressive loading - the agent can read the full This is used for progressive loading - the agent can read the full
skill content using read_file when needed. skill content using read_file when needed.
Returns: Returns:
XML-formatted skills summary. XML-formatted skills summary.
""" """
all_skills = self.list_skills(filter_unavailable=False) all_skills = self.list_skills(filter_unavailable=False)
if not all_skills: if not all_skills:
return "" return ""
def escape_xml(s: str) -> str: def escape_xml(s: str) -> str:
return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
lines = ["<skills>"] lines = ["<skills>"]
for s in all_skills: for s in all_skills:
name = escape_xml(s["name"]) name = escape_xml(s["name"])
@@ -122,23 +122,23 @@ class SkillsLoader:
desc = escape_xml(self._get_skill_description(s["name"])) desc = escape_xml(self._get_skill_description(s["name"]))
skill_meta = self._get_skill_meta(s["name"]) skill_meta = self._get_skill_meta(s["name"])
available = self._check_requirements(skill_meta) available = self._check_requirements(skill_meta)
lines.append(f" <skill available=\"{str(available).lower()}\">") lines.append(f" <skill available=\"{str(available).lower()}\">")
lines.append(f" <name>{name}</name>") lines.append(f" <name>{name}</name>")
lines.append(f" <description>{desc}</description>") lines.append(f" <description>{desc}</description>")
lines.append(f" <location>{path}</location>") lines.append(f" <location>{path}</location>")
# Show missing requirements for unavailable skills # Show missing requirements for unavailable skills
if not available: if not available:
missing = self._get_missing_requirements(skill_meta) missing = self._get_missing_requirements(skill_meta)
if missing: if missing:
lines.append(f" <requires>{escape_xml(missing)}</requires>") lines.append(f" <requires>{escape_xml(missing)}</requires>")
lines.append(f" </skill>") lines.append(" </skill>")
lines.append("</skills>") lines.append("</skills>")
return "\n".join(lines) return "\n".join(lines)
def _get_missing_requirements(self, skill_meta: dict) -> str: def _get_missing_requirements(self, skill_meta: dict) -> str:
"""Get a description of missing requirements.""" """Get a description of missing requirements."""
missing = [] missing = []
@@ -150,14 +150,14 @@ class SkillsLoader:
if not os.environ.get(env): if not os.environ.get(env):
missing.append(f"ENV: {env}") missing.append(f"ENV: {env}")
return ", ".join(missing) return ", ".join(missing)
def _get_skill_description(self, name: str) -> str: def _get_skill_description(self, name: str) -> str:
"""Get the description of a skill from its frontmatter.""" """Get the description of a skill from its frontmatter."""
meta = self.get_skill_metadata(name) meta = self.get_skill_metadata(name)
if meta and meta.get("description"): if meta and meta.get("description"):
return meta["description"] return meta["description"]
return name # Fallback to skill name return name # Fallback to skill name
def _strip_frontmatter(self, content: str) -> str: def _strip_frontmatter(self, content: str) -> str:
"""Remove YAML frontmatter from markdown content.""" """Remove YAML frontmatter from markdown content."""
if content.startswith("---"): if content.startswith("---"):
@@ -165,7 +165,7 @@ class SkillsLoader:
if match: if match:
return content[match.end():].strip() return content[match.end():].strip()
return content return content
def _parse_nanobot_metadata(self, raw: str) -> dict: def _parse_nanobot_metadata(self, raw: str) -> dict:
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys).""" """Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
try: try:
@@ -173,7 +173,7 @@ class SkillsLoader:
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {} return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
return {} return {}
def _check_requirements(self, skill_meta: dict) -> bool: def _check_requirements(self, skill_meta: dict) -> bool:
"""Check if skill requirements are met (bins, env vars).""" """Check if skill requirements are met (bins, env vars)."""
requires = skill_meta.get("requires", {}) requires = skill_meta.get("requires", {})
@@ -184,12 +184,12 @@ class SkillsLoader:
if not os.environ.get(env): if not os.environ.get(env):
return False return False
return True return True
def _get_skill_meta(self, name: str) -> dict: def _get_skill_meta(self, name: str) -> dict:
"""Get nanobot metadata for a skill (cached in frontmatter).""" """Get nanobot metadata for a skill (cached in frontmatter)."""
meta = self.get_skill_metadata(name) or {} meta = self.get_skill_metadata(name) or {}
return self._parse_nanobot_metadata(meta.get("metadata", "")) return self._parse_nanobot_metadata(meta.get("metadata", ""))
def get_always_skills(self) -> list[str]: def get_always_skills(self) -> list[str]:
"""Get skills marked as always=true that meet requirements.""" """Get skills marked as always=true that meet requirements."""
result = [] result = []
@@ -199,21 +199,21 @@ class SkillsLoader:
if skill_meta.get("always") or meta.get("always"): if skill_meta.get("always") or meta.get("always"):
result.append(s["name"]) result.append(s["name"])
return result return result
def get_skill_metadata(self, name: str) -> dict | None: def get_skill_metadata(self, name: str) -> dict | None:
""" """
Get metadata from a skill's frontmatter. Get metadata from a skill's frontmatter.
Args: Args:
name: Skill name. name: Skill name.
Returns: Returns:
Metadata dict or None. Metadata dict or None.
""" """
content = self.load_skill(name) content = self.load_skill(name)
if not content: if not content:
return None return None
if content.startswith("---"): if content.startswith("---"):
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
if match: if match:
@@ -224,5 +224,5 @@ class SkillsLoader:
key, value = line.split(":", 1) key, value = line.split(":", 1)
metadata[key.strip()] = value.strip().strip('"\'') metadata[key.strip()] = value.strip().strip('"\'')
return metadata return metadata
return None return None

View File

@@ -8,18 +8,19 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ExecToolConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
class SubagentManager: class SubagentManager:
"""Manages background subagent execution.""" """Manages background subagent execution."""
def __init__( def __init__(
self, self,
provider: LLMProvider, provider: LLMProvider,
@@ -28,7 +29,9 @@ class SubagentManager:
model: str | None = None, model: str | None = None,
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 4096, max_tokens: int = 4096,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None,
exec_config: "ExecToolConfig | None" = None, exec_config: "ExecToolConfig | None" = None,
restrict_to_workspace: bool = False, restrict_to_workspace: bool = False,
): ):
@@ -39,12 +42,14 @@ class SubagentManager:
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.temperature = temperature self.temperature = temperature
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
self._running_tasks: dict[str, asyncio.Task[None]] = {} self._running_tasks: dict[str, asyncio.Task[None]] = {}
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...} self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
async def spawn( async def spawn(
self, self,
task: str, task: str,
@@ -73,10 +78,10 @@ class SubagentManager:
del self._session_tasks[session_key] del self._session_tasks[session_key]
bg_task.add_done_callback(_cleanup) bg_task.add_done_callback(_cleanup)
logger.info("Spawned subagent [{}]: {}", task_id, display_label) logger.info("Spawned subagent [{}]: {}", task_id, display_label)
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes." return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
async def _run_subagent( async def _run_subagent(
self, self,
task_id: str, task_id: str,
@@ -86,7 +91,7 @@ class SubagentManager:
) -> None: ) -> None:
"""Execute the subagent task and announce the result.""" """Execute the subagent task and announce the result."""
logger.info("Subagent [{}] starting task: {}", task_id, label) logger.info("Subagent [{}] starting task: {}", task_id, label)
try: try:
# Build subagent tools (no message tool, no spawn tool) # Build subagent tools (no message tool, no spawn tool)
tools = ToolRegistry() tools = ToolRegistry()
@@ -101,32 +106,32 @@ class SubagentManager:
restrict_to_workspace=self.restrict_to_workspace, restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append, path_append=self.exec_config.path_append,
)) ))
tools.register(WebSearchTool(api_key=self.brave_api_key)) tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
tools.register(WebFetchTool()) tools.register(WebFetchTool(proxy=self.web_proxy))
# Build messages with subagent-specific prompt system_prompt = self._build_subagent_prompt()
system_prompt = self._build_subagent_prompt(task)
messages: list[dict[str, Any]] = [ messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": task}, {"role": "user", "content": task},
] ]
# Run agent loop (limited iterations) # Run agent loop (limited iterations)
max_iterations = 15 max_iterations = 15
iteration = 0 iteration = 0
final_result: str | None = None final_result: str | None = None
while iteration < max_iterations: while iteration < max_iterations:
iteration += 1 iteration += 1
response = await self.provider.chat( response = await self.provider.chat(
messages=messages, messages=messages,
tools=tools.get_definitions(), tools=tools.get_definitions(),
model=self.model, model=self.model,
temperature=self.temperature, temperature=self.temperature,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:
# Add assistant message with tool calls # Add assistant message with tool calls
tool_call_dicts = [ tool_call_dicts = [
@@ -145,7 +150,7 @@ class SubagentManager:
"content": response.content or "", "content": response.content or "",
"tool_calls": tool_call_dicts, "tool_calls": tool_call_dicts,
}) })
# Execute tools # Execute tools
for tool_call in response.tool_calls: for tool_call in response.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False) args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
@@ -160,18 +165,18 @@ class SubagentManager:
else: else:
final_result = response.content final_result = response.content
break break
if final_result is None: if final_result is None:
final_result = "Task completed but no final response was generated." final_result = "Task completed but no final response was generated."
logger.info("Subagent [{}] completed successfully", task_id) logger.info("Subagent [{}] completed successfully", task_id)
await self._announce_result(task_id, label, task, final_result, origin, "ok") await self._announce_result(task_id, label, task, final_result, origin, "ok")
except Exception as e: except Exception as e:
error_msg = f"Error: {str(e)}" error_msg = f"Error: {str(e)}"
logger.error("Subagent [{}] failed: {}", task_id, e) logger.error("Subagent [{}] failed: {}", task_id, e)
await self._announce_result(task_id, label, task, error_msg, origin, "error") await self._announce_result(task_id, label, task, error_msg, origin, "error")
async def _announce_result( async def _announce_result(
self, self,
task_id: str, task_id: str,
@@ -183,7 +188,7 @@ class SubagentManager:
) -> None: ) -> None:
"""Announce the subagent result to the main agent via the message bus.""" """Announce the subagent result to the main agent via the message bus."""
status_text = "completed successfully" if status == "ok" else "failed" status_text = "completed successfully" if status == "ok" else "failed"
announce_content = f"""[Subagent '{label}' {status_text}] announce_content = f"""[Subagent '{label}' {status_text}]
Task: {task} Task: {task}
@@ -192,7 +197,7 @@ Result:
{result} {result}
Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs.""" Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs."""
# Inject as system message to trigger main agent # Inject as system message to trigger main agent
msg = InboundMessage( msg = InboundMessage(
channel="system", channel="system",
@@ -200,46 +205,31 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
chat_id=f"{origin['channel']}:{origin['chat_id']}", chat_id=f"{origin['channel']}:{origin['chat_id']}",
content=announce_content, content=announce_content,
) )
await self.bus.publish_inbound(msg) await self.bus.publish_inbound(msg)
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id']) logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
def _build_subagent_prompt(self, task: str) -> str: def _build_subagent_prompt(self) -> str:
"""Build a focused system prompt for the subagent.""" """Build a focused system prompt for the subagent."""
from datetime import datetime from nanobot.agent.context import ContextBuilder
import time as _time from nanobot.agent.skills import SkillsLoader
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
tz = _time.strftime("%Z") or "UTC"
return f"""# Subagent time_ctx = ContextBuilder._build_runtime_context(None, None)
parts = [f"""# Subagent
## Current Time {time_ctx}
{now} ({tz})
You are a subagent spawned by the main agent to complete a specific task. You are a subagent spawned by the main agent to complete a specific task.
Stay focused on the assigned task. Your final response will be reported back to the main agent.
## Rules
1. Stay focused - complete only the assigned task, nothing else
2. Your final response will be reported back to the main agent
3. Do not initiate conversations or take on side tasks
4. Be concise but informative in your findings
## What You Can Do
- Read and write files in the workspace
- Execute shell commands
- Search the web and fetch web pages
- Complete the task thoroughly
## What You Cannot Do
- Send messages directly to users (no message tool available)
- Spawn other subagents
- Access the main agent's conversation history
## Workspace ## Workspace
Your workspace is at: {self.workspace} {self.workspace}"""]
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
When you have completed the task, provide a clear summary of your findings or actions.""" skills_summary = SkillsLoader(self.workspace).build_skills_summary()
if skills_summary:
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
return "\n\n".join(parts)
async def cancel_by_session(self, session_key: str) -> int: async def cancel_by_session(self, session_key: str) -> int:
"""Cancel all subagents for the given session. Returns count cancelled.""" """Cancel all subagents for the given session. Returns count cancelled."""

View File

@@ -7,11 +7,11 @@ from typing import Any
class Tool(ABC): class Tool(ABC):
""" """
Abstract base class for agent tools. Abstract base class for agent tools.
Tools are capabilities that the agent can use to interact with Tools are capabilities that the agent can use to interact with
the environment, such as reading files, executing commands, etc. the environment, such as reading files, executing commands, etc.
""" """
_TYPE_MAP = { _TYPE_MAP = {
"string": str, "string": str,
"integer": int, "integer": int,
@@ -20,33 +20,33 @@ class Tool(ABC):
"array": list, "array": list,
"object": dict, "object": dict,
} }
@property @property
@abstractmethod @abstractmethod
def name(self) -> str: def name(self) -> str:
"""Tool name used in function calls.""" """Tool name used in function calls."""
pass pass
@property @property
@abstractmethod @abstractmethod
def description(self) -> str: def description(self) -> str:
"""Description of what the tool does.""" """Description of what the tool does."""
pass pass
@property @property
@abstractmethod @abstractmethod
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
"""JSON Schema for tool parameters.""" """JSON Schema for tool parameters."""
pass pass
@abstractmethod @abstractmethod
async def execute(self, **kwargs: Any) -> str: async def execute(self, **kwargs: Any) -> str:
""" """
Execute the tool with given parameters. Execute the tool with given parameters.
Args: Args:
**kwargs: Tool-specific parameters. **kwargs: Tool-specific parameters.
Returns: Returns:
String result of the tool execution. String result of the tool execution.
""" """
@@ -54,6 +54,8 @@ class Tool(ABC):
def validate_params(self, params: dict[str, Any]) -> list[str]: def validate_params(self, params: dict[str, Any]) -> list[str]:
"""Validate tool parameters against JSON schema. Returns error list (empty if valid).""" """Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
if not isinstance(params, dict):
return [f"parameters must be an object, got {type(params).__name__}"]
schema = self.parameters or {} schema = self.parameters or {}
if schema.get("type", "object") != "object": if schema.get("type", "object") != "object":
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}") raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
@@ -63,7 +65,7 @@ class Tool(ABC):
t, label = schema.get("type"), path or "parameter" t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]): if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"] return [f"{label} should be {t}"]
errors = [] errors = []
if "enum" in schema and val not in schema["enum"]: if "enum" in schema and val not in schema["enum"]:
errors.append(f"{label} must be one of {schema['enum']}") errors.append(f"{label} must be one of {schema['enum']}")
@@ -84,12 +86,14 @@ class Tool(ABC):
errors.append(f"missing required {path + '.' + k if path else k}") errors.append(f"missing required {path + '.' + k if path else k}")
for k, v in val.items(): for k, v in val.items():
if k in props: if k in props:
errors.extend(self._validate(v, props[k], path + '.' + k if path else k)) errors.extend(self._validate(v, props[k], path + "." + k if path else k))
if t == "array" and "items" in schema: if t == "array" and "items" in schema:
for i, item in enumerate(val): for i, item in enumerate(val):
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")) errors.extend(
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
)
return errors return errors
def to_schema(self) -> dict[str, Any]: def to_schema(self) -> dict[str, Any]:
"""Convert tool to OpenAI function schema format.""" """Convert tool to OpenAI function schema format."""
return { return {
@@ -98,5 +102,5 @@ class Tool(ABC):
"name": self.name, "name": self.name,
"description": self.description, "description": self.description,
"parameters": self.parameters, "parameters": self.parameters,
} },
} }

View File

@@ -1,5 +1,6 @@
"""Cron tool for scheduling reminders and tasks.""" """Cron tool for scheduling reminders and tasks."""
from contextvars import ContextVar
from typing import Any from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -9,25 +10,34 @@ from nanobot.cron.types import CronSchedule
class CronTool(Tool): class CronTool(Tool):
"""Tool to schedule reminders and recurring tasks.""" """Tool to schedule reminders and recurring tasks."""
def __init__(self, cron_service: CronService): def __init__(self, cron_service: CronService):
self._cron = cron_service self._cron = cron_service
self._channel = "" self._channel = ""
self._chat_id = "" self._chat_id = ""
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the current session context for delivery.""" """Set the current session context for delivery."""
self._channel = channel self._channel = channel
self._chat_id = chat_id self._chat_id = chat_id
def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback."""
return self._in_cron_context.set(active)
def reset_cron_context(self, token) -> None:
"""Restore previous cron context."""
self._in_cron_context.reset(token)
@property @property
def name(self) -> str: def name(self) -> str:
return "cron" return "cron"
@property @property
def description(self) -> str: def description(self) -> str:
return "Schedule reminders and recurring tasks. Actions: add, list, remove." return "Schedule reminders and recurring tasks. Actions: add, list, remove."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
@@ -36,36 +46,30 @@ class CronTool(Tool):
"action": { "action": {
"type": "string", "type": "string",
"enum": ["add", "list", "remove"], "enum": ["add", "list", "remove"],
"description": "Action to perform" "description": "Action to perform",
},
"message": {
"type": "string",
"description": "Reminder message (for add)"
}, },
"message": {"type": "string", "description": "Reminder message (for add)"},
"every_seconds": { "every_seconds": {
"type": "integer", "type": "integer",
"description": "Interval in seconds (for recurring tasks)" "description": "Interval in seconds (for recurring tasks)",
}, },
"cron_expr": { "cron_expr": {
"type": "string", "type": "string",
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)" "description": "Cron expression like '0 9 * * *' (for scheduled tasks)",
}, },
"tz": { "tz": {
"type": "string", "type": "string",
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')" "description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')",
}, },
"at": { "at": {
"type": "string", "type": "string",
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')" "description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')",
}, },
"job_id": { "job_id": {"type": "string", "description": "Job ID (for remove)"},
"type": "string",
"description": "Job ID (for remove)"
}
}, },
"required": ["action"] "required": ["action"],
} }
async def execute( async def execute(
self, self,
action: str, action: str,
@@ -75,16 +79,18 @@ class CronTool(Tool):
tz: str | None = None, tz: str | None = None,
at: str | None = None, at: str | None = None,
job_id: str | None = None, job_id: str | None = None,
**kwargs: Any **kwargs: Any,
) -> str: ) -> str:
if action == "add": if action == "add":
if self._in_cron_context.get():
return "Error: cannot schedule new jobs from within a cron job execution"
return self._add_job(message, every_seconds, cron_expr, tz, at) return self._add_job(message, every_seconds, cron_expr, tz, at)
elif action == "list": elif action == "list":
return self._list_jobs() return self._list_jobs()
elif action == "remove": elif action == "remove":
return self._remove_job(job_id) return self._remove_job(job_id)
return f"Unknown action: {action}" return f"Unknown action: {action}"
def _add_job( def _add_job(
self, self,
message: str, message: str,
@@ -101,11 +107,12 @@ class CronTool(Tool):
return "Error: tz can only be used with cron_expr" return "Error: tz can only be used with cron_expr"
if tz: if tz:
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
try: try:
ZoneInfo(tz) ZoneInfo(tz)
except (KeyError, Exception): except (KeyError, Exception):
return f"Error: unknown timezone '{tz}'" return f"Error: unknown timezone '{tz}'"
# Build schedule # Build schedule
delete_after = False delete_after = False
if every_seconds: if every_seconds:
@@ -114,13 +121,17 @@ class CronTool(Tool):
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz) schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
elif at: elif at:
from datetime import datetime from datetime import datetime
dt = datetime.fromisoformat(at)
try:
dt = datetime.fromisoformat(at)
except ValueError:
return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS"
at_ms = int(dt.timestamp() * 1000) at_ms = int(dt.timestamp() * 1000)
schedule = CronSchedule(kind="at", at_ms=at_ms) schedule = CronSchedule(kind="at", at_ms=at_ms)
delete_after = True delete_after = True
else: else:
return "Error: either every_seconds, cron_expr, or at is required" return "Error: either every_seconds, cron_expr, or at is required"
job = self._cron.add_job( job = self._cron.add_job(
name=message[:30], name=message[:30],
schedule=schedule, schedule=schedule,
@@ -131,14 +142,14 @@ class CronTool(Tool):
delete_after_run=delete_after, delete_after_run=delete_after,
) )
return f"Created job '{job.name}' (id: {job.id})" return f"Created job '{job.name}' (id: {job.id})"
def _list_jobs(self) -> str: def _list_jobs(self) -> str:
jobs = self._cron.list_jobs() jobs = self._cron.list_jobs()
if not jobs: if not jobs:
return "No scheduled jobs." return "No scheduled jobs."
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs] lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
return "Scheduled jobs:\n" + "\n".join(lines) return "Scheduled jobs:\n" + "\n".join(lines)
def _remove_job(self, job_id: str | None) -> str: def _remove_job(self, job_id: str | None) -> str:
if not job_id: if not job_id:
return "Error: job_id is required for remove" return "Error: job_id is required for remove"

View File

@@ -7,7 +7,9 @@ from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path | None = None) -> Path: def _resolve_path(
path: str, workspace: Path | None = None, allowed_dir: Path | None = None
) -> Path:
"""Resolve path against workspace (if relative) and enforce directory restriction.""" """Resolve path against workspace (if relative) and enforce directory restriction."""
p = Path(path).expanduser() p = Path(path).expanduser()
if not p.is_absolute() and workspace: if not p.is_absolute() and workspace:
@@ -24,6 +26,8 @@ def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path |
class ReadFileTool(Tool): class ReadFileTool(Tool):
"""Tool to read file contents.""" """Tool to read file contents."""
_MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
@@ -31,24 +35,19 @@ class ReadFileTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "read_file" return "read_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Read the contents of a file at the given path." return "Read the contents of a file at the given path."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "The file path to read"}},
"path": { "required": ["path"],
"type": "string",
"description": "The file path to read"
}
},
"required": ["path"]
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
@@ -57,7 +56,16 @@ class ReadFileTool(Tool):
if not file_path.is_file(): if not file_path.is_file():
return f"Error: Not a file: {path}" return f"Error: Not a file: {path}"
size = file_path.stat().st_size
if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes)
return (
f"Error: File too large ({size:,} bytes). "
f"Use exec tool with head/tail/grep to read portions."
)
content = file_path.read_text(encoding="utf-8") content = file_path.read_text(encoding="utf-8")
if len(content) > self._MAX_CHARS:
return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})"
return content return content
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
@@ -75,28 +83,22 @@ class WriteFileTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "write_file" return "write_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Write content to a file at the given path. Creates parent directories if needed." return "Write content to a file at the given path. Creates parent directories if needed."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"path": { "path": {"type": "string", "description": "The file path to write to"},
"type": "string", "content": {"type": "string", "description": "The content to write"},
"description": "The file path to write to"
},
"content": {
"type": "string",
"description": "The content to write"
}
}, },
"required": ["path", "content"] "required": ["path", "content"],
} }
async def execute(self, path: str, content: str, **kwargs: Any) -> str: async def execute(self, path: str, content: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
@@ -119,32 +121,23 @@ class EditFileTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "edit_file" return "edit_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file." return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"path": { "path": {"type": "string", "description": "The file path to edit"},
"type": "string", "old_text": {"type": "string", "description": "The exact text to find and replace"},
"description": "The file path to edit" "new_text": {"type": "string", "description": "The text to replace with"},
},
"old_text": {
"type": "string",
"description": "The exact text to find and replace"
},
"new_text": {
"type": "string",
"description": "The text to replace with"
}
}, },
"required": ["path", "old_text", "new_text"] "required": ["path", "old_text", "new_text"],
} }
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str: async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) file_path = _resolve_path(path, self._workspace, self._allowed_dir)
@@ -184,13 +177,19 @@ class EditFileTool(Tool):
best_ratio, best_start = ratio, i best_ratio, best_start = ratio, i
if best_ratio > 0.5: if best_ratio > 0.5:
diff = "\n".join(difflib.unified_diff( diff = "\n".join(
old_lines, lines[best_start : best_start + window], difflib.unified_diff(
fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})", old_lines,
lineterm="", lines[best_start : best_start + window],
)) fromfile="old_text (provided)",
tofile=f"{path} (actual, line {best_start + 1})",
lineterm="",
)
)
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
return f"Error: old_text not found in {path}. No similar text found. Verify the file content." return (
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
)
class ListDirTool(Tool): class ListDirTool(Tool):
@@ -203,24 +202,19 @@ class ListDirTool(Tool):
@property @property
def name(self) -> str: def name(self) -> str:
return "list_dir" return "list_dir"
@property @property
def description(self) -> str: def description(self) -> str:
return "List the contents of a directory." return "List the contents of a directory."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "The directory path to list"}},
"path": { "required": ["path"],
"type": "string",
"description": "The directory path to list"
}
},
"required": ["path"]
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, **kwargs: Any) -> str:
try: try:
dir_path = _resolve_path(path, self._workspace, self._allowed_dir) dir_path = _resolve_path(path, self._workspace, self._allowed_dir)

View File

@@ -101,7 +101,8 @@ class MessageTool(Tool):
try: try:
await self._send_callback(msg) await self._send_callback(msg)
self._sent_in_turn = True if channel == self._default_channel and chat_id == self._default_chat_id:
self._sent_in_turn = True
media_info = f" with {len(media)} attachments" if media else "" media_info = f" with {len(media)} attachments" if media else ""
return f"Message sent to {channel}:{chat_id}{media_info}" return f"Message sent to {channel}:{chat_id}{media_info}"
except Exception as e: except Exception as e:

View File

@@ -8,33 +8,33 @@ from nanobot.agent.tools.base import Tool
class ToolRegistry: class ToolRegistry:
""" """
Registry for agent tools. Registry for agent tools.
Allows dynamic registration and execution of tools. Allows dynamic registration and execution of tools.
""" """
def __init__(self): def __init__(self):
self._tools: dict[str, Tool] = {} self._tools: dict[str, Tool] = {}
def register(self, tool: Tool) -> None: def register(self, tool: Tool) -> None:
"""Register a tool.""" """Register a tool."""
self._tools[tool.name] = tool self._tools[tool.name] = tool
def unregister(self, name: str) -> None: def unregister(self, name: str) -> None:
"""Unregister a tool by name.""" """Unregister a tool by name."""
self._tools.pop(name, None) self._tools.pop(name, None)
def get(self, name: str) -> Tool | None: def get(self, name: str) -> Tool | None:
"""Get a tool by name.""" """Get a tool by name."""
return self._tools.get(name) return self._tools.get(name)
def has(self, name: str) -> bool: def has(self, name: str) -> bool:
"""Check if a tool is registered.""" """Check if a tool is registered."""
return name in self._tools return name in self._tools
def get_definitions(self) -> list[dict[str, Any]]: def get_definitions(self) -> list[dict[str, Any]]:
"""Get all tool definitions in OpenAI format.""" """Get all tool definitions in OpenAI format."""
return [tool.to_schema() for tool in self._tools.values()] return [tool.to_schema() for tool in self._tools.values()]
async def execute(self, name: str, params: dict[str, Any]) -> str: async def execute(self, name: str, params: dict[str, Any]) -> str:
"""Execute a tool by name with given parameters.""" """Execute a tool by name with given parameters."""
_HINT = "\n\n[Analyze the error above and try a different approach.]" _HINT = "\n\n[Analyze the error above and try a different approach.]"
@@ -53,14 +53,14 @@ class ToolRegistry:
return result return result
except Exception as e: except Exception as e:
return f"Error executing {name}: {str(e)}" + _HINT return f"Error executing {name}: {str(e)}" + _HINT
@property @property
def tool_names(self) -> list[str]: def tool_names(self) -> list[str]:
"""Get list of registered tool names.""" """Get list of registered tool names."""
return list(self._tools.keys()) return list(self._tools.keys())
def __len__(self) -> int: def __len__(self) -> int:
return len(self._tools) return len(self._tools)
def __contains__(self, name: str) -> bool: def __contains__(self, name: str) -> bool:
return name in self._tools return name in self._tools

View File

@@ -11,7 +11,7 @@ from nanobot.agent.tools.base import Tool
class ExecTool(Tool): class ExecTool(Tool):
"""Tool to execute shell commands.""" """Tool to execute shell commands."""
def __init__( def __init__(
self, self,
timeout: int = 60, timeout: int = 60,
@@ -37,15 +37,15 @@ class ExecTool(Tool):
self.allow_patterns = allow_patterns or [] self.allow_patterns = allow_patterns or []
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
self.path_append = path_append self.path_append = path_append
@property @property
def name(self) -> str: def name(self) -> str:
return "exec" return "exec"
@property @property
def description(self) -> str: def description(self) -> str:
return "Execute a shell command and return its output. Use with caution." return "Execute a shell command and return its output. Use with caution."
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
@@ -141,13 +141,7 @@ class ExecTool(Tool):
cwd_path = Path(cwd).resolve() cwd_path = Path(cwd).resolve()
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd) for raw in self._extract_absolute_paths(cmd):
# Only match absolute paths — avoid false positives on relative
# paths like ".venv/bin/python" where "/bin/python" would be
# incorrectly extracted by the old pattern.
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
for raw in win_paths + posix_paths:
try: try:
p = Path(raw.strip()).resolve() p = Path(raw.strip()).resolve()
except Exception: except Exception:
@@ -156,3 +150,9 @@ class ExecTool(Tool):
return "Error: Command blocked by safety guard (path outside working dir)" return "Error: Command blocked by safety guard (path outside working dir)"
return None return None
@staticmethod
def _extract_absolute_paths(command: str) -> list[str]:
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only
return win_paths + posix_paths

View File

@@ -1,6 +1,6 @@
"""Spawn tool for creating background subagents.""" """Spawn tool for creating background subagents."""
from typing import Any, TYPE_CHECKING from typing import TYPE_CHECKING, Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -10,23 +10,23 @@ if TYPE_CHECKING:
class SpawnTool(Tool): class SpawnTool(Tool):
"""Tool to spawn a subagent for background task execution.""" """Tool to spawn a subagent for background task execution."""
def __init__(self, manager: "SubagentManager"): def __init__(self, manager: "SubagentManager"):
self._manager = manager self._manager = manager
self._origin_channel = "cli" self._origin_channel = "cli"
self._origin_chat_id = "direct" self._origin_chat_id = "direct"
self._session_key = "cli:direct" self._session_key = "cli:direct"
def set_context(self, channel: str, chat_id: str) -> None: def set_context(self, channel: str, chat_id: str) -> None:
"""Set the origin context for subagent announcements.""" """Set the origin context for subagent announcements."""
self._origin_channel = channel self._origin_channel = channel
self._origin_chat_id = chat_id self._origin_chat_id = chat_id
self._session_key = f"{channel}:{chat_id}" self._session_key = f"{channel}:{chat_id}"
@property @property
def name(self) -> str: def name(self) -> str:
return "spawn" return "spawn"
@property @property
def description(self) -> str: def description(self) -> str:
return ( return (
@@ -34,7 +34,7 @@ class SpawnTool(Tool):
"Use this for complex or time-consuming tasks that can run independently. " "Use this for complex or time-consuming tasks that can run independently. "
"The subagent will complete the task and report back when done." "The subagent will complete the task and report back when done."
) )
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
@@ -51,7 +51,7 @@ class SpawnTool(Tool):
}, },
"required": ["task"], "required": ["task"],
} }
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str: async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
"""Spawn a subagent to execute the given task.""" """Spawn a subagent to execute the given task."""
return await self._manager.spawn( return await self._manager.spawn(

View File

@@ -8,6 +8,7 @@ from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from loguru import logger
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
@@ -45,7 +46,7 @@ def _validate_url(url: str) -> tuple[bool, str]:
class WebSearchTool(Tool): class WebSearchTool(Tool):
"""Search the web using Brave Search API.""" """Search the web using Brave Search API."""
name = "web_search" name = "web_search"
description = "Search the web. Returns titles, URLs, and snippets." description = "Search the web. Returns titles, URLs, and snippets."
parameters = { parameters = {
@@ -56,10 +57,11 @@ class WebSearchTool(Tool):
}, },
"required": ["query"] "required": ["query"]
} }
def __init__(self, api_key: str | None = None, max_results: int = 5): def __init__(self, api_key: str | None = None, max_results: int = 5, proxy: str | None = None):
self._init_api_key = api_key self._init_api_key = api_key
self.max_results = max_results self.max_results = max_results
self.proxy = proxy
@property @property
def api_key(self) -> str: def api_key(self) -> str:
@@ -69,39 +71,44 @@ class WebSearchTool(Tool):
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
if not self.api_key: if not self.api_key:
return ( return (
"Error: Brave Search API key not configured. " "Error: Brave Search API key not configured. Set it in "
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey " "~/.nanobot/config.json under tools.web.search.apiKey "
"(or export BRAVE_API_KEY), then restart the gateway." "(or export BRAVE_API_KEY), then restart the gateway."
) )
try: try:
n = min(max(count or self.max_results, 1), 10) n = min(max(count or self.max_results, 1), 10)
async with httpx.AsyncClient() as client: logger.debug("WebSearch: {}", "proxy enabled" if self.proxy else "direct connection")
async with httpx.AsyncClient(proxy=self.proxy) as client:
r = await client.get( r = await client.get(
"https://api.search.brave.com/res/v1/web/search", "https://api.search.brave.com/res/v1/web/search",
params={"q": query, "count": n}, params={"q": query, "count": n},
headers={"Accept": "application/json", "X-Subscription-Token": api_key}, headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
timeout=10.0 timeout=10.0
) )
r.raise_for_status() r.raise_for_status()
results = r.json().get("web", {}).get("results", []) results = r.json().get("web", {}).get("results", [])[:n]
if not results: if not results:
return f"No results for: {query}" return f"No results for: {query}"
lines = [f"Results for: {query}\n"] lines = [f"Results for: {query}\n"]
for i, item in enumerate(results[:n], 1): for i, item in enumerate(results, 1):
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}") lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
if desc := item.get("description"): if desc := item.get("description"):
lines.append(f" {desc}") lines.append(f" {desc}")
return "\n".join(lines) return "\n".join(lines)
except httpx.ProxyError as e:
logger.error("WebSearch proxy error: {}", e)
return f"Proxy error: {e}"
except Exception as e: except Exception as e:
logger.error("WebSearch error: {}", e)
return f"Error: {e}" return f"Error: {e}"
class WebFetchTool(Tool): class WebFetchTool(Tool):
"""Fetch and extract content from a URL using Readability.""" """Fetch and extract content from a URL using Readability."""
name = "web_fetch" name = "web_fetch"
description = "Fetch URL and extract readable content (HTML → markdown/text)." description = "Fetch URL and extract readable content (HTML → markdown/text)."
parameters = { parameters = {
@@ -113,35 +120,34 @@ class WebFetchTool(Tool):
}, },
"required": ["url"] "required": ["url"]
} }
def __init__(self, max_chars: int = 50000): def __init__(self, max_chars: int = 50000, proxy: str | None = None):
self.max_chars = max_chars self.max_chars = max_chars
self.proxy = proxy
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str: async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
from readability import Document from readability import Document
max_chars = maxChars or self.max_chars max_chars = maxChars or self.max_chars
# Validate URL before fetching
is_valid, error_msg = _validate_url(url) is_valid, error_msg = _validate_url(url)
if not is_valid: if not is_valid:
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False) return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
try: try:
logger.debug("WebFetch: {}", "proxy enabled" if self.proxy else "direct connection")
async with httpx.AsyncClient( async with httpx.AsyncClient(
follow_redirects=True, follow_redirects=True,
max_redirects=MAX_REDIRECTS, max_redirects=MAX_REDIRECTS,
timeout=30.0 timeout=30.0,
proxy=self.proxy,
) as client: ) as client:
r = await client.get(url, headers={"User-Agent": USER_AGENT}) r = await client.get(url, headers={"User-Agent": USER_AGENT})
r.raise_for_status() r.raise_for_status()
ctype = r.headers.get("content-type", "") ctype = r.headers.get("content-type", "")
# JSON
if "application/json" in ctype: if "application/json" in ctype:
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json" text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
# HTML
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")): elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
doc = Document(r.text) doc = Document(r.text)
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary()) content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
@@ -149,16 +155,19 @@ class WebFetchTool(Tool):
extractor = "readability" extractor = "readability"
else: else:
text, extractor = r.text, "raw" text, extractor = r.text, "raw"
truncated = len(text) > max_chars truncated = len(text) > max_chars
if truncated: if truncated: text = text[:max_chars]
text = text[:max_chars]
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code, return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False) "extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
except httpx.ProxyError as e:
logger.error("WebFetch proxy error for {}: {}", url, e)
return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False)
except Exception as e: except Exception as e:
logger.error("WebFetch error for {}: {}", url, e)
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False) return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
def _to_markdown(self, html: str) -> str: def _to_markdown(self, html: str) -> str:
"""Convert HTML to markdown.""" """Convert HTML to markdown."""
# Convert links, headings, lists before stripping tags # Convert links, headings, lists before stripping tags

View File

@@ -8,7 +8,7 @@ from typing import Any
@dataclass @dataclass
class InboundMessage: class InboundMessage:
"""Message received from a chat channel.""" """Message received from a chat channel."""
channel: str # telegram, discord, slack, whatsapp channel: str # telegram, discord, slack, whatsapp
sender_id: str # User identifier sender_id: str # User identifier
chat_id: str # Chat/channel identifier chat_id: str # Chat/channel identifier
@@ -17,7 +17,7 @@ class InboundMessage:
media: list[str] = field(default_factory=list) # Media URLs media: list[str] = field(default_factory=list) # Media URLs
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
session_key_override: str | None = None # Optional override for thread-scoped sessions session_key_override: str | None = None # Optional override for thread-scoped sessions
@property @property
def session_key(self) -> str: def session_key(self) -> str:
"""Unique key for session identification.""" """Unique key for session identification."""
@@ -27,7 +27,7 @@ class InboundMessage:
@dataclass @dataclass
class OutboundMessage: class OutboundMessage:
"""Message to send to a chat channel.""" """Message to send to a chat channel."""
channel: str channel: str
chat_id: str chat_id: str
content: str content: str

View File

@@ -12,17 +12,17 @@ from nanobot.bus.queue import MessageBus
class BaseChannel(ABC): class BaseChannel(ABC):
""" """
Abstract base class for chat channel implementations. Abstract base class for chat channel implementations.
Each channel (Telegram, Discord, etc.) should implement this interface Each channel (Telegram, Discord, etc.) should implement this interface
to integrate with the nanobot message bus. to integrate with the nanobot message bus.
""" """
name: str = "base" name: str = "base"
def __init__(self, config: Any, bus: MessageBus): def __init__(self, config: Any, bus: MessageBus):
""" """
Initialize the channel. Initialize the channel.
Args: Args:
config: Channel-specific configuration. config: Channel-specific configuration.
bus: The message bus for communication. bus: The message bus for communication.
@@ -30,59 +30,47 @@ class BaseChannel(ABC):
self.config = config self.config = config
self.bus = bus self.bus = bus
self._running = False self._running = False
@abstractmethod @abstractmethod
async def start(self) -> None: async def start(self) -> None:
""" """
Start the channel and begin listening for messages. Start the channel and begin listening for messages.
This should be a long-running async task that: This should be a long-running async task that:
1. Connects to the chat platform 1. Connects to the chat platform
2. Listens for incoming messages 2. Listens for incoming messages
3. Forwards messages to the bus via _handle_message() 3. Forwards messages to the bus via _handle_message()
""" """
pass pass
@abstractmethod @abstractmethod
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the channel and clean up resources.""" """Stop the channel and clean up resources."""
pass pass
@abstractmethod @abstractmethod
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
""" """
Send a message through this channel. Send a message through this channel.
Args: Args:
msg: The message to send. msg: The message to send.
""" """
pass pass
def is_allowed(self, sender_id: str) -> bool: def is_allowed(self, sender_id: str) -> bool:
""" """Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
Check if a sender is allowed to use this bot.
Args:
sender_id: The sender's identifier.
Returns:
True if allowed, False otherwise.
"""
allow_list = getattr(self.config, "allow_from", []) allow_list = getattr(self.config, "allow_from", [])
# If no allow list, allow everyone
if not allow_list: if not allow_list:
logger.warning("{}: allow_from is empty — all access denied", self.name)
return False
if "*" in allow_list:
return True return True
sender_str = str(sender_id) sender_str = str(sender_id)
if sender_str in allow_list: return sender_str in allow_list or any(
return True p in allow_list for p in sender_str.split("|") if p
if "|" in sender_str: )
for part in sender_str.split("|"):
if part and part in allow_list:
return True
return False
async def _handle_message( async def _handle_message(
self, self,
sender_id: str, sender_id: str,
@@ -94,9 +82,9 @@ class BaseChannel(ABC):
) -> None: ) -> None:
""" """
Handle an incoming message from the chat platform. Handle an incoming message from the chat platform.
This method checks permissions and forwards to the bus. This method checks permissions and forwards to the bus.
Args: Args:
sender_id: The sender's identifier. sender_id: The sender's identifier.
chat_id: The chat/channel identifier. chat_id: The chat/channel identifier.
@@ -112,7 +100,7 @@ class BaseChannel(ABC):
sender_id, self.name, sender_id, self.name,
) )
return return
msg = InboundMessage( msg = InboundMessage(
channel=self.name, channel=self.name,
sender_id=str(sender_id), sender_id=str(sender_id),
@@ -122,9 +110,9 @@ class BaseChannel(ABC):
metadata=metadata or {}, metadata=metadata or {},
session_key_override=session_key, session_key_override=session_key,
) )
await self.bus.publish_inbound(msg) await self.bus.publish_inbound(msg)
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
"""Check if the channel is running.""" """Check if the channel is running."""

View File

@@ -2,11 +2,15 @@
import asyncio import asyncio
import json import json
import mimetypes
import os
import time import time
from pathlib import Path
from typing import Any from typing import Any
from urllib.parse import unquote, urlparse
from loguru import logger
import httpx import httpx
from loguru import logger
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@@ -15,11 +19,11 @@ from nanobot.config.schema import DingTalkConfig
try: try:
from dingtalk_stream import ( from dingtalk_stream import (
DingTalkStreamClient, AckMessage,
Credential,
CallbackHandler, CallbackHandler,
CallbackMessage, CallbackMessage,
AckMessage, Credential,
DingTalkStreamClient,
) )
from dingtalk_stream.chatbot import ChatbotMessage from dingtalk_stream.chatbot import ChatbotMessage
@@ -96,6 +100,9 @@ class DingTalkChannel(BaseChannel):
""" """
name = "dingtalk" name = "dingtalk"
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
def __init__(self, config: DingTalkConfig, bus: MessageBus): def __init__(self, config: DingTalkConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
@@ -191,40 +198,224 @@ class DingTalkChannel(BaseChannel):
logger.error("Failed to get DingTalk access token: {}", e) logger.error("Failed to get DingTalk access token: {}", e)
return None return None
@staticmethod
def _is_http_url(value: str) -> bool:
return urlparse(value).scheme in ("http", "https")
def _guess_upload_type(self, media_ref: str) -> str:
ext = Path(urlparse(media_ref).path).suffix.lower()
if ext in self._IMAGE_EXTS: return "image"
if ext in self._AUDIO_EXTS: return "voice"
if ext in self._VIDEO_EXTS: return "video"
return "file"
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
name = os.path.basename(urlparse(media_ref).path)
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
async def _read_media_bytes(
self,
media_ref: str,
) -> tuple[bytes | None, str | None, str | None]:
if not media_ref:
return None, None, None
if self._is_http_url(media_ref):
if not self._http:
return None, None, None
try:
resp = await self._http.get(media_ref, follow_redirects=True)
if resp.status_code >= 400:
logger.warning(
"DingTalk media download failed status={} ref={}",
resp.status_code,
media_ref,
)
return None, None, None
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
return resp.content, filename, content_type or None
except Exception as e:
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
return None, None, None
try:
if media_ref.startswith("file://"):
parsed = urlparse(media_ref)
local_path = Path(unquote(parsed.path))
else:
local_path = Path(os.path.expanduser(media_ref))
if not local_path.is_file():
logger.warning("DingTalk media file not found: {}", local_path)
return None, None, None
data = await asyncio.to_thread(local_path.read_bytes)
content_type = mimetypes.guess_type(local_path.name)[0]
return data, local_path.name, content_type
except Exception as e:
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
return None, None, None
async def _upload_media(
self,
token: str,
data: bytes,
media_type: str,
filename: str,
content_type: str | None,
) -> str | None:
if not self._http:
return None
url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
files = {"media": (filename, data, mime)}
try:
resp = await self._http.post(url, files=files)
text = resp.text
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
if resp.status_code >= 400:
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
return None
errcode = result.get("errcode", 0)
if errcode != 0:
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
return None
sub = result.get("result") or {}
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
if not media_id:
logger.error("DingTalk media upload missing media_id body={}", text[:500])
return None
return str(media_id)
except Exception as e:
logger.error("DingTalk media upload error type={} err={}", media_type, e)
return None
async def _send_batch_message(
self,
token: str,
chat_id: str,
msg_key: str,
msg_param: dict[str, Any],
) -> bool:
if not self._http:
logger.warning("DingTalk HTTP client not initialized, cannot send")
return False
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
headers = {"x-acs-dingtalk-access-token": token}
payload = {
"robotCode": self.config.client_id,
"userIds": [chat_id],
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
try:
resp = await self._http.post(url, json=payload, headers=headers)
body = resp.text
if resp.status_code != 200:
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
return False
try: result = resp.json()
except Exception: result = {}
errcode = result.get("errcode")
if errcode not in (None, 0):
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
return False
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
return True
except Exception as e:
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
return False
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
return await self._send_batch_message(
token,
chat_id,
"sampleMarkdown",
{"text": content, "title": "Nanobot Reply"},
)
async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
media_ref = (media_ref or "").strip()
if not media_ref:
return True
upload_type = self._guess_upload_type(media_ref)
if upload_type == "image" and self._is_http_url(media_ref):
ok = await self._send_batch_message(
token,
chat_id,
"sampleImageMsg",
{"photoURL": media_ref},
)
if ok:
return True
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
data, filename, content_type = await self._read_media_bytes(media_ref)
if not data:
logger.error("DingTalk media read failed: {}", media_ref)
return False
filename = filename or self._guess_filename(media_ref, upload_type)
file_type = Path(filename).suffix.lower().lstrip(".")
if not file_type:
guessed = mimetypes.guess_extension(content_type or "")
file_type = (guessed or ".bin").lstrip(".")
if file_type == "jpeg":
file_type = "jpg"
media_id = await self._upload_media(
token=token,
data=data,
media_type=upload_type,
filename=filename,
content_type=content_type,
)
if not media_id:
return False
if upload_type == "image":
# Verified in production: sampleImageMsg accepts media_id in photoURL.
ok = await self._send_batch_message(
token,
chat_id,
"sampleImageMsg",
{"photoURL": media_id},
)
if ok:
return True
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
return await self._send_batch_message(
token,
chat_id,
"sampleFile",
{"mediaId": media_id, "fileName": filename, "fileType": file_type},
)
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through DingTalk.""" """Send a message through DingTalk."""
token = await self._get_access_token() token = await self._get_access_token()
if not token: if not token:
return return
# oToMessages/batchSend: sends to individual users (private chat) if msg.content and msg.content.strip():
# https://open.dingtalk.com/document/orgapp/robot-batch-send-messages await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
headers = {"x-acs-dingtalk-access-token": token} for media_ref in msg.media or []:
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
data = { if ok:
"robotCode": self.config.client_id, continue
"userIds": [msg.chat_id], # chat_id is the user's staffId logger.error("DingTalk media send failed for {}", media_ref)
"msgKey": "sampleMarkdown", # Send visible fallback so failures are observable by the user.
"msgParam": json.dumps({ filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
"text": msg.content, await self._send_markdown_text(
"title": "Nanobot Reply", token,
}, ensure_ascii=False), msg.chat_id,
} f"[Attachment send failed: {filename}]",
)
if not self._http:
logger.warning("DingTalk HTTP client not initialized, cannot send")
return
try:
resp = await self._http.post(url, json=data, headers=headers)
if resp.status_code != 200:
logger.error("DingTalk send failed: {}", resp.text)
else:
logger.debug("DingTalk message sent to {}", msg.chat_id)
except Exception as e:
logger.error("Error sending DingTalk message: {}", e)
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None: async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
"""Handle incoming message (called by NanobotDingTalkHandler). """Handle incoming message (called by NanobotDingTalkHandler).

View File

@@ -14,7 +14,6 @@ from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import DiscordConfig from nanobot.config.schema import DiscordConfig
DISCORD_API_BASE = "https://discord.com/api/v10" DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit MAX_MESSAGE_LEN = 2000 # Discord message character limit

View File

@@ -16,27 +16,9 @@ from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.schema import FeishuConfig from nanobot.config.schema import FeishuConfig
try: import importlib.util
import lark_oapi as lark
from lark_oapi.api.im.v1 import ( FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
CreateFileRequest,
CreateFileRequestBody,
CreateImageRequest,
CreateImageRequestBody,
CreateMessageRequest,
CreateMessageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
GetFileRequest,
GetMessageResourceRequest,
P2ImMessageReceiveV1,
)
FEISHU_AVAILABLE = True
except ImportError:
FEISHU_AVAILABLE = False
lark = None
Emoji = None
# Message type display mapping # Message type display mapping
MSG_TYPE_MAP = { MSG_TYPE_MAP = {
@@ -70,7 +52,7 @@ def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
def _extract_interactive_content(content: dict) -> list[str]: def _extract_interactive_content(content: dict) -> list[str]:
"""Recursively extract text and links from interactive card content.""" """Recursively extract text and links from interactive card content."""
parts = [] parts = []
if isinstance(content, str): if isinstance(content, str):
try: try:
content = json.loads(content) content = json.loads(content)
@@ -89,8 +71,9 @@ def _extract_interactive_content(content: dict) -> list[str]:
elif isinstance(title, str): elif isinstance(title, str):
parts.append(f"title: {title}") parts.append(f"title: {title}")
for element in content.get("elements", []) if isinstance(content.get("elements"), list) else []: for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
parts.extend(_extract_element_content(element)) for element in elements:
parts.extend(_extract_element_content(element))
card = content.get("card", {}) card = content.get("card", {})
if card: if card:
@@ -103,19 +86,19 @@ def _extract_interactive_content(content: dict) -> list[str]:
header_text = header_title.get("content", "") or header_title.get("text", "") header_text = header_title.get("content", "") or header_title.get("text", "")
if header_text: if header_text:
parts.append(f"title: {header_text}") parts.append(f"title: {header_text}")
return parts return parts
def _extract_element_content(element: dict) -> list[str]: def _extract_element_content(element: dict) -> list[str]:
"""Extract content from a single card element.""" """Extract content from a single card element."""
parts = [] parts = []
if not isinstance(element, dict): if not isinstance(element, dict):
return parts return parts
tag = element.get("tag", "") tag = element.get("tag", "")
if tag in ("markdown", "lark_md"): if tag in ("markdown", "lark_md"):
content = element.get("content", "") content = element.get("content", "")
if content: if content:
@@ -176,69 +159,71 @@ def _extract_element_content(element: dict) -> list[str]:
else: else:
for ne in element.get("elements", []): for ne in element.get("elements", []):
parts.extend(_extract_element_content(ne)) parts.extend(_extract_element_content(ne))
return parts return parts
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]: def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
"""Extract text and image keys from Feishu post (rich text) message content. """Extract text and image keys from Feishu post (rich text) message.
Supports two formats: Handles three payload shapes:
1. Direct format: {"title": "...", "content": [...]} - Direct: {"title": "...", "content": [[...]]}
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}} - Localized: {"zh_cn": {"title": "...", "content": [...]}}
- Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}}
Returns:
(text, image_keys) - extracted text and list of image keys
""" """
def extract_from_lang(lang_content: dict) -> tuple[str | None, list[str]]:
if not isinstance(lang_content, dict): def _parse_block(block: dict) -> tuple[str | None, list[str]]:
if not isinstance(block, dict) or not isinstance(block.get("content"), list):
return None, [] return None, []
title = lang_content.get("title", "") texts, images = [], []
content_blocks = lang_content.get("content", []) if title := block.get("title"):
if not isinstance(content_blocks, list): texts.append(title)
return None, [] for row in block["content"]:
text_parts = [] if not isinstance(row, list):
image_keys = []
if title:
text_parts.append(title)
for block in content_blocks:
if not isinstance(block, list):
continue continue
for element in block: for el in row:
if isinstance(element, dict): if not isinstance(el, dict):
tag = element.get("tag") continue
if tag == "text": tag = el.get("tag")
text_parts.append(element.get("text", "")) if tag in ("text", "a"):
elif tag == "a": texts.append(el.get("text", ""))
text_parts.append(element.get("text", "")) elif tag == "at":
elif tag == "at": texts.append(f"@{el.get('user_name', 'user')}")
text_parts.append(f"@{element.get('user_name', 'user')}") elif tag == "img" and (key := el.get("image_key")):
elif tag == "img": images.append(key)
img_key = element.get("image_key") return (" ".join(texts).strip() or None), images
if img_key:
image_keys.append(img_key) # Unwrap optional {"post": ...} envelope
text = " ".join(text_parts).strip() if text_parts else None root = content_json
return text, image_keys if isinstance(root, dict) and isinstance(root.get("post"), dict):
root = root["post"]
# Try direct format first if not isinstance(root, dict):
if "content" in content_json: return "", []
text, images = extract_from_lang(content_json)
if text or images: # Direct format
return text or "", images if "content" in root:
text, imgs = _parse_block(root)
# Try localized format if text or imgs:
for lang_key in ("zh_cn", "en_us", "ja_jp"): return text or "", imgs
lang_content = content_json.get(lang_key)
text, images = extract_from_lang(lang_content) # Localized: prefer known locales, then fall back to any dict child
if text or images: for key in ("zh_cn", "en_us", "ja_jp"):
return text or "", images if key in root:
text, imgs = _parse_block(root[key])
if text or imgs:
return text or "", imgs
for val in root.values():
if isinstance(val, dict):
text, imgs = _parse_block(val)
if text or imgs:
return text or "", imgs
return "", [] return "", []
def _extract_post_text(content_json: dict) -> str: def _extract_post_text(content_json: dict) -> str:
"""Extract plain text from Feishu post (rich text) message content. """Extract plain text from Feishu post (rich text) message content.
Legacy wrapper for _extract_post_content, returns only text. Legacy wrapper for _extract_post_content, returns only text.
""" """
text, _ = _extract_post_content(content_json) text, _ = _extract_post_content(content_json)
@@ -248,17 +233,17 @@ def _extract_post_text(content_json: dict) -> str:
class FeishuChannel(BaseChannel): class FeishuChannel(BaseChannel):
""" """
Feishu/Lark channel using WebSocket long connection. Feishu/Lark channel using WebSocket long connection.
Uses WebSocket to receive events - no public IP or webhook required. Uses WebSocket to receive events - no public IP or webhook required.
Requires: Requires:
- App ID and App Secret from Feishu Open Platform - App ID and App Secret from Feishu Open Platform
- Bot capability enabled - Bot capability enabled
- Event subscription enabled (im.message.receive_v1) - Event subscription enabled (im.message.receive_v1)
""" """
name = "feishu" name = "feishu"
def __init__(self, config: FeishuConfig, bus: MessageBus): def __init__(self, config: FeishuConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: FeishuConfig = config self.config: FeishuConfig = config
@@ -267,27 +252,28 @@ class FeishuChannel(BaseChannel):
self._ws_thread: threading.Thread | None = None self._ws_thread: threading.Thread | None = None
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
self._loop: asyncio.AbstractEventLoop | None = None self._loop: asyncio.AbstractEventLoop | None = None
async def start(self) -> None: async def start(self) -> None:
"""Start the Feishu bot with WebSocket long connection.""" """Start the Feishu bot with WebSocket long connection."""
if not FEISHU_AVAILABLE: if not FEISHU_AVAILABLE:
logger.error("Feishu SDK not installed. Run: pip install lark-oapi") logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
return return
if not self.config.app_id or not self.config.app_secret: if not self.config.app_id or not self.config.app_secret:
logger.error("Feishu app_id and app_secret not configured") logger.error("Feishu app_id and app_secret not configured")
return return
import lark_oapi as lark
self._running = True self._running = True
self._loop = asyncio.get_running_loop() self._loop = asyncio.get_running_loop()
# Create Lark client for sending messages # Create Lark client for sending messages
self._client = lark.Client.builder() \ self._client = lark.Client.builder() \
.app_id(self.config.app_id) \ .app_id(self.config.app_id) \
.app_secret(self.config.app_secret) \ .app_secret(self.config.app_secret) \
.log_level(lark.LogLevel.INFO) \ .log_level(lark.LogLevel.INFO) \
.build() .build()
# Create event handler (only register message receive, ignore other events) # Create event handler (only register message receive, ignore other events)
event_handler = lark.EventDispatcherHandler.builder( event_handler = lark.EventDispatcherHandler.builder(
self.config.encrypt_key or "", self.config.encrypt_key or "",
@@ -295,7 +281,7 @@ class FeishuChannel(BaseChannel):
).register_p2_im_message_receive_v1( ).register_p2_im_message_receive_v1(
self._on_message_sync self._on_message_sync
).build() ).build()
# Create WebSocket client for long connection # Create WebSocket client for long connection
self._ws_client = lark.ws.Client( self._ws_client = lark.ws.Client(
self.config.app_id, self.config.app_id,
@@ -303,7 +289,7 @@ class FeishuChannel(BaseChannel):
event_handler=event_handler, event_handler=event_handler,
log_level=lark.LogLevel.INFO log_level=lark.LogLevel.INFO
) )
# Start WebSocket client in a separate thread with reconnect loop # Start WebSocket client in a separate thread with reconnect loop
def run_ws(): def run_ws():
while self._running: while self._running:
@@ -312,30 +298,33 @@ class FeishuChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.warning("Feishu WebSocket error: {}", e) logger.warning("Feishu WebSocket error: {}", e)
if self._running: if self._running:
import time; time.sleep(5) import time
time.sleep(5)
self._ws_thread = threading.Thread(target=run_ws, daemon=True) self._ws_thread = threading.Thread(target=run_ws, daemon=True)
self._ws_thread.start() self._ws_thread.start()
logger.info("Feishu bot started with WebSocket long connection") logger.info("Feishu bot started with WebSocket long connection")
logger.info("No public IP required - using WebSocket to receive events") logger.info("No public IP required - using WebSocket to receive events")
# Keep running until stopped # Keep running until stopped
while self._running: while self._running:
await asyncio.sleep(1) await asyncio.sleep(1)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the Feishu bot.""" """
Stop the Feishu bot.
Notice: lark.ws.Client does not expose stop method simply exiting the program will close the client.
Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
"""
self._running = False self._running = False
if self._ws_client:
try:
self._ws_client.stop()
except Exception as e:
logger.warning("Error stopping WebSocket client: {}", e)
logger.info("Feishu bot stopped") logger.info("Feishu bot stopped")
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool).""" """Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
try: try:
request = CreateMessageReactionRequest.builder() \ request = CreateMessageReactionRequest.builder() \
.message_id(message_id) \ .message_id(message_id) \
@@ -344,9 +333,9 @@ class FeishuChannel(BaseChannel):
.reaction_type(Emoji.builder().emoji_type(emoji_type).build()) .reaction_type(Emoji.builder().emoji_type(emoji_type).build())
.build() .build()
).build() ).build()
response = self._client.im.v1.message_reaction.create(request) response = self._client.im.v1.message_reaction.create(request)
if not response.success(): if not response.success():
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg) logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
else: else:
@@ -357,15 +346,15 @@ class FeishuChannel(BaseChannel):
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None: async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
""" """
Add a reaction emoji to a message (non-blocking). Add a reaction emoji to a message (non-blocking).
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
""" """
if not self._client or not Emoji: if not self._client:
return return
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type) await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
# Regex to match markdown tables (header + separator + data rows) # Regex to match markdown tables (header + separator + data rows)
_TABLE_RE = re.compile( _TABLE_RE = re.compile(
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)", r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
@@ -379,12 +368,13 @@ class FeishuChannel(BaseChannel):
@staticmethod @staticmethod
def _parse_md_table(table_text: str) -> dict | None: def _parse_md_table(table_text: str) -> dict | None:
"""Parse a markdown table into a Feishu table element.""" """Parse a markdown table into a Feishu table element."""
lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()] lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
if len(lines) < 3: if len(lines) < 3:
return None return None
split = lambda l: [c.strip() for c in l.strip("|").split("|")] def split(_line: str) -> list[str]:
return [c.strip() for c in _line.strip("|").split("|")]
headers = split(lines[0]) headers = split(lines[0])
rows = [split(l) for l in lines[2:]] rows = [split(_line) for _line in lines[2:]]
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"} columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
for i, h in enumerate(headers)] for i, h in enumerate(headers)]
return { return {
@@ -451,6 +441,7 @@ class FeishuChannel(BaseChannel):
def _upload_image_sync(self, file_path: str) -> str | None: def _upload_image_sync(self, file_path: str) -> str | None:
"""Upload an image to Feishu and return the image_key.""" """Upload an image to Feishu and return the image_key."""
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
try: try:
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
request = CreateImageRequest.builder() \ request = CreateImageRequest.builder() \
@@ -474,6 +465,7 @@ class FeishuChannel(BaseChannel):
def _upload_file_sync(self, file_path: str) -> str | None: def _upload_file_sync(self, file_path: str) -> str | None:
"""Upload a file to Feishu and return the file_key.""" """Upload a file to Feishu and return the file_key."""
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
ext = os.path.splitext(file_path)[1].lower() ext = os.path.splitext(file_path)[1].lower()
file_type = self._FILE_TYPE_MAP.get(ext, "stream") file_type = self._FILE_TYPE_MAP.get(ext, "stream")
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
@@ -501,6 +493,7 @@ class FeishuChannel(BaseChannel):
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]: def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
"""Download an image from Feishu message by message_id and image_key.""" """Download an image from Feishu message by message_id and image_key."""
from lark_oapi.api.im.v1 import GetMessageResourceRequest
try: try:
request = GetMessageResourceRequest.builder() \ request = GetMessageResourceRequest.builder() \
.message_id(message_id) \ .message_id(message_id) \
@@ -525,6 +518,13 @@ class FeishuChannel(BaseChannel):
self, message_id: str, file_key: str, resource_type: str = "file" self, message_id: str, file_key: str, resource_type: str = "file"
) -> tuple[bytes | None, str | None]: ) -> tuple[bytes | None, str | None]:
"""Download a file/audio/media from a Feishu message by message_id and file_key.""" """Download a file/audio/media from a Feishu message by message_id and file_key."""
from lark_oapi.api.im.v1 import GetMessageResourceRequest
# Feishu API only accepts 'image' or 'file' as type parameter
# Convert 'audio' to 'file' for API compatibility
if resource_type == "audio":
resource_type = "file"
try: try:
request = ( request = (
GetMessageResourceRequest.builder() GetMessageResourceRequest.builder()
@@ -593,6 +593,7 @@ class FeishuChannel(BaseChannel):
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool: def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
"""Send a single message (text/image/file/interactive) synchronously.""" """Send a single message (text/image/file/interactive) synchronously."""
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
try: try:
request = CreateMessageRequest.builder() \ request = CreateMessageRequest.builder() \
.receive_id_type(receive_id_type) \ .receive_id_type(receive_id_type) \
@@ -656,7 +657,7 @@ class FeishuChannel(BaseChannel):
except Exception as e: except Exception as e:
logger.error("Error sending Feishu message: {}", e) logger.error("Error sending Feishu message: {}", e)
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None: def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
""" """
Sync handler for incoming messages (called from WebSocket thread). Sync handler for incoming messages (called from WebSocket thread).
@@ -664,7 +665,7 @@ class FeishuChannel(BaseChannel):
""" """
if self._loop and self._loop.is_running(): if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop) asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None: async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
"""Handle incoming message from Feishu.""" """Handle incoming message from Feishu."""
try: try:
@@ -692,7 +693,7 @@ class FeishuChannel(BaseChannel):
msg_type = message.message_type msg_type = message.message_type
# Add reaction # Add reaction
await self._add_reaction(message_id, "THUMBSUP") await self._add_reaction(message_id, self.config.react_emoji)
# Parse content # Parse content
content_parts = [] content_parts = []

View File

@@ -16,24 +16,24 @@ from nanobot.config.schema import Config
class ChannelManager: class ChannelManager:
""" """
Manages chat channels and coordinates message routing. Manages chat channels and coordinates message routing.
Responsibilities: Responsibilities:
- Initialize enabled channels (Telegram, WhatsApp, etc.) - Initialize enabled channels (Telegram, WhatsApp, etc.)
- Start/stop channels - Start/stop channels
- Route outbound messages - Route outbound messages
""" """
def __init__(self, config: Config, bus: MessageBus): def __init__(self, config: Config, bus: MessageBus):
self.config = config self.config = config
self.bus = bus self.bus = bus
self.channels: dict[str, BaseChannel] = {} self.channels: dict[str, BaseChannel] = {}
self._dispatch_task: asyncio.Task | None = None self._dispatch_task: asyncio.Task | None = None
self._init_channels() self._init_channels()
def _init_channels(self) -> None: def _init_channels(self) -> None:
"""Initialize channels based on config.""" """Initialize channels based on config."""
# Telegram channel # Telegram channel
if self.config.channels.telegram.enabled: if self.config.channels.telegram.enabled:
try: try:
@@ -46,7 +46,7 @@ class ChannelManager:
logger.info("Telegram channel enabled") logger.info("Telegram channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("Telegram channel not available: {}", e) logger.warning("Telegram channel not available: {}", e)
# WhatsApp channel # WhatsApp channel
if self.config.channels.whatsapp.enabled: if self.config.channels.whatsapp.enabled:
try: try:
@@ -68,7 +68,7 @@ class ChannelManager:
logger.info("Discord channel enabled") logger.info("Discord channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("Discord channel not available: {}", e) logger.warning("Discord channel not available: {}", e)
# Feishu channel # Feishu channel
if self.config.channels.feishu.enabled: if self.config.channels.feishu.enabled:
try: try:
@@ -136,7 +136,29 @@ class ChannelManager:
logger.info("QQ channel enabled") logger.info("QQ channel enabled")
except ImportError as e: except ImportError as e:
logger.warning("QQ channel not available: {}", e) logger.warning("QQ channel not available: {}", e)
# Matrix channel
if self.config.channels.matrix.enabled:
try:
from nanobot.channels.matrix import MatrixChannel
self.channels["matrix"] = MatrixChannel(
self.config.channels.matrix,
self.bus,
)
logger.info("Matrix channel enabled")
except ImportError as e:
logger.warning("Matrix channel not available: {}", e)
self._validate_allow_from()
def _validate_allow_from(self) -> None:
for name, ch in self.channels.items():
if getattr(ch.config, "allow_from", None) == []:
raise SystemExit(
f'Error: "{name}" has empty allowFrom (denies all). '
f'Set ["*"] to allow everyone, or add specific user IDs.'
)
async def _start_channel(self, name: str, channel: BaseChannel) -> None: async def _start_channel(self, name: str, channel: BaseChannel) -> None:
"""Start a channel and log any exceptions.""" """Start a channel and log any exceptions."""
try: try:
@@ -149,23 +171,23 @@ class ChannelManager:
if not self.channels: if not self.channels:
logger.warning("No channels enabled") logger.warning("No channels enabled")
return return
# Start outbound dispatcher # Start outbound dispatcher
self._dispatch_task = asyncio.create_task(self._dispatch_outbound()) self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
# Start channels # Start channels
tasks = [] tasks = []
for name, channel in self.channels.items(): for name, channel in self.channels.items():
logger.info("Starting {} channel...", name) logger.info("Starting {} channel...", name)
tasks.append(asyncio.create_task(self._start_channel(name, channel))) tasks.append(asyncio.create_task(self._start_channel(name, channel)))
# Wait for all to complete (they should run forever) # Wait for all to complete (they should run forever)
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
async def stop_all(self) -> None: async def stop_all(self) -> None:
"""Stop all channels and the dispatcher.""" """Stop all channels and the dispatcher."""
logger.info("Stopping all channels...") logger.info("Stopping all channels...")
# Stop dispatcher # Stop dispatcher
if self._dispatch_task: if self._dispatch_task:
self._dispatch_task.cancel() self._dispatch_task.cancel()
@@ -173,7 +195,7 @@ class ChannelManager:
await self._dispatch_task await self._dispatch_task
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Stop all channels # Stop all channels
for name, channel in self.channels.items(): for name, channel in self.channels.items():
try: try:
@@ -181,24 +203,24 @@ class ChannelManager:
logger.info("Stopped {} channel", name) logger.info("Stopped {} channel", name)
except Exception as e: except Exception as e:
logger.error("Error stopping {}: {}", name, e) logger.error("Error stopping {}: {}", name, e)
async def _dispatch_outbound(self) -> None: async def _dispatch_outbound(self) -> None:
"""Dispatch outbound messages to the appropriate channel.""" """Dispatch outbound messages to the appropriate channel."""
logger.info("Outbound dispatcher started") logger.info("Outbound dispatcher started")
while True: while True:
try: try:
msg = await asyncio.wait_for( msg = await asyncio.wait_for(
self.bus.consume_outbound(), self.bus.consume_outbound(),
timeout=1.0 timeout=1.0
) )
if msg.metadata.get("_progress"): if msg.metadata.get("_progress"):
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints: if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
continue continue
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress: if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
continue continue
channel = self.channels.get(msg.channel) channel = self.channels.get(msg.channel)
if channel: if channel:
try: try:
@@ -207,16 +229,16 @@ class ChannelManager:
logger.error("Error sending to {}: {}", msg.channel, e) logger.error("Error sending to {}: {}", msg.channel, e)
else: else:
logger.warning("Unknown channel: {}", msg.channel) logger.warning("Unknown channel: {}", msg.channel)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except asyncio.CancelledError: except asyncio.CancelledError:
break break
def get_channel(self, name: str) -> BaseChannel | None: def get_channel(self, name: str) -> BaseChannel | None:
"""Get a channel by name.""" """Get a channel by name."""
return self.channels.get(name) return self.channels.get(name)
def get_status(self) -> dict[str, Any]: def get_status(self) -> dict[str, Any]:
"""Get status of all channels.""" """Get status of all channels."""
return { return {
@@ -226,7 +248,7 @@ class ChannelManager:
} }
for name, channel in self.channels.items() for name, channel in self.channels.items()
} }
@property @property
def enabled_channels(self) -> list[str]: def enabled_channels(self) -> list[str]:
"""Get list of enabled channel names.""" """Get list of enabled channel names."""

View File

@@ -12,10 +12,22 @@ try:
import nh3 import nh3
from mistune import create_markdown from mistune import create_markdown
from nio import ( from nio import (
AsyncClient, AsyncClientConfig, ContentRepositoryConfigError, AsyncClient,
DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse, AsyncClientConfig,
RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText, ContentRepositoryConfigError,
RoomSendError, RoomTypingError, SyncError, UploadError, DownloadError,
InviteEvent,
JoinError,
MatrixRoom,
MemoryDownloadResponse,
RoomEncryptedMedia,
RoomMessage,
RoomMessageMedia,
RoomMessageText,
RoomSendError,
RoomTypingError,
SyncError,
UploadError,
) )
from nio.crypto.attachments import decrypt_attachment from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError from nio.exceptions import EncryptionError
@@ -350,7 +362,11 @@ class MatrixChannel(BaseChannel):
limit_bytes = await self._effective_media_limit_bytes() limit_bytes = await self._effective_media_limit_bytes()
for path in candidates: for path in candidates:
if fail := await self._upload_and_send_attachment( if fail := await self._upload_and_send_attachment(
msg.chat_id, path, limit_bytes, relates_to): room_id=msg.chat_id,
path=path,
limit_bytes=limit_bytes,
relates_to=relates_to,
):
failures.append(fail) failures.append(fail)
if failures: if failures:
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures) text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
@@ -438,8 +454,7 @@ class MatrixChannel(BaseChannel):
await asyncio.sleep(2) await asyncio.sleep(2)
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None: async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
allow_from = self.config.allow_from or [] if self.is_allowed(event.sender):
if not allow_from or event.sender in allow_from:
await self.client.join(room.room_id) await self.client.join(room.room_id)
def _is_direct_room(self, room: MatrixRoom) -> bool: def _is_direct_room(self, room: MatrixRoom) -> bool:
@@ -664,11 +679,13 @@ class MatrixChannel(BaseChannel):
parts: list[str] = [] parts: list[str] = []
if isinstance(body := getattr(event, "body", None), str) and body.strip(): if isinstance(body := getattr(event, "body", None), str) and body.strip():
parts.append(body.strip()) parts.append(body.strip())
parts.append(marker) if marker:
parts.append(marker)
await self._start_typing_keepalive(room.room_id) await self._start_typing_keepalive(room.room_id)
try: try:
meta = self._base_metadata(room, event) meta = self._base_metadata(room, event)
meta["attachments"] = []
if attachment: if attachment:
meta["attachments"] = [attachment] meta["attachments"] = [attachment]
await self._handle_message( await self._handle_message(

View File

@@ -31,7 +31,8 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
class _Bot(botpy.Client): class _Bot(botpy.Client):
def __init__(self): def __init__(self):
super().__init__(intents=intents) # Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
super().__init__(intents=intents, ext_handlers=False)
async def on_ready(self): async def on_ready(self):
logger.info("QQ bot ready: {}", self.robot.name) logger.info("QQ bot ready: {}", self.robot.name)
@@ -55,6 +56,7 @@ class QQChannel(BaseChannel):
self.config: QQConfig = config self.config: QQConfig = config
self._client: "botpy.Client | None" = None self._client: "botpy.Client | None" = None
self._processed_ids: deque = deque(maxlen=1000) self._processed_ids: deque = deque(maxlen=1000)
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
async def start(self) -> None: async def start(self) -> None:
"""Start the QQ bot.""" """Start the QQ bot."""
@@ -100,10 +102,14 @@ class QQChannel(BaseChannel):
logger.warning("QQ client not initialized") logger.warning("QQ client not initialized")
return return
try: try:
msg_id = msg.metadata.get("message_id")
self._msg_seq += 1 # 递增序列号
await self._client.api.post_c2c_message( await self._client.api.post_c2c_message(
openid=msg.chat_id, openid=msg.chat_id,
msg_type=0, msg_type=0,
content=msg.content, content=msg.content,
msg_id=msg_id,
msg_seq=self._msg_seq, # 添加序列号避免去重
) )
except Exception as e: except Exception as e:
logger.error("Error sending QQ message: {}", e) logger.error("Error sending QQ message: {}", e)
@@ -130,3 +136,4 @@ class QQChannel(BaseChannel):
) )
except Exception: except Exception:
logger.exception("Error handling QQ message") logger.exception("Error handling QQ message")

View File

@@ -5,11 +5,10 @@ import re
from typing import Any from typing import Any
from loguru import logger from loguru import logger
from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
from slackify_markdown import slackify_markdown from slackify_markdown import slackify_markdown
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
import asyncio import asyncio
import re import re
from loguru import logger from loguru import logger
from telegram import BotCommand, Update, ReplyParameters from telegram import BotCommand, ReplyParameters, Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
@@ -21,60 +22,60 @@ def _markdown_to_telegram_html(text: str) -> str:
""" """
if not text: if not text:
return "" return ""
# 1. Extract and protect code blocks (preserve content from other processing) # 1. Extract and protect code blocks (preserve content from other processing)
code_blocks: list[str] = [] code_blocks: list[str] = []
def save_code_block(m: re.Match) -> str: def save_code_block(m: re.Match) -> str:
code_blocks.append(m.group(1)) code_blocks.append(m.group(1))
return f"\x00CB{len(code_blocks) - 1}\x00" return f"\x00CB{len(code_blocks) - 1}\x00"
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text) text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
# 2. Extract and protect inline code # 2. Extract and protect inline code
inline_codes: list[str] = [] inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str: def save_inline_code(m: re.Match) -> str:
inline_codes.append(m.group(1)) inline_codes.append(m.group(1))
return f"\x00IC{len(inline_codes) - 1}\x00" return f"\x00IC{len(inline_codes) - 1}\x00"
text = re.sub(r'`([^`]+)`', save_inline_code, text) text = re.sub(r'`([^`]+)`', save_inline_code, text)
# 3. Headers # Title -> just the title text # 3. Headers # Title -> just the title text
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE) text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
# 4. Blockquotes > text -> just the text (before HTML escaping) # 4. Blockquotes > text -> just the text (before HTML escaping)
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE) text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
# 5. Escape HTML special characters # 5. Escape HTML special characters
text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
# 6. Links [text](url) - must be before bold/italic to handle nested cases # 6. Links [text](url) - must be before bold/italic to handle nested cases
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text) text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
# 7. Bold **text** or __text__ # 7. Bold **text** or __text__
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text) text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text) text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
# 8. Italic _text_ (avoid matching inside words like some_var_name) # 8. Italic _text_ (avoid matching inside words like some_var_name)
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text) text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
# 9. Strikethrough ~~text~~ # 9. Strikethrough ~~text~~
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text) text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
# 10. Bullet lists - item -> • item # 10. Bullet lists - item -> • item
text = re.sub(r'^[-*]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^[-*]\s+', '', text, flags=re.MULTILINE)
# 11. Restore inline code with HTML tags # 11. Restore inline code with HTML tags
for i, code in enumerate(inline_codes): for i, code in enumerate(inline_codes):
# Escape HTML in code content # Escape HTML in code content
escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>") text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
# 12. Restore code blocks with HTML tags # 12. Restore code blocks with HTML tags
for i, code in enumerate(code_blocks): for i, code in enumerate(code_blocks):
# Escape HTML in code content # Escape HTML in code content
escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;") escaped = code.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>") text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
return text return text
@@ -101,12 +102,12 @@ def _split_message(content: str, max_len: int = 4000) -> list[str]:
class TelegramChannel(BaseChannel): class TelegramChannel(BaseChannel):
""" """
Telegram channel using long polling. Telegram channel using long polling.
Simple and reliable - no webhook/public IP needed. Simple and reliable - no webhook/public IP needed.
""" """
name = "telegram" name = "telegram"
# Commands registered with Telegram's command menu # Commands registered with Telegram's command menu
BOT_COMMANDS = [ BOT_COMMANDS = [
BotCommand("start", "Start the bot"), BotCommand("start", "Start the bot"),
@@ -114,7 +115,7 @@ class TelegramChannel(BaseChannel):
BotCommand("stop", "Stop the current task"), BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"), BotCommand("help", "Show available commands"),
] ]
def __init__( def __init__(
self, self,
config: TelegramConfig, config: TelegramConfig,
@@ -127,15 +128,17 @@ class TelegramChannel(BaseChannel):
self._app: Application | None = None self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
self._media_group_buffers: dict[str, dict] = {}
self._media_group_tasks: dict[str, asyncio.Task] = {}
async def start(self) -> None: async def start(self) -> None:
"""Start the Telegram bot with long polling.""" """Start the Telegram bot with long polling."""
if not self.config.token: if not self.config.token:
logger.error("Telegram bot token not configured") logger.error("Telegram bot token not configured")
return return
self._running = True self._running = True
# Build the application with larger connection pool to avoid pool-timeout on long runs # Build the application with larger connection pool to avoid pool-timeout on long runs
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0) req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req) builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
@@ -143,62 +146,67 @@ class TelegramChannel(BaseChannel):
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy) builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build() self._app = builder.build()
self._app.add_error_handler(self._on_error) self._app.add_error_handler(self._on_error)
# Add command handlers # Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command)) self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help)) self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents # Add message handler for text, photos, voice, documents
self._app.add_handler( self._app.add_handler(
MessageHandler( MessageHandler(
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL) (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
& ~filters.COMMAND, & ~filters.COMMAND,
self._on_message self._on_message
) )
) )
logger.info("Starting Telegram bot (polling mode)...") logger.info("Starting Telegram bot (polling mode)...")
# Initialize and start polling # Initialize and start polling
await self._app.initialize() await self._app.initialize()
await self._app.start() await self._app.start()
# Get bot info and register command menu # Get bot info and register command menu
bot_info = await self._app.bot.get_me() bot_info = await self._app.bot.get_me()
logger.info("Telegram bot @{} connected", bot_info.username) logger.info("Telegram bot @{} connected", bot_info.username)
try: try:
await self._app.bot.set_my_commands(self.BOT_COMMANDS) await self._app.bot.set_my_commands(self.BOT_COMMANDS)
logger.debug("Telegram bot commands registered") logger.debug("Telegram bot commands registered")
except Exception as e: except Exception as e:
logger.warning("Failed to register bot commands: {}", e) logger.warning("Failed to register bot commands: {}", e)
# Start polling (this runs until stopped) # Start polling (this runs until stopped)
await self._app.updater.start_polling( await self._app.updater.start_polling(
allowed_updates=["message"], allowed_updates=["message"],
drop_pending_updates=True # Ignore old messages on startup drop_pending_updates=True # Ignore old messages on startup
) )
# Keep running until stopped # Keep running until stopped
while self._running: while self._running:
await asyncio.sleep(1) await asyncio.sleep(1)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the Telegram bot.""" """Stop the Telegram bot."""
self._running = False self._running = False
# Cancel all typing indicators # Cancel all typing indicators
for chat_id in list(self._typing_tasks): for chat_id in list(self._typing_tasks):
self._stop_typing(chat_id) self._stop_typing(chat_id)
for task in self._media_group_tasks.values():
task.cancel()
self._media_group_tasks.clear()
self._media_group_buffers.clear()
if self._app: if self._app:
logger.info("Stopping Telegram bot...") logger.info("Stopping Telegram bot...")
await self._app.updater.stop() await self._app.updater.stop()
await self._app.stop() await self._app.stop()
await self._app.shutdown() await self._app.shutdown()
self._app = None self._app = None
@staticmethod @staticmethod
def _get_media_type(path: str) -> str: def _get_media_type(path: str) -> str:
"""Guess media type from file extension.""" """Guess media type from file extension."""
@@ -246,7 +254,7 @@ class TelegramChannel(BaseChannel):
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document" param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
with open(media_path, 'rb') as f: with open(media_path, 'rb') as f:
await sender( await sender(
chat_id=chat_id, chat_id=chat_id,
**{param: f}, **{param: f},
reply_parameters=reply_params reply_parameters=reply_params
) )
@@ -265,8 +273,8 @@ class TelegramChannel(BaseChannel):
try: try:
html = _markdown_to_telegram_html(chunk) html = _markdown_to_telegram_html(chunk)
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=html, text=html,
parse_mode="HTML", parse_mode="HTML",
reply_parameters=reply_params reply_parameters=reply_params
) )
@@ -274,13 +282,13 @@ class TelegramChannel(BaseChannel):
logger.warning("HTML parse failed, falling back to plain text: {}", e) logger.warning("HTML parse failed, falling back to plain text: {}", e)
try: try:
await self._app.bot.send_message( await self._app.bot.send_message(
chat_id=chat_id, chat_id=chat_id,
text=chunk, text=chunk,
reply_parameters=reply_params reply_parameters=reply_params
) )
except Exception as e2: except Exception as e2:
logger.error("Error sending Telegram message: {}", e2) logger.error("Error sending Telegram message: {}", e2)
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command.""" """Handle /start command."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
@@ -319,34 +327,34 @@ class TelegramChannel(BaseChannel):
chat_id=str(update.message.chat_id), chat_id=str(update.message.chat_id),
content=update.message.text, content=update.message.text,
) )
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming messages (text, photos, voice, documents).""" """Handle incoming messages (text, photos, voice, documents)."""
if not update.message or not update.effective_user: if not update.message or not update.effective_user:
return return
message = update.message message = update.message
user = update.effective_user user = update.effective_user
chat_id = message.chat_id chat_id = message.chat_id
sender_id = self._sender_id(user) sender_id = self._sender_id(user)
# Store chat_id for replies # Store chat_id for replies
self._chat_ids[sender_id] = chat_id self._chat_ids[sender_id] = chat_id
# Build content from text and/or media # Build content from text and/or media
content_parts = [] content_parts = []
media_paths = [] media_paths = []
# Text content # Text content
if message.text: if message.text:
content_parts.append(message.text) content_parts.append(message.text)
if message.caption: if message.caption:
content_parts.append(message.caption) content_parts.append(message.caption)
# Handle media files # Handle media files
media_file = None media_file = None
media_type = None media_type = None
if message.photo: if message.photo:
media_file = message.photo[-1] # Largest photo media_file = message.photo[-1] # Largest photo
media_type = "image" media_type = "image"
@@ -359,23 +367,23 @@ class TelegramChannel(BaseChannel):
elif message.document: elif message.document:
media_file = message.document media_file = message.document
media_type = "file" media_type = "file"
# Download media if present # Download media if present
if media_file and self._app: if media_file and self._app:
try: try:
file = await self._app.bot.get_file(media_file.file_id) file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None)) ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
# Save to workspace/media/ # Save to workspace/media/
from pathlib import Path from pathlib import Path
media_dir = Path.home() / ".nanobot" / "media" media_dir = Path.home() / ".nanobot" / "media"
media_dir.mkdir(parents=True, exist_ok=True) media_dir.mkdir(parents=True, exist_ok=True)
file_path = media_dir / f"{media_file.file_id[:16]}{ext}" file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
await file.download_to_drive(str(file_path)) await file.download_to_drive(str(file_path))
media_paths.append(str(file_path)) media_paths.append(str(file_path))
# Handle voice transcription # Handle voice transcription
if media_type == "voice" or media_type == "audio": if media_type == "voice" or media_type == "audio":
from nanobot.providers.transcription import GroqTranscriptionProvider from nanobot.providers.transcription import GroqTranscriptionProvider
@@ -388,21 +396,43 @@ class TelegramChannel(BaseChannel):
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
else: else:
content_parts.append(f"[{media_type}: {file_path}]") content_parts.append(f"[{media_type}: {file_path}]")
logger.debug("Downloaded {} to {}", media_type, file_path) logger.debug("Downloaded {} to {}", media_type, file_path)
except Exception as e: except Exception as e:
logger.error("Failed to download media: {}", e) logger.error("Failed to download media: {}", e)
content_parts.append(f"[{media_type}: download failed]") content_parts.append(f"[{media_type}: download failed]")
content = "\n".join(content_parts) if content_parts else "[empty message]" content = "\n".join(content_parts) if content_parts else "[empty message]"
logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id) str_chat_id = str(chat_id)
# Telegram media groups: buffer briefly, forward as one aggregated turn.
if media_group_id := getattr(message, "media_group_id", None):
key = f"{str_chat_id}:{media_group_id}"
if key not in self._media_group_buffers:
self._media_group_buffers[key] = {
"sender_id": sender_id, "chat_id": str_chat_id,
"contents": [], "media": [],
"metadata": {
"message_id": message.message_id, "user_id": user.id,
"username": user.username, "first_name": user.first_name,
"is_group": message.chat.type != "private",
},
}
self._start_typing(str_chat_id)
buf = self._media_group_buffers[key]
if content and content != "[empty message]":
buf["contents"].append(content)
buf["media"].extend(media_paths)
if key not in self._media_group_tasks:
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
return
# Start typing indicator before processing # Start typing indicator before processing
self._start_typing(str_chat_id) self._start_typing(str_chat_id)
# Forward to the message bus # Forward to the message bus
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
@@ -417,19 +447,34 @@ class TelegramChannel(BaseChannel):
"is_group": message.chat.type != "private" "is_group": message.chat.type != "private"
} }
) )
async def _flush_media_group(self, key: str) -> None:
"""Wait briefly, then forward buffered media-group as one turn."""
try:
await asyncio.sleep(0.6)
if not (buf := self._media_group_buffers.pop(key, None)):
return
content = "\n".join(buf["contents"]) or "[empty message]"
await self._handle_message(
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
content=content, media=list(dict.fromkeys(buf["media"])),
metadata=buf["metadata"],
)
finally:
self._media_group_tasks.pop(key, None)
def _start_typing(self, chat_id: str) -> None: def _start_typing(self, chat_id: str) -> None:
"""Start sending 'typing...' indicator for a chat.""" """Start sending 'typing...' indicator for a chat."""
# Cancel any existing typing task for this chat # Cancel any existing typing task for this chat
self._stop_typing(chat_id) self._stop_typing(chat_id)
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
def _stop_typing(self, chat_id: str) -> None: def _stop_typing(self, chat_id: str) -> None:
"""Stop the typing indicator for a chat.""" """Stop the typing indicator for a chat."""
task = self._typing_tasks.pop(chat_id, None) task = self._typing_tasks.pop(chat_id, None)
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
async def _typing_loop(self, chat_id: str) -> None: async def _typing_loop(self, chat_id: str) -> None:
"""Repeatedly send 'typing' action until cancelled.""" """Repeatedly send 'typing' action until cancelled."""
try: try:
@@ -440,7 +485,7 @@ class TelegramChannel(BaseChannel):
pass pass
except Exception as e: except Exception as e:
logger.debug("Typing indicator stopped for {}: {}", chat_id, e) logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None: async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Log polling / handler errors instead of silently swallowing them.""" """Log polling / handler errors instead of silently swallowing them."""
logger.error("Telegram error: {}", context.error) logger.error("Telegram error: {}", context.error)
@@ -454,6 +499,6 @@ class TelegramChannel(BaseChannel):
} }
if mime_type in ext_map: if mime_type in ext_map:
return ext_map[mime_type] return ext_map[mime_type]
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""} type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
return type_map.get(media_type, "") return type_map.get(media_type, "")

View File

@@ -2,7 +2,7 @@
import asyncio import asyncio
import json import json
from typing import Any from collections import OrderedDict
from loguru import logger from loguru import logger
@@ -15,29 +15,30 @@ from nanobot.config.schema import WhatsAppConfig
class WhatsAppChannel(BaseChannel): class WhatsAppChannel(BaseChannel):
""" """
WhatsApp channel that connects to a Node.js bridge. WhatsApp channel that connects to a Node.js bridge.
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol. The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
Communication between Python and Node.js is via WebSocket. Communication between Python and Node.js is via WebSocket.
""" """
name = "whatsapp" name = "whatsapp"
def __init__(self, config: WhatsAppConfig, bus: MessageBus): def __init__(self, config: WhatsAppConfig, bus: MessageBus):
super().__init__(config, bus) super().__init__(config, bus)
self.config: WhatsAppConfig = config self.config: WhatsAppConfig = config
self._ws = None self._ws = None
self._connected = False self._connected = False
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
async def start(self) -> None: async def start(self) -> None:
"""Start the WhatsApp channel by connecting to the bridge.""" """Start the WhatsApp channel by connecting to the bridge."""
import websockets import websockets
bridge_url = self.config.bridge_url bridge_url = self.config.bridge_url
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url) logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
self._running = True self._running = True
while self._running: while self._running:
try: try:
async with websockets.connect(bridge_url) as ws: async with websockets.connect(bridge_url) as ws:
@@ -47,40 +48,40 @@ class WhatsAppChannel(BaseChannel):
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token})) await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
self._connected = True self._connected = True
logger.info("Connected to WhatsApp bridge") logger.info("Connected to WhatsApp bridge")
# Listen for messages # Listen for messages
async for message in ws: async for message in ws:
try: try:
await self._handle_bridge_message(message) await self._handle_bridge_message(message)
except Exception as e: except Exception as e:
logger.error("Error handling bridge message: {}", e) logger.error("Error handling bridge message: {}", e)
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
self._connected = False self._connected = False
self._ws = None self._ws = None
logger.warning("WhatsApp bridge connection error: {}", e) logger.warning("WhatsApp bridge connection error: {}", e)
if self._running: if self._running:
logger.info("Reconnecting in 5 seconds...") logger.info("Reconnecting in 5 seconds...")
await asyncio.sleep(5) await asyncio.sleep(5)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the WhatsApp channel.""" """Stop the WhatsApp channel."""
self._running = False self._running = False
self._connected = False self._connected = False
if self._ws: if self._ws:
await self._ws.close() await self._ws.close()
self._ws = None self._ws = None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WhatsApp.""" """Send a message through WhatsApp."""
if not self._ws or not self._connected: if not self._ws or not self._connected:
logger.warning("WhatsApp bridge not connected") logger.warning("WhatsApp bridge not connected")
return return
try: try:
payload = { payload = {
"type": "send", "type": "send",
@@ -90,7 +91,7 @@ class WhatsAppChannel(BaseChannel):
await self._ws.send(json.dumps(payload, ensure_ascii=False)) await self._ws.send(json.dumps(payload, ensure_ascii=False))
except Exception as e: except Exception as e:
logger.error("Error sending WhatsApp message: {}", e) logger.error("Error sending WhatsApp message: {}", e)
async def _handle_bridge_message(self, raw: str) -> None: async def _handle_bridge_message(self, raw: str) -> None:
"""Handle a message from the bridge.""" """Handle a message from the bridge."""
try: try:
@@ -98,51 +99,59 @@ class WhatsAppChannel(BaseChannel):
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning("Invalid JSON from bridge: {}", raw[:100]) logger.warning("Invalid JSON from bridge: {}", raw[:100])
return return
msg_type = data.get("type") msg_type = data.get("type")
if msg_type == "message": if msg_type == "message":
# Incoming message from WhatsApp # Incoming message from WhatsApp
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net # Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
pn = data.get("pn", "") pn = data.get("pn", "")
# New LID sytle typically: # New LID sytle typically:
sender = data.get("sender", "") sender = data.get("sender", "")
content = data.get("content", "") content = data.get("content", "")
message_id = data.get("id", "")
if message_id:
if message_id in self._processed_message_ids:
return
self._processed_message_ids[message_id] = None
while len(self._processed_message_ids) > 1000:
self._processed_message_ids.popitem(last=False)
# Extract just the phone number or lid as chat_id # Extract just the phone number or lid as chat_id
user_id = pn if pn else sender user_id = pn if pn else sender
sender_id = user_id.split("@")[0] if "@" in user_id else user_id sender_id = user_id.split("@")[0] if "@" in user_id else user_id
logger.info("Sender {}", sender) logger.info("Sender {}", sender)
# Handle voice transcription if it's a voice message # Handle voice transcription if it's a voice message
if content == "[Voice Message]": if content == "[Voice Message]":
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id) logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]" content = "[Voice Message: Transcription not available for WhatsApp yet]"
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=sender, # Use full LID for replies chat_id=sender, # Use full LID for replies
content=content, content=content,
metadata={ metadata={
"message_id": data.get("id"), "message_id": message_id,
"timestamp": data.get("timestamp"), "timestamp": data.get("timestamp"),
"is_group": data.get("isGroup", False) "is_group": data.get("isGroup", False)
} }
) )
elif msg_type == "status": elif msg_type == "status":
# Connection status update # Connection status update
status = data.get("status") status = data.get("status")
logger.info("WhatsApp status: {}", status) logger.info("WhatsApp status: {}", status)
if status == "connected": if status == "connected":
self._connected = True self._connected = True
elif status == "disconnected": elif status == "disconnected":
self._connected = False self._connected = False
elif msg_type == "qr": elif msg_type == "qr":
# QR code for authentication # QR code for authentication
logger.info("Scan QR code in the bridge terminal to connect WhatsApp") logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
elif msg_type == "error": elif msg_type == "error":
logger.error("WhatsApp bridge error: {}", data.get('error')) logger.error("WhatsApp bridge error: {}", data.get('error'))

View File

@@ -2,24 +2,24 @@
import asyncio import asyncio
import os import os
import signal
from pathlib import Path
import select import select
import signal
import sys import sys
from pathlib import Path
import typer import typer
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
from rich.text import Text from rich.text import Text
from prompt_toolkit import PromptSession from nanobot import __logo__, __version__
from prompt_toolkit.formatted_text import HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from nanobot import __version__, __logo__
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.utils.helpers import sync_workspace_templates
app = typer.Typer( app = typer.Typer(
name="nanobot", name="nanobot",
@@ -159,9 +159,9 @@ def onboard():
from nanobot.config.loader import get_config_path, load_config, save_config from nanobot.config.loader import get_config_path, load_config, save_config
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.utils.helpers import get_workspace_path from nanobot.utils.helpers import get_workspace_path
config_path = get_config_path() config_path = get_config_path()
if config_path.exists(): if config_path.exists():
console.print(f"[yellow]Config already exists at {config_path}[/yellow]") console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)") console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
@@ -177,17 +177,16 @@ def onboard():
else: else:
save_config(Config()) save_config(Config())
console.print(f"[green]✓[/green] Created config at {config_path}") console.print(f"[green]✓[/green] Created config at {config_path}")
# Create workspace # Create workspace
workspace = get_workspace_path() workspace = get_workspace_path()
if not workspace.exists(): if not workspace.exists():
workspace.mkdir(parents=True, exist_ok=True) workspace.mkdir(parents=True, exist_ok=True)
console.print(f"[green]✓[/green] Created workspace at {workspace}") console.print(f"[green]✓[/green] Created workspace at {workspace}")
# Create default bootstrap files sync_workspace_templates(workspace)
_create_workspace_templates(workspace)
console.print(f"\n{__logo__} nanobot is ready!") console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:") console.print("\nNext steps:")
console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]") console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
@@ -198,42 +197,12 @@ def onboard():
def _create_workspace_templates(workspace: Path):
"""Create default workspace template files from bundled templates."""
from importlib.resources import files as pkg_files
templates_dir = pkg_files("nanobot") / "templates"
for item in templates_dir.iterdir():
if not item.name.endswith(".md"):
continue
dest = workspace / item.name
if not dest.exists():
dest.write_text(item.read_text(encoding="utf-8"), encoding="utf-8")
console.print(f" [dim]Created {item.name}[/dim]")
memory_dir = workspace / "memory"
memory_dir.mkdir(exist_ok=True)
memory_template = templates_dir / "memory" / "MEMORY.md"
memory_file = memory_dir / "MEMORY.md"
if not memory_file.exists():
memory_file.write_text(memory_template.read_text(encoding="utf-8"), encoding="utf-8")
console.print(" [dim]Created memory/MEMORY.md[/dim]")
history_file = memory_dir / "HISTORY.md"
if not history_file.exists():
history_file.write_text("", encoding="utf-8")
console.print(" [dim]Created memory/HISTORY.md[/dim]")
(workspace / "skills").mkdir(exist_ok=True)
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.custom_provider import CustomProvider
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.custom_provider import CustomProvider
model = config.agents.defaults.model model = config.agents.defaults.model
provider_name = config.get_provider_name(model) provider_name = config.get_provider_name(model)
@@ -278,30 +247,31 @@ def gateway(
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
): ):
"""Start the nanobot gateway.""" """Start the nanobot gateway."""
from nanobot.config.loader import load_config, get_data_dir
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager from nanobot.channels.manager import ChannelManager
from nanobot.session.manager import SessionManager from nanobot.config.loader import get_data_dir, load_config
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
from nanobot.session.manager import SessionManager
if verbose: if verbose:
import logging import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
console.print(f"{__logo__} Starting nanobot gateway on port {port}...") console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
config = load_config() config = load_config()
sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path) session_manager = SessionManager(config.workspace_path)
# Create cron service first (callback set after agent creation) # Create cron service first (callback set after agent creation)
cron_store_path = get_data_dir() / "cron" / "jobs.json" cron_store_path = get_data_dir() / "cron" / "jobs.json"
cron = CronService(cron_store_path) cron = CronService(cron_store_path)
# Create agent with cron service # Create agent with cron service
agent = AgentLoop( agent = AgentLoop(
bus=bus, bus=bus,
@@ -312,7 +282,9 @@ def gateway(
max_tokens=config.agents.defaults.max_tokens, max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, memory_window=config.agents.defaults.memory_window,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
cron_service=cron, cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace, restrict_to_workspace=config.tools.restrict_to_workspace,
@@ -320,26 +292,48 @@ def gateway(
mcp_servers=config.tools.mcp_servers, mcp_servers=config.tools.mcp_servers,
channels_config=config.channels, channels_config=config.channels,
) )
# Set cron callback (needs agent) # Set cron callback (needs agent)
async def on_cron_job(job: CronJob) -> str | None: async def on_cron_job(job: CronJob) -> str | None:
"""Execute a cron job through the agent.""" """Execute a cron job through the agent."""
response = await agent.process_direct( from nanobot.agent.tools.cron import CronTool
job.payload.message, from nanobot.agent.tools.message import MessageTool
session_key=f"cron:{job.id}", reminder_note = (
channel=job.payload.channel or "cli", "[Scheduled Task] Timer finished.\n\n"
chat_id=job.payload.to or "direct", f"Task '{job.name}' has been triggered.\n"
f"Scheduled instruction: {job.payload.message}"
) )
if job.payload.deliver and job.payload.to:
# Prevent the agent from scheduling new cron jobs during execution
cron_tool = agent.tools.get("cron")
cron_token = None
if isinstance(cron_tool, CronTool):
cron_token = cron_tool.set_cron_context(True)
try:
response = await agent.process_direct(
reminder_note,
session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli",
chat_id=job.payload.to or "direct",
)
finally:
if isinstance(cron_tool, CronTool) and cron_token is not None:
cron_tool.reset_cron_context(cron_token)
message_tool = agent.tools.get("message")
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
return response
if job.payload.deliver and job.payload.to and response:
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
await bus.publish_outbound(OutboundMessage( await bus.publish_outbound(OutboundMessage(
channel=job.payload.channel or "cli", channel=job.payload.channel or "cli",
chat_id=job.payload.to, chat_id=job.payload.to,
content=response or "" content=response
)) ))
return response return response
cron.on_job = on_cron_job cron.on_job = on_cron_job
# Create channel manager # Create channel manager
channels = ChannelManager(config, bus) channels = ChannelManager(config, bus)
@@ -393,18 +387,18 @@ def gateway(
interval_s=hb_cfg.interval_s, interval_s=hb_cfg.interval_s,
enabled=hb_cfg.enabled, enabled=hb_cfg.enabled,
) )
if channels.enabled_channels: if channels.enabled_channels:
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}") console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
else: else:
console.print("[yellow]Warning: No channels enabled[/yellow]") console.print("[yellow]Warning: No channels enabled[/yellow]")
cron_status = cron.status() cron_status = cron.status()
if cron_status["jobs"] > 0: if cron_status["jobs"] > 0:
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs") console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s") console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
async def run(): async def run():
try: try:
await cron.start() await cron.start()
@@ -421,7 +415,7 @@ def gateway(
cron.stop() cron.stop()
agent.stop() agent.stop()
await channels.stop_all() await channels.stop_all()
asyncio.run(run()) asyncio.run(run())
@@ -440,14 +434,16 @@ def agent(
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"), logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
): ):
"""Interact with the agent directly.""" """Interact with the agent directly."""
from nanobot.config.loader import load_config, get_data_dir
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop
from nanobot.cron.service import CronService
from loguru import logger from loguru import logger
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import get_data_dir, load_config
from nanobot.cron.service import CronService
config = load_config() config = load_config()
sync_workspace_templates(config.workspace_path)
bus = MessageBus() bus = MessageBus()
provider = _make_provider(config) provider = _make_provider(config)
@@ -459,7 +455,7 @@ def agent(
logger.enable("nanobot") logger.enable("nanobot")
else: else:
logger.disable("nanobot") logger.disable("nanobot")
agent_loop = AgentLoop( agent_loop = AgentLoop(
bus=bus, bus=bus,
provider=provider, provider=provider,
@@ -469,14 +465,16 @@ def agent(
max_tokens=config.agents.defaults.max_tokens, max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window, memory_window=config.agents.defaults.memory_window,
reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
cron_service=cron, cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace, restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers, mcp_servers=config.tools.mcp_servers,
channels_config=config.channels, channels_config=config.channels,
) )
# Show spinner when logs are off (no output to miss); skip when logs are on # Show spinner when logs are off (no output to miss); skip when logs are on
def _thinking_ctx(): def _thinking_ctx():
if logs: if logs:
@@ -652,7 +650,7 @@ def channels_status():
"" if mc.enabled else "", "" if mc.enabled else "",
mc_base mc_base
) )
# Telegram # Telegram
tg = config.channels.telegram tg = config.channels.telegram
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]" tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
@@ -705,57 +703,57 @@ def _get_bridge_dir() -> Path:
"""Get the bridge directory, setting it up if needed.""" """Get the bridge directory, setting it up if needed."""
import shutil import shutil
import subprocess import subprocess
# User's bridge location # User's bridge location
user_bridge = Path.home() / ".nanobot" / "bridge" user_bridge = Path.home() / ".nanobot" / "bridge"
# Check if already built # Check if already built
if (user_bridge / "dist" / "index.js").exists(): if (user_bridge / "dist" / "index.js").exists():
return user_bridge return user_bridge
# Check for npm # Check for npm
if not shutil.which("npm"): if not shutil.which("npm"):
console.print("[red]npm not found. Please install Node.js >= 18.[/red]") console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
raise typer.Exit(1) raise typer.Exit(1)
# Find source bridge: first check package data, then source dir # Find source bridge: first check package data, then source dir
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed) pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev) src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
source = None source = None
if (pkg_bridge / "package.json").exists(): if (pkg_bridge / "package.json").exists():
source = pkg_bridge source = pkg_bridge
elif (src_bridge / "package.json").exists(): elif (src_bridge / "package.json").exists():
source = src_bridge source = src_bridge
if not source: if not source:
console.print("[red]Bridge source not found.[/red]") console.print("[red]Bridge source not found.[/red]")
console.print("Try reinstalling: pip install --force-reinstall nanobot") console.print("Try reinstalling: pip install --force-reinstall nanobot")
raise typer.Exit(1) raise typer.Exit(1)
console.print(f"{__logo__} Setting up bridge...") console.print(f"{__logo__} Setting up bridge...")
# Copy to user directory # Copy to user directory
user_bridge.parent.mkdir(parents=True, exist_ok=True) user_bridge.parent.mkdir(parents=True, exist_ok=True)
if user_bridge.exists(): if user_bridge.exists():
shutil.rmtree(user_bridge) shutil.rmtree(user_bridge)
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist")) shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
# Install and build # Install and build
try: try:
console.print(" Installing dependencies...") console.print(" Installing dependencies...")
subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True) subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
console.print(" Building...") console.print(" Building...")
subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True) subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
console.print("[green]✓[/green] Bridge ready\n") console.print("[green]✓[/green] Bridge ready\n")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
console.print(f"[red]Build failed: {e}[/red]") console.print(f"[red]Build failed: {e}[/red]")
if e.stderr: if e.stderr:
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]") console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
raise typer.Exit(1) raise typer.Exit(1)
return user_bridge return user_bridge
@@ -763,18 +761,19 @@ def _get_bridge_dir() -> Path:
def channels_login(): def channels_login():
"""Link device via QR code.""" """Link device via QR code."""
import subprocess import subprocess
from nanobot.config.loader import load_config from nanobot.config.loader import load_config
config = load_config() config = load_config()
bridge_dir = _get_bridge_dir() bridge_dir = _get_bridge_dir()
console.print(f"{__logo__} Starting bridge...") console.print(f"{__logo__} Starting bridge...")
console.print("Scan the QR code to connect.\n") console.print("Scan the QR code to connect.\n")
env = {**os.environ} env = {**os.environ}
if config.channels.whatsapp.bridge_token: if config.channels.whatsapp.bridge_token:
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
try: try:
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env) subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
@@ -783,218 +782,6 @@ def channels_login():
console.print("[red]npm not found. Please install Node.js.[/red]") console.print("[red]npm not found. Please install Node.js.[/red]")
# ============================================================================
# Cron Commands
# ============================================================================
cron_app = typer.Typer(help="Manage scheduled tasks")
app.add_typer(cron_app, name="cron")
@cron_app.command("list")
def cron_list(
all: bool = typer.Option(False, "--all", "-a", help="Include disabled jobs"),
):
"""List scheduled jobs."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
jobs = service.list_jobs(include_disabled=all)
if not jobs:
console.print("No scheduled jobs.")
return
table = Table(title="Scheduled Jobs")
table.add_column("ID", style="cyan")
table.add_column("Name")
table.add_column("Schedule")
table.add_column("Status")
table.add_column("Next Run")
import time
from datetime import datetime as _dt
from zoneinfo import ZoneInfo
for job in jobs:
# Format schedule
if job.schedule.kind == "every":
sched = f"every {(job.schedule.every_ms or 0) // 1000}s"
elif job.schedule.kind == "cron":
sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
else:
sched = "one-time"
# Format next run
next_run = ""
if job.state.next_run_at_ms:
ts = job.state.next_run_at_ms / 1000
try:
tz = ZoneInfo(job.schedule.tz) if job.schedule.tz else None
next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M")
except Exception:
next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))
status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
table.add_row(job.id, job.name, sched, status, next_run)
console.print(table)
@cron_app.command("add")
def cron_add(
name: str = typer.Option(..., "--name", "-n", help="Job name"),
message: str = typer.Option(..., "--message", "-m", help="Message for agent"),
every: int = typer.Option(None, "--every", "-e", help="Run every N seconds"),
cron_expr: str = typer.Option(None, "--cron", "-c", help="Cron expression (e.g. '0 9 * * *')"),
tz: str | None = typer.Option(None, "--tz", help="IANA timezone for cron (e.g. 'America/Vancouver')"),
at: str = typer.Option(None, "--at", help="Run once at time (ISO format)"),
deliver: bool = typer.Option(False, "--deliver", "-d", help="Deliver response to channel"),
to: str = typer.Option(None, "--to", help="Recipient for delivery"),
channel: str = typer.Option(None, "--channel", help="Channel for delivery (e.g. 'telegram', 'whatsapp')"),
):
"""Add a scheduled job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
from nanobot.cron.types import CronSchedule
if tz and not cron_expr:
console.print("[red]Error: --tz can only be used with --cron[/red]")
raise typer.Exit(1)
# Determine schedule type
if every:
schedule = CronSchedule(kind="every", every_ms=every * 1000)
elif cron_expr:
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
elif at:
import datetime
dt = datetime.datetime.fromisoformat(at)
schedule = CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000))
else:
console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
raise typer.Exit(1)
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
try:
job = service.add_job(
name=name,
schedule=schedule,
message=message,
deliver=deliver,
to=to,
channel=channel,
)
except ValueError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1) from e
console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
@cron_app.command("remove")
def cron_remove(
job_id: str = typer.Argument(..., help="Job ID to remove"),
):
"""Remove a scheduled job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
if service.remove_job(job_id):
console.print(f"[green]✓[/green] Removed job {job_id}")
else:
console.print(f"[red]Job {job_id} not found[/red]")
@cron_app.command("enable")
def cron_enable(
job_id: str = typer.Argument(..., help="Job ID"),
disable: bool = typer.Option(False, "--disable", help="Disable instead of enable"),
):
"""Enable or disable a job."""
from nanobot.config.loader import get_data_dir
from nanobot.cron.service import CronService
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
job = service.enable_job(job_id, enabled=not disable)
if job:
status = "disabled" if disable else "enabled"
console.print(f"[green]✓[/green] Job '{job.name}' {status}")
else:
console.print(f"[red]Job {job_id} not found[/red]")
@cron_app.command("run")
def cron_run(
job_id: str = typer.Argument(..., help="Job ID to run"),
force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
):
"""Manually run a job."""
from loguru import logger
from nanobot.config.loader import load_config, get_data_dir
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop
logger.disable("nanobot")
config = load_config()
provider = _make_provider(config)
bus = MessageBus()
agent_loop = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window,
brave_api_key=config.tools.web.search.api_key or None,
exec_config=config.tools.exec,
restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
)
store_path = get_data_dir() / "cron" / "jobs.json"
service = CronService(store_path)
result_holder = []
async def on_job(job: CronJob) -> str | None:
response = await agent_loop.process_direct(
job.payload.message,
session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli",
chat_id=job.payload.to or "direct",
)
result_holder.append(response)
return response
service.on_job = on_job
async def run():
return await service.run_job(job_id, force=force)
if asyncio.run(run()):
console.print("[green]✓[/green] Job executed")
if result_holder:
_print_agent_response(result_holder[0], render_markdown=True)
else:
console.print(f"[red]Failed to run job {job_id}[/red]")
# ============================================================================ # ============================================================================
# Status Commands # Status Commands
# ============================================================================ # ============================================================================
@@ -1003,7 +790,7 @@ def cron_run(
@app.command() @app.command()
def status(): def status():
"""Show nanobot status.""" """Show nanobot status."""
from nanobot.config.loader import load_config, get_config_path from nanobot.config.loader import get_config_path, load_config
config_path = get_config_path() config_path = get_config_path()
config = load_config() config = load_config()
@@ -1018,7 +805,7 @@ def status():
from nanobot.providers.registry import PROVIDERS from nanobot.providers.registry import PROVIDERS
console.print(f"Model: {config.agents.defaults.model}") console.print(f"Model: {config.agents.defaults.model}")
# Check API keys from registry # Check API keys from registry
for spec in PROVIDERS: for spec in PROVIDERS:
p = getattr(config.providers, spec.name, None) p = getattr(config.providers, spec.name, None)

View File

@@ -1,6 +1,6 @@
"""Configuration module for nanobot.""" """Configuration module for nanobot."""
from nanobot.config.loader import load_config, get_config_path from nanobot.config.loader import get_config_path, load_config
from nanobot.config.schema import Config from nanobot.config.schema import Config
__all__ = ["Config", "load_config", "get_config_path"] __all__ = ["Config", "load_config", "get_config_path"]

View File

@@ -3,7 +3,7 @@
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_camel from pydantic.alias_generators import to_camel
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@@ -42,6 +42,7 @@ class FeishuConfig(Base):
encrypt_key: str = "" # Encrypt Key for event subscription (optional) encrypt_key: str = "" # Encrypt Key for event subscription (optional)
verification_token: str = "" # Verification Token for event subscription (optional) verification_token: str = "" # Verification Token for event subscription (optional)
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
class DingTalkConfig(Base): class DingTalkConfig(Base):
@@ -170,6 +171,7 @@ class SlackConfig(Base):
user_token_read_only: bool = True user_token_read_only: bool = True
reply_in_thread: bool = True reply_in_thread: bool = True
react_emoji: str = "eyes" react_emoji: str = "eyes"
allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs (sender-level)
group_policy: str = "mention" # "mention", "open", "allowlist" group_policy: str = "mention" # "mention", "open", "allowlist"
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
dm: SlackDMConfig = Field(default_factory=SlackDMConfig) dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
@@ -183,6 +185,20 @@ class QQConfig(Base):
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access) allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
class MatrixConfig(Base):
"""Matrix (Element) channel configuration."""
enabled: bool = False
homeserver: str = "https://matrix.org"
access_token: str = ""
user_id: str = "" # e.g. @bot:matrix.org
device_id: str = ""
e2ee_enabled: bool = True # end-to-end encryption support
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
class ChannelsConfig(Base): class ChannelsConfig(Base):
"""Configuration for chat channels.""" """Configuration for chat channels."""
@@ -211,6 +227,7 @@ class AgentDefaults(Base):
temperature: float = 0.1 temperature: float = 0.1
max_tool_iterations: int = 40 max_tool_iterations: int = 40
memory_window: int = 100 memory_window: int = 100
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
class AgentsConfig(Base): class AgentsConfig(Base):
@@ -277,6 +294,7 @@ class WebSearchConfig(Base):
class WebToolsConfig(Base): class WebToolsConfig(Base):
"""Web tools configuration.""" """Web tools configuration."""
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
search: WebSearchConfig = Field(default_factory=WebSearchConfig) search: WebSearchConfig = Field(default_factory=WebSearchConfig)

View File

@@ -21,17 +21,18 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
"""Compute next run time in ms.""" """Compute next run time in ms."""
if schedule.kind == "at": if schedule.kind == "at":
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
if schedule.kind == "every": if schedule.kind == "every":
if not schedule.every_ms or schedule.every_ms <= 0: if not schedule.every_ms or schedule.every_ms <= 0:
return None return None
# Next interval from now # Next interval from now
return now_ms + schedule.every_ms return now_ms + schedule.every_ms
if schedule.kind == "cron" and schedule.expr: if schedule.kind == "cron" and schedule.expr:
try: try:
from croniter import croniter
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from croniter import croniter
# Use caller-provided reference time for deterministic scheduling # Use caller-provided reference time for deterministic scheduling
base_time = now_ms / 1000 base_time = now_ms / 1000
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
@@ -41,7 +42,7 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
return int(next_dt.timestamp() * 1000) return int(next_dt.timestamp() * 1000)
except Exception: except Exception:
return None return None
return None return None
@@ -61,23 +62,29 @@ def _validate_schedule_for_add(schedule: CronSchedule) -> None:
class CronService: class CronService:
"""Service for managing and executing scheduled jobs.""" """Service for managing and executing scheduled jobs."""
def __init__( def __init__(
self, self,
store_path: Path, store_path: Path,
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
): ):
self.store_path = store_path self.store_path = store_path
self.on_job = on_job # Callback to execute job, returns response text self.on_job = on_job
self._store: CronStore | None = None self._store: CronStore | None = None
self._last_mtime: float = 0.0
self._timer_task: asyncio.Task | None = None self._timer_task: asyncio.Task | None = None
self._running = False self._running = False
def _load_store(self) -> CronStore: def _load_store(self) -> CronStore:
"""Load jobs from disk.""" """Load jobs from disk. Reloads automatically if file was modified externally."""
if self._store and self.store_path.exists():
mtime = self.store_path.stat().st_mtime
if mtime != self._last_mtime:
logger.info("Cron: jobs.json modified externally, reloading")
self._store = None
if self._store: if self._store:
return self._store return self._store
if self.store_path.exists(): if self.store_path.exists():
try: try:
data = json.loads(self.store_path.read_text(encoding="utf-8")) data = json.loads(self.store_path.read_text(encoding="utf-8"))
@@ -117,16 +124,16 @@ class CronService:
self._store = CronStore() self._store = CronStore()
else: else:
self._store = CronStore() self._store = CronStore()
return self._store return self._store
def _save_store(self) -> None: def _save_store(self) -> None:
"""Save jobs to disk.""" """Save jobs to disk."""
if not self._store: if not self._store:
return return
self.store_path.parent.mkdir(parents=True, exist_ok=True) self.store_path.parent.mkdir(parents=True, exist_ok=True)
data = { data = {
"version": self._store.version, "version": self._store.version,
"jobs": [ "jobs": [
@@ -161,8 +168,9 @@ class CronService:
for j in self._store.jobs for j in self._store.jobs
] ]
} }
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8") self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
self._last_mtime = self.store_path.stat().st_mtime
async def start(self) -> None: async def start(self) -> None:
"""Start the cron service.""" """Start the cron service."""
@@ -172,14 +180,14 @@ class CronService:
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else [])) logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
def stop(self) -> None: def stop(self) -> None:
"""Stop the cron service.""" """Stop the cron service."""
self._running = False self._running = False
if self._timer_task: if self._timer_task:
self._timer_task.cancel() self._timer_task.cancel()
self._timer_task = None self._timer_task = None
def _recompute_next_runs(self) -> None: def _recompute_next_runs(self) -> None:
"""Recompute next run times for all enabled jobs.""" """Recompute next run times for all enabled jobs."""
if not self._store: if not self._store:
@@ -188,73 +196,74 @@ class CronService:
for job in self._store.jobs: for job in self._store.jobs:
if job.enabled: if job.enabled:
job.state.next_run_at_ms = _compute_next_run(job.schedule, now) job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
def _get_next_wake_ms(self) -> int | None: def _get_next_wake_ms(self) -> int | None:
"""Get the earliest next run time across all jobs.""" """Get the earliest next run time across all jobs."""
if not self._store: if not self._store:
return None return None
times = [j.state.next_run_at_ms for j in self._store.jobs times = [j.state.next_run_at_ms for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms] if j.enabled and j.state.next_run_at_ms]
return min(times) if times else None return min(times) if times else None
def _arm_timer(self) -> None: def _arm_timer(self) -> None:
"""Schedule the next timer tick.""" """Schedule the next timer tick."""
if self._timer_task: if self._timer_task:
self._timer_task.cancel() self._timer_task.cancel()
next_wake = self._get_next_wake_ms() next_wake = self._get_next_wake_ms()
if not next_wake or not self._running: if not next_wake or not self._running:
return return
delay_ms = max(0, next_wake - _now_ms()) delay_ms = max(0, next_wake - _now_ms())
delay_s = delay_ms / 1000 delay_s = delay_ms / 1000
async def tick(): async def tick():
await asyncio.sleep(delay_s) await asyncio.sleep(delay_s)
if self._running: if self._running:
await self._on_timer() await self._on_timer()
self._timer_task = asyncio.create_task(tick()) self._timer_task = asyncio.create_task(tick())
async def _on_timer(self) -> None: async def _on_timer(self) -> None:
"""Handle timer tick - run due jobs.""" """Handle timer tick - run due jobs."""
self._load_store()
if not self._store: if not self._store:
return return
now = _now_ms() now = _now_ms()
due_jobs = [ due_jobs = [
j for j in self._store.jobs j for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
] ]
for job in due_jobs: for job in due_jobs:
await self._execute_job(job) await self._execute_job(job)
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
async def _execute_job(self, job: CronJob) -> None: async def _execute_job(self, job: CronJob) -> None:
"""Execute a single job.""" """Execute a single job."""
start_ms = _now_ms() start_ms = _now_ms()
logger.info("Cron: executing job '{}' ({})", job.name, job.id) logger.info("Cron: executing job '{}' ({})", job.name, job.id)
try: try:
response = None response = None
if self.on_job: if self.on_job:
response = await self.on_job(job) response = await self.on_job(job)
job.state.last_status = "ok" job.state.last_status = "ok"
job.state.last_error = None job.state.last_error = None
logger.info("Cron: job '{}' completed", job.name) logger.info("Cron: job '{}' completed", job.name)
except Exception as e: except Exception as e:
job.state.last_status = "error" job.state.last_status = "error"
job.state.last_error = str(e) job.state.last_error = str(e)
logger.error("Cron: job '{}' failed: {}", job.name, e) logger.error("Cron: job '{}' failed: {}", job.name, e)
job.state.last_run_at_ms = start_ms job.state.last_run_at_ms = start_ms
job.updated_at_ms = _now_ms() job.updated_at_ms = _now_ms()
# Handle one-shot jobs # Handle one-shot jobs
if job.schedule.kind == "at": if job.schedule.kind == "at":
if job.delete_after_run: if job.delete_after_run:
@@ -265,15 +274,15 @@ class CronService:
else: else:
# Compute next run # Compute next run
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms()) job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
# ========== Public API ========== # ========== Public API ==========
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]: def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
"""List all jobs.""" """List all jobs."""
store = self._load_store() store = self._load_store()
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled] jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf')) return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf'))
def add_job( def add_job(
self, self,
name: str, name: str,
@@ -288,7 +297,7 @@ class CronService:
store = self._load_store() store = self._load_store()
_validate_schedule_for_add(schedule) _validate_schedule_for_add(schedule)
now = _now_ms() now = _now_ms()
job = CronJob( job = CronJob(
id=str(uuid.uuid4())[:8], id=str(uuid.uuid4())[:8],
name=name, name=name,
@@ -306,28 +315,28 @@ class CronService:
updated_at_ms=now, updated_at_ms=now,
delete_after_run=delete_after_run, delete_after_run=delete_after_run,
) )
store.jobs.append(job) store.jobs.append(job)
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info("Cron: added job '{}' ({})", name, job.id) logger.info("Cron: added job '{}' ({})", name, job.id)
return job return job
def remove_job(self, job_id: str) -> bool: def remove_job(self, job_id: str) -> bool:
"""Remove a job by ID.""" """Remove a job by ID."""
store = self._load_store() store = self._load_store()
before = len(store.jobs) before = len(store.jobs)
store.jobs = [j for j in store.jobs if j.id != job_id] store.jobs = [j for j in store.jobs if j.id != job_id]
removed = len(store.jobs) < before removed = len(store.jobs) < before
if removed: if removed:
self._save_store() self._save_store()
self._arm_timer() self._arm_timer()
logger.info("Cron: removed job {}", job_id) logger.info("Cron: removed job {}", job_id)
return removed return removed
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None: def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
"""Enable or disable a job.""" """Enable or disable a job."""
store = self._load_store() store = self._load_store()
@@ -343,7 +352,7 @@ class CronService:
self._arm_timer() self._arm_timer()
return job return job
return None return None
async def run_job(self, job_id: str, force: bool = False) -> bool: async def run_job(self, job_id: str, force: bool = False) -> bool:
"""Manually run a job.""" """Manually run a job."""
store = self._load_store() store = self._load_store()
@@ -356,7 +365,7 @@ class CronService:
self._arm_timer() self._arm_timer()
return True return True
return False return False
def status(self) -> dict: def status(self) -> dict:
"""Get service status.""" """Get service status."""
store = self._load_store() store = self._load_store()

View File

@@ -21,6 +21,7 @@ class LLMResponse:
finish_reason: str = "stop" finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict) usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property @property
def has_tool_calls(self) -> bool: def has_tool_calls(self) -> bool:
@@ -35,7 +36,7 @@ class LLMProvider(ABC):
Implementations should handle the specifics of each provider's API Implementations should handle the specifics of each provider's API
while maintaining a consistent interface. while maintaining a consistent interface.
""" """
def __init__(self, api_key: str | None = None, api_base: str | None = None): def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
@@ -77,9 +78,15 @@ class LLMProvider(ABC):
result.append(clean) result.append(clean)
continue continue
if isinstance(content, dict):
clean = dict(msg)
clean["content"] = [content]
result.append(clean)
continue
result.append(msg) result.append(msg)
return result return result
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
@@ -88,6 +95,7 @@ class LLMProvider(ABC):
model: str | None = None, model: str | None = None,
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request. Send a chat completion request.
@@ -103,7 +111,7 @@ class LLMProvider(ABC):
LLMResponse with content and/or tool calls. LLMResponse with content and/or tool calls.
""" """
pass pass
@abstractmethod @abstractmethod
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model for this provider.""" """Get the default model for this provider."""

View File

@@ -18,13 +18,16 @@ class CustomProvider(LLMProvider):
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base) self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse: model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
reasoning_effort: str | None = None) -> LLMResponse:
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model or self.default_model, "model": model or self.default_model,
"messages": self._sanitize_empty_content(messages), "messages": self._sanitize_empty_content(messages),
"max_tokens": max(1, max_tokens), "max_tokens": max(1, max_tokens),
"temperature": temperature, "temperature": temperature,
} }
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
if tools: if tools:
kwargs.update(tools=tools, tool_choice="auto") kwargs.update(tools=tools, tool_choice="auto")
try: try:

View File

@@ -1,20 +1,25 @@
"""LiteLLM provider implementation for multi-provider support.""" """LiteLLM provider implementation for multi-provider support."""
import json
import json_repair
import os import os
import secrets
import string
from typing import Any from typing import Any
import json_repair
import litellm import litellm
from litellm import acompletion from litellm import acompletion
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway from nanobot.providers.registry import find_by_model, find_gateway
# Standard chat-completion message keys.
# Standard OpenAI chat-completion message keys plus reasoning_content for
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) _ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
_ALNUM = string.ascii_letters + string.digits
def _short_tool_id() -> str:
"""Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
return "".join(secrets.choice(_ALNUM) for _ in range(9))
class LiteLLMProvider(LLMProvider): class LiteLLMProvider(LLMProvider):
@@ -25,10 +30,10 @@ class LiteLLMProvider(LLMProvider):
a unified interface. Provider-specific logic is driven by the registry a unified interface. Provider-specific logic is driven by the registry
(see providers/registry.py) — no if-elif chains needed here. (see providers/registry.py) — no if-elif chains needed here.
""" """
def __init__( def __init__(
self, self,
api_key: str | None = None, api_key: str | None = None,
api_base: str | None = None, api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5", default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None,
@@ -37,24 +42,24 @@ class LiteLLMProvider(LLMProvider):
super().__init__(api_key, api_base) super().__init__(api_key, api_base)
self.default_model = default_model self.default_model = default_model
self.extra_headers = extra_headers or {} self.extra_headers = extra_headers or {}
# Detect gateway / local deployment. # Detect gateway / local deployment.
# provider_name (from config key) is the primary signal; # provider_name (from config key) is the primary signal;
# api_key / api_base are fallback for auto-detection. # api_key / api_base are fallback for auto-detection.
self._gateway = find_gateway(provider_name, api_key, api_base) self._gateway = find_gateway(provider_name, api_key, api_base)
# Configure environment variables # Configure environment variables
if api_key: if api_key:
self._setup_env(api_key, api_base, default_model) self._setup_env(api_key, api_base, default_model)
if api_base: if api_base:
litellm.api_base = api_base litellm.api_base = api_base
# Disable LiteLLM logging noise # Disable LiteLLM logging noise
litellm.suppress_debug_info = True litellm.suppress_debug_info = True
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params) # Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
litellm.drop_params = True litellm.drop_params = True
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
"""Set environment variables based on detected provider.""" """Set environment variables based on detected provider."""
spec = self._gateway or find_by_model(model) spec = self._gateway or find_by_model(model)
@@ -78,7 +83,7 @@ class LiteLLMProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key) resolved = env_val.replace("{api_key}", api_key)
resolved = resolved.replace("{api_base}", effective_base) resolved = resolved.replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved) os.environ.setdefault(env_name, resolved)
def _resolve_model(self, model: str) -> str: def _resolve_model(self, model: str) -> str:
"""Resolve model name by applying provider/gateway prefixes.""" """Resolve model name by applying provider/gateway prefixes."""
if self._gateway: if self._gateway:
@@ -89,7 +94,7 @@ class LiteLLMProvider(LLMProvider):
if prefix and not model.startswith(f"{prefix}/"): if prefix and not model.startswith(f"{prefix}/"):
model = f"{prefix}/{model}" model = f"{prefix}/{model}"
return model return model
# Standard mode: auto-prefix for known providers # Standard mode: auto-prefix for known providers
spec = find_by_model(model) spec = find_by_model(model)
if spec and spec.litellm_prefix: if spec and spec.litellm_prefix:
@@ -108,7 +113,7 @@ class LiteLLMProvider(LLMProvider):
if prefix.lower().replace("-", "_") != spec_name: if prefix.lower().replace("-", "_") != spec_name:
return model return model
return f"{canonical_prefix}/{remainder}" return f"{canonical_prefix}/{remainder}"
def _supports_cache_control(self, model: str) -> bool: def _supports_cache_control(self, model: str) -> bool:
"""Return True when the provider supports cache_control on content blocks.""" """Return True when the provider supports cache_control on content blocks."""
if self._gateway is not None: if self._gateway is not None:
@@ -151,13 +156,22 @@ class LiteLLMProvider(LLMProvider):
if pattern in model_lower: if pattern in model_lower:
kwargs.update(overrides) kwargs.update(overrides)
return return
@staticmethod @staticmethod
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
"""Return provider-specific extra keys to preserve in request messages."""
spec = find_by_model(original_model) or find_by_model(resolved_model)
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
return _ANTHROPIC_EXTRA_KEYS
return frozenset()
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key.""" """Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = [] sanitized = []
for msg in messages: for msg in messages:
clean = {k: v for k, v in msg.items() if k in _ALLOWED_MSG_KEYS} clean = {k: v for k, v in msg.items() if k in allowed}
# Strict providers require "content" even when assistant only has tool_calls # Strict providers require "content" even when assistant only has tool_calls
if clean.get("role") == "assistant" and "content" not in clean: if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None clean["content"] = None
@@ -171,22 +185,24 @@ class LiteLLMProvider(LLMProvider):
model: str | None = None, model: str | None = None,
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request via LiteLLM. Send a chat completion request via LiteLLM.
Args: Args:
messages: List of message dicts with 'role' and 'content'. messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format. tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5'). model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
max_tokens: Maximum tokens in response. max_tokens: Maximum tokens in response.
temperature: Sampling temperature. temperature: Sampling temperature.
Returns: Returns:
LLMResponse with content and/or tool calls. LLMResponse with content and/or tool calls.
""" """
original_model = model or self.default_model original_model = model or self.default_model
model = self._resolve_model(original_model) model = self._resolve_model(original_model)
extra_msg_keys = self._extra_msg_keys(original_model, model)
if self._supports_cache_control(original_model): if self._supports_cache_control(original_model):
messages, tools = self._apply_cache_control(messages, tools) messages, tools = self._apply_cache_control(messages, tools)
@@ -194,33 +210,37 @@ class LiteLLMProvider(LLMProvider):
# Clamp max_tokens to at least 1 — negative or zero values cause # Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1". # LiteLLM to reject the request with "max_tokens must be at least 1".
max_tokens = max(1, max_tokens) max_tokens = max(1, max_tokens)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model, "model": model,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)), "messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
"max_tokens": max_tokens, "max_tokens": max_tokens,
"temperature": temperature, "temperature": temperature,
} }
# Apply model-specific overrides (e.g. kimi-k2.5 temperature) # Apply model-specific overrides (e.g. kimi-k2.5 temperature)
self._apply_model_overrides(model, kwargs) self._apply_model_overrides(model, kwargs)
# Pass api_key directly — more reliable than env vars alone # Pass api_key directly — more reliable than env vars alone
if self.api_key: if self.api_key:
kwargs["api_key"] = self.api_key kwargs["api_key"] = self.api_key
# Pass api_base for custom endpoints # Pass api_base for custom endpoints
if self.api_base: if self.api_base:
kwargs["api_base"] = self.api_base kwargs["api_base"] = self.api_base
# Pass extra headers (e.g. APP-Code for AiHubMix) # Pass extra headers (e.g. APP-Code for AiHubMix)
if self.extra_headers: if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers kwargs["extra_headers"] = self.extra_headers
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
kwargs["drop_params"] = True
if tools: if tools:
kwargs["tools"] = tools kwargs["tools"] = tools
kwargs["tool_choice"] = "auto" kwargs["tool_choice"] = "auto"
try: try:
response = await acompletion(**kwargs) response = await acompletion(**kwargs)
return self._parse_response(response) return self._parse_response(response)
@@ -230,12 +250,12 @@ class LiteLLMProvider(LLMProvider):
content=f"Error calling LLM: {str(e)}", content=f"Error calling LLM: {str(e)}",
finish_reason="error", finish_reason="error",
) )
def _parse_response(self, response: Any) -> LLMResponse: def _parse_response(self, response: Any) -> LLMResponse:
"""Parse LiteLLM response into our standard format.""" """Parse LiteLLM response into our standard format."""
choice = response.choices[0] choice = response.choices[0]
message = choice.message message = choice.message
tool_calls = [] tool_calls = []
if hasattr(message, "tool_calls") and message.tool_calls: if hasattr(message, "tool_calls") and message.tool_calls:
for tc in message.tool_calls: for tc in message.tool_calls:
@@ -243,13 +263,13 @@ class LiteLLMProvider(LLMProvider):
args = tc.function.arguments args = tc.function.arguments
if isinstance(args, str): if isinstance(args, str):
args = json_repair.loads(args) args = json_repair.loads(args)
tool_calls.append(ToolCallRequest( tool_calls.append(ToolCallRequest(
id=tc.id, id=_short_tool_id(),
name=tc.function.name, name=tc.function.name,
arguments=args, arguments=args,
)) ))
usage = {} usage = {}
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
usage = { usage = {
@@ -257,8 +277,9 @@ class LiteLLMProvider(LLMProvider):
"completion_tokens": response.usage.completion_tokens, "completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens, "total_tokens": response.usage.total_tokens,
} }
reasoning_content = getattr(message, "reasoning_content", None) or None reasoning_content = getattr(message, "reasoning_content", None) or None
thinking_blocks = getattr(message, "thinking_blocks", None) or None
return LLMResponse( return LLMResponse(
content=message.content, content=message.content,
@@ -266,8 +287,9 @@ class LiteLLMProvider(LLMProvider):
finish_reason=choice.finish_reason or "stop", finish_reason=choice.finish_reason or "stop",
usage=usage, usage=usage,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
thinking_blocks=thinking_blocks,
) )
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model.""" """Get the default model."""
return self.default_model return self.default_model

View File

@@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
import httpx import httpx
from loguru import logger from loguru import logger
from oauth_cli_kit import get_token as get_codex_token from oauth_cli_kit import get_token as get_codex_token
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses" DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
@@ -31,6 +31,7 @@ class OpenAICodexProvider(LLMProvider):
model: str | None = None, model: str | None = None,
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse: ) -> LLMResponse:
model = model or self.default_model model = model or self.default_model
system_prompt, input_items = _convert_messages(messages) system_prompt, input_items = _convert_messages(messages)
@@ -51,6 +52,9 @@ class OpenAICodexProvider(LLMProvider):
"parallel_tool_calls": True, "parallel_tool_calls": True,
} }
if reasoning_effort:
body["reasoning"] = {"effort": reasoning_effort}
if tools: if tools:
body["tools"] = _convert_tools(tools) body["tools"] = _convert_tools(tools)

View File

@@ -255,7 +255,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# OpenAI Codex: uses OAuth, not API key. # OpenAI Codex: uses OAuth, not API key.
ProviderSpec( ProviderSpec(
name="openai_codex", name="openai_codex",
keywords=("openai-codex", "codex"), keywords=("openai-codex",),
env_key="", # OAuth-based, no API key env_key="", # OAuth-based, no API key
display_name="OpenAI Codex", display_name="OpenAI Codex",
litellm_prefix="", # Not routed through LiteLLM litellm_prefix="", # Not routed through LiteLLM

View File

@@ -2,7 +2,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Any
import httpx import httpx
from loguru import logger from loguru import logger
@@ -11,33 +10,33 @@ from loguru import logger
class GroqTranscriptionProvider: class GroqTranscriptionProvider:
""" """
Voice transcription provider using Groq's Whisper API. Voice transcription provider using Groq's Whisper API.
Groq offers extremely fast transcription with a generous free tier. Groq offers extremely fast transcription with a generous free tier.
""" """
def __init__(self, api_key: str | None = None): def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.environ.get("GROQ_API_KEY") self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions" self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
async def transcribe(self, file_path: str | Path) -> str: async def transcribe(self, file_path: str | Path) -> str:
""" """
Transcribe an audio file using Groq. Transcribe an audio file using Groq.
Args: Args:
file_path: Path to the audio file. file_path: Path to the audio file.
Returns: Returns:
Transcribed text. Transcribed text.
""" """
if not self.api_key: if not self.api_key:
logger.warning("Groq API key not configured for transcription") logger.warning("Groq API key not configured for transcription")
return "" return ""
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
logger.error("Audio file not found: {}", file_path) logger.error("Audio file not found: {}", file_path)
return "" return ""
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
with open(path, "rb") as f: with open(path, "rb") as f:
@@ -48,18 +47,18 @@ class GroqTranscriptionProvider:
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
} }
response = await client.post( response = await client.post(
self.api_url, self.api_url,
headers=headers, headers=headers,
files=files, files=files,
timeout=60.0 timeout=60.0
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
return data.get("text", "") return data.get("text", "")
except Exception as e: except Exception as e:
logger.error("Groq transcription error: {}", e) logger.error("Groq transcription error: {}", e)
return "" return ""

View File

@@ -1,5 +1,5 @@
"""Session management module.""" """Session management module."""
from nanobot.session.manager import SessionManager, Session from nanobot.session.manager import Session, SessionManager
__all__ = ["SessionManager", "Session"] __all__ = ["SessionManager", "Session"]

View File

@@ -2,9 +2,9 @@
import json import json
import shutil import shutil
from pathlib import Path
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@@ -30,7 +30,7 @@ class Session:
updated_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now)
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
last_consolidated: int = 0 # Number of messages already consolidated to files last_consolidated: int = 0 # Number of messages already consolidated to files
def add_message(self, role: str, content: str, **kwargs: Any) -> None: def add_message(self, role: str, content: str, **kwargs: Any) -> None:
"""Add a message to the session.""" """Add a message to the session."""
msg = { msg = {
@@ -41,7 +41,7 @@ class Session:
} }
self.messages.append(msg) self.messages.append(msg)
self.updated_at = datetime.now() self.updated_at = datetime.now()
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a user turn.""" """Return unconsolidated messages for LLM input, aligned to a user turn."""
unconsolidated = self.messages[self.last_consolidated:] unconsolidated = self.messages[self.last_consolidated:]
@@ -61,7 +61,7 @@ class Session:
entry[k] = m[k] entry[k] = m[k]
out.append(entry) out.append(entry)
return out return out
def clear(self) -> None: def clear(self) -> None:
"""Clear all messages and reset session to initial state.""" """Clear all messages and reset session to initial state."""
self.messages = [] self.messages = []
@@ -81,7 +81,7 @@ class SessionManager:
self.sessions_dir = ensure_dir(self.workspace / "sessions") self.sessions_dir = ensure_dir(self.workspace / "sessions")
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions" self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
self._cache: dict[str, Session] = {} self._cache: dict[str, Session] = {}
def _get_session_path(self, key: str) -> Path: def _get_session_path(self, key: str) -> Path:
"""Get the file path for a session.""" """Get the file path for a session."""
safe_key = safe_filename(key.replace(":", "_")) safe_key = safe_filename(key.replace(":", "_"))
@@ -91,27 +91,27 @@ class SessionManager:
"""Legacy global session path (~/.nanobot/sessions/).""" """Legacy global session path (~/.nanobot/sessions/)."""
safe_key = safe_filename(key.replace(":", "_")) safe_key = safe_filename(key.replace(":", "_"))
return self.legacy_sessions_dir / f"{safe_key}.jsonl" return self.legacy_sessions_dir / f"{safe_key}.jsonl"
def get_or_create(self, key: str) -> Session: def get_or_create(self, key: str) -> Session:
""" """
Get an existing session or create a new one. Get an existing session or create a new one.
Args: Args:
key: Session key (usually channel:chat_id). key: Session key (usually channel:chat_id).
Returns: Returns:
The session. The session.
""" """
if key in self._cache: if key in self._cache:
return self._cache[key] return self._cache[key]
session = self._load(key) session = self._load(key)
if session is None: if session is None:
session = Session(key=key) session = Session(key=key)
self._cache[key] = session self._cache[key] = session
return session return session
def _load(self, key: str) -> Session | None: def _load(self, key: str) -> Session | None:
"""Load a session from disk.""" """Load a session from disk."""
path = self._get_session_path(key) path = self._get_session_path(key)
@@ -158,7 +158,7 @@ class SessionManager:
except Exception as e: except Exception as e:
logger.warning("Failed to load session {}: {}", key, e) logger.warning("Failed to load session {}: {}", key, e)
return None return None
def save(self, session: Session) -> None: def save(self, session: Session) -> None:
"""Save a session to disk.""" """Save a session to disk."""
path = self._get_session_path(session.key) path = self._get_session_path(session.key)
@@ -177,20 +177,20 @@ class SessionManager:
f.write(json.dumps(msg, ensure_ascii=False) + "\n") f.write(json.dumps(msg, ensure_ascii=False) + "\n")
self._cache[session.key] = session self._cache[session.key] = session
def invalidate(self, key: str) -> None: def invalidate(self, key: str) -> None:
"""Remove a session from the in-memory cache.""" """Remove a session from the in-memory cache."""
self._cache.pop(key, None) self._cache.pop(key, None)
def list_sessions(self) -> list[dict[str, Any]]: def list_sessions(self) -> list[dict[str, Any]]:
""" """
List all sessions. List all sessions.
Returns: Returns:
List of session info dicts. List of session info dicts.
""" """
sessions = [] sessions = []
for path in self.sessions_dir.glob("*.jsonl"): for path in self.sessions_dir.glob("*.jsonl"):
try: try:
# Read just the metadata line # Read just the metadata line
@@ -208,5 +208,5 @@ class SessionManager:
}) })
except Exception: except Exception:
continue continue
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True) return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)

View File

@@ -9,7 +9,7 @@ always: true
## Structure ## Structure
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context. - `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. - `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM].
## Search Past Events ## Search Past Events

View File

@@ -4,17 +4,15 @@ You are a helpful AI assistant. Be concise, accurate, and friendly.
## Scheduled Reminders ## Scheduled Reminders
When user asks for a reminder at a specific time, use `exec` to run: Before scheduling reminders, check available skills and follow skill guidance first.
``` Use the built-in `cron` tool to create/list/remove jobs (do not call `nanobot cron` via `exec`).
nanobot cron add --name "reminder" --message "Your message" --at "YYYY-MM-DDTHH:MM:SS" --deliver --to "USER_ID" --channel "CHANNEL"
```
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`). Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications. **Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
## Heartbeat Tasks ## Heartbeat Tasks
`HEARTBEAT.md` is checked every 30 minutes. Use file tools to manage periodic tasks: `HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
- **Add**: `edit_file` to append new tasks - **Add**: `edit_file` to append new tasks
- **Remove**: `edit_file` to delete completed tasks - **Remove**: `edit_file` to delete completed tasks

View File

@@ -1,5 +1,5 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
from nanobot.utils.helpers import ensure_dir, get_workspace_path, get_data_path from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"] __all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]

View File

@@ -1,79 +1,67 @@
"""Utility functions for nanobot.""" """Utility functions for nanobot."""
from pathlib import Path import re
from datetime import datetime from datetime import datetime
from pathlib import Path
def ensure_dir(path: Path) -> Path: def ensure_dir(path: Path) -> Path:
"""Ensure a directory exists, creating it if necessary.""" """Ensure directory exists, return it."""
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
return path return path
def get_data_path() -> Path: def get_data_path() -> Path:
"""Get the nanobot data directory (~/.nanobot).""" """~/.nanobot data directory."""
return ensure_dir(Path.home() / ".nanobot") return ensure_dir(Path.home() / ".nanobot")
def get_workspace_path(workspace: str | None = None) -> Path: def get_workspace_path(workspace: str | None = None) -> Path:
""" """Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace."""
Get the workspace path. path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
Args:
workspace: Optional workspace path. Defaults to ~/.nanobot/workspace.
Returns:
Expanded and ensured workspace path.
"""
if workspace:
path = Path(workspace).expanduser()
else:
path = Path.home() / ".nanobot" / "workspace"
return ensure_dir(path) return ensure_dir(path)
def get_sessions_path() -> Path:
"""Get the sessions storage directory."""
return ensure_dir(get_data_path() / "sessions")
def get_skills_path(workspace: Path | None = None) -> Path:
"""Get the skills directory within the workspace."""
ws = workspace or get_workspace_path()
return ensure_dir(ws / "skills")
def timestamp() -> str: def timestamp() -> str:
"""Get current timestamp in ISO format.""" """Current ISO timestamp."""
return datetime.now().isoformat() return datetime.now().isoformat()
def truncate_string(s: str, max_len: int = 100, suffix: str = "...") -> str: _UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
"""Truncate a string to max length, adding suffix if truncated."""
if len(s) <= max_len:
return s
return s[: max_len - len(suffix)] + suffix
def safe_filename(name: str) -> str: def safe_filename(name: str) -> str:
"""Convert a string to a safe filename.""" """Replace unsafe path characters with underscores."""
# Replace unsafe characters return _UNSAFE_CHARS.sub("_", name).strip()
unsafe = '<>:"/\\|?*'
for char in unsafe:
name = name.replace(char, "_")
return name.strip()
def parse_session_key(key: str) -> tuple[str, str]: def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
""" """Sync bundled templates to workspace. Only creates missing files."""
Parse a session key into channel and chat_id. from importlib.resources import files as pkg_files
try:
Args: tpl = pkg_files("nanobot") / "templates"
key: Session key in format "channel:chat_id" except Exception:
return []
Returns: if not tpl.is_dir():
Tuple of (channel, chat_id) return []
"""
parts = key.split(":", 1) added: list[str] = []
if len(parts) != 2:
raise ValueError(f"Invalid session key: {key}") def _write(src, dest: Path):
return parts[0], parts[1] if dest.exists():
return
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
added.append(str(dest.relative_to(workspace)))
for item in tpl.iterdir():
if item.name.endswith(".md"):
_write(item, workspace / item.name)
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
_write(None, workspace / "memory" / "HISTORY.md")
(workspace / "skills").mkdir(exist_ok=True)
if added and not silent:
from rich.console import Console
for name in added:
Console().print(f" [dim]Created {name}[/dim]")
return added

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "nanobot-ai" name = "nanobot-ai"
version = "0.1.4.post2" version = "0.1.4.post3"
description = "A lightweight personal AI assistant framework" description = "A lightweight personal AI assistant framework"
requires-python = ">=3.11" requires-python = ">=3.11"
license = {text = "MIT"} license = {text = "MIT"}
@@ -42,6 +42,8 @@ dependencies = [
"prompt-toolkit>=3.0.50,<4.0.0", "prompt-toolkit>=3.0.50,<4.0.0",
"mcp>=1.26.0,<2.0.0", "mcp>=1.26.0,<2.0.0",
"json-repair>=0.57.0,<1.0.0", "json-repair>=0.57.0,<1.0.0",
"chardet>=3.0.2,<6.0.0",
"openai>=2.8.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -54,6 +56,9 @@ dev = [
"pytest>=9.0.0,<10.0.0", "pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0", "pytest-asyncio>=1.3.0,<2.0.0",
"ruff>=0.1.0", "ruff>=0.1.0",
"matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
] ]
[project.scripts] [project.scripts]

View File

@@ -786,10 +786,8 @@ class TestConsolidationDeduplicationGuard:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_cleans_up_consolidation_lock_for_invalidated_session( async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
self, tmp_path: Path """/new clears session and returns confirmation."""
) -> None:
"""/new should remove lock entry for fully invalidated session key."""
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@@ -801,7 +799,6 @@ class TestConsolidationDeduplicationGuard:
loop = AgentLoop( loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
) )
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
@@ -811,10 +808,6 @@ class TestConsolidationDeduplicationGuard:
session.add_message("assistant", f"resp{i}") session.add_message("assistant", f"resp{i}")
loop.sessions.save(session) loop.sessions.save(session)
# Ensure lock exists before /new.
_ = loop._get_consolidation_lock(session.key)
assert session.key in loop._consolidation_locks
async def _ok_consolidate(sess, archive_all: bool = False) -> bool: async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
return True return True
@@ -825,4 +818,4 @@ class TestConsolidationDeduplicationGuard:
assert response is not None assert response is not None
assert "new session started" in response.content.lower() assert "new session started" in response.content.lower()
assert session.key not in loop._consolidation_locks assert loop.sessions.get_or_create("cli:test").messages == []

View File

@@ -40,7 +40,7 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None: def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
"""Runtime metadata should be a separate user message before the actual user message.""" """Runtime metadata should be merged with the user message."""
workspace = _make_workspace(tmp_path) workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace) builder = ContextBuilder(workspace)
@@ -54,13 +54,12 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
assert messages[0]["role"] == "system" assert messages[0]["role"] == "system"
assert "## Current Session" not in messages[0]["content"] assert "## Current Session" not in messages[0]["content"]
assert messages[-2]["role"] == "user" # Runtime context is now merged with user message into a single message
runtime_content = messages[-2]["content"]
assert isinstance(runtime_content, str)
assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
assert "Current Time:" in runtime_content
assert "Channel: cli" in runtime_content
assert "Chat ID: direct" in runtime_content
assert messages[-1]["role"] == "user" assert messages[-1]["role"] == "user"
assert messages[-1]["content"] == "Return exactly: OK" user_content = messages[-1]["content"]
assert isinstance(user_content, str)
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
assert "Current Time:" in user_content
assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content
assert "Return exactly: OK" in user_content

View File

@@ -1,29 +0,0 @@
from typer.testing import CliRunner
from nanobot.cli.commands import app
runner = CliRunner()
def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
result = runner.invoke(
app,
[
"cron",
"add",
"--name",
"demo",
"--message",
"hello",
"--cron",
"0 9 * * *",
"--tz",
"America/Vancovuer",
],
)
assert result.exit_code == 1
assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
assert not (tmp_path / "cron" / "jobs.json").exists()

View File

@@ -1,3 +1,5 @@
import asyncio
import pytest import pytest
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
@@ -28,3 +30,32 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
assert job.schedule.tz == "America/Vancouver" assert job.schedule.tz == "America/Vancouver"
assert job.state.next_run_at_ms is not None assert job.state.next_run_at_ms is not None
@pytest.mark.asyncio
async def test_running_service_honors_external_disable(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
called: list[str] = []
async def on_job(job) -> None:
called.append(job.id)
service = CronService(store_path, on_job=on_job)
job = service.add_job(
name="external-disable",
schedule=CronSchedule(kind="every", every_ms=200),
message="hello",
)
await service.start()
try:
# Wait slightly to ensure file mtime is definitively different
await asyncio.sleep(0.05)
external = CronService(store_path)
updated = external.enable_job(job.id, enabled=False)
assert updated is not None
assert updated.enabled is False
await asyncio.sleep(0.35)
assert called == []
finally:
service.stop()

View File

@@ -0,0 +1,40 @@
from nanobot.channels.feishu import _extract_post_content
def test_extract_post_content_supports_post_wrapper_shape() -> None:
payload = {
"post": {
"zh_cn": {
"title": "日报",
"content": [
[
{"tag": "text", "text": "完成"},
{"tag": "img", "image_key": "img_1"},
]
],
}
}
}
text, image_keys = _extract_post_content(payload)
assert text == "日报 完成"
assert image_keys == ["img_1"]
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
payload = {
"title": "Daily",
"content": [
[
{"tag": "text", "text": "report"},
{"tag": "img", "image_key": "img_a"},
{"tag": "img", "image_key": "img_b"},
]
],
}
text, image_keys = _extract_post_content(payload)
assert text == "Daily report"
assert image_keys == ["img_a", "img_b"]

View File

@@ -2,34 +2,28 @@ import asyncio
import pytest import pytest
from nanobot.heartbeat.service import ( from nanobot.heartbeat.service import HeartbeatService
HEARTBEAT_OK_TOKEN, from nanobot.providers.base import LLMResponse, ToolCallRequest
HeartbeatService,
)
def test_heartbeat_ok_detection() -> None: class DummyProvider:
def is_ok(response: str) -> bool: def __init__(self, responses: list[LLMResponse]):
return HEARTBEAT_OK_TOKEN in response.upper() self._responses = list(responses)
assert is_ok("HEARTBEAT_OK") async def chat(self, *args, **kwargs) -> LLMResponse:
assert is_ok("`HEARTBEAT_OK`") if self._responses:
assert is_ok("**HEARTBEAT_OK**") return self._responses.pop(0)
assert is_ok("heartbeat_ok") return LLMResponse(content="", tool_calls=[])
assert is_ok("HEARTBEAT_OK.")
assert not is_ok("HEARTBEAT_NOT_OK")
assert not is_ok("all good")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None: async def test_start_is_idempotent(tmp_path) -> None:
async def _on_heartbeat(_: str) -> str: provider = DummyProvider([])
return "HEARTBEAT_OK"
service = HeartbeatService( service = HeartbeatService(
workspace=tmp_path, workspace=tmp_path,
on_heartbeat=_on_heartbeat, provider=provider,
model="openai/gpt-4o-mini",
interval_s=9999, interval_s=9999,
enabled=True, enabled=True,
) )
@@ -42,3 +36,82 @@ async def test_start_is_idempotent(tmp_path) -> None:
service.stop() service.stop()
await asyncio.sleep(0) await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "skip"
assert tasks == ""
@pytest.mark.asyncio
async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
)
])
called_with: list[str] = []
async def _on_execute(tasks: str) -> str:
called_with.append(tasks)
return "done"
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
)
result = await service.trigger_now()
assert result == "done"
assert called_with == ["check open tasks"]
@pytest.mark.asyncio
async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "skip"},
)
],
)
])
async def _on_execute(tasks: str) -> str:
return tasks
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
)
assert await service.trigger_now() is None

View File

@@ -0,0 +1,41 @@
from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = 500
return loop
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
loop = _mk_loop()
session = Session(key="test:runtime-only")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
skip=0,
)
assert session.messages == []
def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
loop = _mk_loop()
session = Session(key="test:image")
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
loop._save_turn(
session,
[{
"role": "user",
"content": [
{"type": "text", "text": runtime},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
}],
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]

View File

@@ -159,6 +159,7 @@ class _FakeAsyncClient:
def _make_config(**kwargs) -> MatrixConfig: def _make_config(**kwargs) -> MatrixConfig:
kwargs.setdefault("allow_from", ["*"])
return MatrixConfig( return MatrixConfig(
enabled=True, enabled=True,
homeserver="https://matrix.org", homeserver="https://matrix.org",
@@ -274,7 +275,7 @@ async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_invite_joins_when_allow_list_is_empty() -> None: async def test_room_invite_ignores_when_allow_list_is_empty() -> None:
channel = MatrixChannel(_make_config(allow_from=[]), MessageBus()) channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
client = _FakeAsyncClient("", "", "", None) client = _FakeAsyncClient("", "", "", None)
channel.client = client channel.client = client
@@ -284,9 +285,22 @@ async def test_room_invite_joins_when_allow_list_is_empty() -> None:
await channel._on_room_invite(room, event) await channel._on_room_invite(room, event)
assert client.join_calls == ["!room:matrix.org"] assert client.join_calls == []
@pytest.mark.asyncio
async def test_room_invite_joins_when_sender_allowed() -> None:
channel = MatrixChannel(_make_config(allow_from=["@alice:matrix.org"]), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
room = SimpleNamespace(room_id="!room:matrix.org")
event = SimpleNamespace(sender="@alice:matrix.org")
await channel._on_room_invite(room, event)
assert client.join_calls == ["!room:matrix.org"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_invite_respects_allow_list_when_configured() -> None: async def test_room_invite_respects_allow_list_when_configured() -> None:
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus()) channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
@@ -1163,6 +1177,8 @@ async def test_send_progress_keeps_typing_keepalive_running() -> None:
assert "!room:matrix.org" in channel._typing_tasks assert "!room:matrix.org" in channel._typing_tasks
assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS) assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
await channel.stop()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_clears_typing_when_send_fails() -> None: async def test_send_clears_typing_when_send_fails() -> None:

View File

@@ -0,0 +1,103 @@
"""Test message tool suppress logic for final replies."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.message import MessageTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
class TestMessageToolSuppressLogic:
"""Final reply suppressed only when message tool sends to the same target."""
@pytest.mark.asyncio
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
result = await loop._process_message(msg)
assert len(sent) == 1
assert result is None # suppressed
@pytest.mark.asyncio
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
result = await loop._process_message(msg)
assert len(sent) == 1
assert sent[0].channel == "email"
assert result is not None # not suppressed
assert result.channel == "feishu"
@pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
result = await loop._process_message(msg)
assert result is not None
assert "Hello" in result.content
class TestMessageToolTurnTracking:
def test_sent_in_turn_tracks_same_target(self) -> None:
tool = MessageTool()
tool.set_context("feishu", "chat1")
assert not tool._sent_in_turn
tool._sent_in_turn = True
assert tool._sent_in_turn
def test_start_turn_resets(self) -> None:
tool = MessageTool()
tool._sent_in_turn = True
tool.start_turn()
assert not tool._sent_in_turn

View File

@@ -2,6 +2,7 @@ from typing import Any
from nanobot.agent.tools.base import Tool from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
class SampleTool(Tool): class SampleTool(Tool):
@@ -86,3 +87,22 @@ async def test_registry_returns_validation_error() -> None:
reg.register(SampleTool()) reg.register(SampleTool())
result = await reg.execute("sample", {"query": "hi"}) result = await reg.execute("sample", {"query": "hi"})
assert "Invalid parameters" in result assert "Invalid parameters" in result
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
cmd = r"type C:\user\workspace\txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert paths == [r"C:\user\workspace\txt"]
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
cmd = ".venv/bin/python script.py"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/bin/python" not in paths
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
cmd = "cat /tmp/data.txt > /tmp/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "/tmp/out.txt" in paths