Merge branch 'main' into pr-1330
Made-with: Cursor # Conflicts: # nanobot/providers/litellm_provider.py
This commit is contained in:
79
README.md
79
README.md
@@ -16,10 +16,13 @@
|
|||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||||
|
|
||||||
📏 Real-time line count: **3,955 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: **3,935 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
||||||
|
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
||||||
|
- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements.
|
||||||
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
||||||
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
||||||
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
||||||
@@ -135,12 +138,13 @@ Add or merge these **two parts** into your config (other options have defaults).
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
*Set your model*:
|
*Set your model* (optionally pin a provider — defaults to auto-detection):
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"agents": {
|
"agents": {
|
||||||
"defaults": {
|
"defaults": {
|
||||||
"model": "anthropic/claude-opus-4-5"
|
"model": "anthropic/claude-opus-4-5",
|
||||||
|
"provider": "openrouter"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -305,6 +309,72 @@ nanobot gateway
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Matrix (Element)</b></summary>
|
||||||
|
|
||||||
|
Install Matrix dependencies first:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install nanobot-ai[matrix]
|
||||||
|
```
|
||||||
|
|
||||||
|
**1. Create/choose a Matrix account**
|
||||||
|
|
||||||
|
- Create or reuse a Matrix account on your homeserver (for example `matrix.org`).
|
||||||
|
- Confirm you can log in with Element.
|
||||||
|
|
||||||
|
**2. Get credentials**
|
||||||
|
|
||||||
|
- You need:
|
||||||
|
- `userId` (example: `@nanobot:matrix.org`)
|
||||||
|
- `accessToken`
|
||||||
|
- `deviceId` (recommended so sync tokens can be restored across restarts)
|
||||||
|
- You can obtain these from your homeserver login API (`/_matrix/client/v3/login`) or from your client's advanced session settings.
|
||||||
|
|
||||||
|
**3. Configure**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"matrix": {
|
||||||
|
"enabled": true,
|
||||||
|
"homeserver": "https://matrix.org",
|
||||||
|
"userId": "@nanobot:matrix.org",
|
||||||
|
"accessToken": "syt_xxx",
|
||||||
|
"deviceId": "NANOBOT01",
|
||||||
|
"e2eeEnabled": true,
|
||||||
|
"allowFrom": [],
|
||||||
|
"groupPolicy": "open",
|
||||||
|
"groupAllowFrom": [],
|
||||||
|
"allowRoomMentions": false,
|
||||||
|
"maxMediaBytes": 20971520
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> Keep a persistent `matrix-store` and stable `deviceId` — encrypted session state is lost if these change across restarts.
|
||||||
|
|
||||||
|
| Option | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `allowFrom` | User IDs allowed to interact. Empty = all senders. |
|
||||||
|
| `groupPolicy` | `open` (default), `mention`, or `allowlist`. |
|
||||||
|
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
|
||||||
|
| `allowRoomMentions` | Accept `@room` mentions in mention mode. |
|
||||||
|
| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. |
|
||||||
|
| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**4. Run**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nanobot gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>WhatsApp</b></summary>
|
<summary><b>WhatsApp</b></summary>
|
||||||
|
|
||||||
@@ -350,7 +420,7 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
**1. Create a Feishu bot**
|
**1. Create a Feishu bot**
|
||||||
- Visit [Feishu Open Platform](https://open.feishu.cn/app)
|
- Visit [Feishu Open Platform](https://open.feishu.cn/app)
|
||||||
- Create a new app → Enable **Bot** capability
|
- Create a new app → Enable **Bot** capability
|
||||||
- **Permissions**: Add `im:message` (send messages)
|
- **Permissions**: Add `im:message` (send messages) and `im:message.p2p_msg:readonly` (receive messages)
|
||||||
- **Events**: Add `im.message.receive_v1` (receive messages)
|
- **Events**: Add `im.message.receive_v1` (receive messages)
|
||||||
- Select **Long Connection** mode (requires running nanobot first to establish connection)
|
- Select **Long Connection** mode (requires running nanobot first to establish connection)
|
||||||
- Get **App ID** and **App Secret** from "Credentials & Basic Info"
|
- Get **App ID** and **App Secret** from "Credentials & Basic Info"
|
||||||
@@ -804,6 +874,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
|||||||
| Option | Default | Description |
|
| Option | Default | Description |
|
||||||
|--------|---------|-------------|
|
|--------|---------|-------------|
|
||||||
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
|
||||||
|
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
|
||||||
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
|
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
nanobot - A lightweight AI agent framework
|
nanobot - A lightweight AI agent framework
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.4"
|
__version__ = "0.1.4.post2"
|
||||||
__logo__ = "🐈"
|
__logo__ = "🐈"
|
||||||
|
|||||||
@@ -13,14 +13,10 @@ from nanobot.agent.skills import SkillsLoader
|
|||||||
|
|
||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
"""
|
"""Builds the context (system prompt + messages) for the agent."""
|
||||||
Builds the context (system prompt + messages) for the agent.
|
|
||||||
|
|
||||||
Assembles bootstrap files, memory, skills, and conversation history
|
|
||||||
into a coherent prompt for the LLM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
|
||||||
|
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||||
|
|
||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
@@ -28,39 +24,23 @@ class ContextBuilder:
|
|||||||
self.skills = SkillsLoader(workspace)
|
self.skills = SkillsLoader(workspace)
|
||||||
|
|
||||||
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
||||||
"""
|
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
||||||
Build the system prompt from bootstrap files, memory, and skills.
|
parts = [self._get_identity()]
|
||||||
|
|
||||||
Args:
|
|
||||||
skill_names: Optional list of skills to include.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Complete system prompt.
|
|
||||||
"""
|
|
||||||
parts = []
|
|
||||||
|
|
||||||
# Core identity
|
|
||||||
parts.append(self._get_identity())
|
|
||||||
|
|
||||||
# Bootstrap files
|
|
||||||
bootstrap = self._load_bootstrap_files()
|
bootstrap = self._load_bootstrap_files()
|
||||||
if bootstrap:
|
if bootstrap:
|
||||||
parts.append(bootstrap)
|
parts.append(bootstrap)
|
||||||
|
|
||||||
# Memory context
|
|
||||||
memory = self.memory.get_memory_context()
|
memory = self.memory.get_memory_context()
|
||||||
if memory:
|
if memory:
|
||||||
parts.append(f"# Memory\n\n{memory}")
|
parts.append(f"# Memory\n\n{memory}")
|
||||||
|
|
||||||
# Skills - progressive loading
|
|
||||||
# 1. Always-loaded skills: include full content
|
|
||||||
always_skills = self.skills.get_always_skills()
|
always_skills = self.skills.get_always_skills()
|
||||||
if always_skills:
|
if always_skills:
|
||||||
always_content = self.skills.load_skills_for_context(always_skills)
|
always_content = self.skills.load_skills_for_context(always_skills)
|
||||||
if always_content:
|
if always_content:
|
||||||
parts.append(f"# Active Skills\n\n{always_content}")
|
parts.append(f"# Active Skills\n\n{always_content}")
|
||||||
|
|
||||||
# 2. Available skills: only show summary (agent uses read_file to load)
|
|
||||||
skills_summary = self.skills.build_skills_summary()
|
skills_summary = self.skills.build_skills_summary()
|
||||||
if skills_summary:
|
if skills_summary:
|
||||||
parts.append(f"""# Skills
|
parts.append(f"""# Skills
|
||||||
@@ -87,39 +67,28 @@ You are nanobot, a helpful AI assistant.
|
|||||||
|
|
||||||
## Workspace
|
## Workspace
|
||||||
Your workspace is at: {workspace_path}
|
Your workspace is at: {workspace_path}
|
||||||
- Long-term memory: {workspace_path}/memory/MEMORY.md
|
- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
|
||||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable)
|
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
|
||||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||||
|
|
||||||
Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel.
|
## nanobot Guidelines
|
||||||
|
- State intent before tool calls, but NEVER predict or claim results before receiving them.
|
||||||
## Tool Call Guidelines
|
- Before modifying a file, read it first. Do not assume files or directories exist.
|
||||||
- Before calling tools, you may briefly state your intent (e.g. "Let me check that"), but NEVER predict or describe the expected result before receiving it.
|
|
||||||
- Before modifying a file, read it first to confirm its current content.
|
|
||||||
- Do not assume a file or directory exists — use list_dir or read_file to verify.
|
|
||||||
- After writing or editing a file, re-read it if accuracy matters.
|
- After writing or editing a file, re-read it if accuracy matters.
|
||||||
- If a tool call fails, analyze the error before retrying with a different approach.
|
- If a tool call fails, analyze the error before retrying with a different approach.
|
||||||
|
- Ask for clarification when the request is ambiguous.
|
||||||
|
|
||||||
## Memory
|
Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel."""
|
||||||
- Remember important facts: write to {workspace_path}/memory/MEMORY.md
|
|
||||||
- Recall past events: grep {workspace_path}/memory/HISTORY.md"""
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _inject_runtime_context(
|
def _build_runtime_context(channel: str | None, chat_id: str | None) -> str:
|
||||||
user_content: str | list[dict[str, Any]],
|
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||||
channel: str | None,
|
|
||||||
chat_id: str | None,
|
|
||||||
) -> str | list[dict[str, Any]]:
|
|
||||||
"""Append dynamic runtime context to the tail of the user message."""
|
|
||||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||||
tz = time.strftime("%Z") or "UTC"
|
tz = time.strftime("%Z") or "UTC"
|
||||||
lines = [f"Current Time: {now} ({tz})"]
|
lines = [f"Current Time: {now} ({tz})"]
|
||||||
if channel and chat_id:
|
if channel and chat_id:
|
||||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
block = "[Runtime Context]\n" + "\n".join(lines)
|
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||||
if isinstance(user_content, str):
|
|
||||||
return f"{user_content}\n\n{block}"
|
|
||||||
return [*user_content, {"type": "text", "text": block}]
|
|
||||||
|
|
||||||
def _load_bootstrap_files(self) -> str:
|
def _load_bootstrap_files(self) -> str:
|
||||||
"""Load all bootstrap files from workspace."""
|
"""Load all bootstrap files from workspace."""
|
||||||
@@ -142,35 +111,13 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
channel: str | None = None,
|
channel: str | None = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Build the complete message list for an LLM call."""
|
||||||
Build the complete message list for an LLM call.
|
return [
|
||||||
|
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||||
Args:
|
*history,
|
||||||
history: Previous conversation messages.
|
{"role": "user", "content": self._build_runtime_context(channel, chat_id)},
|
||||||
current_message: The new user message.
|
{"role": "user", "content": self._build_user_content(current_message, media)},
|
||||||
skill_names: Optional skills to include.
|
]
|
||||||
media: Optional list of local file paths for images/media.
|
|
||||||
channel: Current channel (telegram, feishu, etc.).
|
|
||||||
chat_id: Current chat/user ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages including system prompt.
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# System prompt
|
|
||||||
system_prompt = self.build_system_prompt(skill_names)
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
|
||||||
|
|
||||||
# History
|
|
||||||
messages.extend(history)
|
|
||||||
|
|
||||||
# Current message (with optional image attachments)
|
|
||||||
user_content = self._build_user_content(current_message, media)
|
|
||||||
user_content = self._inject_runtime_context(user_content, channel, chat_id)
|
|
||||||
messages.append({"role": "user", "content": user_content})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||||
"""Build user message content with optional base64-encoded images."""
|
"""Build user message content with optional base64-encoded images."""
|
||||||
@@ -191,63 +138,24 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
return images + [{"type": "text", "text": text}]
|
return images + [{"type": "text", "text": text}]
|
||||||
|
|
||||||
def add_tool_result(
|
def add_tool_result(
|
||||||
self,
|
self, messages: list[dict[str, Any]],
|
||||||
messages: list[dict[str, Any]],
|
tool_call_id: str, tool_name: str, result: str,
|
||||||
tool_call_id: str,
|
|
||||||
tool_name: str,
|
|
||||||
result: str
|
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Add a tool result to the message list."""
|
||||||
Add a tool result to the message list.
|
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Current message list.
|
|
||||||
tool_call_id: ID of the tool call.
|
|
||||||
tool_name: Name of the tool.
|
|
||||||
result: Tool execution result.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated message list.
|
|
||||||
"""
|
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"name": tool_name,
|
|
||||||
"content": result
|
|
||||||
})
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_assistant_message(
|
def add_assistant_message(
|
||||||
self,
|
self, messages: list[dict[str, Any]],
|
||||||
messages: list[dict[str, Any]],
|
|
||||||
content: str | None,
|
content: str | None,
|
||||||
tool_calls: list[dict[str, Any]] | None = None,
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
reasoning_content: str | None = None,
|
reasoning_content: str | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Add an assistant message to the message list."""
|
||||||
Add an assistant message to the message list.
|
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Current message list.
|
|
||||||
content: Message content.
|
|
||||||
tool_calls: Optional tool calls.
|
|
||||||
reasoning_content: Thinking output (Kimi, DeepSeek-R1, etc.).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated message list.
|
|
||||||
"""
|
|
||||||
msg: dict[str, Any] = {"role": "assistant"}
|
|
||||||
|
|
||||||
# Always include content — some providers (e.g. StepFun) reject
|
|
||||||
# assistant messages that omit the key entirely.
|
|
||||||
msg["content"] = content
|
|
||||||
|
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
msg["tool_calls"] = tool_calls
|
msg["tool_calls"] = tool_calls
|
||||||
|
|
||||||
# Include reasoning content when provided (required by some thinking models)
|
|
||||||
if reasoning_content is not None:
|
if reasoning_content is not None:
|
||||||
msg["reasoning_content"] = reasoning_content
|
msg["reasoning_content"] = reasoning_content
|
||||||
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import weakref
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
@@ -43,6 +44,8 @@ class AgentLoop:
|
|||||||
5. Sends responses back
|
5. Sends responses back
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_TOOL_RESULT_MAX_CHARS = 500
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
@@ -53,6 +56,7 @@ class AgentLoop:
|
|||||||
temperature: float = 0.1,
|
temperature: float = 0.1,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
memory_window: int = 100,
|
memory_window: int = 100,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
cron_service: CronService | None = None,
|
cron_service: CronService | None = None,
|
||||||
@@ -71,6 +75,7 @@ class AgentLoop:
|
|||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.memory_window = memory_window
|
self.memory_window = memory_window
|
||||||
|
self.reasoning_effort = reasoning_effort
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.cron_service = cron_service
|
self.cron_service = cron_service
|
||||||
@@ -86,6 +91,7 @@ class AgentLoop:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
brave_api_key=brave_api_key,
|
brave_api_key=brave_api_key,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
restrict_to_workspace=restrict_to_workspace,
|
restrict_to_workspace=restrict_to_workspace,
|
||||||
@@ -98,7 +104,9 @@ class AgentLoop:
|
|||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||||
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||||
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
|
self._processing_lock = asyncio.Lock()
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@@ -110,6 +118,7 @@ class AgentLoop:
|
|||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
timeout=self.exec_config.timeout,
|
timeout=self.exec_config.timeout,
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
|
path_append=self.exec_config.path_append,
|
||||||
))
|
))
|
||||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
self.tools.register(WebFetchTool())
|
self.tools.register(WebFetchTool())
|
||||||
@@ -142,17 +151,10 @@ class AgentLoop:
|
|||||||
|
|
||||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||||
"""Update context for all tools that need routing info."""
|
"""Update context for all tools that need routing info."""
|
||||||
if message_tool := self.tools.get("message"):
|
for name in ("message", "spawn", "cron"):
|
||||||
if isinstance(message_tool, MessageTool):
|
if tool := self.tools.get(name):
|
||||||
message_tool.set_context(channel, chat_id, message_id)
|
if hasattr(tool, "set_context"):
|
||||||
|
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||||
if spawn_tool := self.tools.get("spawn"):
|
|
||||||
if isinstance(spawn_tool, SpawnTool):
|
|
||||||
spawn_tool.set_context(channel, chat_id)
|
|
||||||
|
|
||||||
if cron_tool := self.tools.get("cron"):
|
|
||||||
if isinstance(cron_tool, CronTool):
|
|
||||||
cron_tool.set_context(channel, chat_id)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _strip_think(text: str | None) -> str | None:
|
def _strip_think(text: str | None) -> str | None:
|
||||||
@@ -165,7 +167,8 @@ class AgentLoop:
|
|||||||
def _tool_hint(tool_calls: list) -> str:
|
def _tool_hint(tool_calls: list) -> str:
|
||||||
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
||||||
def _fmt(tc):
|
def _fmt(tc):
|
||||||
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
|
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
|
||||||
|
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||||
if not isinstance(val, str):
|
if not isinstance(val, str):
|
||||||
return tc.name
|
return tc.name
|
||||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||||
@@ -191,6 +194,7 @@ class AgentLoop:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
reasoning_effort=self.reasoning_effort,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
@@ -225,7 +229,17 @@ class AgentLoop:
|
|||||||
messages, tool_call.id, tool_call.name, result
|
messages, tool_call.id, tool_call.name, result
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
final_content = self._strip_think(response.content)
|
clean = self._strip_think(response.content)
|
||||||
|
# Don't persist error responses to session history — they can
|
||||||
|
# poison the context and cause permanent 400 loops (#1303).
|
||||||
|
if response.finish_reason == "error":
|
||||||
|
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||||
|
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||||
|
break
|
||||||
|
messages = self.context.add_assistant_message(
|
||||||
|
messages, clean, reasoning_content=response.reasoning_content,
|
||||||
|
)
|
||||||
|
final_content = clean
|
||||||
break
|
break
|
||||||
|
|
||||||
if final_content is None and iteration >= self.max_iterations:
|
if final_content is None and iteration >= self.max_iterations:
|
||||||
@@ -238,34 +252,61 @@ class AgentLoop:
|
|||||||
return final_content, tools_used, messages
|
return final_content, tools_used, messages
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, processing messages from the bus."""
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||||
self._running = True
|
self._running = True
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
logger.info("Agent loop started")
|
logger.info("Agent loop started")
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(
|
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||||
self.bus.consume_inbound(),
|
except asyncio.TimeoutError:
|
||||||
timeout=1.0
|
continue
|
||||||
)
|
|
||||||
|
if msg.content.strip().lower() == "/stop":
|
||||||
|
await self._handle_stop(msg)
|
||||||
|
else:
|
||||||
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
|
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||||
|
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||||
|
|
||||||
|
async def _handle_stop(self, msg: InboundMessage) -> None:
|
||||||
|
"""Cancel all active tasks and subagents for the session."""
|
||||||
|
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||||
|
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||||
|
for t in tasks:
|
||||||
|
try:
|
||||||
|
await t
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||||
|
total = cancelled + sub_cancelled
|
||||||
|
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
|
"""Process a message under the global lock."""
|
||||||
|
async with self._processing_lock:
|
||||||
try:
|
try:
|
||||||
response = await self._process_message(msg)
|
response = await self._process_message(msg)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
await self.bus.publish_outbound(response)
|
await self.bus.publish_outbound(response)
|
||||||
elif msg.channel == "cli":
|
elif msg.channel == "cli":
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content="", metadata=msg.metadata or {},
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="", metadata=msg.metadata or {},
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except asyncio.CancelledError:
|
||||||
logger.error("Error processing message: {}", e)
|
logger.info("Task cancelled for session {}", msg.session_key)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing message for session {}", msg.session_key)
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
chat_id=msg.chat_id,
|
content="Sorry, I encountered an error.",
|
||||||
content=f"Sorry, I encountered an error: {str(e)}"
|
|
||||||
))
|
))
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Close MCP connections."""
|
"""Close MCP connections."""
|
||||||
@@ -281,18 +322,6 @@ class AgentLoop:
|
|||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Agent loop stopping")
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
|
|
||||||
lock = self._consolidation_locks.get(session_key)
|
|
||||||
if lock is None:
|
|
||||||
lock = asyncio.Lock()
|
|
||||||
self._consolidation_locks[session_key] = lock
|
|
||||||
return lock
|
|
||||||
|
|
||||||
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
|
|
||||||
"""Drop lock entry if no longer in use."""
|
|
||||||
if not lock.locked():
|
|
||||||
self._consolidation_locks.pop(session_key, None)
|
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
@@ -328,7 +357,7 @@ class AgentLoop:
|
|||||||
# Slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
lock = self._get_consolidation_lock(session.key)
|
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||||
self._consolidating.add(session.key)
|
self._consolidating.add(session.key)
|
||||||
try:
|
try:
|
||||||
async with lock:
|
async with lock:
|
||||||
@@ -349,7 +378,6 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
self._prune_consolidation_lock(session.key, lock)
|
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
@@ -358,12 +386,12 @@ class AgentLoop:
|
|||||||
content="New session started.")
|
content="New session started.")
|
||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||||
|
|
||||||
unconsolidated = len(session.messages) - session.last_consolidated
|
unconsolidated = len(session.messages) - session.last_consolidated
|
||||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||||
self._consolidating.add(session.key)
|
self._consolidating.add(session.key)
|
||||||
lock = self._get_consolidation_lock(session.key)
|
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||||
|
|
||||||
async def _consolidate_and_unlock():
|
async def _consolidate_and_unlock():
|
||||||
try:
|
try:
|
||||||
@@ -371,7 +399,6 @@ class AgentLoop:
|
|||||||
await self._consolidate_memory(session)
|
await self._consolidate_memory(session)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
self._prune_consolidation_lock(session.key, lock)
|
|
||||||
_task = asyncio.current_task()
|
_task = asyncio.current_task()
|
||||||
if _task is not None:
|
if _task is not None:
|
||||||
self._consolidation_tasks.discard(_task)
|
self._consolidation_tasks.discard(_task)
|
||||||
@@ -407,32 +434,39 @@ class AgentLoop:
|
|||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
|
||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
if message_tool := self.tools.get("message"):
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||||
metadata=msg.metadata or {},
|
metadata=msg.metadata or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
_TOOL_RESULT_MAX_CHARS = 500
|
|
||||||
|
|
||||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||||
"""Save new-turn messages into session, truncating large tool results."""
|
"""Save new-turn messages into session, truncating large tool results."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
for m in messages[skip:]:
|
for m in messages[skip:]:
|
||||||
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
|
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
|
||||||
if entry.get("role") == "tool" and isinstance(entry.get("content"), str):
|
role, content = entry.get("role"), entry.get("content")
|
||||||
content = entry["content"]
|
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||||
if len(content) > self._TOOL_RESULT_MAX_CHARS:
|
continue # skip empty assistant messages — they poison session context
|
||||||
|
if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||||
|
elif role == "user":
|
||||||
|
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||||
|
continue
|
||||||
|
if isinstance(content, list):
|
||||||
|
entry["content"] = [
|
||||||
|
{"type": "text", "text": "[image]"} if (
|
||||||
|
c.get("type") == "image_url"
|
||||||
|
and c.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||||
|
) else c for c in content
|
||||||
|
]
|
||||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|||||||
@@ -18,13 +18,7 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
|||||||
|
|
||||||
|
|
||||||
class SubagentManager:
|
class SubagentManager:
|
||||||
"""
|
"""Manages background subagent execution."""
|
||||||
Manages background subagent execution.
|
|
||||||
|
|
||||||
Subagents are lightweight agent instances that run in the background
|
|
||||||
to handle specific tasks. They share the same LLM provider but have
|
|
||||||
isolated context and a focused system prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -34,6 +28,7 @@ class SubagentManager:
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
@@ -45,10 +40,12 @@ class SubagentManager:
|
|||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.reasoning_effort = reasoning_effort
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
|
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
self,
|
self,
|
||||||
@@ -56,35 +53,28 @@ class SubagentManager:
|
|||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
origin_channel: str = "cli",
|
origin_channel: str = "cli",
|
||||||
origin_chat_id: str = "direct",
|
origin_chat_id: str = "direct",
|
||||||
|
session_key: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""Spawn a subagent to execute a task in the background."""
|
||||||
Spawn a subagent to execute a task in the background.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The task description for the subagent.
|
|
||||||
label: Optional human-readable label for the task.
|
|
||||||
origin_channel: The channel to announce results to.
|
|
||||||
origin_chat_id: The chat ID to announce results to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Status message indicating the subagent was started.
|
|
||||||
"""
|
|
||||||
task_id = str(uuid.uuid4())[:8]
|
task_id = str(uuid.uuid4())[:8]
|
||||||
display_label = label or task[:30] + ("..." if len(task) > 30 else "")
|
display_label = label or task[:30] + ("..." if len(task) > 30 else "")
|
||||||
|
origin = {"channel": origin_channel, "chat_id": origin_chat_id}
|
||||||
|
|
||||||
origin = {
|
|
||||||
"channel": origin_channel,
|
|
||||||
"chat_id": origin_chat_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create background task
|
|
||||||
bg_task = asyncio.create_task(
|
bg_task = asyncio.create_task(
|
||||||
self._run_subagent(task_id, task, display_label, origin)
|
self._run_subagent(task_id, task, display_label, origin)
|
||||||
)
|
)
|
||||||
self._running_tasks[task_id] = bg_task
|
self._running_tasks[task_id] = bg_task
|
||||||
|
if session_key:
|
||||||
|
self._session_tasks.setdefault(session_key, set()).add(task_id)
|
||||||
|
|
||||||
# Cleanup when done
|
def _cleanup(_: asyncio.Task) -> None:
|
||||||
bg_task.add_done_callback(lambda _: self._running_tasks.pop(task_id, None))
|
self._running_tasks.pop(task_id, None)
|
||||||
|
if session_key and (ids := self._session_tasks.get(session_key)):
|
||||||
|
ids.discard(task_id)
|
||||||
|
if not ids:
|
||||||
|
del self._session_tasks[session_key]
|
||||||
|
|
||||||
|
bg_task.add_done_callback(_cleanup)
|
||||||
|
|
||||||
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
||||||
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
||||||
@@ -111,12 +101,12 @@ class SubagentManager:
|
|||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
timeout=self.exec_config.timeout,
|
timeout=self.exec_config.timeout,
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
|
path_append=self.exec_config.path_append,
|
||||||
))
|
))
|
||||||
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
tools.register(WebFetchTool())
|
tools.register(WebFetchTool())
|
||||||
|
|
||||||
# Build messages with subagent-specific prompt
|
system_prompt = self._build_subagent_prompt()
|
||||||
system_prompt = self._build_subagent_prompt(task)
|
|
||||||
messages: list[dict[str, Any]] = [
|
messages: list[dict[str, Any]] = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": task},
|
{"role": "user", "content": task},
|
||||||
@@ -136,6 +126,7 @@ class SubagentManager:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
reasoning_effort=self.reasoning_effort,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
@@ -215,42 +206,37 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
|||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||||
|
|
||||||
def _build_subagent_prompt(self, task: str) -> str:
|
def _build_subagent_prompt(self) -> str:
|
||||||
"""Build a focused system prompt for the subagent."""
|
"""Build a focused system prompt for the subagent."""
|
||||||
from datetime import datetime
|
from nanobot.agent.context import ContextBuilder
|
||||||
import time as _time
|
from nanobot.agent.skills import SkillsLoader
|
||||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
|
||||||
tz = _time.strftime("%Z") or "UTC"
|
|
||||||
|
|
||||||
return f"""# Subagent
|
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
||||||
|
parts = [f"""# Subagent
|
||||||
|
|
||||||
## Current Time
|
{time_ctx}
|
||||||
{now} ({tz})
|
|
||||||
|
|
||||||
You are a subagent spawned by the main agent to complete a specific task.
|
You are a subagent spawned by the main agent to complete a specific task.
|
||||||
|
Stay focused on the assigned task. Your final response will be reported back to the main agent.
|
||||||
## Rules
|
|
||||||
1. Stay focused - complete only the assigned task, nothing else
|
|
||||||
2. Your final response will be reported back to the main agent
|
|
||||||
3. Do not initiate conversations or take on side tasks
|
|
||||||
4. Be concise but informative in your findings
|
|
||||||
|
|
||||||
## What You Can Do
|
|
||||||
- Read and write files in the workspace
|
|
||||||
- Execute shell commands
|
|
||||||
- Search the web and fetch web pages
|
|
||||||
- Complete the task thoroughly
|
|
||||||
|
|
||||||
## What You Cannot Do
|
|
||||||
- Send messages directly to users (no message tool available)
|
|
||||||
- Spawn other subagents
|
|
||||||
- Access the main agent's conversation history
|
|
||||||
|
|
||||||
## Workspace
|
## Workspace
|
||||||
Your workspace is at: {self.workspace}
|
{self.workspace}"""]
|
||||||
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
|
|
||||||
|
|
||||||
When you have completed the task, provide a clear summary of your findings or actions."""
|
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
|
||||||
|
if skills_summary:
|
||||||
|
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
async def cancel_by_session(self, session_key: str) -> int:
|
||||||
|
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||||
|
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
||||||
|
if tid in self._running_tasks and not self._running_tasks[tid].done()]
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
return len(tasks)
|
||||||
|
|
||||||
def get_running_count(self) -> int:
|
def get_running_count(self) -> int:
|
||||||
"""Return the number of currently running subagents."""
|
"""Return the number of currently running subagents."""
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class MessageTool(Tool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_callback(msg)
|
await self._send_callback(msg)
|
||||||
|
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||||
self._sent_in_turn = True
|
self._sent_in_turn = True
|
||||||
media_info = f" with {len(media)} attachments" if media else ""
|
media_info = f" with {len(media)} attachments" if media else ""
|
||||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class ExecTool(Tool):
|
|||||||
deny_patterns: list[str] | None = None,
|
deny_patterns: list[str] | None = None,
|
||||||
allow_patterns: list[str] | None = None,
|
allow_patterns: list[str] | None = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
|
path_append: str = "",
|
||||||
):
|
):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.working_dir = working_dir
|
self.working_dir = working_dir
|
||||||
@@ -35,6 +36,7 @@ class ExecTool(Tool):
|
|||||||
]
|
]
|
||||||
self.allow_patterns = allow_patterns or []
|
self.allow_patterns = allow_patterns or []
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
|
self.path_append = path_append
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -67,12 +69,17 @@ class ExecTool(Tool):
|
|||||||
if guard_error:
|
if guard_error:
|
||||||
return guard_error
|
return guard_error
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
if self.path_append:
|
||||||
|
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = await asyncio.create_subprocess_shell(
|
process = await asyncio.create_subprocess_shell(
|
||||||
command,
|
command,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -134,13 +141,7 @@ class ExecTool(Tool):
|
|||||||
|
|
||||||
cwd_path = Path(cwd).resolve()
|
cwd_path = Path(cwd).resolve()
|
||||||
|
|
||||||
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
|
for raw in self._extract_absolute_paths(cmd):
|
||||||
# Only match absolute paths — avoid false positives on relative
|
|
||||||
# paths like ".venv/bin/python" where "/bin/python" would be
|
|
||||||
# incorrectly extracted by the old pattern.
|
|
||||||
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
|
|
||||||
|
|
||||||
for raw in win_paths + posix_paths:
|
|
||||||
try:
|
try:
|
||||||
p = Path(raw.strip()).resolve()
|
p = Path(raw.strip()).resolve()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -149,3 +150,9 @@ class ExecTool(Tool):
|
|||||||
return "Error: Command blocked by safety guard (path outside working dir)"
|
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_absolute_paths(command: str) -> list[str]:
|
||||||
|
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
|
||||||
|
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only
|
||||||
|
return win_paths + posix_paths
|
||||||
|
|||||||
@@ -15,11 +15,13 @@ class SpawnTool(Tool):
|
|||||||
self._manager = manager
|
self._manager = manager
|
||||||
self._origin_channel = "cli"
|
self._origin_channel = "cli"
|
||||||
self._origin_chat_id = "direct"
|
self._origin_chat_id = "direct"
|
||||||
|
self._session_key = "cli:direct"
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str) -> None:
|
def set_context(self, channel: str, chat_id: str) -> None:
|
||||||
"""Set the origin context for subagent announcements."""
|
"""Set the origin context for subagent announcements."""
|
||||||
self._origin_channel = channel
|
self._origin_channel = channel
|
||||||
self._origin_chat_id = chat_id
|
self._origin_chat_id = chat_id
|
||||||
|
self._session_key = f"{channel}:{chat_id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -57,4 +59,5 @@ class SpawnTool(Tool):
|
|||||||
label=label,
|
label=label,
|
||||||
origin_channel=self._origin_channel,
|
origin_channel=self._origin_channel,
|
||||||
origin_chat_id=self._origin_chat_id,
|
origin_chat_id=self._origin_chat_id,
|
||||||
|
session_key=self._session_key,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class WebSearchTool(Tool):
|
|||||||
r = await client.get(
|
r = await client.get(
|
||||||
"https://api.search.brave.com/res/v1/web/search",
|
"https://api.search.brave.com/res/v1/web/search",
|
||||||
params={"q": query, "count": n},
|
params={"q": query, "count": n},
|
||||||
headers={"Accept": "application/json", "X-Subscription-Token": api_key},
|
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
||||||
timeout=10.0
|
timeout=10.0
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|||||||
@@ -2,8 +2,12 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import httpx
|
import httpx
|
||||||
@@ -96,6 +100,9 @@ class DingTalkChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name = "dingtalk"
|
name = "dingtalk"
|
||||||
|
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||||
|
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||||
|
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||||
|
|
||||||
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
@@ -191,40 +198,224 @@ class DingTalkChannel(BaseChannel):
|
|||||||
logger.error("Failed to get DingTalk access token: {}", e)
|
logger.error("Failed to get DingTalk access token: {}", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_http_url(value: str) -> bool:
|
||||||
|
return urlparse(value).scheme in ("http", "https")
|
||||||
|
|
||||||
|
def _guess_upload_type(self, media_ref: str) -> str:
|
||||||
|
ext = Path(urlparse(media_ref).path).suffix.lower()
|
||||||
|
if ext in self._IMAGE_EXTS: return "image"
|
||||||
|
if ext in self._AUDIO_EXTS: return "voice"
|
||||||
|
if ext in self._VIDEO_EXTS: return "video"
|
||||||
|
return "file"
|
||||||
|
|
||||||
|
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
||||||
|
name = os.path.basename(urlparse(media_ref).path)
|
||||||
|
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
||||||
|
|
||||||
|
async def _read_media_bytes(
|
||||||
|
self,
|
||||||
|
media_ref: str,
|
||||||
|
) -> tuple[bytes | None, str | None, str | None]:
|
||||||
|
if not media_ref:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
if self._is_http_url(media_ref):
|
||||||
|
if not self._http:
|
||||||
|
return None, None, None
|
||||||
|
try:
|
||||||
|
resp = await self._http.get(media_ref, follow_redirects=True)
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk media download failed status={} ref={}",
|
||||||
|
resp.status_code,
|
||||||
|
media_ref,
|
||||||
|
)
|
||||||
|
return None, None, None
|
||||||
|
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||||
|
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||||
|
return resp.content, filename, content_type or None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if media_ref.startswith("file://"):
|
||||||
|
parsed = urlparse(media_ref)
|
||||||
|
local_path = Path(unquote(parsed.path))
|
||||||
|
else:
|
||||||
|
local_path = Path(os.path.expanduser(media_ref))
|
||||||
|
if not local_path.is_file():
|
||||||
|
logger.warning("DingTalk media file not found: {}", local_path)
|
||||||
|
return None, None, None
|
||||||
|
data = await asyncio.to_thread(local_path.read_bytes)
|
||||||
|
content_type = mimetypes.guess_type(local_path.name)[0]
|
||||||
|
return data, local_path.name, content_type
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
async def _upload_media(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
data: bytes,
|
||||||
|
media_type: str,
|
||||||
|
filename: str,
|
||||||
|
content_type: str | None,
|
||||||
|
) -> str | None:
|
||||||
|
if not self._http:
|
||||||
|
return None
|
||||||
|
url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
|
||||||
|
mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||||
|
files = {"media": (filename, data, mime)}
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await self._http.post(url, files=files)
|
||||||
|
text = resp.text
|
||||||
|
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
|
||||||
|
return None
|
||||||
|
errcode = result.get("errcode", 0)
|
||||||
|
if errcode != 0:
|
||||||
|
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
|
||||||
|
return None
|
||||||
|
sub = result.get("result") or {}
|
||||||
|
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
|
||||||
|
if not media_id:
|
||||||
|
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||||
|
return None
|
||||||
|
return str(media_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _send_batch_message(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
chat_id: str,
|
||||||
|
msg_key: str,
|
||||||
|
msg_param: dict[str, Any],
|
||||||
|
) -> bool:
|
||||||
|
if not self._http:
|
||||||
|
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||||
|
return False
|
||||||
|
|
||||||
|
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||||
|
headers = {"x-acs-dingtalk-access-token": token}
|
||||||
|
payload = {
|
||||||
|
"robotCode": self.config.client_id,
|
||||||
|
"userIds": [chat_id],
|
||||||
|
"msgKey": msg_key,
|
||||||
|
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await self._http.post(url, json=payload, headers=headers)
|
||||||
|
body = resp.text
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||||
|
return False
|
||||||
|
try: result = resp.json()
|
||||||
|
except Exception: result = {}
|
||||||
|
errcode = result.get("errcode")
|
||||||
|
if errcode not in (None, 0):
|
||||||
|
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||||
|
return False
|
||||||
|
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
|
||||||
|
return await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleMarkdown",
|
||||||
|
{"text": content, "title": "Nanobot Reply"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
|
||||||
|
media_ref = (media_ref or "").strip()
|
||||||
|
if not media_ref:
|
||||||
|
return True
|
||||||
|
|
||||||
|
upload_type = self._guess_upload_type(media_ref)
|
||||||
|
if upload_type == "image" and self._is_http_url(media_ref):
|
||||||
|
ok = await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleImageMsg",
|
||||||
|
{"photoURL": media_ref},
|
||||||
|
)
|
||||||
|
if ok:
|
||||||
|
return True
|
||||||
|
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
|
||||||
|
|
||||||
|
data, filename, content_type = await self._read_media_bytes(media_ref)
|
||||||
|
if not data:
|
||||||
|
logger.error("DingTalk media read failed: {}", media_ref)
|
||||||
|
return False
|
||||||
|
|
||||||
|
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||||
|
file_type = Path(filename).suffix.lower().lstrip(".")
|
||||||
|
if not file_type:
|
||||||
|
guessed = mimetypes.guess_extension(content_type or "")
|
||||||
|
file_type = (guessed or ".bin").lstrip(".")
|
||||||
|
if file_type == "jpeg":
|
||||||
|
file_type = "jpg"
|
||||||
|
|
||||||
|
media_id = await self._upload_media(
|
||||||
|
token=token,
|
||||||
|
data=data,
|
||||||
|
media_type=upload_type,
|
||||||
|
filename=filename,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
if not media_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if upload_type == "image":
|
||||||
|
# Verified in production: sampleImageMsg accepts media_id in photoURL.
|
||||||
|
ok = await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleImageMsg",
|
||||||
|
{"photoURL": media_id},
|
||||||
|
)
|
||||||
|
if ok:
|
||||||
|
return True
|
||||||
|
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
|
||||||
|
|
||||||
|
return await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleFile",
|
||||||
|
{"mediaId": media_id, "fileName": filename, "fileType": file_type},
|
||||||
|
)
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through DingTalk."""
|
"""Send a message through DingTalk."""
|
||||||
token = await self._get_access_token()
|
token = await self._get_access_token()
|
||||||
if not token:
|
if not token:
|
||||||
return
|
return
|
||||||
|
|
||||||
# oToMessages/batchSend: sends to individual users (private chat)
|
if msg.content and msg.content.strip():
|
||||||
# https://open.dingtalk.com/document/orgapp/robot-batch-send-messages
|
await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
|
||||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
|
||||||
|
|
||||||
headers = {"x-acs-dingtalk-access-token": token}
|
for media_ref in msg.media or []:
|
||||||
|
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
|
||||||
data = {
|
if ok:
|
||||||
"robotCode": self.config.client_id,
|
continue
|
||||||
"userIds": [msg.chat_id], # chat_id is the user's staffId
|
logger.error("DingTalk media send failed for {}", media_ref)
|
||||||
"msgKey": "sampleMarkdown",
|
# Send visible fallback so failures are observable by the user.
|
||||||
"msgParam": json.dumps({
|
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||||
"text": msg.content,
|
await self._send_markdown_text(
|
||||||
"title": "Nanobot Reply",
|
token,
|
||||||
}, ensure_ascii=False),
|
msg.chat_id,
|
||||||
}
|
f"[Attachment send failed: {filename}]",
|
||||||
|
)
|
||||||
if not self._http:
|
|
||||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp = await self._http.post(url, json=data, headers=headers)
|
|
||||||
if resp.status_code != 200:
|
|
||||||
logger.error("DingTalk send failed: {}", resp.text)
|
|
||||||
else:
|
|
||||||
logger.debug("DingTalk message sent to {}", msg.chat_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error sending DingTalk message: {}", e)
|
|
||||||
|
|
||||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
||||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||||
|
|||||||
@@ -89,7 +89,8 @@ def _extract_interactive_content(content: dict) -> list[str]:
|
|||||||
elif isinstance(title, str):
|
elif isinstance(title, str):
|
||||||
parts.append(f"title: {title}")
|
parts.append(f"title: {title}")
|
||||||
|
|
||||||
for element in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
||||||
|
for element in elements:
|
||||||
parts.extend(_extract_element_content(element))
|
parts.extend(_extract_element_content(element))
|
||||||
|
|
||||||
card = content.get("card", {})
|
card = content.get("card", {})
|
||||||
@@ -325,13 +326,14 @@ class FeishuChannel(BaseChannel):
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the Feishu bot."""
|
"""
|
||||||
|
Stop the Feishu bot.
|
||||||
|
|
||||||
|
Notice: lark.ws.Client does not expose stop method, simply exiting the program will close the client.
|
||||||
|
|
||||||
|
Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
|
||||||
|
"""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._ws_client:
|
|
||||||
try:
|
|
||||||
self._ws_client.stop()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Error stopping WebSocket client: {}", e)
|
|
||||||
logger.info("Feishu bot stopped")
|
logger.info("Feishu bot stopped")
|
||||||
|
|
||||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||||
@@ -692,7 +694,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
msg_type = message.message_type
|
msg_type = message.message_type
|
||||||
|
|
||||||
# Add reaction
|
# Add reaction
|
||||||
await self._add_reaction(message_id, "THUMBSUP")
|
await self._add_reaction(message_id, self.config.react_emoji)
|
||||||
|
|
||||||
# Parse content
|
# Parse content
|
||||||
content_parts = []
|
content_parts = []
|
||||||
|
|||||||
@@ -137,6 +137,18 @@ class ChannelManager:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("QQ channel not available: {}", e)
|
logger.warning("QQ channel not available: {}", e)
|
||||||
|
|
||||||
|
# Matrix channel
|
||||||
|
if self.config.channels.matrix.enabled:
|
||||||
|
try:
|
||||||
|
from nanobot.channels.matrix import MatrixChannel
|
||||||
|
self.channels["matrix"] = MatrixChannel(
|
||||||
|
self.config.channels.matrix,
|
||||||
|
self.bus,
|
||||||
|
)
|
||||||
|
logger.info("Matrix channel enabled")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning("Matrix channel not available: {}", e)
|
||||||
|
|
||||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||||
"""Start a channel and log any exceptions."""
|
"""Start a channel and log any exceptions."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
682
nanobot/channels/matrix.py
Normal file
682
nanobot/channels/matrix.py
Normal file
@@ -0,0 +1,682 @@
|
|||||||
|
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, TypeAlias
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import nh3
|
||||||
|
from mistune import create_markdown
|
||||||
|
from nio import (
|
||||||
|
AsyncClient, AsyncClientConfig, ContentRepositoryConfigError,
|
||||||
|
DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse,
|
||||||
|
RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText,
|
||||||
|
RoomSendError, RoomTypingError, SyncError, UploadError,
|
||||||
|
)
|
||||||
|
from nio.crypto.attachments import decrypt_attachment
|
||||||
|
from nio.exceptions import EncryptionError
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
from nanobot.config.loader import get_data_dir
|
||||||
|
from nanobot.utils.helpers import safe_filename
|
||||||
|
|
||||||
|
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||||
|
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
||||||
|
TYPING_KEEPALIVE_INTERVAL_MS = 20_000
|
||||||
|
MATRIX_HTML_FORMAT = "org.matrix.custom.html"
|
||||||
|
_ATTACH_MARKER = "[attachment: {}]"
|
||||||
|
_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
|
||||||
|
_ATTACH_FAILED = "[attachment: {} - download failed]"
|
||||||
|
_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
|
||||||
|
_DEFAULT_ATTACH_NAME = "attachment"
|
||||||
|
_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
|
||||||
|
|
||||||
|
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||||
|
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||||
|
|
||||||
|
MATRIX_MARKDOWN = create_markdown(
|
||||||
|
escape=True,
|
||||||
|
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
||||||
|
)
|
||||||
|
|
||||||
|
MATRIX_ALLOWED_HTML_TAGS = {
|
||||||
|
"p", "a", "strong", "em", "del", "code", "pre", "blockquote",
|
||||||
|
"ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
|
||||||
|
"hr", "br", "table", "thead", "tbody", "tr", "th", "td",
|
||||||
|
"caption", "sup", "sub", "img",
|
||||||
|
}
|
||||||
|
MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
|
||||||
|
"a": {"href"}, "code": {"class"}, "ol": {"start"},
|
||||||
|
"img": {"src", "alt", "title", "width", "height"},
|
||||||
|
}
|
||||||
|
MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
|
||||||
|
"""Filter attribute values to a safe Matrix-compatible subset."""
|
||||||
|
if tag == "a" and attr == "href":
|
||||||
|
return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
|
||||||
|
if tag == "img" and attr == "src":
|
||||||
|
return value if value.lower().startswith("mxc://") else None
|
||||||
|
if tag == "code" and attr == "class":
|
||||||
|
classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
|
||||||
|
return " ".join(classes) if classes else None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||||
|
tags=MATRIX_ALLOWED_HTML_TAGS,
|
||||||
|
attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
|
||||||
|
attribute_filter=_filter_matrix_html_attribute,
|
||||||
|
url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
|
||||||
|
strip_comments=True,
|
||||||
|
link_rel="noopener noreferrer",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_markdown_html(text: str) -> str | None:
|
||||||
|
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||||
|
try:
|
||||||
|
formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if not formatted:
|
||||||
|
return None
|
||||||
|
# Skip formatted_body for plain <p>text</p> to keep payload minimal.
|
||||||
|
if formatted.startswith("<p>") and formatted.endswith("</p>"):
|
||||||
|
inner = formatted[3:-4]
|
||||||
|
if "<" not in inner and ">" not in inner:
|
||||||
|
return None
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
|
||||||
|
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||||
|
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||||
|
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||||
|
if html := _render_markdown_html(text):
|
||||||
|
content["format"] = MATRIX_HTML_FORMAT
|
||||||
|
content["formatted_body"] = html
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
class _NioLoguruHandler(logging.Handler):
|
||||||
|
"""Route matrix-nio stdlib logs into Loguru."""
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
|
try:
|
||||||
|
level = logger.level(record.levelname).name
|
||||||
|
except ValueError:
|
||||||
|
level = record.levelno
|
||||||
|
frame, depth = logging.currentframe(), 2
|
||||||
|
while frame and frame.f_code.co_filename == logging.__file__:
|
||||||
|
frame, depth = frame.f_back, depth + 1
|
||||||
|
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_nio_logging_bridge() -> None:
|
||||||
|
"""Bridge matrix-nio logs to Loguru (idempotent)."""
|
||||||
|
nio_logger = logging.getLogger("nio")
|
||||||
|
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
|
||||||
|
nio_logger.handlers = [_NioLoguruHandler()]
|
||||||
|
nio_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixChannel(BaseChannel):
|
||||||
|
"""Matrix (Element) channel using long-polling sync."""
|
||||||
|
|
||||||
|
name = "matrix"
|
||||||
|
|
||||||
|
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
|
||||||
|
workspace: Path | None = None):
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self.client: AsyncClient | None = None
|
||||||
|
self._sync_task: asyncio.Task | None = None
|
||||||
|
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._restrict_to_workspace = restrict_to_workspace
|
||||||
|
self._workspace = workspace.expanduser().resolve() if workspace else None
|
||||||
|
self._server_upload_limit_bytes: int | None = None
|
||||||
|
self._server_upload_limit_checked = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start Matrix client and begin sync loop."""
|
||||||
|
self._running = True
|
||||||
|
_configure_nio_logging_bridge()
|
||||||
|
|
||||||
|
store_path = get_data_dir() / "matrix-store"
|
||||||
|
store_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.client = AsyncClient(
|
||||||
|
homeserver=self.config.homeserver, user=self.config.user_id,
|
||||||
|
store_path=store_path,
|
||||||
|
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
|
||||||
|
)
|
||||||
|
self.client.user_id = self.config.user_id
|
||||||
|
self.client.access_token = self.config.access_token
|
||||||
|
self.client.device_id = self.config.device_id
|
||||||
|
|
||||||
|
self._register_event_callbacks()
|
||||||
|
self._register_response_callbacks()
|
||||||
|
|
||||||
|
if not self.config.e2ee_enabled:
|
||||||
|
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
||||||
|
|
||||||
|
if self.config.device_id:
|
||||||
|
try:
|
||||||
|
self.client.load_store()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Matrix store load failed; restart may replay recent messages.")
|
||||||
|
else:
|
||||||
|
logger.warning("Matrix device_id empty; restart may replay recent messages.")
|
||||||
|
|
||||||
|
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the Matrix channel with graceful sync shutdown."""
|
||||||
|
self._running = False
|
||||||
|
for room_id in list(self._typing_tasks):
|
||||||
|
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||||
|
if self.client:
|
||||||
|
self.client.stop_sync_forever()
|
||||||
|
if self._sync_task:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.shield(self._sync_task),
|
||||||
|
timeout=self.config.sync_stop_grace_seconds)
|
||||||
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
|
self._sync_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._sync_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
if self.client:
|
||||||
|
await self.client.close()
|
||||||
|
|
||||||
|
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||||
|
"""Check path is inside workspace (when restriction enabled)."""
|
||||||
|
if not self._restrict_to_workspace or not self._workspace:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
path.resolve(strict=False).relative_to(self._workspace)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
|
||||||
|
"""Deduplicate and resolve outbound attachment paths."""
|
||||||
|
seen: set[str] = set()
|
||||||
|
candidates: list[Path] = []
|
||||||
|
for raw in media:
|
||||||
|
if not isinstance(raw, str) or not raw.strip():
|
||||||
|
continue
|
||||||
|
path = Path(raw.strip()).expanduser()
|
||||||
|
try:
|
||||||
|
key = str(path.resolve(strict=False))
|
||||||
|
except OSError:
|
||||||
|
key = str(path)
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
candidates.append(path)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_outbound_attachment_content(
|
||||||
|
*, filename: str, mime: str, size_bytes: int,
|
||||||
|
mxc_url: str, encryption_info: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build Matrix content payload for an uploaded file/image/audio/video."""
|
||||||
|
prefix = mime.split("/")[0]
|
||||||
|
msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
|
||||||
|
content: dict[str, Any] = {
|
||||||
|
"msgtype": msgtype, "body": filename, "filename": filename,
|
||||||
|
"info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
|
||||||
|
}
|
||||||
|
if encryption_info:
|
||||||
|
content["file"] = {**encryption_info, "url": mxc_url}
|
||||||
|
else:
|
||||||
|
content["url"] = mxc_url
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _is_encrypted_room(self, room_id: str) -> bool:
|
||||||
|
if not self.client:
|
||||||
|
return False
|
||||||
|
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||||
|
return bool(getattr(room, "encrypted", False))
|
||||||
|
|
||||||
|
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||||
|
"""Send m.room.message with E2EE options."""
|
||||||
|
if not self.client:
|
||||||
|
return
|
||||||
|
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||||
|
if self.config.e2ee_enabled:
|
||||||
|
kwargs["ignore_unverified_devices"] = True
|
||||||
|
await self.client.room_send(**kwargs)
|
||||||
|
|
||||||
|
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||||
|
"""Query homeserver upload limit once per channel lifecycle."""
|
||||||
|
if self._server_upload_limit_checked:
|
||||||
|
return self._server_upload_limit_bytes
|
||||||
|
self._server_upload_limit_checked = True
|
||||||
|
if not self.client:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
response = await self.client.content_repository_config()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
upload_size = getattr(response, "upload_size", None)
|
||||||
|
if isinstance(upload_size, int) and upload_size > 0:
|
||||||
|
self._server_upload_limit_bytes = upload_size
|
||||||
|
return upload_size
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _effective_media_limit_bytes(self) -> int:
|
||||||
|
"""min(local config, server advertised) — 0 blocks all uploads."""
|
||||||
|
local_limit = max(int(self.config.max_media_bytes), 0)
|
||||||
|
server_limit = await self._resolve_server_upload_limit_bytes()
|
||||||
|
if server_limit is None:
|
||||||
|
return local_limit
|
||||||
|
return min(local_limit, server_limit) if local_limit else 0
|
||||||
|
|
||||||
|
async def _upload_and_send_attachment(
|
||||||
|
self, room_id: str, path: Path, limit_bytes: int,
|
||||||
|
relates_to: dict[str, Any] | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
|
||||||
|
if not self.client:
|
||||||
|
return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
|
||||||
|
|
||||||
|
resolved = path.expanduser().resolve(strict=False)
|
||||||
|
filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
|
||||||
|
fail = _ATTACH_UPLOAD_FAILED.format(filename)
|
||||||
|
|
||||||
|
if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
|
||||||
|
return fail
|
||||||
|
try:
|
||||||
|
size_bytes = resolved.stat().st_size
|
||||||
|
except OSError:
|
||||||
|
return fail
|
||||||
|
if limit_bytes <= 0 or size_bytes > limit_bytes:
|
||||||
|
return _ATTACH_TOO_LARGE.format(filename)
|
||||||
|
|
||||||
|
mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
|
||||||
|
try:
|
||||||
|
with resolved.open("rb") as f:
|
||||||
|
upload_result = await self.client.upload(
|
||||||
|
f, content_type=mime, filename=filename,
|
||||||
|
encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
|
||||||
|
filesize=size_bytes,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return fail
|
||||||
|
|
||||||
|
upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
|
||||||
|
encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
|
||||||
|
if isinstance(upload_response, UploadError):
|
||||||
|
return fail
|
||||||
|
mxc_url = getattr(upload_response, "content_uri", None)
|
||||||
|
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||||
|
return fail
|
||||||
|
|
||||||
|
content = self._build_outbound_attachment_content(
|
||||||
|
filename=filename, mime=mime, size_bytes=size_bytes,
|
||||||
|
mxc_url=mxc_url, encryption_info=encryption_info,
|
||||||
|
)
|
||||||
|
if relates_to:
|
||||||
|
content["m.relates_to"] = relates_to
|
||||||
|
try:
|
||||||
|
await self._send_room_content(room_id, content)
|
||||||
|
except Exception:
|
||||||
|
return fail
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
"""Send outbound content; clear typing for non-progress messages."""
|
||||||
|
if not self.client:
|
||||||
|
return
|
||||||
|
text = msg.content or ""
|
||||||
|
candidates = self._collect_outbound_media_candidates(msg.media)
|
||||||
|
relates_to = self._build_thread_relates_to(msg.metadata)
|
||||||
|
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||||
|
try:
|
||||||
|
failures: list[str] = []
|
||||||
|
if candidates:
|
||||||
|
limit_bytes = await self._effective_media_limit_bytes()
|
||||||
|
for path in candidates:
|
||||||
|
if fail := await self._upload_and_send_attachment(
|
||||||
|
msg.chat_id, path, limit_bytes, relates_to):
|
||||||
|
failures.append(fail)
|
||||||
|
if failures:
|
||||||
|
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||||
|
if text or not candidates:
|
||||||
|
content = _build_matrix_text_content(text)
|
||||||
|
if relates_to:
|
||||||
|
content["m.relates_to"] = relates_to
|
||||||
|
await self._send_room_content(msg.chat_id, content)
|
||||||
|
finally:
|
||||||
|
if not is_progress:
|
||||||
|
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||||
|
|
||||||
|
def _register_event_callbacks(self) -> None:
|
||||||
|
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||||
|
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||||
|
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||||
|
|
||||||
|
def _register_response_callbacks(self) -> None:
|
||||||
|
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||||
|
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||||
|
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||||
|
|
||||||
|
def _log_response_error(self, label: str, response: Any) -> None:
|
||||||
|
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
||||||
|
code = getattr(response, "status_code", None)
|
||||||
|
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||||
|
is_fatal = is_auth or getattr(response, "soft_logout", False)
|
||||||
|
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
|
||||||
|
|
||||||
|
async def _on_sync_error(self, response: SyncError) -> None:
|
||||||
|
self._log_response_error("sync", response)
|
||||||
|
|
||||||
|
async def _on_join_error(self, response: JoinError) -> None:
|
||||||
|
self._log_response_error("join", response)
|
||||||
|
|
||||||
|
async def _on_send_error(self, response: RoomSendError) -> None:
|
||||||
|
self._log_response_error("send", response)
|
||||||
|
|
||||||
|
async def _set_typing(self, room_id: str, typing: bool) -> None:
|
||||||
|
"""Best-effort typing indicator update."""
|
||||||
|
if not self.client:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||||
|
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||||
|
if isinstance(response, RoomTypingError):
|
||||||
|
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||||
|
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||||
|
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||||
|
await self._set_typing(room_id, True)
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def loop() -> None:
|
||||||
|
try:
|
||||||
|
while self._running:
|
||||||
|
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
||||||
|
await self._set_typing(room_id, True)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
||||||
|
|
||||||
|
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
||||||
|
if task := self._typing_tasks.pop(room_id, None):
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
if clear_typing:
|
||||||
|
await self._set_typing(room_id, False)
|
||||||
|
|
||||||
|
async def _sync_loop(self) -> None:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
await self.client.sync_forever(timeout=30000, full_state=True)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||||||
|
allow_from = self.config.allow_from or []
|
||||||
|
if not allow_from or event.sender in allow_from:
|
||||||
|
await self.client.join(room.room_id)
|
||||||
|
|
||||||
|
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
||||||
|
count = getattr(room, "member_count", None)
|
||||||
|
return isinstance(count, int) and count <= 2
|
||||||
|
|
||||||
|
def _is_bot_mentioned(self, event: RoomMessage) -> bool:
|
||||||
|
"""Check m.mentions payload for bot mention."""
|
||||||
|
source = getattr(event, "source", None)
|
||||||
|
if not isinstance(source, dict):
|
||||||
|
return False
|
||||||
|
mentions = (source.get("content") or {}).get("m.mentions")
|
||||||
|
if not isinstance(mentions, dict):
|
||||||
|
return False
|
||||||
|
user_ids = mentions.get("user_ids")
|
||||||
|
if isinstance(user_ids, list) and self.config.user_id in user_ids:
|
||||||
|
return True
|
||||||
|
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||||
|
|
||||||
|
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||||
|
"""Apply sender and room policy checks."""
|
||||||
|
if not self.is_allowed(event.sender):
|
||||||
|
return False
|
||||||
|
if self._is_direct_room(room):
|
||||||
|
return True
|
||||||
|
policy = self.config.group_policy
|
||||||
|
if policy == "open":
|
||||||
|
return True
|
||||||
|
if policy == "allowlist":
|
||||||
|
return room.room_id in (self.config.group_allow_from or [])
|
||||||
|
if policy == "mention":
|
||||||
|
return self._is_bot_mentioned(event)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _media_dir(self) -> Path:
|
||||||
|
d = get_data_dir() / "media" / "matrix"
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
return d
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||||
|
source = getattr(event, "source", None)
|
||||||
|
if not isinstance(source, dict):
|
||||||
|
return {}
|
||||||
|
content = source.get("content")
|
||||||
|
return content if isinstance(content, dict) else {}
|
||||||
|
|
||||||
|
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||||
|
relates_to = self._event_source_content(event).get("m.relates_to")
|
||||||
|
if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
|
||||||
|
return None
|
||||||
|
root_id = relates_to.get("event_id")
|
||||||
|
return root_id if isinstance(root_id, str) and root_id else None
|
||||||
|
|
||||||
|
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||||
|
if not (root_id := self._event_thread_root_id(event)):
|
||||||
|
return None
|
||||||
|
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||||
|
if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
|
||||||
|
meta["thread_reply_to_event_id"] = reply_to
|
||||||
|
return meta
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
|
if not metadata:
|
||||||
|
return None
|
||||||
|
root_id = metadata.get("thread_root_event_id")
|
||||||
|
if not isinstance(root_id, str) or not root_id:
|
||||||
|
return None
|
||||||
|
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||||
|
if not isinstance(reply_to, str) or not reply_to:
|
||||||
|
return None
|
||||||
|
return {"rel_type": "m.thread", "event_id": root_id,
|
||||||
|
"m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
|
||||||
|
|
||||||
|
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||||
|
msgtype = self._event_source_content(event).get("msgtype")
|
||||||
|
return _MSGTYPE_MAP.get(msgtype, "file")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||||
|
return (isinstance(getattr(event, "key", None), dict)
|
||||||
|
and isinstance(getattr(event, "hashes", None), dict)
|
||||||
|
and isinstance(getattr(event, "iv", None), str))
|
||||||
|
|
||||||
|
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||||
|
info = self._event_source_content(event).get("info")
|
||||||
|
size = info.get("size") if isinstance(info, dict) else None
|
||||||
|
return size if isinstance(size, int) and size >= 0 else None
|
||||||
|
|
||||||
|
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||||
|
info = self._event_source_content(event).get("info")
|
||||||
|
if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
|
||||||
|
return m
|
||||||
|
m = getattr(event, "mimetype", None)
|
||||||
|
return m if isinstance(m, str) and m else None
|
||||||
|
|
||||||
|
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||||
|
body = getattr(event, "body", None)
|
||||||
|
if isinstance(body, str) and body.strip():
|
||||||
|
if candidate := safe_filename(Path(body).name):
|
||||||
|
return candidate
|
||||||
|
return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
|
||||||
|
|
||||||
|
def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
|
||||||
|
filename: str, mime: str | None) -> Path:
|
||||||
|
safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
|
||||||
|
suffix = Path(safe_name).suffix
|
||||||
|
if not suffix and mime:
|
||||||
|
if guessed := mimetypes.guess_extension(mime, strict=False):
|
||||||
|
safe_name, suffix = f"{safe_name}{guessed}", guessed
|
||||||
|
stem = (Path(safe_name).stem or attachment_type)[:72]
|
||||||
|
suffix = suffix[:16]
|
||||||
|
event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
|
||||||
|
event_prefix = (event_id[:24] or "evt").strip("_")
|
||||||
|
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
||||||
|
|
||||||
|
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
|
||||||
|
if not self.client:
|
||||||
|
return None
|
||||||
|
response = await self.client.download(mxc=mxc_url)
|
||||||
|
if isinstance(response, DownloadError):
|
||||||
|
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
|
||||||
|
return None
|
||||||
|
body = getattr(response, "body", None)
|
||||||
|
if isinstance(body, (bytes, bytearray)):
|
||||||
|
return bytes(body)
|
||||||
|
if isinstance(response, MemoryDownloadResponse):
|
||||||
|
return bytes(response.body)
|
||||||
|
if isinstance(body, (str, Path)):
|
||||||
|
path = Path(body)
|
||||||
|
if path.is_file():
|
||||||
|
try:
|
||||||
|
return path.read_bytes()
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||||
|
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
||||||
|
key = key_obj.get("k") if isinstance(key_obj, dict) else None
|
||||||
|
sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
|
||||||
|
if not all(isinstance(v, str) for v in (key, sha256, iv)):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return decrypt_attachment(ciphertext, key, sha256, iv)
|
||||||
|
except (EncryptionError, ValueError, TypeError):
|
||||||
|
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_media_attachment(
|
||||||
|
self, room: MatrixRoom, event: MatrixMediaEvent,
|
||||||
|
) -> tuple[dict[str, Any] | None, str]:
|
||||||
|
"""Download, decrypt if needed, and persist a Matrix attachment."""
|
||||||
|
atype = self._event_attachment_type(event)
|
||||||
|
mime = self._event_mime(event)
|
||||||
|
filename = self._event_filename(event, atype)
|
||||||
|
mxc_url = getattr(event, "url", None)
|
||||||
|
fail = _ATTACH_FAILED.format(filename)
|
||||||
|
|
||||||
|
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||||
|
return None, fail
|
||||||
|
|
||||||
|
limit_bytes = await self._effective_media_limit_bytes()
|
||||||
|
declared = self._event_declared_size_bytes(event)
|
||||||
|
if declared is not None and declared > limit_bytes:
|
||||||
|
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||||
|
|
||||||
|
downloaded = await self._download_media_bytes(mxc_url)
|
||||||
|
if downloaded is None:
|
||||||
|
return None, fail
|
||||||
|
|
||||||
|
encrypted = self._is_encrypted_media_event(event)
|
||||||
|
data = downloaded
|
||||||
|
if encrypted:
|
||||||
|
if (data := self._decrypt_media_bytes(event, downloaded)) is None:
|
||||||
|
return None, fail
|
||||||
|
|
||||||
|
if len(data) > limit_bytes:
|
||||||
|
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||||
|
|
||||||
|
path = self._build_attachment_path(event, atype, filename, mime)
|
||||||
|
try:
|
||||||
|
path.write_bytes(data)
|
||||||
|
except OSError:
|
||||||
|
return None, fail
|
||||||
|
|
||||||
|
attachment = {
|
||||||
|
"type": atype, "mime": mime, "filename": filename,
|
||||||
|
"event_id": str(getattr(event, "event_id", "") or ""),
|
||||||
|
"encrypted": encrypted, "size_bytes": len(data),
|
||||||
|
"path": str(path), "mxc_url": mxc_url,
|
||||||
|
}
|
||||||
|
return attachment, _ATTACH_MARKER.format(path)
|
||||||
|
|
||||||
|
def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
|
||||||
|
"""Build common metadata for text and media handlers."""
|
||||||
|
meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
|
||||||
|
if isinstance(eid := getattr(event, "event_id", None), str) and eid:
|
||||||
|
meta["event_id"] = eid
|
||||||
|
if thread := self._thread_metadata(event):
|
||||||
|
meta.update(thread)
|
||||||
|
return meta
|
||||||
|
|
||||||
|
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
|
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||||
|
return
|
||||||
|
await self._start_typing_keepalive(room.room_id)
|
||||||
|
try:
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=event.sender, chat_id=room.room_id,
|
||||||
|
content=event.body, metadata=self._base_metadata(room, event),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||||
|
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||||
|
return
|
||||||
|
attachment, marker = await self._fetch_media_attachment(room, event)
|
||||||
|
parts: list[str] = []
|
||||||
|
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||||
|
parts.append(body.strip())
|
||||||
|
parts.append(marker)
|
||||||
|
|
||||||
|
await self._start_typing_keepalive(room.room_id)
|
||||||
|
try:
|
||||||
|
meta = self._base_metadata(room, event)
|
||||||
|
if attachment:
|
||||||
|
meta["attachments"] = [attachment]
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=event.sender, chat_id=room.room_id,
|
||||||
|
content="\n".join(parts),
|
||||||
|
media=[attachment["path"]] if attachment else [],
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||||
|
raise
|
||||||
@@ -31,7 +31,8 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
|||||||
|
|
||||||
class _Bot(botpy.Client):
|
class _Bot(botpy.Client):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(intents=intents)
|
# Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
|
||||||
|
super().__init__(intents=intents, ext_handlers=False)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
logger.info("QQ bot ready: {}", self.robot.name)
|
logger.info("QQ bot ready: {}", self.robot.name)
|
||||||
@@ -100,10 +101,12 @@ class QQChannel(BaseChannel):
|
|||||||
logger.warning("QQ client not initialized")
|
logger.warning("QQ client not initialized")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
msg_id = msg.metadata.get("message_id")
|
||||||
await self._client.api.post_c2c_message(
|
await self._client.api.post_c2c_message(
|
||||||
openid=msg.chat_id,
|
openid=msg.chat_id,
|
||||||
msg_type=0,
|
msg_type=0,
|
||||||
content=msg.content,
|
content=msg.content,
|
||||||
|
msg_id=msg_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending QQ message: {}", e)
|
logger.error("Error sending QQ message: {}", e)
|
||||||
|
|||||||
@@ -229,6 +229,11 @@ class SlackChannel(BaseChannel):
|
|||||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||||
|
|
||||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||||
|
_CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
|
||||||
|
_INLINE_CODE_RE = re.compile(r"`[^`]+`")
|
||||||
|
_LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
|
||||||
|
_LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||||
|
_BARE_URL_RE = re.compile(r"(?<![|<])(https?://\S+)")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _to_mrkdwn(cls, text: str) -> str:
|
def _to_mrkdwn(cls, text: str) -> str:
|
||||||
@@ -236,7 +241,26 @@ class SlackChannel(BaseChannel):
|
|||||||
if not text:
|
if not text:
|
||||||
return ""
|
return ""
|
||||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||||
return slackify_markdown(text)
|
return cls._fixup_mrkdwn(slackify_markdown(text))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fixup_mrkdwn(cls, text: str) -> str:
|
||||||
|
"""Fix markdown artifacts that slackify_markdown misses."""
|
||||||
|
code_blocks: list[str] = []
|
||||||
|
|
||||||
|
def _save_code(m: re.Match) -> str:
|
||||||
|
code_blocks.append(m.group(0))
|
||||||
|
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||||
|
|
||||||
|
text = cls._CODE_FENCE_RE.sub(_save_code, text)
|
||||||
|
text = cls._INLINE_CODE_RE.sub(_save_code, text)
|
||||||
|
text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
|
||||||
|
text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
|
||||||
|
text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
|
||||||
|
|
||||||
|
for i, block in enumerate(code_blocks):
|
||||||
|
text = text.replace(f"\x00CB{i}\x00", block)
|
||||||
|
return text
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_table(match: re.Match) -> str:
|
def _convert_table(match: re.Match) -> str:
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
BOT_COMMANDS = [
|
BOT_COMMANDS = [
|
||||||
BotCommand("start", "Start the bot"),
|
BotCommand("start", "Start the bot"),
|
||||||
BotCommand("new", "Start a new conversation"),
|
BotCommand("new", "Start a new conversation"),
|
||||||
|
BotCommand("stop", "Stop the current task"),
|
||||||
BotCommand("help", "Show available commands"),
|
BotCommand("help", "Show available commands"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -126,6 +127,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._app: Application | None = None
|
self._app: Application | None = None
|
||||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||||
|
self._media_group_buffers: dict[str, dict] = {}
|
||||||
|
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Telegram bot with long polling."""
|
"""Start the Telegram bot with long polling."""
|
||||||
@@ -191,6 +194,11 @@ class TelegramChannel(BaseChannel):
|
|||||||
for chat_id in list(self._typing_tasks):
|
for chat_id in list(self._typing_tasks):
|
||||||
self._stop_typing(chat_id)
|
self._stop_typing(chat_id)
|
||||||
|
|
||||||
|
for task in self._media_group_tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
self._media_group_tasks.clear()
|
||||||
|
self._media_group_buffers.clear()
|
||||||
|
|
||||||
if self._app:
|
if self._app:
|
||||||
logger.info("Stopping Telegram bot...")
|
logger.info("Stopping Telegram bot...")
|
||||||
await self._app.updater.stop()
|
await self._app.updater.stop()
|
||||||
@@ -299,6 +307,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
await update.message.reply_text(
|
await update.message.reply_text(
|
||||||
"🐈 nanobot commands:\n"
|
"🐈 nanobot commands:\n"
|
||||||
"/new — Start a new conversation\n"
|
"/new — Start a new conversation\n"
|
||||||
|
"/stop — Stop the current task\n"
|
||||||
"/help — Show available commands"
|
"/help — Show available commands"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -398,6 +407,28 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
str_chat_id = str(chat_id)
|
str_chat_id = str(chat_id)
|
||||||
|
|
||||||
|
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||||
|
if media_group_id := getattr(message, "media_group_id", None):
|
||||||
|
key = f"{str_chat_id}:{media_group_id}"
|
||||||
|
if key not in self._media_group_buffers:
|
||||||
|
self._media_group_buffers[key] = {
|
||||||
|
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||||
|
"contents": [], "media": [],
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message.message_id, "user_id": user.id,
|
||||||
|
"username": user.username, "first_name": user.first_name,
|
||||||
|
"is_group": message.chat.type != "private",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self._start_typing(str_chat_id)
|
||||||
|
buf = self._media_group_buffers[key]
|
||||||
|
if content and content != "[empty message]":
|
||||||
|
buf["contents"].append(content)
|
||||||
|
buf["media"].extend(media_paths)
|
||||||
|
if key not in self._media_group_tasks:
|
||||||
|
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
||||||
|
return
|
||||||
|
|
||||||
# Start typing indicator before processing
|
# Start typing indicator before processing
|
||||||
self._start_typing(str_chat_id)
|
self._start_typing(str_chat_id)
|
||||||
|
|
||||||
@@ -416,6 +447,21 @@ class TelegramChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _flush_media_group(self, key: str) -> None:
|
||||||
|
"""Wait briefly, then forward buffered media-group as one turn."""
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.6)
|
||||||
|
if not (buf := self._media_group_buffers.pop(key, None)):
|
||||||
|
return
|
||||||
|
content = "\n".join(buf["contents"]) or "[empty message]"
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||||
|
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||||
|
metadata=buf["metadata"],
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._media_group_tasks.pop(key, None)
|
||||||
|
|
||||||
def _start_typing(self, chat_id: str) -> None:
|
def _start_typing(self, chat_id: str) -> None:
|
||||||
"""Start sending 'typing...' indicator for a chat."""
|
"""Start sending 'typing...' indicator for a chat."""
|
||||||
# Cancel any existing typing task for this chat
|
# Cancel any existing typing task for this chat
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -27,6 +28,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
self.config: WhatsAppConfig = config
|
self.config: WhatsAppConfig = config
|
||||||
self._ws = None
|
self._ws = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||||
@@ -108,6 +110,14 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
# New LID sytle typically:
|
# New LID sytle typically:
|
||||||
sender = data.get("sender", "")
|
sender = data.get("sender", "")
|
||||||
content = data.get("content", "")
|
content = data.get("content", "")
|
||||||
|
message_id = data.get("id", "")
|
||||||
|
|
||||||
|
if message_id:
|
||||||
|
if message_id in self._processed_message_ids:
|
||||||
|
return
|
||||||
|
self._processed_message_ids[message_id] = None
|
||||||
|
while len(self._processed_message_ids) > 1000:
|
||||||
|
self._processed_message_ids.popitem(last=False)
|
||||||
|
|
||||||
# Extract just the phone number or lid as chat_id
|
# Extract just the phone number or lid as chat_id
|
||||||
user_id = pn if pn else sender
|
user_id = pn if pn else sender
|
||||||
@@ -124,7 +134,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
chat_id=sender, # Use full LID for replies
|
chat_id=sender, # Use full LID for replies
|
||||||
content=content,
|
content=content,
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": data.get("id"),
|
"message_id": message_id,
|
||||||
"timestamp": data.get("timestamp"),
|
"timestamp": data.get("timestamp"),
|
||||||
"is_group": data.get("isGroup", False)
|
"is_group": data.get("isGroup", False)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from prompt_toolkit.patch_stdout import patch_stdout
|
|||||||
|
|
||||||
from nanobot import __version__, __logo__
|
from nanobot import __version__, __logo__
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
from nanobot.utils.helpers import sync_workspace_templates
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="nanobot",
|
name="nanobot",
|
||||||
@@ -185,8 +186,7 @@ def onboard():
|
|||||||
workspace.mkdir(parents=True, exist_ok=True)
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||||
|
|
||||||
# Create default bootstrap files
|
sync_workspace_templates(workspace)
|
||||||
_create_workspace_templates(workspace)
|
|
||||||
|
|
||||||
console.print(f"\n{__logo__} nanobot is ready!")
|
console.print(f"\n{__logo__} nanobot is ready!")
|
||||||
console.print("\nNext steps:")
|
console.print("\nNext steps:")
|
||||||
@@ -198,36 +198,6 @@ def onboard():
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _create_workspace_templates(workspace: Path):
|
|
||||||
"""Create default workspace template files from bundled templates."""
|
|
||||||
from importlib.resources import files as pkg_files
|
|
||||||
|
|
||||||
templates_dir = pkg_files("nanobot") / "templates"
|
|
||||||
|
|
||||||
for item in templates_dir.iterdir():
|
|
||||||
if not item.name.endswith(".md"):
|
|
||||||
continue
|
|
||||||
dest = workspace / item.name
|
|
||||||
if not dest.exists():
|
|
||||||
dest.write_text(item.read_text(encoding="utf-8"), encoding="utf-8")
|
|
||||||
console.print(f" [dim]Created {item.name}[/dim]")
|
|
||||||
|
|
||||||
memory_dir = workspace / "memory"
|
|
||||||
memory_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
memory_template = templates_dir / "memory" / "MEMORY.md"
|
|
||||||
memory_file = memory_dir / "MEMORY.md"
|
|
||||||
if not memory_file.exists():
|
|
||||||
memory_file.write_text(memory_template.read_text(encoding="utf-8"), encoding="utf-8")
|
|
||||||
console.print(" [dim]Created memory/MEMORY.md[/dim]")
|
|
||||||
|
|
||||||
history_file = memory_dir / "HISTORY.md"
|
|
||||||
if not history_file.exists():
|
|
||||||
history_file.write_text("", encoding="utf-8")
|
|
||||||
console.print(" [dim]Created memory/HISTORY.md[/dim]")
|
|
||||||
|
|
||||||
(workspace / "skills").mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config."""
|
||||||
@@ -294,6 +264,7 @@ def gateway(
|
|||||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
|
sync_workspace_templates(config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
provider = _make_provider(config)
|
||||||
session_manager = SessionManager(config.workspace_path)
|
session_manager = SessionManager(config.workspace_path)
|
||||||
@@ -312,6 +283,7 @@ def gateway(
|
|||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
memory_window=config.agents.defaults.memory_window,
|
||||||
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
@@ -447,6 +419,7 @@ def agent(
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
|
sync_workspace_templates(config.workspace_path)
|
||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
provider = _make_provider(config)
|
||||||
@@ -469,6 +442,7 @@ def agent(
|
|||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
memory_window=config.agents.defaults.memory_window,
|
||||||
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
@@ -960,6 +934,7 @@ def cron_run(
|
|||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
memory_window=config.agents.defaults.memory_window,
|
||||||
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Configuration schema using Pydantic."""
|
"""Configuration schema using Pydantic."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from pydantic.alias_generators import to_camel
|
from pydantic.alias_generators import to_camel
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
@@ -40,6 +42,7 @@ class FeishuConfig(Base):
|
|||||||
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
|
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
|
||||||
verification_token: str = "" # Verification Token for event subscription (optional)
|
verification_token: str = "" # Verification Token for event subscription (optional)
|
||||||
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
|
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
|
||||||
|
react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
||||||
|
|
||||||
|
|
||||||
class DingTalkConfig(Base):
|
class DingTalkConfig(Base):
|
||||||
@@ -61,6 +64,23 @@ class DiscordConfig(Base):
|
|||||||
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
|
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixConfig(Base):
|
||||||
|
"""Matrix (Element) channel configuration."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
homeserver: str = "https://matrix.org"
|
||||||
|
access_token: str = ""
|
||||||
|
user_id: str = "" # @bot:matrix.org
|
||||||
|
device_id: str = ""
|
||||||
|
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
|
||||||
|
sync_stop_grace_seconds: int = 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
|
||||||
|
max_media_bytes: int = 20 * 1024 * 1024 # Max attachment size accepted for Matrix media handling (inbound + outbound).
|
||||||
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
|
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||||
|
group_allow_from: list[str] = Field(default_factory=list)
|
||||||
|
allow_room_mentions: bool = False
|
||||||
|
|
||||||
|
|
||||||
class EmailConfig(Base):
|
class EmailConfig(Base):
|
||||||
"""Email channel configuration (IMAP inbound + SMTP outbound)."""
|
"""Email channel configuration (IMAP inbound + SMTP outbound)."""
|
||||||
|
|
||||||
@@ -164,6 +184,20 @@ class QQConfig(Base):
|
|||||||
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
||||||
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
|
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
|
||||||
|
|
||||||
|
class MatrixConfig(Base):
|
||||||
|
"""Matrix (Element) channel configuration."""
|
||||||
|
enabled: bool = False
|
||||||
|
homeserver: str = "https://matrix.org"
|
||||||
|
access_token: str = ""
|
||||||
|
user_id: str = "" # e.g. @bot:matrix.org
|
||||||
|
device_id: str = ""
|
||||||
|
e2ee_enabled: bool = True # end-to-end encryption support
|
||||||
|
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
|
||||||
|
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
|
||||||
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
|
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||||
|
group_allow_from: list[str] = Field(default_factory=list)
|
||||||
|
allow_room_mentions: bool = False
|
||||||
|
|
||||||
class ChannelsConfig(Base):
|
class ChannelsConfig(Base):
|
||||||
"""Configuration for chat channels."""
|
"""Configuration for chat channels."""
|
||||||
@@ -179,6 +213,7 @@ class ChannelsConfig(Base):
|
|||||||
email: EmailConfig = Field(default_factory=EmailConfig)
|
email: EmailConfig = Field(default_factory=EmailConfig)
|
||||||
slack: SlackConfig = Field(default_factory=SlackConfig)
|
slack: SlackConfig = Field(default_factory=SlackConfig)
|
||||||
qq: QQConfig = Field(default_factory=QQConfig)
|
qq: QQConfig = Field(default_factory=QQConfig)
|
||||||
|
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
|
||||||
|
|
||||||
|
|
||||||
class AgentDefaults(Base):
|
class AgentDefaults(Base):
|
||||||
@@ -186,10 +221,12 @@ class AgentDefaults(Base):
|
|||||||
|
|
||||||
workspace: str = "~/.nanobot/workspace"
|
workspace: str = "~/.nanobot/workspace"
|
||||||
model: str = "anthropic/claude-opus-4-5"
|
model: str = "anthropic/claude-opus-4-5"
|
||||||
|
provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 40
|
max_tool_iterations: int = 40
|
||||||
memory_window: int = 100
|
memory_window: int = 100
|
||||||
|
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||||
|
|
||||||
|
|
||||||
class AgentsConfig(Base):
|
class AgentsConfig(Base):
|
||||||
@@ -260,6 +297,7 @@ class ExecToolConfig(Base):
|
|||||||
"""Shell exec tool configuration."""
|
"""Shell exec tool configuration."""
|
||||||
|
|
||||||
timeout: int = 60
|
timeout: int = 60
|
||||||
|
path_append: str = ""
|
||||||
|
|
||||||
|
|
||||||
class MCPServerConfig(Base):
|
class MCPServerConfig(Base):
|
||||||
@@ -300,6 +338,11 @@ class Config(BaseSettings):
|
|||||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import PROVIDERS
|
||||||
|
|
||||||
|
forced = self.agents.defaults.provider
|
||||||
|
if forced != "auto":
|
||||||
|
p = getattr(self.providers, forced, None)
|
||||||
|
return (p, forced) if p else (None, None)
|
||||||
|
|
||||||
model_lower = (model or self.agents.defaults.model).lower()
|
model_lower = (model or self.agents.defaults.model).lower()
|
||||||
model_normalized = model_lower.replace("-", "_")
|
model_normalized = model_lower.replace("-", "_")
|
||||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ class LLMProvider(ABC):
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request.
|
Send a chat completion request.
|
||||||
|
|||||||
@@ -18,13 +18,16 @@ class CustomProvider(LLMProvider):
|
|||||||
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||||
|
|
||||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None) -> LLMResponse:
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model or self.default_model,
|
"model": model or self.default_model,
|
||||||
"messages": self._sanitize_empty_content(messages),
|
"messages": self._sanitize_empty_content(messages),
|
||||||
"max_tokens": max(1, max_tokens),
|
"max_tokens": max(1, max_tokens),
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
if reasoning_effort:
|
||||||
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
if tools:
|
if tools:
|
||||||
kwargs.update(tools=tools, tool_choice="auto")
|
kwargs.update(tools=tools, tool_choice="auto")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import json
|
import json
|
||||||
import json_repair
|
import json_repair
|
||||||
import os
|
import os
|
||||||
|
import secrets
|
||||||
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
@@ -12,8 +14,14 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
from nanobot.providers.registry import find_by_model, find_gateway
|
||||||
|
|
||||||
|
|
||||||
# Standard OpenAI chat-completion message keys; extras (e.g. reasoning_content) are stripped for strict providers.
|
# Standard OpenAI chat-completion message keys plus reasoning_content for
|
||||||
|
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
|
||||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||||
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
|
def _short_tool_id() -> str:
|
||||||
|
"""Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
||||||
|
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(LLMProvider):
|
class LiteLLMProvider(LLMProvider):
|
||||||
@@ -170,6 +178,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request via LiteLLM.
|
Send a chat completion request via LiteLLM.
|
||||||
@@ -216,6 +225,10 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if self.extra_headers:
|
if self.extra_headers:
|
||||||
kwargs["extra_headers"] = self.extra_headers
|
kwargs["extra_headers"] = self.extra_headers
|
||||||
|
|
||||||
|
if reasoning_effort:
|
||||||
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
kwargs["drop_params"] = True
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
kwargs["tool_choice"] = "auto"
|
kwargs["tool_choice"] = "auto"
|
||||||
@@ -244,7 +257,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
args = json_repair.loads(args)
|
args = json_repair.loads(args)
|
||||||
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=tc.id,
|
id=_short_tool_id(),
|
||||||
name=tc.function.name,
|
name=tc.function.name,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
))
|
))
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
system_prompt, input_items = _convert_messages(messages)
|
system_prompt, input_items = _convert_messages(messages)
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
# OpenAI Codex: uses OAuth, not API key.
|
# OpenAI Codex: uses OAuth, not API key.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openai_codex",
|
name="openai_codex",
|
||||||
keywords=("openai-codex", "codex"),
|
keywords=("openai-codex",),
|
||||||
env_key="", # OAuth-based, no API key
|
env_key="", # OAuth-based, no API key
|
||||||
display_name="OpenAI Codex",
|
display_name="OpenAI Codex",
|
||||||
litellm_prefix="", # Not routed through LiteLLM
|
litellm_prefix="", # Not routed through LiteLLM
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ always: true
|
|||||||
## Structure
|
## Structure
|
||||||
|
|
||||||
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
|
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
|
||||||
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep.
|
- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM].
|
||||||
|
|
||||||
## Search Past Events
|
## Search Past Events
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,6 @@
|
|||||||
|
|
||||||
You are a helpful AI assistant. Be concise, accurate, and friendly.
|
You are a helpful AI assistant. Be concise, accurate, and friendly.
|
||||||
|
|
||||||
## Guidelines
|
|
||||||
|
|
||||||
- Before calling tools, briefly state your intent — but NEVER predict results before receiving them
|
|
||||||
- Use precise tense: "I will run X" before the call, "X returned Y" after
|
|
||||||
- NEVER claim success before a tool result confirms it
|
|
||||||
- Ask for clarification when the request is ambiguous
|
|
||||||
- Remember important information in `memory/MEMORY.md`; past events are logged in `memory/HISTORY.md`
|
|
||||||
|
|
||||||
## Scheduled Reminders
|
## Scheduled Reminders
|
||||||
|
|
||||||
When user asks for a reminder at a specific time, use `exec` to run:
|
When user asks for a reminder at a specific time, use `exec` to run:
|
||||||
|
|||||||
@@ -1,80 +1,67 @@
|
|||||||
"""Utility functions for nanobot."""
|
"""Utility functions for nanobot."""
|
||||||
|
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
def ensure_dir(path: Path) -> Path:
|
def ensure_dir(path: Path) -> Path:
|
||||||
"""Ensure a directory exists, creating it if necessary."""
|
"""Ensure directory exists, return it."""
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def get_data_path() -> Path:
|
def get_data_path() -> Path:
|
||||||
"""Get the nanobot data directory (~/.nanobot)."""
|
"""~/.nanobot data directory."""
|
||||||
return ensure_dir(Path.home() / ".nanobot")
|
return ensure_dir(Path.home() / ".nanobot")
|
||||||
|
|
||||||
|
|
||||||
def get_workspace_path(workspace: str | None = None) -> Path:
|
def get_workspace_path(workspace: str | None = None) -> Path:
|
||||||
"""
|
"""Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace."""
|
||||||
Get the workspace path.
|
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace: Optional workspace path. Defaults to ~/.nanobot/workspace.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Expanded and ensured workspace path.
|
|
||||||
"""
|
|
||||||
if workspace:
|
|
||||||
path = Path(workspace).expanduser()
|
|
||||||
else:
|
|
||||||
path = Path.home() / ".nanobot" / "workspace"
|
|
||||||
return ensure_dir(path)
|
return ensure_dir(path)
|
||||||
|
|
||||||
|
|
||||||
def get_sessions_path() -> Path:
|
|
||||||
"""Get the sessions storage directory."""
|
|
||||||
return ensure_dir(get_data_path() / "sessions")
|
|
||||||
|
|
||||||
|
|
||||||
def get_skills_path(workspace: Path | None = None) -> Path:
|
|
||||||
"""Get the skills directory within the workspace."""
|
|
||||||
ws = workspace or get_workspace_path()
|
|
||||||
return ensure_dir(ws / "skills")
|
|
||||||
|
|
||||||
|
|
||||||
def timestamp() -> str:
|
def timestamp() -> str:
|
||||||
"""Get current timestamp in ISO format."""
|
"""Current ISO timestamp."""
|
||||||
return datetime.now().isoformat()
|
return datetime.now().isoformat()
|
||||||
|
|
||||||
|
|
||||||
def truncate_string(s: str, max_len: int = 100, suffix: str = "...") -> str:
|
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
||||||
"""Truncate a string to max length, adding suffix if truncated."""
|
|
||||||
if len(s) <= max_len:
|
|
||||||
return s
|
|
||||||
return s[: max_len - len(suffix)] + suffix
|
|
||||||
|
|
||||||
|
|
||||||
def safe_filename(name: str) -> str:
|
def safe_filename(name: str) -> str:
|
||||||
"""Convert a string to a safe filename."""
|
"""Replace unsafe path characters with underscores."""
|
||||||
# Replace unsafe characters
|
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||||
unsafe = '<>:"/\\|?*'
|
|
||||||
for char in unsafe:
|
|
||||||
name = name.replace(char, "_")
|
|
||||||
return name.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def parse_session_key(key: str) -> tuple[str, str]:
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||||
"""
|
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||||
Parse a session key into channel and chat_id.
|
from importlib.resources import files as pkg_files
|
||||||
|
try:
|
||||||
|
tpl = pkg_files("nanobot") / "templates"
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
if not tpl.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
Args:
|
added: list[str] = []
|
||||||
key: Session key in format "channel:chat_id"
|
|
||||||
|
|
||||||
Returns:
|
def _write(src, dest: Path):
|
||||||
Tuple of (channel, chat_id)
|
if dest.exists():
|
||||||
"""
|
return
|
||||||
parts = key.split(":", 1)
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if len(parts) != 2:
|
dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
|
||||||
raise ValueError(f"Invalid session key: {key}")
|
added.append(str(dest.relative_to(workspace)))
|
||||||
return parts[0], parts[1]
|
|
||||||
|
for item in tpl.iterdir():
|
||||||
|
if item.name.endswith(".md"):
|
||||||
|
_write(item, workspace / item.name)
|
||||||
|
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
|
||||||
|
_write(None, workspace / "memory" / "HISTORY.md")
|
||||||
|
(workspace / "skills").mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
if added and not silent:
|
||||||
|
from rich.console import Console
|
||||||
|
for name in added:
|
||||||
|
Console().print(f" [dim]Created {name}[/dim]")
|
||||||
|
return added
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "nanobot-ai"
|
name = "nanobot-ai"
|
||||||
version = "0.1.4.post1"
|
version = "0.1.4.post2"
|
||||||
description = "A lightweight personal AI assistant framework"
|
description = "A lightweight personal AI assistant framework"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
@@ -45,6 +45,11 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
matrix = [
|
||||||
|
"matrix-nio[e2e]>=0.25.2",
|
||||||
|
"mistune>=3.0.0,<4.0.0",
|
||||||
|
"nh3>=0.2.17,<1.0.0",
|
||||||
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=9.0.0,<10.0.0",
|
"pytest>=9.0.0,<10.0.0",
|
||||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||||
|
|||||||
@@ -786,10 +786,8 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
|
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
"""/new clears session and returns confirmation."""
|
||||||
) -> None:
|
|
||||||
"""/new should remove lock entry for fully invalidated session key."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@@ -801,7 +799,6 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
)
|
)
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
@@ -811,10 +808,6 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
session.add_message("assistant", f"resp{i}")
|
session.add_message("assistant", f"resp{i}")
|
||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
# Ensure lock exists before /new.
|
|
||||||
_ = loop._get_consolidation_lock(session.key)
|
|
||||||
assert session.key in loop._consolidation_locks
|
|
||||||
|
|
||||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -825,4 +818,4 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert "new session started" in response.content.lower()
|
assert "new session started" in response.content.lower()
|
||||||
assert session.key not in loop._consolidation_locks
|
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
|
|||||||
assert prompt1 == prompt2
|
assert prompt1 == prompt2
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None:
|
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||||
"""Dynamic runtime details should be added at the tail user message, not system."""
|
"""Runtime metadata should be a separate user message before the actual user message."""
|
||||||
workspace = _make_workspace(tmp_path)
|
workspace = _make_workspace(tmp_path)
|
||||||
builder = ContextBuilder(workspace)
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
@@ -54,10 +54,13 @@ def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None:
|
|||||||
assert messages[0]["role"] == "system"
|
assert messages[0]["role"] == "system"
|
||||||
assert "## Current Session" not in messages[0]["content"]
|
assert "## Current Session" not in messages[0]["content"]
|
||||||
|
|
||||||
|
assert messages[-2]["role"] == "user"
|
||||||
|
runtime_content = messages[-2]["content"]
|
||||||
|
assert isinstance(runtime_content, str)
|
||||||
|
assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
|
||||||
|
assert "Current Time:" in runtime_content
|
||||||
|
assert "Channel: cli" in runtime_content
|
||||||
|
assert "Chat ID: direct" in runtime_content
|
||||||
|
|
||||||
assert messages[-1]["role"] == "user"
|
assert messages[-1]["role"] == "user"
|
||||||
user_content = messages[-1]["content"]
|
assert messages[-1]["content"] == "Return exactly: OK"
|
||||||
assert isinstance(user_content, str)
|
|
||||||
assert "Return exactly: OK" in user_content
|
|
||||||
assert "Current Time:" in user_content
|
|
||||||
assert "Channel: cli" in user_content
|
|
||||||
assert "Chat ID: direct" in user_content
|
|
||||||
|
|||||||
@@ -2,34 +2,28 @@ import asyncio
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.heartbeat.service import (
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
HEARTBEAT_OK_TOKEN,
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
HeartbeatService,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_heartbeat_ok_detection() -> None:
|
class DummyProvider:
|
||||||
def is_ok(response: str) -> bool:
|
def __init__(self, responses: list[LLMResponse]):
|
||||||
return HEARTBEAT_OK_TOKEN in response.upper()
|
self._responses = list(responses)
|
||||||
|
|
||||||
assert is_ok("HEARTBEAT_OK")
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
assert is_ok("`HEARTBEAT_OK`")
|
if self._responses:
|
||||||
assert is_ok("**HEARTBEAT_OK**")
|
return self._responses.pop(0)
|
||||||
assert is_ok("heartbeat_ok")
|
return LLMResponse(content="", tool_calls=[])
|
||||||
assert is_ok("HEARTBEAT_OK.")
|
|
||||||
|
|
||||||
assert not is_ok("HEARTBEAT_NOT_OK")
|
|
||||||
assert not is_ok("all good")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_is_idempotent(tmp_path) -> None:
|
async def test_start_is_idempotent(tmp_path) -> None:
|
||||||
async def _on_heartbeat(_: str) -> str:
|
provider = DummyProvider([])
|
||||||
return "HEARTBEAT_OK"
|
|
||||||
|
|
||||||
service = HeartbeatService(
|
service = HeartbeatService(
|
||||||
workspace=tmp_path,
|
workspace=tmp_path,
|
||||||
on_heartbeat=_on_heartbeat,
|
provider=provider,
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
interval_s=9999,
|
interval_s=9999,
|
||||||
enabled=True,
|
enabled=True,
|
||||||
)
|
)
|
||||||
@@ -42,3 +36,82 @@ async def test_start_is_idempotent(tmp_path) -> None:
|
|||||||
|
|
||||||
service.stop()
|
service.stop()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
|
||||||
|
provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
provider=provider,
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
action, tasks = await service._decide("heartbeat content")
|
||||||
|
assert action == "skip"
|
||||||
|
assert tasks == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
|
||||||
|
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||||
|
|
||||||
|
provider = DummyProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="hb_1",
|
||||||
|
name="heartbeat",
|
||||||
|
arguments={"action": "run", "tasks": "check open tasks"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
called_with: list[str] = []
|
||||||
|
|
||||||
|
async def _on_execute(tasks: str) -> str:
|
||||||
|
called_with.append(tasks)
|
||||||
|
return "done"
|
||||||
|
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
provider=provider,
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
on_execute=_on_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.trigger_now()
|
||||||
|
assert result == "done"
|
||||||
|
assert called_with == ["check open tasks"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
||||||
|
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||||
|
|
||||||
|
provider = DummyProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="hb_1",
|
||||||
|
name="heartbeat",
|
||||||
|
arguments={"action": "skip"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
async def _on_execute(tasks: str) -> str:
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
provider=provider,
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
on_execute=_on_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await service.trigger_now() is None
|
||||||
|
|||||||
1302
tests/test_matrix_channel.py
Normal file
1302
tests/test_matrix_channel.py
Normal file
File diff suppressed because it is too large
Load Diff
10
tests/test_message_tool.py
Normal file
10
tests/test_message_tool.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_tool_returns_error_when_no_target_context() -> None:
|
||||||
|
tool = MessageTool()
|
||||||
|
result = await tool.execute(content="test")
|
||||||
|
assert result == "Error: No target channel/chat specified"
|
||||||
103
tests/test_message_tool_suppress.py
Normal file
103
tests/test_message_tool_suppress.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Test message tool suppress logic for final replies."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolSuppressLogic:
|
||||||
|
"""Final reply suppressed only when message tool sends to the same target."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert result is None # suppressed
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert sent[0].channel == "email"
|
||||||
|
assert result is not None # not suppressed
|
||||||
|
assert result.channel == "feishu"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolTurnTracking:
|
||||||
|
|
||||||
|
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||||
|
tool = MessageTool()
|
||||||
|
tool.set_context("feishu", "chat1")
|
||||||
|
assert not tool._sent_in_turn
|
||||||
|
tool._sent_in_turn = True
|
||||||
|
assert tool._sent_in_turn
|
||||||
|
|
||||||
|
def test_start_turn_resets(self) -> None:
|
||||||
|
tool = MessageTool()
|
||||||
|
tool._sent_in_turn = True
|
||||||
|
tool.start_turn()
|
||||||
|
assert not tool._sent_in_turn
|
||||||
167
tests/test_task_cancel.py
Normal file
167
tests/test_task_cancel.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""Tests for /stop task cancellation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop():
|
||||||
|
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
workspace = MagicMock()
|
||||||
|
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||||
|
return loop, bus
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleStop:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_no_active_task(self):
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||||
|
await loop._handle_stop(msg)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "No active task" in out.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_active_task(self):
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow_task():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow_task())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
loop._active_tasks["test:c1"] = [task]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||||
|
await loop._handle_stop(msg)
|
||||||
|
|
||||||
|
assert cancelled.is_set()
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "stopped" in out.content.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_multiple_tasks(self):
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
events = [asyncio.Event(), asyncio.Event()]
|
||||||
|
|
||||||
|
async def slow(idx):
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
events[idx].set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
loop._active_tasks["test:c1"] = tasks
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||||
|
await loop._handle_stop(msg)
|
||||||
|
|
||||||
|
assert all(e.is_set() for e in events)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "2 task" in out.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestDispatch:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_processes_and_publishes(self):
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
|
||||||
|
loop._process_message = AsyncMock(
|
||||||
|
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
|
||||||
|
)
|
||||||
|
await loop._dispatch(msg)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert out.content == "hi"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_processing_lock_serializes(self):
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
loop, bus = _make_loop()
|
||||||
|
order = []
|
||||||
|
|
||||||
|
async def mock_process(m, **kwargs):
|
||||||
|
order.append(f"start-{m.content}")
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
order.append(f"end-{m.content}")
|
||||||
|
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
|
||||||
|
|
||||||
|
loop._process_message = mock_process
|
||||||
|
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
|
||||||
|
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
|
||||||
|
|
||||||
|
t1 = asyncio.create_task(loop._dispatch(msg1))
|
||||||
|
t2 = asyncio.create_task(loop._dispatch(msg2))
|
||||||
|
await asyncio.gather(t1, t2)
|
||||||
|
assert order == ["start-a", "end-a", "start-b", "end-b"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubagentCancellation:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_by_session(self):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
|
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
mgr._running_tasks["sub-1"] = task
|
||||||
|
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||||
|
|
||||||
|
count = await mgr.cancel_by_session("test:c1")
|
||||||
|
assert count == 1
|
||||||
|
assert cancelled.is_set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_by_session_no_tasks(self):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
|
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||||
@@ -2,6 +2,7 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
|
||||||
|
|
||||||
class SampleTool(Tool):
|
class SampleTool(Tool):
|
||||||
@@ -86,3 +87,22 @@ async def test_registry_returns_validation_error() -> None:
|
|||||||
reg.register(SampleTool())
|
reg.register(SampleTool())
|
||||||
result = await reg.execute("sample", {"query": "hi"})
|
result = await reg.execute("sample", {"query": "hi"})
|
||||||
assert "Invalid parameters" in result
|
assert "Invalid parameters" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
|
||||||
|
cmd = r"type C:\user\workspace\txt"
|
||||||
|
paths = ExecTool._extract_absolute_paths(cmd)
|
||||||
|
assert paths == [r"C:\user\workspace\txt"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||||
|
cmd = ".venv/bin/python script.py"
|
||||||
|
paths = ExecTool._extract_absolute_paths(cmd)
|
||||||
|
assert "/bin/python" not in paths
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||||
|
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
||||||
|
paths = ExecTool._extract_absolute_paths(cmd)
|
||||||
|
assert "/tmp/data.txt" in paths
|
||||||
|
assert "/tmp/out.txt" in paths
|
||||||
|
|||||||
Reference in New Issue
Block a user