diff --git a/README.md b/README.md index d2483e4..0d46b7f 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,13 @@ ⚑️ Delivers core agent functionality in just **~4,000** lines of code β€” **99% smaller** than Clawdbot's 430k+ lines. -πŸ“ Real-time line count: **3,955 lines** (run `bash core_agent_lines.sh` to verify anytime) +πŸ“ Real-time line count: **3,935 lines** (run `bash core_agent_lines.sh` to verify anytime) ## πŸ“’ News +- **2026-02-24** πŸš€ Released **v0.1.4.post2** β€” a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details. +- **2026-02-23** πŸ”§ Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes. +- **2026-02-22** πŸ›‘οΈ Slack thread isolation, Discord typing fix, agent reliability improvements. - **2026-02-21** πŸŽ‰ Released **v0.1.4.post1** β€” new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details. - **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood. - **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode. @@ -135,12 +138,13 @@ Add or merge these **two parts** into your config (other options have defaults). } ``` -*Set your model*: +*Set your model* (optionally pin a provider β€” defaults to auto-detection): ```json { "agents": { "defaults": { - "model": "anthropic/claude-opus-4-5" + "model": "anthropic/claude-opus-4-5", + "provider": "openrouter" } } } @@ -305,6 +309,72 @@ nanobot gateway +
+Matrix (Element) + +Install Matrix dependencies first: + +```bash +pip install nanobot-ai[matrix] +``` + +**1. Create/choose a Matrix account** + +- Create or reuse a Matrix account on your homeserver (for example `matrix.org`). +- Confirm you can log in with Element. + +**2. Get credentials** + +- You need: + - `userId` (example: `@nanobot:matrix.org`) + - `accessToken` + - `deviceId` (recommended so sync tokens can be restored across restarts) +- You can obtain these from your homeserver login API (`/_matrix/client/v3/login`) or from your client's advanced session settings. + +**3. Configure** + +```json +{ + "channels": { + "matrix": { + "enabled": true, + "homeserver": "https://matrix.org", + "userId": "@nanobot:matrix.org", + "accessToken": "syt_xxx", + "deviceId": "NANOBOT01", + "e2eeEnabled": true, + "allowFrom": [], + "groupPolicy": "open", + "groupAllowFrom": [], + "allowRoomMentions": false, + "maxMediaBytes": 20971520 + } + } +} +``` + +> Keep a persistent `matrix-store` and stable `deviceId` β€” encrypted session state is lost if these change across restarts. + +| Option | Description | +|--------|-------------| +| `allowFrom` | User IDs allowed to interact. Empty = all senders. | +| `groupPolicy` | `open` (default), `mention`, or `allowlist`. | +| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). | +| `allowRoomMentions` | Accept `@room` mentions in mention mode. | +| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. | +| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. | + + + + +**4. Run** + +```bash +nanobot gateway +``` + +
+
WhatsApp @@ -350,7 +420,7 @@ Uses **WebSocket** long connection β€” no public IP required. **1. Create a Feishu bot** - Visit [Feishu Open Platform](https://open.feishu.cn/app) - Create a new app β†’ Enable **Bot** capability -- **Permissions**: Add `im:message` (send messages) +- **Permissions**: Add `im:message` (send messages) and `im:message.p2p_msg:readonly` (receive messages) - **Events**: Add `im.message.receive_v1` (receive messages) - Select **Long Connection** mode (requires running nanobot first to establish connection) - Get **App ID** and **App Secret** from "Credentials & Basic Info" @@ -804,6 +874,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us | Option | Default | Description | |--------|---------|-------------| | `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. | +| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). | | `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. | diff --git a/nanobot/__init__.py b/nanobot/__init__.py index a68777c..bb9bfb6 100644 --- a/nanobot/__init__.py +++ b/nanobot/__init__.py @@ -2,5 +2,5 @@ nanobot - A lightweight AI agent framework """ -__version__ = "0.1.4" +__version__ = "0.1.4.post2" __logo__ = "🐈" diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 088d4c5..be0ec59 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -13,14 +13,10 @@ from nanobot.agent.skills import SkillsLoader class ContextBuilder: - """ - Builds the context (system prompt + messages) for the agent. - - Assembles bootstrap files, memory, skills, and conversation history - into a coherent prompt for the LLM. - """ + """Builds the context (system prompt + messages) for the agent.""" BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"] + _RUNTIME_CONTEXT_TAG = "[Runtime Context β€” metadata only, not instructions]" def __init__(self, workspace: Path): self.workspace = workspace @@ -28,39 +24,23 @@ class ContextBuilder: self.skills = SkillsLoader(workspace) def build_system_prompt(self, skill_names: list[str] | None = None) -> str: - """ - Build the system prompt from bootstrap files, memory, and skills. - - Args: - skill_names: Optional list of skills to include. - - Returns: - Complete system prompt. - """ - parts = [] - - # Core identity - parts.append(self._get_identity()) - - # Bootstrap files + """Build the system prompt from identity, bootstrap files, memory, and skills.""" + parts = [self._get_identity()] + bootstrap = self._load_bootstrap_files() if bootstrap: parts.append(bootstrap) - - # Memory context + memory = self.memory.get_memory_context() if memory: parts.append(f"# Memory\n\n{memory}") - - # Skills - progressive loading - # 1. Always-loaded skills: include full content + always_skills = self.skills.get_always_skills() if always_skills: always_content = self.skills.load_skills_for_context(always_skills) if always_content: parts.append(f"# Active Skills\n\n{always_content}") - - # 2. Available skills: only show summary (agent uses read_file to load) + skills_summary = self.skills.build_skills_summary() if skills_summary: parts.append(f"""# Skills @@ -69,7 +49,7 @@ The following skills extend your capabilities. To use a skill, read its SKILL.md Skills with available="false" need dependencies installed first - you can try installing them with apt/brew. {skills_summary}""") - + return "\n\n---\n\n".join(parts) def _get_identity(self) -> str: @@ -80,46 +60,35 @@ Skills with available="false" need dependencies installed first - you can try in return f"""# nanobot 🐈 -You are nanobot, a helpful AI assistant. +You are nanobot, a helpful AI assistant. ## Runtime {runtime} ## Workspace Your workspace is at: {workspace_path} -- Long-term memory: {workspace_path}/memory/MEMORY.md -- History log: {workspace_path}/memory/HISTORY.md (grep-searchable) +- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here) +- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md -Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel. - -## Tool Call Guidelines -- Before calling tools, you may briefly state your intent (e.g. "Let me check that"), but NEVER predict or describe the expected result before receiving it. -- Before modifying a file, read it first to confirm its current content. -- Do not assume a file or directory exists β€” use list_dir or read_file to verify. +## nanobot Guidelines +- State intent before tool calls, but NEVER predict or claim results before receiving them. +- Before modifying a file, read it first. Do not assume files or directories exist. - After writing or editing a file, re-read it if accuracy matters. - If a tool call fails, analyze the error before retrying with a different approach. +- Ask for clarification when the request is ambiguous. -## Memory -- Remember important facts: write to {workspace_path}/memory/MEMORY.md -- Recall past events: grep {workspace_path}/memory/HISTORY.md""" +Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel.""" @staticmethod - def _inject_runtime_context( - user_content: str | list[dict[str, Any]], - channel: str | None, - chat_id: str | None, - ) -> str | list[dict[str, Any]]: - """Append dynamic runtime context to the tail of the user message.""" + def _build_runtime_context(channel: str | None, chat_id: str | None) -> str: + """Build untrusted runtime metadata block for injection before the user message.""" now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") tz = time.strftime("%Z") or "UTC" lines = [f"Current Time: {now} ({tz})"] if channel and chat_id: lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] - block = "[Runtime Context]\n" + "\n".join(lines) - if isinstance(user_content, str): - return f"{user_content}\n\n{block}" - return [*user_content, {"type": "text", "text": block}] + return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) def _load_bootstrap_files(self) -> str: """Load all bootstrap files from workspace.""" @@ -142,35 +111,13 @@ Reply directly with text for conversations. Only use the 'message' tool to send channel: str | None = None, chat_id: str | None = None, ) -> list[dict[str, Any]]: - """ - Build the complete message list for an LLM call. - - Args: - history: Previous conversation messages. - current_message: The new user message. - skill_names: Optional skills to include. - media: Optional list of local file paths for images/media. - channel: Current channel (telegram, feishu, etc.). - chat_id: Current chat/user ID. - - Returns: - List of messages including system prompt. - """ - messages = [] - - # System prompt - system_prompt = self.build_system_prompt(skill_names) - messages.append({"role": "system", "content": system_prompt}) - - # History - messages.extend(history) - - # Current message (with optional image attachments) - user_content = self._build_user_content(current_message, media) - user_content = self._inject_runtime_context(user_content, channel, chat_id) - messages.append({"role": "user", "content": user_content}) - - return messages + """Build the complete message list for an LLM call.""" + return [ + {"role": "system", "content": self.build_system_prompt(skill_names)}, + *history, + {"role": "user", "content": self._build_runtime_context(channel, chat_id)}, + {"role": "user", "content": self._build_user_content(current_message, media)}, + ] def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: """Build user message content with optional base64-encoded images.""" @@ -191,63 +138,24 @@ Reply directly with text for conversations. Only use the 'message' tool to send return images + [{"type": "text", "text": text}] def add_tool_result( - self, - messages: list[dict[str, Any]], - tool_call_id: str, - tool_name: str, - result: str + self, messages: list[dict[str, Any]], + tool_call_id: str, tool_name: str, result: str, ) -> list[dict[str, Any]]: - """ - Add a tool result to the message list. - - Args: - messages: Current message list. - tool_call_id: ID of the tool call. - tool_name: Name of the tool. - result: Tool execution result. - - Returns: - Updated message list. - """ - messages.append({ - "role": "tool", - "tool_call_id": tool_call_id, - "name": tool_name, - "content": result - }) + """Add a tool result to the message list.""" + messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result}) return messages def add_assistant_message( - self, - messages: list[dict[str, Any]], + self, messages: list[dict[str, Any]], content: str | None, tool_calls: list[dict[str, Any]] | None = None, reasoning_content: str | None = None, ) -> list[dict[str, Any]]: - """ - Add an assistant message to the message list. - - Args: - messages: Current message list. - content: Message content. - tool_calls: Optional tool calls. - reasoning_content: Thinking output (Kimi, DeepSeek-R1, etc.). - - Returns: - Updated message list. - """ - msg: dict[str, Any] = {"role": "assistant"} - - # Always include content β€” some providers (e.g. StepFun) reject - # assistant messages that omit the key entirely. - msg["content"] = content - + """Add an assistant message to the message list.""" + msg: dict[str, Any] = {"role": "assistant", "content": content} if tool_calls: msg["tool_calls"] = tool_calls - - # Include reasoning content when provided (required by some thinking models) if reasoning_content is not None: msg["reasoning_content"] = reasoning_content - messages.append(msg) return messages diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 8be8e51..b42c3ba 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import json import re +import weakref from contextlib import AsyncExitStack from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable @@ -43,6 +44,8 @@ class AgentLoop: 5. Sends responses back """ + _TOOL_RESULT_MAX_CHARS = 500 + def __init__( self, bus: MessageBus, @@ -53,6 +56,7 @@ class AgentLoop: temperature: float = 0.1, max_tokens: int = 4096, memory_window: int = 100, + reasoning_effort: str | None = None, brave_api_key: str | None = None, exec_config: ExecToolConfig | None = None, cron_service: CronService | None = None, @@ -71,6 +75,7 @@ class AgentLoop: self.temperature = temperature self.max_tokens = max_tokens self.memory_window = memory_window + self.reasoning_effort = reasoning_effort self.brave_api_key = brave_api_key self.exec_config = exec_config or ExecToolConfig() self.cron_service = cron_service @@ -86,6 +91,7 @@ class AgentLoop: model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, + reasoning_effort=reasoning_effort, brave_api_key=brave_api_key, exec_config=self.exec_config, restrict_to_workspace=restrict_to_workspace, @@ -98,7 +104,9 @@ class AgentLoop: self._mcp_connecting = False self._consolidating: set[str] = set() # Session keys with consolidation in progress self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks - self._consolidation_locks: dict[str, asyncio.Lock] = {} + self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() + self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks + self._processing_lock = asyncio.Lock() self._register_default_tools() def _register_default_tools(self) -> None: @@ -110,6 +118,7 @@ class AgentLoop: working_dir=str(self.workspace), timeout=self.exec_config.timeout, restrict_to_workspace=self.restrict_to_workspace, + path_append=self.exec_config.path_append, )) self.tools.register(WebSearchTool(api_key=self.brave_api_key)) self.tools.register(WebFetchTool()) @@ -142,17 +151,10 @@ class AgentLoop: def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Update context for all tools that need routing info.""" - if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool): - message_tool.set_context(channel, chat_id, message_id) - - if spawn_tool := self.tools.get("spawn"): - if isinstance(spawn_tool, SpawnTool): - spawn_tool.set_context(channel, chat_id) - - if cron_tool := self.tools.get("cron"): - if isinstance(cron_tool, CronTool): - cron_tool.set_context(channel, chat_id) + for name in ("message", "spawn", "cron"): + if tool := self.tools.get(name): + if hasattr(tool, "set_context"): + tool.set_context(channel, chat_id, *([message_id] if name == "message" else [])) @staticmethod def _strip_think(text: str | None) -> str | None: @@ -165,7 +167,8 @@ class AgentLoop: def _tool_hint(tool_calls: list) -> str: """Format tool calls as concise hint, e.g. 'web_search("query")'.""" def _fmt(tc): - val = next(iter(tc.arguments.values()), None) if tc.arguments else None + args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {} + val = next(iter(args.values()), None) if isinstance(args, dict) else None if not isinstance(val, str): return tc.name return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")' @@ -191,6 +194,7 @@ class AgentLoop: model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, + reasoning_effort=self.reasoning_effort, ) if response.has_tool_calls: @@ -225,7 +229,17 @@ class AgentLoop: messages, tool_call.id, tool_call.name, result ) else: - final_content = self._strip_think(response.content) + clean = self._strip_think(response.content) + # Don't persist error responses to session history β€” they can + # poison the context and cause permanent 400 loops (#1303). + if response.finish_reason == "error": + logger.error("LLM returned error: {}", (clean or "")[:200]) + final_content = clean or "Sorry, I encountered an error calling the AI model." + break + messages = self.context.add_assistant_message( + messages, clean, reasoning_content=response.reasoning_content, + ) + final_content = clean break if final_content is None and iteration >= self.max_iterations: @@ -238,35 +252,62 @@ class AgentLoop: return final_content, tools_used, messages async def run(self) -> None: - """Run the agent loop, processing messages from the bus.""" + """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" self._running = True await self._connect_mcp() logger.info("Agent loop started") while self._running: try: - msg = await asyncio.wait_for( - self.bus.consume_inbound(), - timeout=1.0 - ) - try: - response = await self._process_message(msg) - if response is not None: - await self.bus.publish_outbound(response) - elif msg.channel == "cli": - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, content="", metadata=msg.metadata or {}, - )) - except Exception as e: - logger.error("Error processing message: {}", e) - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content=f"Sorry, I encountered an error: {str(e)}" - )) + msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0) except asyncio.TimeoutError: continue + if msg.content.strip().lower() == "/stop": + await self._handle_stop(msg) + else: + task = asyncio.create_task(self._dispatch(msg)) + self._active_tasks.setdefault(msg.session_key, []).append(task) + task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) + + async def _handle_stop(self, msg: InboundMessage) -> None: + """Cancel all active tasks and subagents for the session.""" + tasks = self._active_tasks.pop(msg.session_key, []) + cancelled = sum(1 for t in tasks if not t.done() and t.cancel()) + for t in tasks: + try: + await t + except (asyncio.CancelledError, Exception): + pass + sub_cancelled = await self.subagents.cancel_by_session(msg.session_key) + total = cancelled + sub_cancelled + content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop." + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content=content, + )) + + async def _dispatch(self, msg: InboundMessage) -> None: + """Process a message under the global lock.""" + async with self._processing_lock: + try: + response = await self._process_message(msg) + if response is not None: + await self.bus.publish_outbound(response) + elif msg.channel == "cli": + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="", metadata=msg.metadata or {}, + )) + except asyncio.CancelledError: + logger.info("Task cancelled for session {}", msg.session_key) + raise + except Exception: + logger.exception("Error processing message for session {}", msg.session_key) + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="Sorry, I encountered an error.", + )) + async def close_mcp(self) -> None: """Close MCP connections.""" if self._mcp_stack: @@ -281,18 +322,6 @@ class AgentLoop: self._running = False logger.info("Agent loop stopping") - def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock: - lock = self._consolidation_locks.get(session_key) - if lock is None: - lock = asyncio.Lock() - self._consolidation_locks[session_key] = lock - return lock - - def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None: - """Drop lock entry if no longer in use.""" - if not lock.locked(): - self._consolidation_locks.pop(session_key, None) - async def _process_message( self, msg: InboundMessage, @@ -328,7 +357,7 @@ class AgentLoop: # Slash commands cmd = msg.content.strip().lower() if cmd == "/new": - lock = self._get_consolidation_lock(session.key) + lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) self._consolidating.add(session.key) try: async with lock: @@ -349,7 +378,6 @@ class AgentLoop: ) finally: self._consolidating.discard(session.key) - self._prune_consolidation_lock(session.key, lock) session.clear() self.sessions.save(session) @@ -358,12 +386,12 @@ class AgentLoop: content="New session started.") if cmd == "/help": return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, - content="🐈 nanobot commands:\n/new β€” Start a new conversation\n/help β€” Show available commands") + content="🐈 nanobot commands:\n/new β€” Start a new conversation\n/stop β€” Stop the current task\n/help β€” Show available commands") unconsolidated = len(session.messages) - session.last_consolidated if (unconsolidated >= self.memory_window and session.key not in self._consolidating): self._consolidating.add(session.key) - lock = self._get_consolidation_lock(session.key) + lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) async def _consolidate_and_unlock(): try: @@ -371,7 +399,6 @@ class AgentLoop: await self._consolidate_memory(session) finally: self._consolidating.discard(session.key) - self._prune_consolidation_lock(session.key, lock) _task = asyncio.current_task() if _task is not None: self._consolidation_tasks.discard(_task) @@ -407,32 +434,39 @@ class AgentLoop: if final_content is None: final_content = "I've completed processing but have no response to give." - preview = final_content[:120] + "..." if len(final_content) > 120 else final_content - logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) - self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: - return None + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + return None + preview = final_content[:120] + "..." if len(final_content) > 120 else final_content + logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=final_content, metadata=msg.metadata or {}, ) - _TOOL_RESULT_MAX_CHARS = 500 - def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: """Save new-turn messages into session, truncating large tool results.""" from datetime import datetime for m in messages[skip:]: entry = {k: v for k, v in m.items() if k != "reasoning_content"} - if entry.get("role") == "tool" and isinstance(entry.get("content"), str): - content = entry["content"] - if len(content) > self._TOOL_RESULT_MAX_CHARS: - entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + role, content = entry.get("role"), entry.get("content") + if role == "assistant" and not content and not entry.get("tool_calls"): + continue # skip empty assistant messages β€” they poison session context + if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS: + entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + elif role == "user": + if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG): + continue + if isinstance(content, list): + entry["content"] = [ + {"type": "text", "text": "[image]"} if ( + c.get("type") == "image_url" + and c.get("image_url", {}).get("url", "").startswith("data:image/") + ) else c for c in content + ] entry.setdefault("timestamp", datetime.now().isoformat()) session.messages.append(entry) session.updated_at = datetime.now() diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index d87c61a..a99ba4d 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -18,13 +18,7 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool class SubagentManager: - """ - Manages background subagent execution. - - Subagents are lightweight agent instances that run in the background - to handle specific tasks. They share the same LLM provider but have - isolated context and a focused system prompt. - """ + """Manages background subagent execution.""" def __init__( self, @@ -34,6 +28,7 @@ class SubagentManager: model: str | None = None, temperature: float = 0.7, max_tokens: int = 4096, + reasoning_effort: str | None = None, brave_api_key: str | None = None, exec_config: "ExecToolConfig | None" = None, restrict_to_workspace: bool = False, @@ -45,10 +40,12 @@ class SubagentManager: self.model = model or provider.get_default_model() self.temperature = temperature self.max_tokens = max_tokens + self.reasoning_effort = reasoning_effort self.brave_api_key = brave_api_key self.exec_config = exec_config or ExecToolConfig() self.restrict_to_workspace = restrict_to_workspace self._running_tasks: dict[str, asyncio.Task[None]] = {} + self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...} async def spawn( self, @@ -56,35 +53,28 @@ class SubagentManager: label: str | None = None, origin_channel: str = "cli", origin_chat_id: str = "direct", + session_key: str | None = None, ) -> str: - """ - Spawn a subagent to execute a task in the background. - - Args: - task: The task description for the subagent. - label: Optional human-readable label for the task. - origin_channel: The channel to announce results to. - origin_chat_id: The chat ID to announce results to. - - Returns: - Status message indicating the subagent was started. - """ + """Spawn a subagent to execute a task in the background.""" task_id = str(uuid.uuid4())[:8] display_label = label or task[:30] + ("..." if len(task) > 30 else "") - - origin = { - "channel": origin_channel, - "chat_id": origin_chat_id, - } - - # Create background task + origin = {"channel": origin_channel, "chat_id": origin_chat_id} + bg_task = asyncio.create_task( self._run_subagent(task_id, task, display_label, origin) ) self._running_tasks[task_id] = bg_task - - # Cleanup when done - bg_task.add_done_callback(lambda _: self._running_tasks.pop(task_id, None)) + if session_key: + self._session_tasks.setdefault(session_key, set()).add(task_id) + + def _cleanup(_: asyncio.Task) -> None: + self._running_tasks.pop(task_id, None) + if session_key and (ids := self._session_tasks.get(session_key)): + ids.discard(task_id) + if not ids: + del self._session_tasks[session_key] + + bg_task.add_done_callback(_cleanup) logger.info("Spawned subagent [{}]: {}", task_id, display_label) return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes." @@ -111,12 +101,12 @@ class SubagentManager: working_dir=str(self.workspace), timeout=self.exec_config.timeout, restrict_to_workspace=self.restrict_to_workspace, + path_append=self.exec_config.path_append, )) tools.register(WebSearchTool(api_key=self.brave_api_key)) tools.register(WebFetchTool()) - # Build messages with subagent-specific prompt - system_prompt = self._build_subagent_prompt(task) + system_prompt = self._build_subagent_prompt() messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": task}, @@ -136,6 +126,7 @@ class SubagentManager: model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, + reasoning_effort=self.reasoning_effort, ) if response.has_tool_calls: @@ -215,43 +206,38 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men await self.bus.publish_inbound(msg) logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id']) - def _build_subagent_prompt(self, task: str) -> str: + def _build_subagent_prompt(self) -> str: """Build a focused system prompt for the subagent.""" - from datetime import datetime - import time as _time - now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)") - tz = _time.strftime("%Z") or "UTC" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.skills import SkillsLoader - return f"""# Subagent + time_ctx = ContextBuilder._build_runtime_context(None, None) + parts = [f"""# Subagent -## Current Time -{now} ({tz}) +{time_ctx} You are a subagent spawned by the main agent to complete a specific task. - -## Rules -1. Stay focused - complete only the assigned task, nothing else -2. Your final response will be reported back to the main agent -3. Do not initiate conversations or take on side tasks -4. Be concise but informative in your findings - -## What You Can Do -- Read and write files in the workspace -- Execute shell commands -- Search the web and fetch web pages -- Complete the task thoroughly - -## What You Cannot Do -- Send messages directly to users (no message tool available) -- Spawn other subagents -- Access the main agent's conversation history +Stay focused on the assigned task. Your final response will be reported back to the main agent. ## Workspace -Your workspace is at: {self.workspace} -Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed) +{self.workspace}"""] -When you have completed the task, provide a clear summary of your findings or actions.""" + skills_summary = SkillsLoader(self.workspace).build_skills_summary() + if skills_summary: + parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}") + + return "\n\n".join(parts) + async def cancel_by_session(self, session_key: str) -> int: + """Cancel all subagents for the given session. Returns count cancelled.""" + tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, []) + if tid in self._running_tasks and not self._running_tasks[tid].done()] + for t in tasks: + t.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + return len(tasks) + def get_running_count(self) -> int: """Return the number of currently running subagents.""" return len(self._running_tasks) diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 40e76e3..35e519a 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -101,7 +101,8 @@ class MessageTool(Tool): try: await self._send_callback(msg) - self._sent_in_turn = True + if channel == self._default_channel and chat_id == self._default_chat_id: + self._sent_in_turn = True media_info = f" with {len(media)} attachments" if media else "" return f"Message sent to {channel}:{chat_id}{media_info}" except Exception as e: diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index e3592a7..6b57874 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -19,6 +19,7 @@ class ExecTool(Tool): deny_patterns: list[str] | None = None, allow_patterns: list[str] | None = None, restrict_to_workspace: bool = False, + path_append: str = "", ): self.timeout = timeout self.working_dir = working_dir @@ -35,6 +36,7 @@ class ExecTool(Tool): ] self.allow_patterns = allow_patterns or [] self.restrict_to_workspace = restrict_to_workspace + self.path_append = path_append @property def name(self) -> str: @@ -67,12 +69,17 @@ class ExecTool(Tool): if guard_error: return guard_error + env = os.environ.copy() + if self.path_append: + env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append + try: process = await asyncio.create_subprocess_shell( command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=cwd, + env=env, ) try: @@ -134,13 +141,7 @@ class ExecTool(Tool): cwd_path = Path(cwd).resolve() - win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd) - # Only match absolute paths β€” avoid false positives on relative - # paths like ".venv/bin/python" where "/bin/python" would be - # incorrectly extracted by the old pattern. - posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd) - - for raw in win_paths + posix_paths: + for raw in self._extract_absolute_paths(cmd): try: p = Path(raw.strip()).resolve() except Exception: @@ -149,3 +150,9 @@ class ExecTool(Tool): return "Error: Command blocked by safety guard (path outside working dir)" return None + + @staticmethod + def _extract_absolute_paths(command: str) -> list[str]: + win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\... + posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only + return win_paths + posix_paths diff --git a/nanobot/agent/tools/spawn.py b/nanobot/agent/tools/spawn.py index 33cf8e7..fb816ca 100644 --- a/nanobot/agent/tools/spawn.py +++ b/nanobot/agent/tools/spawn.py @@ -15,11 +15,13 @@ class SpawnTool(Tool): self._manager = manager self._origin_channel = "cli" self._origin_chat_id = "direct" + self._session_key = "cli:direct" def set_context(self, channel: str, chat_id: str) -> None: """Set the origin context for subagent announcements.""" self._origin_channel = channel self._origin_chat_id = chat_id + self._session_key = f"{channel}:{chat_id}" @property def name(self) -> str: @@ -57,4 +59,5 @@ class SpawnTool(Tool): label=label, origin_channel=self._origin_channel, origin_chat_id=self._origin_chat_id, + session_key=self._session_key, ) diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 56956c3..7860f12 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -80,7 +80,7 @@ class WebSearchTool(Tool): r = await client.get( "https://api.search.brave.com/res/v1/web/search", params={"q": query, "count": n}, - headers={"Accept": "application/json", "X-Subscription-Token": api_key}, + headers={"Accept": "application/json", "X-Subscription-Token": self.api_key}, timeout=10.0 ) r.raise_for_status() diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 09c7714..2797029 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -2,8 +2,12 @@ import asyncio import json +import mimetypes +import os import time +from pathlib import Path from typing import Any +from urllib.parse import unquote, urlparse from loguru import logger import httpx @@ -96,6 +100,9 @@ class DingTalkChannel(BaseChannel): """ name = "dingtalk" + _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} + _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"} + _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"} def __init__(self, config: DingTalkConfig, bus: MessageBus): super().__init__(config, bus) @@ -191,40 +198,224 @@ class DingTalkChannel(BaseChannel): logger.error("Failed to get DingTalk access token: {}", e) return None + @staticmethod + def _is_http_url(value: str) -> bool: + return urlparse(value).scheme in ("http", "https") + + def _guess_upload_type(self, media_ref: str) -> str: + ext = Path(urlparse(media_ref).path).suffix.lower() + if ext in self._IMAGE_EXTS: return "image" + if ext in self._AUDIO_EXTS: return "voice" + if ext in self._VIDEO_EXTS: return "video" + return "file" + + def _guess_filename(self, media_ref: str, upload_type: str) -> str: + name = os.path.basename(urlparse(media_ref).path) + return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin") + + async def _read_media_bytes( + self, + media_ref: str, + ) -> tuple[bytes | None, str | None, str | None]: + if not media_ref: + return None, None, None + + if self._is_http_url(media_ref): + if not self._http: + return None, None, None + try: + resp = await self._http.get(media_ref, follow_redirects=True) + if resp.status_code >= 400: + logger.warning( + "DingTalk media download failed status={} ref={}", + resp.status_code, + media_ref, + ) + return None, None, None + content_type = (resp.headers.get("content-type") or "").split(";")[0].strip() + filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref)) + return resp.content, filename, content_type or None + except Exception as e: + logger.error("DingTalk media download error ref={} err={}", media_ref, e) + return None, None, None + + try: + if media_ref.startswith("file://"): + parsed = urlparse(media_ref) + local_path = Path(unquote(parsed.path)) + else: + local_path = Path(os.path.expanduser(media_ref)) + if not local_path.is_file(): + logger.warning("DingTalk media file not found: {}", local_path) + return None, None, None + data = await asyncio.to_thread(local_path.read_bytes) + content_type = mimetypes.guess_type(local_path.name)[0] + return data, local_path.name, content_type + except Exception as e: + logger.error("DingTalk media read error ref={} err={}", media_ref, e) + return None, None, None + + async def _upload_media( + self, + token: str, + data: bytes, + media_type: str, + filename: str, + content_type: str | None, + ) -> str | None: + if not self._http: + return None + url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}" + mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" + files = {"media": (filename, data, mime)} + + try: + resp = await self._http.post(url, files=files) + text = resp.text + result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {} + if resp.status_code >= 400: + logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500]) + return None + errcode = result.get("errcode", 0) + if errcode != 0: + logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500]) + return None + sub = result.get("result") or {} + media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId") + if not media_id: + logger.error("DingTalk media upload missing media_id body={}", text[:500]) + return None + return str(media_id) + except Exception as e: + logger.error("DingTalk media upload error type={} err={}", media_type, e) + return None + + async def _send_batch_message( + self, + token: str, + chat_id: str, + msg_key: str, + msg_param: dict[str, Any], + ) -> bool: + if not self._http: + logger.warning("DingTalk HTTP client not initialized, cannot send") + return False + + url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend" + headers = {"x-acs-dingtalk-access-token": token} + payload = { + "robotCode": self.config.client_id, + "userIds": [chat_id], + "msgKey": msg_key, + "msgParam": json.dumps(msg_param, ensure_ascii=False), + } + + try: + resp = await self._http.post(url, json=payload, headers=headers) + body = resp.text + if resp.status_code != 200: + logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500]) + return False + try: result = resp.json() + except Exception: result = {} + errcode = result.get("errcode") + if errcode not in (None, 0): + logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500]) + return False + logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key) + return True + except Exception as e: + logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e) + return False + + async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool: + return await self._send_batch_message( + token, + chat_id, + "sampleMarkdown", + {"text": content, "title": "Nanobot Reply"}, + ) + + async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool: + media_ref = (media_ref or "").strip() + if not media_ref: + return True + + upload_type = self._guess_upload_type(media_ref) + if upload_type == "image" and self._is_http_url(media_ref): + ok = await self._send_batch_message( + token, + chat_id, + "sampleImageMsg", + {"photoURL": media_ref}, + ) + if ok: + return True + logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref) + + data, filename, content_type = await self._read_media_bytes(media_ref) + if not data: + logger.error("DingTalk media read failed: {}", media_ref) + return False + + filename = filename or self._guess_filename(media_ref, upload_type) + file_type = Path(filename).suffix.lower().lstrip(".") + if not file_type: + guessed = mimetypes.guess_extension(content_type or "") + file_type = (guessed or ".bin").lstrip(".") + if file_type == "jpeg": + file_type = "jpg" + + media_id = await self._upload_media( + token=token, + data=data, + media_type=upload_type, + filename=filename, + content_type=content_type, + ) + if not media_id: + return False + + if upload_type == "image": + # Verified in production: sampleImageMsg accepts media_id in photoURL. + ok = await self._send_batch_message( + token, + chat_id, + "sampleImageMsg", + {"photoURL": media_id}, + ) + if ok: + return True + logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref) + + return await self._send_batch_message( + token, + chat_id, + "sampleFile", + {"mediaId": media_id, "fileName": filename, "fileType": file_type}, + ) + async def send(self, msg: OutboundMessage) -> None: """Send a message through DingTalk.""" token = await self._get_access_token() if not token: return - # oToMessages/batchSend: sends to individual users (private chat) - # https://open.dingtalk.com/document/orgapp/robot-batch-send-messages - url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend" + if msg.content and msg.content.strip(): + await self._send_markdown_text(token, msg.chat_id, msg.content.strip()) - headers = {"x-acs-dingtalk-access-token": token} - - data = { - "robotCode": self.config.client_id, - "userIds": [msg.chat_id], # chat_id is the user's staffId - "msgKey": "sampleMarkdown", - "msgParam": json.dumps({ - "text": msg.content, - "title": "Nanobot Reply", - }, ensure_ascii=False), - } - - if not self._http: - logger.warning("DingTalk HTTP client not initialized, cannot send") - return - - try: - resp = await self._http.post(url, json=data, headers=headers) - if resp.status_code != 200: - logger.error("DingTalk send failed: {}", resp.text) - else: - logger.debug("DingTalk message sent to {}", msg.chat_id) - except Exception as e: - logger.error("Error sending DingTalk message: {}", e) + for media_ref in msg.media or []: + ok = await self._send_media_ref(token, msg.chat_id, media_ref) + if ok: + continue + logger.error("DingTalk media send failed for {}", media_ref) + # Send visible fallback so failures are observable by the user. + filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref)) + await self._send_markdown_text( + token, + msg.chat_id, + f"[Attachment send failed: {filename}]", + ) async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None: """Handle incoming message (called by NanobotDingTalkHandler). diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 480bf7b..c632fb7 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -89,8 +89,9 @@ def _extract_interactive_content(content: dict) -> list[str]: elif isinstance(title, str): parts.append(f"title: {title}") - for element in content.get("elements", []) if isinstance(content.get("elements"), list) else []: - parts.extend(_extract_element_content(element)) + for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []: + for element in elements: + parts.extend(_extract_element_content(element)) card = content.get("card", {}) if card: @@ -325,13 +326,14 @@ class FeishuChannel(BaseChannel): await asyncio.sleep(1) async def stop(self) -> None: - """Stop the Feishu bot.""" + """ + Stop the Feishu bot. + + Notice: lark.ws.Client does not expose stop method, simply exiting the program will close the client. + + Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86 + """ self._running = False - if self._ws_client: - try: - self._ws_client.stop() - except Exception as e: - logger.warning("Error stopping WebSocket client: {}", e) logger.info("Feishu bot stopped") def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: @@ -692,7 +694,7 @@ class FeishuChannel(BaseChannel): msg_type = message.message_type # Add reaction - await self._add_reaction(message_id, "THUMBSUP") + await self._add_reaction(message_id, self.config.react_emoji) # Parse content content_parts = [] diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 77b7294..c8df6b2 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -136,6 +136,18 @@ class ChannelManager: logger.info("QQ channel enabled") except ImportError as e: logger.warning("QQ channel not available: {}", e) + + # Matrix channel + if self.config.channels.matrix.enabled: + try: + from nanobot.channels.matrix import MatrixChannel + self.channels["matrix"] = MatrixChannel( + self.config.channels.matrix, + self.bus, + ) + logger.info("Matrix channel enabled") + except ImportError as e: + logger.warning("Matrix channel not available: {}", e) async def _start_channel(self, name: str, channel: BaseChannel) -> None: """Start a channel and log any exceptions.""" diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py new file mode 100644 index 0000000..21192e9 --- /dev/null +++ b/nanobot/channels/matrix.py @@ -0,0 +1,682 @@ +"""Matrix (Element) channel β€” inbound sync + outbound message/media delivery.""" + +import asyncio +import logging +import mimetypes +from pathlib import Path +from typing import Any, TypeAlias + +from loguru import logger + +try: + import nh3 + from mistune import create_markdown + from nio import ( + AsyncClient, AsyncClientConfig, ContentRepositoryConfigError, + DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse, + RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText, + RoomSendError, RoomTypingError, SyncError, UploadError, + ) + from nio.crypto.attachments import decrypt_attachment + from nio.exceptions import EncryptionError +except ImportError as e: + raise ImportError( + "Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]" + ) from e + +from nanobot.bus.events import OutboundMessage +from nanobot.channels.base import BaseChannel +from nanobot.config.loader import get_data_dir +from nanobot.utils.helpers import safe_filename + +TYPING_NOTICE_TIMEOUT_MS = 30_000 +# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing. +TYPING_KEEPALIVE_INTERVAL_MS = 20_000 +MATRIX_HTML_FORMAT = "org.matrix.custom.html" +_ATTACH_MARKER = "[attachment: {}]" +_ATTACH_TOO_LARGE = "[attachment: {} - too large]" +_ATTACH_FAILED = "[attachment: {} - download failed]" +_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]" +_DEFAULT_ATTACH_NAME = "attachment" +_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"} + +MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia) +MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia + +MATRIX_MARKDOWN = create_markdown( + escape=True, + plugins=["table", "strikethrough", "url", "superscript", "subscript"], +) + +MATRIX_ALLOWED_HTML_TAGS = { + "p", "a", "strong", "em", "del", "code", "pre", "blockquote", + "ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6", + "hr", "br", "table", "thead", "tbody", "tr", "th", "td", + "caption", "sup", "sub", "img", +} +MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = { + "a": {"href"}, "code": {"class"}, "ol": {"start"}, + "img": {"src", "alt", "title", "width", "height"}, +} +MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"} + + +def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None: + """Filter attribute values to a safe Matrix-compatible subset.""" + if tag == "a" and attr == "href": + return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None + if tag == "img" and attr == "src": + return value if value.lower().startswith("mxc://") else None + if tag == "code" and attr == "class": + classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")] + return " ".join(classes) if classes else None + return value + + +MATRIX_HTML_CLEANER = nh3.Cleaner( + tags=MATRIX_ALLOWED_HTML_TAGS, + attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES, + attribute_filter=_filter_matrix_html_attribute, + url_schemes=MATRIX_ALLOWED_URL_SCHEMES, + strip_comments=True, + link_rel="noopener noreferrer", +) + + +def _render_markdown_html(text: str) -> str | None: + """Render markdown to sanitized HTML; returns None for plain text.""" + try: + formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip() + except Exception: + return None + if not formatted: + return None + # Skip formatted_body for plain

text

to keep payload minimal. + if formatted.startswith("

") and formatted.endswith("

"): + inner = formatted[3:-4] + if "<" not in inner and ">" not in inner: + return None + return formatted + + +def _build_matrix_text_content(text: str) -> dict[str, object]: + """Build Matrix m.text payload with optional HTML formatted_body.""" + content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}} + if html := _render_markdown_html(text): + content["format"] = MATRIX_HTML_FORMAT + content["formatted_body"] = html + return content + + +class _NioLoguruHandler(logging.Handler): + """Route matrix-nio stdlib logs into Loguru.""" + + def emit(self, record: logging.LogRecord) -> None: + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + frame, depth = logging.currentframe(), 2 + while frame and frame.f_code.co_filename == logging.__file__: + frame, depth = frame.f_back, depth + 1 + logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + + +def _configure_nio_logging_bridge() -> None: + """Bridge matrix-nio logs to Loguru (idempotent).""" + nio_logger = logging.getLogger("nio") + if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers): + nio_logger.handlers = [_NioLoguruHandler()] + nio_logger.propagate = False + + +class MatrixChannel(BaseChannel): + """Matrix (Element) channel using long-polling sync.""" + + name = "matrix" + + def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False, + workspace: Path | None = None): + super().__init__(config, bus) + self.client: AsyncClient | None = None + self._sync_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} + self._restrict_to_workspace = restrict_to_workspace + self._workspace = workspace.expanduser().resolve() if workspace else None + self._server_upload_limit_bytes: int | None = None + self._server_upload_limit_checked = False + + async def start(self) -> None: + """Start Matrix client and begin sync loop.""" + self._running = True + _configure_nio_logging_bridge() + + store_path = get_data_dir() / "matrix-store" + store_path.mkdir(parents=True, exist_ok=True) + + self.client = AsyncClient( + homeserver=self.config.homeserver, user=self.config.user_id, + store_path=store_path, + config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled), + ) + self.client.user_id = self.config.user_id + self.client.access_token = self.config.access_token + self.client.device_id = self.config.device_id + + self._register_event_callbacks() + self._register_response_callbacks() + + if not self.config.e2ee_enabled: + logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.") + + if self.config.device_id: + try: + self.client.load_store() + except Exception: + logger.exception("Matrix store load failed; restart may replay recent messages.") + else: + logger.warning("Matrix device_id empty; restart may replay recent messages.") + + self._sync_task = asyncio.create_task(self._sync_loop()) + + async def stop(self) -> None: + """Stop the Matrix channel with graceful sync shutdown.""" + self._running = False + for room_id in list(self._typing_tasks): + await self._stop_typing_keepalive(room_id, clear_typing=False) + if self.client: + self.client.stop_sync_forever() + if self._sync_task: + try: + await asyncio.wait_for(asyncio.shield(self._sync_task), + timeout=self.config.sync_stop_grace_seconds) + except (asyncio.TimeoutError, asyncio.CancelledError): + self._sync_task.cancel() + try: + await self._sync_task + except asyncio.CancelledError: + pass + if self.client: + await self.client.close() + + def _is_workspace_path_allowed(self, path: Path) -> bool: + """Check path is inside workspace (when restriction enabled).""" + if not self._restrict_to_workspace or not self._workspace: + return True + try: + path.resolve(strict=False).relative_to(self._workspace) + return True + except ValueError: + return False + + def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]: + """Deduplicate and resolve outbound attachment paths.""" + seen: set[str] = set() + candidates: list[Path] = [] + for raw in media: + if not isinstance(raw, str) or not raw.strip(): + continue + path = Path(raw.strip()).expanduser() + try: + key = str(path.resolve(strict=False)) + except OSError: + key = str(path) + if key not in seen: + seen.add(key) + candidates.append(path) + return candidates + + @staticmethod + def _build_outbound_attachment_content( + *, filename: str, mime: str, size_bytes: int, + mxc_url: str, encryption_info: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Build Matrix content payload for an uploaded file/image/audio/video.""" + prefix = mime.split("/")[0] + msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file") + content: dict[str, Any] = { + "msgtype": msgtype, "body": filename, "filename": filename, + "info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {}, + } + if encryption_info: + content["file"] = {**encryption_info, "url": mxc_url} + else: + content["url"] = mxc_url + return content + + def _is_encrypted_room(self, room_id: str) -> bool: + if not self.client: + return False + room = getattr(self.client, "rooms", {}).get(room_id) + return bool(getattr(room, "encrypted", False)) + + async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None: + """Send m.room.message with E2EE options.""" + if not self.client: + return + kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content} + if self.config.e2ee_enabled: + kwargs["ignore_unverified_devices"] = True + await self.client.room_send(**kwargs) + + async def _resolve_server_upload_limit_bytes(self) -> int | None: + """Query homeserver upload limit once per channel lifecycle.""" + if self._server_upload_limit_checked: + return self._server_upload_limit_bytes + self._server_upload_limit_checked = True + if not self.client: + return None + try: + response = await self.client.content_repository_config() + except Exception: + return None + upload_size = getattr(response, "upload_size", None) + if isinstance(upload_size, int) and upload_size > 0: + self._server_upload_limit_bytes = upload_size + return upload_size + return None + + async def _effective_media_limit_bytes(self) -> int: + """min(local config, server advertised) β€” 0 blocks all uploads.""" + local_limit = max(int(self.config.max_media_bytes), 0) + server_limit = await self._resolve_server_upload_limit_bytes() + if server_limit is None: + return local_limit + return min(local_limit, server_limit) if local_limit else 0 + + async def _upload_and_send_attachment( + self, room_id: str, path: Path, limit_bytes: int, + relates_to: dict[str, Any] | None = None, + ) -> str | None: + """Upload one local file to Matrix and send it as a media message. Returns failure marker or None.""" + if not self.client: + return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME) + + resolved = path.expanduser().resolve(strict=False) + filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME + fail = _ATTACH_UPLOAD_FAILED.format(filename) + + if not resolved.is_file() or not self._is_workspace_path_allowed(resolved): + return fail + try: + size_bytes = resolved.stat().st_size + except OSError: + return fail + if limit_bytes <= 0 or size_bytes > limit_bytes: + return _ATTACH_TOO_LARGE.format(filename) + + mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream" + try: + with resolved.open("rb") as f: + upload_result = await self.client.upload( + f, content_type=mime, filename=filename, + encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id), + filesize=size_bytes, + ) + except Exception: + return fail + + upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result + encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None + if isinstance(upload_response, UploadError): + return fail + mxc_url = getattr(upload_response, "content_uri", None) + if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"): + return fail + + content = self._build_outbound_attachment_content( + filename=filename, mime=mime, size_bytes=size_bytes, + mxc_url=mxc_url, encryption_info=encryption_info, + ) + if relates_to: + content["m.relates_to"] = relates_to + try: + await self._send_room_content(room_id, content) + except Exception: + return fail + return None + + async def send(self, msg: OutboundMessage) -> None: + """Send outbound content; clear typing for non-progress messages.""" + if not self.client: + return + text = msg.content or "" + candidates = self._collect_outbound_media_candidates(msg.media) + relates_to = self._build_thread_relates_to(msg.metadata) + is_progress = bool((msg.metadata or {}).get("_progress")) + try: + failures: list[str] = [] + if candidates: + limit_bytes = await self._effective_media_limit_bytes() + for path in candidates: + if fail := await self._upload_and_send_attachment( + msg.chat_id, path, limit_bytes, relates_to): + failures.append(fail) + if failures: + text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures) + if text or not candidates: + content = _build_matrix_text_content(text) + if relates_to: + content["m.relates_to"] = relates_to + await self._send_room_content(msg.chat_id, content) + finally: + if not is_progress: + await self._stop_typing_keepalive(msg.chat_id, clear_typing=True) + + def _register_event_callbacks(self) -> None: + self.client.add_event_callback(self._on_message, RoomMessageText) + self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) + self.client.add_event_callback(self._on_room_invite, InviteEvent) + + def _register_response_callbacks(self) -> None: + self.client.add_response_callback(self._on_sync_error, SyncError) + self.client.add_response_callback(self._on_join_error, JoinError) + self.client.add_response_callback(self._on_send_error, RoomSendError) + + def _log_response_error(self, label: str, response: Any) -> None: + """Log Matrix response errors β€” auth errors at ERROR level, rest at WARNING.""" + code = getattr(response, "status_code", None) + is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"} + is_fatal = is_auth or getattr(response, "soft_logout", False) + (logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response) + + async def _on_sync_error(self, response: SyncError) -> None: + self._log_response_error("sync", response) + + async def _on_join_error(self, response: JoinError) -> None: + self._log_response_error("join", response) + + async def _on_send_error(self, response: RoomSendError) -> None: + self._log_response_error("send", response) + + async def _set_typing(self, room_id: str, typing: bool) -> None: + """Best-effort typing indicator update.""" + if not self.client: + return + try: + response = await self.client.room_typing(room_id=room_id, typing_state=typing, + timeout=TYPING_NOTICE_TIMEOUT_MS) + if isinstance(response, RoomTypingError): + logger.debug("Matrix typing failed for {}: {}", room_id, response) + except Exception: + pass + + async def _start_typing_keepalive(self, room_id: str) -> None: + """Start periodic typing refresh (spec-recommended keepalive).""" + await self._stop_typing_keepalive(room_id, clear_typing=False) + await self._set_typing(room_id, True) + if not self._running: + return + + async def loop() -> None: + try: + while self._running: + await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000) + await self._set_typing(room_id, True) + except asyncio.CancelledError: + pass + + self._typing_tasks[room_id] = asyncio.create_task(loop()) + + async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None: + if task := self._typing_tasks.pop(room_id, None): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + if clear_typing: + await self._set_typing(room_id, False) + + async def _sync_loop(self) -> None: + while self._running: + try: + await self.client.sync_forever(timeout=30000, full_state=True) + except asyncio.CancelledError: + break + except Exception: + await asyncio.sleep(2) + + async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None: + allow_from = self.config.allow_from or [] + if not allow_from or event.sender in allow_from: + await self.client.join(room.room_id) + + def _is_direct_room(self, room: MatrixRoom) -> bool: + count = getattr(room, "member_count", None) + return isinstance(count, int) and count <= 2 + + def _is_bot_mentioned(self, event: RoomMessage) -> bool: + """Check m.mentions payload for bot mention.""" + source = getattr(event, "source", None) + if not isinstance(source, dict): + return False + mentions = (source.get("content") or {}).get("m.mentions") + if not isinstance(mentions, dict): + return False + user_ids = mentions.get("user_ids") + if isinstance(user_ids, list) and self.config.user_id in user_ids: + return True + return bool(self.config.allow_room_mentions and mentions.get("room") is True) + + def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool: + """Apply sender and room policy checks.""" + if not self.is_allowed(event.sender): + return False + if self._is_direct_room(room): + return True + policy = self.config.group_policy + if policy == "open": + return True + if policy == "allowlist": + return room.room_id in (self.config.group_allow_from or []) + if policy == "mention": + return self._is_bot_mentioned(event) + return False + + def _media_dir(self) -> Path: + d = get_data_dir() / "media" / "matrix" + d.mkdir(parents=True, exist_ok=True) + return d + + @staticmethod + def _event_source_content(event: RoomMessage) -> dict[str, Any]: + source = getattr(event, "source", None) + if not isinstance(source, dict): + return {} + content = source.get("content") + return content if isinstance(content, dict) else {} + + def _event_thread_root_id(self, event: RoomMessage) -> str | None: + relates_to = self._event_source_content(event).get("m.relates_to") + if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread": + return None + root_id = relates_to.get("event_id") + return root_id if isinstance(root_id, str) and root_id else None + + def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None: + if not (root_id := self._event_thread_root_id(event)): + return None + meta: dict[str, str] = {"thread_root_event_id": root_id} + if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to: + meta["thread_reply_to_event_id"] = reply_to + return meta + + @staticmethod + def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None: + if not metadata: + return None + root_id = metadata.get("thread_root_event_id") + if not isinstance(root_id, str) or not root_id: + return None + reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id") + if not isinstance(reply_to, str) or not reply_to: + return None + return {"rel_type": "m.thread", "event_id": root_id, + "m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True} + + def _event_attachment_type(self, event: MatrixMediaEvent) -> str: + msgtype = self._event_source_content(event).get("msgtype") + return _MSGTYPE_MAP.get(msgtype, "file") + + @staticmethod + def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool: + return (isinstance(getattr(event, "key", None), dict) + and isinstance(getattr(event, "hashes", None), dict) + and isinstance(getattr(event, "iv", None), str)) + + def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None: + info = self._event_source_content(event).get("info") + size = info.get("size") if isinstance(info, dict) else None + return size if isinstance(size, int) and size >= 0 else None + + def _event_mime(self, event: MatrixMediaEvent) -> str | None: + info = self._event_source_content(event).get("info") + if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m: + return m + m = getattr(event, "mimetype", None) + return m if isinstance(m, str) and m else None + + def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str: + body = getattr(event, "body", None) + if isinstance(body, str) and body.strip(): + if candidate := safe_filename(Path(body).name): + return candidate + return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type + + def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str, + filename: str, mime: str | None) -> Path: + safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME + suffix = Path(safe_name).suffix + if not suffix and mime: + if guessed := mimetypes.guess_extension(mime, strict=False): + safe_name, suffix = f"{safe_name}{guessed}", guessed + stem = (Path(safe_name).stem or attachment_type)[:72] + suffix = suffix[:16] + event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$")) + event_prefix = (event_id[:24] or "evt").strip("_") + return self._media_dir() / f"{event_prefix}_{stem}{suffix}" + + async def _download_media_bytes(self, mxc_url: str) -> bytes | None: + if not self.client: + return None + response = await self.client.download(mxc=mxc_url) + if isinstance(response, DownloadError): + logger.warning("Matrix download failed for {}: {}", mxc_url, response) + return None + body = getattr(response, "body", None) + if isinstance(body, (bytes, bytearray)): + return bytes(body) + if isinstance(response, MemoryDownloadResponse): + return bytes(response.body) + if isinstance(body, (str, Path)): + path = Path(body) + if path.is_file(): + try: + return path.read_bytes() + except OSError: + return None + return None + + def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None: + key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None) + key = key_obj.get("k") if isinstance(key_obj, dict) else None + sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None + if not all(isinstance(v, str) for v in (key, sha256, iv)): + return None + try: + return decrypt_attachment(ciphertext, key, sha256, iv) + except (EncryptionError, ValueError, TypeError): + logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", "")) + return None + + async def _fetch_media_attachment( + self, room: MatrixRoom, event: MatrixMediaEvent, + ) -> tuple[dict[str, Any] | None, str]: + """Download, decrypt if needed, and persist a Matrix attachment.""" + atype = self._event_attachment_type(event) + mime = self._event_mime(event) + filename = self._event_filename(event, atype) + mxc_url = getattr(event, "url", None) + fail = _ATTACH_FAILED.format(filename) + + if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"): + return None, fail + + limit_bytes = await self._effective_media_limit_bytes() + declared = self._event_declared_size_bytes(event) + if declared is not None and declared > limit_bytes: + return None, _ATTACH_TOO_LARGE.format(filename) + + downloaded = await self._download_media_bytes(mxc_url) + if downloaded is None: + return None, fail + + encrypted = self._is_encrypted_media_event(event) + data = downloaded + if encrypted: + if (data := self._decrypt_media_bytes(event, downloaded)) is None: + return None, fail + + if len(data) > limit_bytes: + return None, _ATTACH_TOO_LARGE.format(filename) + + path = self._build_attachment_path(event, atype, filename, mime) + try: + path.write_bytes(data) + except OSError: + return None, fail + + attachment = { + "type": atype, "mime": mime, "filename": filename, + "event_id": str(getattr(event, "event_id", "") or ""), + "encrypted": encrypted, "size_bytes": len(data), + "path": str(path), "mxc_url": mxc_url, + } + return attachment, _ATTACH_MARKER.format(path) + + def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]: + """Build common metadata for text and media handlers.""" + meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)} + if isinstance(eid := getattr(event, "event_id", None), str) and eid: + meta["event_id"] = eid + if thread := self._thread_metadata(event): + meta.update(thread) + return meta + + async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None: + if event.sender == self.config.user_id or not self._should_process_message(room, event): + return + await self._start_typing_keepalive(room.room_id) + try: + await self._handle_message( + sender_id=event.sender, chat_id=room.room_id, + content=event.body, metadata=self._base_metadata(room, event), + ) + except Exception: + await self._stop_typing_keepalive(room.room_id, clear_typing=True) + raise + + async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None: + if event.sender == self.config.user_id or not self._should_process_message(room, event): + return + attachment, marker = await self._fetch_media_attachment(room, event) + parts: list[str] = [] + if isinstance(body := getattr(event, "body", None), str) and body.strip(): + parts.append(body.strip()) + parts.append(marker) + + await self._start_typing_keepalive(room.room_id) + try: + meta = self._base_metadata(room, event) + if attachment: + meta["attachments"] = [attachment] + await self._handle_message( + sender_id=event.sender, chat_id=room.room_id, + content="\n".join(parts), + media=[attachment["path"]] if attachment else [], + metadata=meta, + ) + except Exception: + await self._stop_typing_keepalive(room.room_id, clear_typing=True) + raise diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 5352a30..7b171bc 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -31,7 +31,8 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]": class _Bot(botpy.Client): def __init__(self): - super().__init__(intents=intents) + # Disable botpy's file log β€” nanobot uses loguru; default "botpy.log" fails on read-only fs + super().__init__(intents=intents, ext_handlers=False) async def on_ready(self): logger.info("QQ bot ready: {}", self.robot.name) @@ -100,10 +101,12 @@ class QQChannel(BaseChannel): logger.warning("QQ client not initialized") return try: + msg_id = msg.metadata.get("message_id") await self._client.api.post_c2c_message( openid=msg.chat_id, msg_type=0, content=msg.content, + msg_id=msg_id, ) except Exception as e: logger.error("Error sending QQ message: {}", e) diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index 906593b..57bfbcb 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -229,6 +229,11 @@ class SlackChannel(BaseChannel): return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip() _TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*") + _CODE_FENCE_RE = re.compile(r"```[\s\S]*?```") + _INLINE_CODE_RE = re.compile(r"`[^`]+`") + _LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*") + _LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE) + _BARE_URL_RE = re.compile(r"(? str: @@ -236,7 +241,26 @@ class SlackChannel(BaseChannel): if not text: return "" text = cls._TABLE_RE.sub(cls._convert_table, text) - return slackify_markdown(text) + return cls._fixup_mrkdwn(slackify_markdown(text)) + + @classmethod + def _fixup_mrkdwn(cls, text: str) -> str: + """Fix markdown artifacts that slackify_markdown misses.""" + code_blocks: list[str] = [] + + def _save_code(m: re.Match) -> str: + code_blocks.append(m.group(0)) + return f"\x00CB{len(code_blocks) - 1}\x00" + + text = cls._CODE_FENCE_RE.sub(_save_code, text) + text = cls._INLINE_CODE_RE.sub(_save_code, text) + text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text) + text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text) + text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text) + + for i, block in enumerate(code_blocks): + text = text.replace(f"\x00CB{i}\x00", block) + return text @staticmethod def _convert_table(match: re.Match) -> str: diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 6cd98e7..969d853 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -111,6 +111,7 @@ class TelegramChannel(BaseChannel): BOT_COMMANDS = [ BotCommand("start", "Start the bot"), BotCommand("new", "Start a new conversation"), + BotCommand("stop", "Stop the current task"), BotCommand("help", "Show available commands"), ] @@ -126,6 +127,8 @@ class TelegramChannel(BaseChannel): self._app: Application | None = None self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task + self._media_group_buffers: dict[str, dict] = {} + self._media_group_tasks: dict[str, asyncio.Task] = {} async def start(self) -> None: """Start the Telegram bot with long polling.""" @@ -190,6 +193,11 @@ class TelegramChannel(BaseChannel): # Cancel all typing indicators for chat_id in list(self._typing_tasks): self._stop_typing(chat_id) + + for task in self._media_group_tasks.values(): + task.cancel() + self._media_group_tasks.clear() + self._media_group_buffers.clear() if self._app: logger.info("Stopping Telegram bot...") @@ -299,6 +307,7 @@ class TelegramChannel(BaseChannel): await update.message.reply_text( "🐈 nanobot commands:\n" "/new β€” Start a new conversation\n" + "/stop β€” Stop the current task\n" "/help β€” Show available commands" ) @@ -397,6 +406,28 @@ class TelegramChannel(BaseChannel): logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) str_chat_id = str(chat_id) + + # Telegram media groups: buffer briefly, forward as one aggregated turn. + if media_group_id := getattr(message, "media_group_id", None): + key = f"{str_chat_id}:{media_group_id}" + if key not in self._media_group_buffers: + self._media_group_buffers[key] = { + "sender_id": sender_id, "chat_id": str_chat_id, + "contents": [], "media": [], + "metadata": { + "message_id": message.message_id, "user_id": user.id, + "username": user.username, "first_name": user.first_name, + "is_group": message.chat.type != "private", + }, + } + self._start_typing(str_chat_id) + buf = self._media_group_buffers[key] + if content and content != "[empty message]": + buf["contents"].append(content) + buf["media"].extend(media_paths) + if key not in self._media_group_tasks: + self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key)) + return # Start typing indicator before processing self._start_typing(str_chat_id) @@ -416,6 +447,21 @@ class TelegramChannel(BaseChannel): } ) + async def _flush_media_group(self, key: str) -> None: + """Wait briefly, then forward buffered media-group as one turn.""" + try: + await asyncio.sleep(0.6) + if not (buf := self._media_group_buffers.pop(key, None)): + return + content = "\n".join(buf["contents"]) or "[empty message]" + await self._handle_message( + sender_id=buf["sender_id"], chat_id=buf["chat_id"], + content=content, media=list(dict.fromkeys(buf["media"])), + metadata=buf["metadata"], + ) + finally: + self._media_group_tasks.pop(key, None) + def _start_typing(self, chat_id: str) -> None: """Start sending 'typing...' indicator for a chat.""" # Cancel any existing typing task for this chat diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index f5fb521..49d2390 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -2,6 +2,7 @@ import asyncio import json +from collections import OrderedDict from typing import Any from loguru import logger @@ -15,18 +16,19 @@ from nanobot.config.schema import WhatsAppConfig class WhatsAppChannel(BaseChannel): """ WhatsApp channel that connects to a Node.js bridge. - + The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol. Communication between Python and Node.js is via WebSocket. """ - + name = "whatsapp" - + def __init__(self, config: WhatsAppConfig, bus: MessageBus): super().__init__(config, bus) self.config: WhatsAppConfig = config self._ws = None self._connected = False + self._processed_message_ids: OrderedDict[str, None] = OrderedDict() async def start(self) -> None: """Start the WhatsApp channel by connecting to the bridge.""" @@ -105,26 +107,34 @@ class WhatsAppChannel(BaseChannel): # Incoming message from WhatsApp # Deprecated by whatsapp: old phone number style typically: @s.whatspp.net pn = data.get("pn", "") - # New LID sytle typically: + # New LID sytle typically: sender = data.get("sender", "") content = data.get("content", "") - + message_id = data.get("id", "") + + if message_id: + if message_id in self._processed_message_ids: + return + self._processed_message_ids[message_id] = None + while len(self._processed_message_ids) > 1000: + self._processed_message_ids.popitem(last=False) + # Extract just the phone number or lid as chat_id user_id = pn if pn else sender sender_id = user_id.split("@")[0] if "@" in user_id else user_id logger.info("Sender {}", sender) - + # Handle voice transcription if it's a voice message if content == "[Voice Message]": logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id) content = "[Voice Message: Transcription not available for WhatsApp yet]" - + await self._handle_message( sender_id=sender_id, chat_id=sender, # Use full LID for replies content=content, metadata={ - "message_id": data.get("id"), + "message_id": message_id, "timestamp": data.get("timestamp"), "is_group": data.get("isGroup", False) } diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 1c20b50..2e417d6 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -20,6 +20,7 @@ from prompt_toolkit.patch_stdout import patch_stdout from nanobot import __version__, __logo__ from nanobot.config.schema import Config +from nanobot.utils.helpers import sync_workspace_templates app = typer.Typer( name="nanobot", @@ -185,8 +186,7 @@ def onboard(): workspace.mkdir(parents=True, exist_ok=True) console.print(f"[green]βœ“[/green] Created workspace at {workspace}") - # Create default bootstrap files - _create_workspace_templates(workspace) + sync_workspace_templates(workspace) console.print(f"\n{__logo__} nanobot is ready!") console.print("\nNext steps:") @@ -198,36 +198,6 @@ def onboard(): -def _create_workspace_templates(workspace: Path): - """Create default workspace template files from bundled templates.""" - from importlib.resources import files as pkg_files - - templates_dir = pkg_files("nanobot") / "templates" - - for item in templates_dir.iterdir(): - if not item.name.endswith(".md"): - continue - dest = workspace / item.name - if not dest.exists(): - dest.write_text(item.read_text(encoding="utf-8"), encoding="utf-8") - console.print(f" [dim]Created {item.name}[/dim]") - - memory_dir = workspace / "memory" - memory_dir.mkdir(exist_ok=True) - - memory_template = templates_dir / "memory" / "MEMORY.md" - memory_file = memory_dir / "MEMORY.md" - if not memory_file.exists(): - memory_file.write_text(memory_template.read_text(encoding="utf-8"), encoding="utf-8") - console.print(" [dim]Created memory/MEMORY.md[/dim]") - - history_file = memory_dir / "HISTORY.md" - if not history_file.exists(): - history_file.write_text("", encoding="utf-8") - console.print(" [dim]Created memory/HISTORY.md[/dim]") - - (workspace / "skills").mkdir(exist_ok=True) - def _make_provider(config: Config): """Create the appropriate LLM provider from config.""" @@ -294,6 +264,7 @@ def gateway( console.print(f"{__logo__} Starting nanobot gateway on port {port}...") config = load_config() + sync_workspace_templates(config.workspace_path) bus = MessageBus() provider = _make_provider(config) session_manager = SessionManager(config.workspace_path) @@ -312,6 +283,7 @@ def gateway( max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, memory_window=config.agents.defaults.memory_window, + reasoning_effort=config.agents.defaults.reasoning_effort, brave_api_key=config.tools.web.search.api_key or None, exec_config=config.tools.exec, cron_service=cron, @@ -447,6 +419,7 @@ def agent( from loguru import logger config = load_config() + sync_workspace_templates(config.workspace_path) bus = MessageBus() provider = _make_provider(config) @@ -469,6 +442,7 @@ def agent( max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, memory_window=config.agents.defaults.memory_window, + reasoning_effort=config.agents.defaults.reasoning_effort, brave_api_key=config.tools.web.search.api_key or None, exec_config=config.tools.exec, cron_service=cron, @@ -960,6 +934,7 @@ def cron_run( max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, memory_window=config.agents.defaults.memory_window, + reasoning_effort=config.agents.defaults.reasoning_effort, brave_api_key=config.tools.web.search.api_key or None, exec_config=config.tools.exec, restrict_to_workspace=config.tools.restrict_to_workspace, diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 215f38d..4f06ebe 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -1,6 +1,8 @@ """Configuration schema using Pydantic.""" from pathlib import Path +from typing import Literal + from pydantic import BaseModel, Field, ConfigDict from pydantic.alias_generators import to_camel from pydantic_settings import BaseSettings @@ -40,6 +42,7 @@ class FeishuConfig(Base): encrypt_key: str = "" # Encrypt Key for event subscription (optional) verification_token: str = "" # Verification Token for event subscription (optional) allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids + react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE) class DingTalkConfig(Base): @@ -61,6 +64,23 @@ class DiscordConfig(Base): intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT +class MatrixConfig(Base): + """Matrix (Element) channel configuration.""" + + enabled: bool = False + homeserver: str = "https://matrix.org" + access_token: str = "" + user_id: str = "" # @bot:matrix.org + device_id: str = "" + e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling). + sync_stop_grace_seconds: int = 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback. + max_media_bytes: int = 20 * 1024 * 1024 # Max attachment size accepted for Matrix media handling (inbound + outbound). + allow_from: list[str] = Field(default_factory=list) + group_policy: Literal["open", "mention", "allowlist"] = "open" + group_allow_from: list[str] = Field(default_factory=list) + allow_room_mentions: bool = False + + class EmailConfig(Base): """Email channel configuration (IMAP inbound + SMTP outbound).""" @@ -164,6 +184,20 @@ class QQConfig(Base): secret: str = "" # ζœΊε™¨δΊΊε―†ι’₯ (AppSecret) from q.qq.com allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access) +class MatrixConfig(Base): + """Matrix (Element) channel configuration.""" + enabled: bool = False + homeserver: str = "https://matrix.org" + access_token: str = "" + user_id: str = "" # e.g. @bot:matrix.org + device_id: str = "" + e2ee_enabled: bool = True # end-to-end encryption support + sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout + max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit + allow_from: list[str] = Field(default_factory=list) + group_policy: Literal["open", "mention", "allowlist"] = "open" + group_allow_from: list[str] = Field(default_factory=list) + allow_room_mentions: bool = False class ChannelsConfig(Base): """Configuration for chat channels.""" @@ -179,6 +213,7 @@ class ChannelsConfig(Base): email: EmailConfig = Field(default_factory=EmailConfig) slack: SlackConfig = Field(default_factory=SlackConfig) qq: QQConfig = Field(default_factory=QQConfig) + matrix: MatrixConfig = Field(default_factory=MatrixConfig) class AgentDefaults(Base): @@ -186,10 +221,12 @@ class AgentDefaults(Base): workspace: str = "~/.nanobot/workspace" model: str = "anthropic/claude-opus-4-5" + provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection max_tokens: int = 8192 temperature: float = 0.1 max_tool_iterations: int = 40 memory_window: int = 100 + reasoning_effort: str | None = None # low / medium / high β€” enables LLM thinking mode class AgentsConfig(Base): @@ -260,6 +297,7 @@ class ExecToolConfig(Base): """Shell exec tool configuration.""" timeout: int = 60 + path_append: str = "" class MCPServerConfig(Base): @@ -300,6 +338,11 @@ class Config(BaseSettings): """Match provider config and its registry name. Returns (config, spec_name).""" from nanobot.providers.registry import PROVIDERS + forced = self.agents.defaults.provider + if forced != "auto": + p = getattr(self.providers, forced, None) + return (p, forced) if p else (None, None) + model_lower = (model or self.agents.defaults.model).lower() model_normalized = model_lower.replace("-", "_") model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else "" diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index eb1599a..36e9938 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -88,6 +88,7 @@ class LLMProvider(ABC): model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None, ) -> LLMResponse: """ Send a chat completion request. diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py index a578d14..56e6270 100644 --- a/nanobot/providers/custom_provider.py +++ b/nanobot/providers/custom_provider.py @@ -18,13 +18,16 @@ class CustomProvider(LLMProvider): self._client = AsyncOpenAI(api_key=api_key, base_url=api_base) async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse: + model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None) -> LLMResponse: kwargs: dict[str, Any] = { "model": model or self.default_model, "messages": self._sanitize_empty_content(messages), "max_tokens": max(1, max_tokens), "temperature": temperature, } + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort if tools: kwargs.update(tools=tools, tool_choice="auto") try: diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 03a6c4d..0067ae8 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -3,6 +3,8 @@ import json import json_repair import os +import secrets +import string from typing import Any import litellm @@ -12,8 +14,14 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.registry import find_by_model, find_gateway -# Standard OpenAI chat-completion message keys; extras (e.g. reasoning_content) are stripped for strict providers. +# Standard OpenAI chat-completion message keys plus reasoning_content for +# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.). _ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) +_ALNUM = string.ascii_letters + string.digits + +def _short_tool_id() -> str: + """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral).""" + return "".join(secrets.choice(_ALNUM) for _ in range(9)) class LiteLLMProvider(LLMProvider): @@ -170,6 +178,7 @@ class LiteLLMProvider(LLMProvider): model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None, ) -> LLMResponse: """ Send a chat completion request via LiteLLM. @@ -216,6 +225,10 @@ class LiteLLMProvider(LLMProvider): if self.extra_headers: kwargs["extra_headers"] = self.extra_headers + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort + kwargs["drop_params"] = True + if tools: kwargs["tools"] = tools kwargs["tool_choice"] = "auto" @@ -244,7 +257,7 @@ class LiteLLMProvider(LLMProvider): args = json_repair.loads(args) tool_calls.append(ToolCallRequest( - id=tc.id, + id=_short_tool_id(), name=tc.function.name, arguments=args, )) diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index fa28593..9039202 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -31,6 +31,7 @@ class OpenAICodexProvider(LLMProvider): model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + reasoning_effort: str | None = None, ) -> LLMResponse: model = model or self.default_model system_prompt, input_items = _convert_messages(messages) diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 2766929..df915b7 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -201,7 +201,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( # OpenAI Codex: uses OAuth, not API key. ProviderSpec( name="openai_codex", - keywords=("openai-codex", "codex"), + keywords=("openai-codex",), env_key="", # OAuth-based, no API key display_name="OpenAI Codex", litellm_prefix="", # Not routed through LiteLLM diff --git a/nanobot/skills/memory/SKILL.md b/nanobot/skills/memory/SKILL.md index 39adbde..529a02d 100644 --- a/nanobot/skills/memory/SKILL.md +++ b/nanobot/skills/memory/SKILL.md @@ -9,7 +9,7 @@ always: true ## Structure - `memory/MEMORY.md` β€” Long-term facts (preferences, project context, relationships). Always loaded into your context. -- `memory/HISTORY.md` β€” Append-only event log. NOT loaded into context. Search it with grep. +- `memory/HISTORY.md` β€” Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM]. ## Search Past Events diff --git a/nanobot/templates/AGENTS.md b/nanobot/templates/AGENTS.md index 84ba657..4c3e5b1 100644 --- a/nanobot/templates/AGENTS.md +++ b/nanobot/templates/AGENTS.md @@ -2,14 +2,6 @@ You are a helpful AI assistant. Be concise, accurate, and friendly. -## Guidelines - -- Before calling tools, briefly state your intent β€” but NEVER predict results before receiving them -- Use precise tense: "I will run X" before the call, "X returned Y" after -- NEVER claim success before a tool result confirms it -- Ask for clarification when the request is ambiguous -- Remember important information in `memory/MEMORY.md`; past events are logged in `memory/HISTORY.md` - ## Scheduled Reminders When user asks for a reminder at a specific time, use `exec` to run: diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 62f80ac..8322bc8 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -1,80 +1,67 @@ """Utility functions for nanobot.""" +import re from pathlib import Path from datetime import datetime def ensure_dir(path: Path) -> Path: - """Ensure a directory exists, creating it if necessary.""" + """Ensure directory exists, return it.""" path.mkdir(parents=True, exist_ok=True) return path def get_data_path() -> Path: - """Get the nanobot data directory (~/.nanobot).""" + """~/.nanobot data directory.""" return ensure_dir(Path.home() / ".nanobot") def get_workspace_path(workspace: str | None = None) -> Path: - """ - Get the workspace path. - - Args: - workspace: Optional workspace path. Defaults to ~/.nanobot/workspace. - - Returns: - Expanded and ensured workspace path. - """ - if workspace: - path = Path(workspace).expanduser() - else: - path = Path.home() / ".nanobot" / "workspace" + """Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace.""" + path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace" return ensure_dir(path) -def get_sessions_path() -> Path: - """Get the sessions storage directory.""" - return ensure_dir(get_data_path() / "sessions") - - -def get_skills_path(workspace: Path | None = None) -> Path: - """Get the skills directory within the workspace.""" - ws = workspace or get_workspace_path() - return ensure_dir(ws / "skills") - - def timestamp() -> str: - """Get current timestamp in ISO format.""" + """Current ISO timestamp.""" return datetime.now().isoformat() -def truncate_string(s: str, max_len: int = 100, suffix: str = "...") -> str: - """Truncate a string to max length, adding suffix if truncated.""" - if len(s) <= max_len: - return s - return s[: max_len - len(suffix)] + suffix - +_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]') def safe_filename(name: str) -> str: - """Convert a string to a safe filename.""" - # Replace unsafe characters - unsafe = '<>:"/\\|?*' - for char in unsafe: - name = name.replace(char, "_") - return name.strip() + """Replace unsafe path characters with underscores.""" + return _UNSAFE_CHARS.sub("_", name).strip() -def parse_session_key(key: str) -> tuple[str, str]: - """ - Parse a session key into channel and chat_id. - - Args: - key: Session key in format "channel:chat_id" - - Returns: - Tuple of (channel, chat_id) - """ - parts = key.split(":", 1) - if len(parts) != 2: - raise ValueError(f"Invalid session key: {key}") - return parts[0], parts[1] +def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: + """Sync bundled templates to workspace. Only creates missing files.""" + from importlib.resources import files as pkg_files + try: + tpl = pkg_files("nanobot") / "templates" + except Exception: + return [] + if not tpl.is_dir(): + return [] + + added: list[str] = [] + + def _write(src, dest: Path): + if dest.exists(): + return + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8") + added.append(str(dest.relative_to(workspace))) + + for item in tpl.iterdir(): + if item.name.endswith(".md"): + _write(item, workspace / item.name) + _write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md") + _write(None, workspace / "memory" / "HISTORY.md") + (workspace / "skills").mkdir(exist_ok=True) + + if added and not silent: + from rich.console import Console + for name in added: + Console().print(f" [dim]Created {name}[/dim]") + return added diff --git a/pyproject.toml b/pyproject.toml index cb58ec5..20dcb1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nanobot-ai" -version = "0.1.4.post1" +version = "0.1.4.post2" description = "A lightweight personal AI assistant framework" requires-python = ">=3.11" license = {text = "MIT"} @@ -45,6 +45,11 @@ dependencies = [ ] [project.optional-dependencies] +matrix = [ + "matrix-nio[e2e]>=0.25.2", + "mistune>=3.0.0,<4.0.0", + "nh3>=0.2.17,<1.0.0", +] dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index 323519e..a3213dd 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -786,10 +786,8 @@ class TestConsolidationDeduplicationGuard: ) @pytest.mark.asyncio - async def test_new_cleans_up_consolidation_lock_for_invalidated_session( - self, tmp_path: Path - ) -> None: - """/new should remove lock entry for fully invalidated session key.""" + async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None: + """/new clears session and returns confirmation.""" from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus @@ -801,7 +799,6 @@ class TestConsolidationDeduplicationGuard: loop = AgentLoop( bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 ) - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) loop.tools.get_definitions = MagicMock(return_value=[]) @@ -811,10 +808,6 @@ class TestConsolidationDeduplicationGuard: session.add_message("assistant", f"resp{i}") loop.sessions.save(session) - # Ensure lock exists before /new. - _ = loop._get_consolidation_lock(session.key) - assert session.key in loop._consolidation_locks - async def _ok_consolidate(sess, archive_all: bool = False) -> bool: return True @@ -825,4 +818,4 @@ class TestConsolidationDeduplicationGuard: assert response is not None assert "new session started" in response.content.lower() - assert session.key not in loop._consolidation_locks + assert loop.sessions.get_or_create("cli:test").messages == [] diff --git a/tests/test_context_prompt_cache.py b/tests/test_context_prompt_cache.py index 8e2333c..9afcc7d 100644 --- a/tests/test_context_prompt_cache.py +++ b/tests/test_context_prompt_cache.py @@ -39,8 +39,8 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> assert prompt1 == prompt2 -def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None: - """Dynamic runtime details should be added at the tail user message, not system.""" +def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None: + """Runtime metadata should be a separate user message before the actual user message.""" workspace = _make_workspace(tmp_path) builder = ContextBuilder(workspace) @@ -54,10 +54,13 @@ def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None: assert messages[0]["role"] == "system" assert "## Current Session" not in messages[0]["content"] + assert messages[-2]["role"] == "user" + runtime_content = messages[-2]["content"] + assert isinstance(runtime_content, str) + assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content + assert "Current Time:" in runtime_content + assert "Channel: cli" in runtime_content + assert "Chat ID: direct" in runtime_content + assert messages[-1]["role"] == "user" - user_content = messages[-1]["content"] - assert isinstance(user_content, str) - assert "Return exactly: OK" in user_content - assert "Current Time:" in user_content - assert "Channel: cli" in user_content - assert "Chat ID: direct" in user_content + assert messages[-1]["content"] == "Return exactly: OK" diff --git a/tests/test_heartbeat_service.py b/tests/test_heartbeat_service.py index ec91c6b..c5478af 100644 --- a/tests/test_heartbeat_service.py +++ b/tests/test_heartbeat_service.py @@ -2,34 +2,28 @@ import asyncio import pytest -from nanobot.heartbeat.service import ( - HEARTBEAT_OK_TOKEN, - HeartbeatService, -) +from nanobot.heartbeat.service import HeartbeatService +from nanobot.providers.base import LLMResponse, ToolCallRequest -def test_heartbeat_ok_detection() -> None: - def is_ok(response: str) -> bool: - return HEARTBEAT_OK_TOKEN in response.upper() +class DummyProvider: + def __init__(self, responses: list[LLMResponse]): + self._responses = list(responses) - assert is_ok("HEARTBEAT_OK") - assert is_ok("`HEARTBEAT_OK`") - assert is_ok("**HEARTBEAT_OK**") - assert is_ok("heartbeat_ok") - assert is_ok("HEARTBEAT_OK.") - - assert not is_ok("HEARTBEAT_NOT_OK") - assert not is_ok("all good") + async def chat(self, *args, **kwargs) -> LLMResponse: + if self._responses: + return self._responses.pop(0) + return LLMResponse(content="", tool_calls=[]) @pytest.mark.asyncio async def test_start_is_idempotent(tmp_path) -> None: - async def _on_heartbeat(_: str) -> str: - return "HEARTBEAT_OK" + provider = DummyProvider([]) service = HeartbeatService( workspace=tmp_path, - on_heartbeat=_on_heartbeat, + provider=provider, + model="openai/gpt-4o-mini", interval_s=9999, enabled=True, ) @@ -42,3 +36,82 @@ async def test_start_is_idempotent(tmp_path) -> None: service.stop() await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None: + provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])]) + service = HeartbeatService( + workspace=tmp_path, + provider=provider, + model="openai/gpt-4o-mini", + ) + + action, tasks = await service._decide("heartbeat content") + assert action == "skip" + assert tasks == "" + + +@pytest.mark.asyncio +async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None: + (tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8") + + provider = DummyProvider([ + LLMResponse( + content="", + tool_calls=[ + ToolCallRequest( + id="hb_1", + name="heartbeat", + arguments={"action": "run", "tasks": "check open tasks"}, + ) + ], + ) + ]) + + called_with: list[str] = [] + + async def _on_execute(tasks: str) -> str: + called_with.append(tasks) + return "done" + + service = HeartbeatService( + workspace=tmp_path, + provider=provider, + model="openai/gpt-4o-mini", + on_execute=_on_execute, + ) + + result = await service.trigger_now() + assert result == "done" + assert called_with == ["check open tasks"] + + +@pytest.mark.asyncio +async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None: + (tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8") + + provider = DummyProvider([ + LLMResponse( + content="", + tool_calls=[ + ToolCallRequest( + id="hb_1", + name="heartbeat", + arguments={"action": "skip"}, + ) + ], + ) + ]) + + async def _on_execute(tasks: str) -> str: + return tasks + + service = HeartbeatService( + workspace=tmp_path, + provider=provider, + model="openai/gpt-4o-mini", + on_execute=_on_execute, + ) + + assert await service.trigger_now() is None diff --git a/tests/test_matrix_channel.py b/tests/test_matrix_channel.py new file mode 100644 index 0000000..c6714c2 --- /dev/null +++ b/tests/test_matrix_channel.py @@ -0,0 +1,1302 @@ +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import nanobot.channels.matrix as matrix_module +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.matrix import ( + MATRIX_HTML_FORMAT, + TYPING_NOTICE_TIMEOUT_MS, + MatrixChannel, +) +from nanobot.config.schema import MatrixConfig + +_ROOM_SEND_UNSET = object() + + +class _DummyTask: + def __init__(self) -> None: + self.cancelled = False + + def cancel(self) -> None: + self.cancelled = True + + def __await__(self): + async def _done(): + return None + + return _done().__await__() + + +class _FakeAsyncClient: + def __init__(self, homeserver, user, store_path, config) -> None: + self.homeserver = homeserver + self.user = user + self.store_path = store_path + self.config = config + self.user_id: str | None = None + self.access_token: str | None = None + self.device_id: str | None = None + self.load_store_called = False + self.stop_sync_forever_called = False + self.join_calls: list[str] = [] + self.callbacks: list[tuple[object, object]] = [] + self.response_callbacks: list[tuple[object, object]] = [] + self.rooms: dict[str, object] = {} + self.room_send_calls: list[dict[str, object]] = [] + self.typing_calls: list[tuple[str, bool, int]] = [] + self.download_calls: list[dict[str, object]] = [] + self.upload_calls: list[dict[str, object]] = [] + self.download_response: object | None = None + self.download_bytes: bytes = b"media" + self.download_content_type: str = "application/octet-stream" + self.download_filename: str | None = None + self.upload_response: object | None = None + self.content_repository_config_response: object = SimpleNamespace(upload_size=None) + self.raise_on_send = False + self.raise_on_typing = False + self.raise_on_upload = False + + def add_event_callback(self, callback, event_type) -> None: + self.callbacks.append((callback, event_type)) + + def add_response_callback(self, callback, response_type) -> None: + self.response_callbacks.append((callback, response_type)) + + def load_store(self) -> None: + self.load_store_called = True + + def stop_sync_forever(self) -> None: + self.stop_sync_forever_called = True + + async def join(self, room_id: str) -> None: + self.join_calls.append(room_id) + + async def room_send( + self, + room_id: str, + message_type: str, + content: dict[str, object], + ignore_unverified_devices: object = _ROOM_SEND_UNSET, + ) -> None: + call: dict[str, object] = { + "room_id": room_id, + "message_type": message_type, + "content": content, + } + if ignore_unverified_devices is not _ROOM_SEND_UNSET: + call["ignore_unverified_devices"] = ignore_unverified_devices + self.room_send_calls.append(call) + if self.raise_on_send: + raise RuntimeError("send failed") + + async def room_typing( + self, + room_id: str, + typing_state: bool = True, + timeout: int = 30_000, + ) -> None: + self.typing_calls.append((room_id, typing_state, timeout)) + if self.raise_on_typing: + raise RuntimeError("typing failed") + + async def download(self, **kwargs): + self.download_calls.append(kwargs) + if self.download_response is not None: + return self.download_response + return matrix_module.MemoryDownloadResponse( + body=self.download_bytes, + content_type=self.download_content_type, + filename=self.download_filename, + ) + + async def upload( + self, + data_provider, + content_type: str | None = None, + filename: str | None = None, + filesize: int | None = None, + encrypt: bool = False, + ): + if self.raise_on_upload: + raise RuntimeError("upload failed") + if isinstance(data_provider, (bytes, bytearray)): + raise TypeError( + f"data_provider type {type(data_provider)!r} is not of a usable type " + "(Callable, IOBase)" + ) + self.upload_calls.append( + { + "data_provider": data_provider, + "content_type": content_type, + "filename": filename, + "filesize": filesize, + "encrypt": encrypt, + } + ) + if self.upload_response is not None: + return self.upload_response + if encrypt: + return ( + SimpleNamespace(content_uri="mxc://example.org/uploaded"), + { + "v": "v2", + "iv": "iv", + "hashes": {"sha256": "hash"}, + "key": {"alg": "A256CTR", "k": "key"}, + }, + ) + return SimpleNamespace(content_uri="mxc://example.org/uploaded"), None + + async def content_repository_config(self): + return self.content_repository_config_response + + async def close(self) -> None: + return None + + +def _make_config(**kwargs) -> MatrixConfig: + return MatrixConfig( + enabled=True, + homeserver="https://matrix.org", + access_token="token", + user_id="@bot:matrix.org", + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_start_skips_load_store_when_device_id_missing( + monkeypatch, tmp_path +) -> None: + clients: list[_FakeAsyncClient] = [] + + def _fake_client(*args, **kwargs) -> _FakeAsyncClient: + client = _FakeAsyncClient(*args, **kwargs) + clients.append(client) + return client + + def _fake_create_task(coro): + coro.close() + return _DummyTask() + + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + monkeypatch.setattr( + "nanobot.channels.matrix.AsyncClientConfig", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr("nanobot.channels.matrix.AsyncClient", _fake_client) + monkeypatch.setattr( + "nanobot.channels.matrix.asyncio.create_task", _fake_create_task + ) + + channel = MatrixChannel(_make_config(device_id=""), MessageBus()) + await channel.start() + + assert len(clients) == 1 + assert clients[0].config.encryption_enabled is True + assert clients[0].load_store_called is False + assert len(clients[0].callbacks) == 3 + assert len(clients[0].response_callbacks) == 3 + + await channel.stop() + + +@pytest.mark.asyncio +async def test_register_event_callbacks_uses_media_base_filter() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + channel._register_event_callbacks() + + assert len(client.callbacks) == 3 + assert client.callbacks[1][0] == channel._on_media_message + assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER + + +def test_media_event_filter_does_not_match_text_events() -> None: + assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER) + + +@pytest.mark.asyncio +async def test_start_disables_e2ee_when_configured( + monkeypatch, tmp_path +) -> None: + clients: list[_FakeAsyncClient] = [] + + def _fake_client(*args, **kwargs) -> _FakeAsyncClient: + client = _FakeAsyncClient(*args, **kwargs) + clients.append(client) + return client + + def _fake_create_task(coro): + coro.close() + return _DummyTask() + + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + monkeypatch.setattr( + "nanobot.channels.matrix.AsyncClientConfig", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr("nanobot.channels.matrix.AsyncClient", _fake_client) + monkeypatch.setattr( + "nanobot.channels.matrix.asyncio.create_task", _fake_create_task + ) + + channel = MatrixChannel(_make_config(device_id="", e2ee_enabled=False), MessageBus()) + await channel.start() + + assert len(clients) == 1 + assert clients[0].config.encryption_enabled is False + + await channel.stop() + + +@pytest.mark.asyncio +async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None: + channel = MatrixChannel(_make_config(device_id="DEVICE"), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + task = _DummyTask() + + channel.client = client + channel._sync_task = task + channel._running = True + + await channel.stop() + + assert channel._running is False + assert client.stop_sync_forever_called is True + assert task.cancelled is False + + +@pytest.mark.asyncio +async def test_room_invite_joins_when_allow_list_is_empty() -> None: + channel = MatrixChannel(_make_config(allow_from=[]), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + room = SimpleNamespace(room_id="!room:matrix.org") + event = SimpleNamespace(sender="@alice:matrix.org") + + await channel._on_room_invite(room, event) + + assert client.join_calls == ["!room:matrix.org"] + + +@pytest.mark.asyncio +async def test_room_invite_respects_allow_list_when_configured() -> None: + channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + room = SimpleNamespace(room_id="!room:matrix.org") + event = SimpleNamespace(sender="@alice:matrix.org") + + await channel._on_room_invite(room, event) + + assert client.join_calls == [] + + +@pytest.mark.asyncio +async def test_on_message_sets_typing_for_allowed_sender() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[str] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs["sender_id"]) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room") + event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={}) + + await channel._on_message(room, event) + + assert handled == ["@alice:matrix.org"] + assert client.typing_calls == [ + ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS), + ] + + +@pytest.mark.asyncio +async def test_typing_keepalive_refreshes_periodically(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + channel._running = True + + monkeypatch.setattr(matrix_module, "TYPING_KEEPALIVE_INTERVAL_MS", 10) + + await channel._start_typing_keepalive("!room:matrix.org") + await asyncio.sleep(0.03) + await channel._stop_typing_keepalive("!room:matrix.org", clear_typing=True) + + true_updates = [call for call in client.typing_calls if call[1] is True] + assert len(true_updates) >= 2 + assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS) + + +@pytest.mark.asyncio +async def test_on_message_skips_typing_for_self_message() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room") + event = SimpleNamespace(sender="@bot:matrix.org", body="Hello", source={}) + + await channel._on_message(room, event) + + assert client.typing_calls == [] + + +@pytest.mark.asyncio +async def test_on_message_skips_typing_for_denied_sender() -> None: + channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[str] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs["sender_id"]) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room") + event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={}) + + await channel._on_message(room, event) + + assert handled == [] + assert client.typing_calls == [] + + +@pytest.mark.asyncio +async def test_on_message_mention_policy_requires_mx_mentions() -> None: + channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[str] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs["sender_id"]) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3) + event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}}) + + await channel._on_message(room, event) + + assert handled == [] + assert client.typing_calls == [] + + +@pytest.mark.asyncio +async def test_on_message_mention_policy_accepts_bot_user_mentions() -> None: + channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[str] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs["sender_id"]) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="Hello", + source={"content": {"m.mentions": {"user_ids": ["@bot:matrix.org"]}}}, + ) + + await channel._on_message(room, event) + + assert handled == ["@alice:matrix.org"] + assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)] + + +@pytest.mark.asyncio +async def test_on_message_mention_policy_allows_direct_room_without_mentions() -> None: + channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[str] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs["sender_id"]) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!dm:matrix.org", display_name="DM", member_count=2) + event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}}) + + await channel._on_message(room, event) + + assert handled == ["@alice:matrix.org"] + assert client.typing_calls == [("!dm:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)] + + +@pytest.mark.asyncio +async def test_on_message_allowlist_policy_requires_room_id() -> None: + channel = MatrixChannel( + _make_config(group_policy="allowlist", group_allow_from=["!allowed:matrix.org"]), + MessageBus(), + ) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[str] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs["chat_id"]) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + denied_room = SimpleNamespace(room_id="!denied:matrix.org", display_name="Denied", member_count=3) + event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}}) + await channel._on_message(denied_room, event) + + allowed_room = SimpleNamespace( + room_id="!allowed:matrix.org", + display_name="Allowed", + member_count=3, + ) + await channel._on_message(allowed_room, event) + + assert handled == ["!allowed:matrix.org"] + assert client.typing_calls == [("!allowed:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)] + + +@pytest.mark.asyncio +async def test_on_message_room_mention_requires_opt_in() -> None: + channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[str] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs["sender_id"]) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3) + room_mention_event = SimpleNamespace( + sender="@alice:matrix.org", + body="Hello everyone", + source={"content": {"m.mentions": {"room": True}}}, + ) + + await channel._on_message(room, room_mention_event) + assert handled == [] + assert client.typing_calls == [] + + channel.config.allow_room_mentions = True + await channel._on_message(room, room_mention_event) + assert handled == ["@alice:matrix.org"] + assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)] + + +@pytest.mark.asyncio +async def test_on_message_sets_thread_metadata_when_threaded_event() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="Hello", + event_id="$reply1", + source={ + "content": { + "m.relates_to": { + "rel_type": "m.thread", + "event_id": "$root1", + } + } + }, + ) + + await channel._on_message(room, event) + + assert len(handled) == 1 + metadata = handled[0]["metadata"] + assert metadata["thread_root_event_id"] == "$root1" + assert metadata["thread_reply_to_event_id"] == "$reply1" + assert metadata["event_id"] == "$reply1" + + +@pytest.mark.asyncio +async def test_on_media_message_downloads_attachment_and_sets_metadata( + monkeypatch, tmp_path +) -> None: + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.download_bytes = b"image" + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="photo.png", + url="mxc://example.org/mediaid", + event_id="$event1", + source={ + "content": { + "msgtype": "m.image", + "info": {"mimetype": "image/png", "size": 5}, + } + }, + ) + + await channel._on_media_message(room, event) + + assert len(client.download_calls) == 1 + assert len(handled) == 1 + assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)] + + media_paths = handled[0]["media"] + assert isinstance(media_paths, list) and len(media_paths) == 1 + media_path = Path(media_paths[0]) + assert media_path.is_file() + assert media_path.read_bytes() == b"image" + + metadata = handled[0]["metadata"] + attachments = metadata["attachments"] + assert isinstance(attachments, list) and len(attachments) == 1 + assert attachments[0]["type"] == "image" + assert attachments[0]["mxc_url"] == "mxc://example.org/mediaid" + assert attachments[0]["path"] == str(media_path) + assert "[attachment: " in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_on_media_message_sets_thread_metadata_when_threaded_event( + monkeypatch, tmp_path +) -> None: + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.download_bytes = b"image" + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="photo.png", + url="mxc://example.org/mediaid", + event_id="$event1", + source={ + "content": { + "msgtype": "m.image", + "info": {"mimetype": "image/png", "size": 5}, + "m.relates_to": { + "rel_type": "m.thread", + "event_id": "$root1", + }, + } + }, + ) + + await channel._on_media_message(room, event) + + assert len(handled) == 1 + metadata = handled[0]["metadata"] + assert metadata["thread_root_event_id"] == "$root1" + assert metadata["thread_reply_to_event_id"] == "$event1" + assert metadata["event_id"] == "$event1" + + +@pytest.mark.asyncio +async def test_on_media_message_respects_declared_size_limit( + monkeypatch, tmp_path +) -> None: + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + + channel = MatrixChannel(_make_config(max_media_bytes=3), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="large.bin", + url="mxc://example.org/large", + event_id="$event2", + source={"content": {"msgtype": "m.file", "info": {"size": 10}}}, + ) + + await channel._on_media_message(room, event) + + assert client.download_calls == [] + assert len(handled) == 1 + assert handled[0]["media"] == [] + assert handled[0]["metadata"]["attachments"] == [] + assert "[attachment: large.bin - too large]" in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_on_media_message_uses_server_limit_when_smaller_than_local_limit( + monkeypatch, tmp_path +) -> None: + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + + channel = MatrixChannel(_make_config(max_media_bytes=10), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.content_repository_config_response = SimpleNamespace(upload_size=3) + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="large.bin", + url="mxc://example.org/large", + event_id="$event2_server", + source={"content": {"msgtype": "m.file", "info": {"size": 5}}}, + ) + + await channel._on_media_message(room, event) + + assert client.download_calls == [] + assert len(handled) == 1 + assert handled[0]["media"] == [] + assert handled[0]["metadata"]["attachments"] == [] + assert "[attachment: large.bin - too large]" in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) -> None: + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.download_response = matrix_module.DownloadError("download failed") + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="photo.png", + url="mxc://example.org/mediaid", + event_id="$event3", + source={"content": {"msgtype": "m.image"}}, + ) + + await channel._on_media_message(room, event) + + assert len(client.download_calls) == 1 + assert len(handled) == 1 + assert handled[0]["media"] == [] + assert handled[0]["metadata"]["attachments"] == [] + assert "[attachment: photo.png - download failed]" in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_on_media_message_decrypts_encrypted_media(monkeypatch, tmp_path) -> None: + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + monkeypatch.setattr( + matrix_module, + "decrypt_attachment", + lambda ciphertext, key, sha256, iv: b"plain", + ) + + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.download_bytes = b"cipher" + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="secret.txt", + url="mxc://example.org/encrypted", + event_id="$event4", + key={"k": "key"}, + hashes={"sha256": "hash"}, + iv="iv", + source={"content": {"msgtype": "m.file", "info": {"size": 6}}}, + ) + + await channel._on_media_message(room, event) + + assert len(handled) == 1 + media_path = Path(handled[0]["media"][0]) + assert media_path.read_bytes() == b"plain" + attachment = handled[0]["metadata"]["attachments"][0] + assert attachment["encrypted"] is True + assert attachment["size_bytes"] == 5 + + +@pytest.mark.asyncio +async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) -> None: + monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path) + + def _raise(*args, **kwargs): + raise matrix_module.EncryptionError("boom") + + monkeypatch.setattr(matrix_module, "decrypt_attachment", _raise) + + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.download_bytes = b"cipher" + channel.client = client + + handled: list[dict[str, object]] = [] + + async def _fake_handle_message(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = _fake_handle_message # type: ignore[method-assign] + + room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2) + event = SimpleNamespace( + sender="@alice:matrix.org", + body="secret.txt", + url="mxc://example.org/encrypted", + event_id="$event5", + key={"k": "key"}, + hashes={"sha256": "hash"}, + iv="iv", + source={"content": {"msgtype": "m.file"}}, + ) + + await channel._on_media_message(room, event) + + assert len(handled) == 1 + assert handled[0]["media"] == [] + assert handled[0]["metadata"]["attachments"] == [] + assert "[attachment: secret.txt - download failed]" in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_send_clears_typing_after_send() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi") + ) + + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"] == { + "msgtype": "m.text", + "body": "Hi", + "m.mentions": {}, + } + assert client.room_send_calls[0]["ignore_unverified_devices"] is True + assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS) + + +@pytest.mark.asyncio +async def test_send_uploads_media_and_sends_file_event(tmp_path) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + file_path = tmp_path / "test.txt" + file_path.write_text("hello", encoding="utf-8") + + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="Please review.", + media=[str(file_path)], + ) + ) + + assert len(client.upload_calls) == 1 + assert not isinstance(client.upload_calls[0]["data_provider"], (bytes, bytearray)) + assert hasattr(client.upload_calls[0]["data_provider"], "read") + assert client.upload_calls[0]["filename"] == "test.txt" + assert client.upload_calls[0]["filesize"] == 5 + assert len(client.room_send_calls) == 2 + assert client.room_send_calls[0]["content"]["msgtype"] == "m.file" + assert client.room_send_calls[0]["content"]["url"] == "mxc://example.org/uploaded" + assert client.room_send_calls[1]["content"]["body"] == "Please review." + + +@pytest.mark.asyncio +async def test_send_adds_thread_relates_to_for_thread_metadata() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + metadata = { + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + } + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="Hi", + metadata=metadata, + ) + ) + + content = client.room_send_calls[0]["content"] + assert content["m.relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + + +@pytest.mark.asyncio +async def test_send_uses_encrypted_media_payload_in_encrypted_room(tmp_path) -> None: + channel = MatrixChannel(_make_config(e2ee_enabled=True), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.rooms["!encrypted:matrix.org"] = SimpleNamespace(encrypted=True) + channel.client = client + + file_path = tmp_path / "secret.txt" + file_path.write_text("topsecret", encoding="utf-8") + + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!encrypted:matrix.org", + content="", + media=[str(file_path)], + ) + ) + + assert len(client.upload_calls) == 1 + assert client.upload_calls[0]["encrypt"] is True + assert len(client.room_send_calls) == 1 + content = client.room_send_calls[0]["content"] + assert content["msgtype"] == "m.file" + assert "file" in content + assert "url" not in content + assert content["file"]["url"] == "mxc://example.org/uploaded" + assert content["file"]["hashes"]["sha256"] == "hash" + + +@pytest.mark.asyncio +async def test_send_does_not_parse_attachment_marker_without_media(tmp_path) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + missing_path = tmp_path / "missing.txt" + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content=f"[attachment: {missing_path}]", + ) + ) + + assert client.upload_calls == [] + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == f"[attachment: {missing_path}]" + + +@pytest.mark.asyncio +async def test_send_passes_thread_relates_to_to_attachment_upload(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + channel._server_upload_limit_checked = True + channel._server_upload_limit_bytes = None + + captured: dict[str, object] = {} + + async def _fake_upload_and_send_attachment( + *, + room_id: str, + path: Path, + limit_bytes: int, + relates_to: dict[str, object] | None = None, + ) -> str | None: + captured["relates_to"] = relates_to + return None + + monkeypatch.setattr(channel, "_upload_and_send_attachment", _fake_upload_and_send_attachment) + + metadata = { + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + } + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="Hi", + media=["/tmp/fake.txt"], + metadata=metadata, + ) + ) + + assert captured["relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + + +@pytest.mark.asyncio +async def test_send_workspace_restriction_blocks_external_attachment(tmp_path) -> None: + workspace = tmp_path / "workspace" + workspace.mkdir() + file_path = tmp_path / "external.txt" + file_path.write_text("outside", encoding="utf-8") + + channel = MatrixChannel( + _make_config(), + MessageBus(), + restrict_to_workspace=True, + workspace=workspace, + ) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="", + media=[str(file_path)], + ) + ) + + assert client.upload_calls == [] + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "[attachment: external.txt - upload failed]" + + +@pytest.mark.asyncio +async def test_send_handles_upload_exception_and_reports_failure(tmp_path) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.raise_on_upload = True + channel.client = client + + file_path = tmp_path / "broken.txt" + file_path.write_text("hello", encoding="utf-8") + + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="Please review.", + media=[str(file_path)], + ) + ) + + assert len(client.upload_calls) == 0 + assert len(client.room_send_calls) == 1 + assert ( + client.room_send_calls[0]["content"]["body"] + == "Please review.\n[attachment: broken.txt - upload failed]" + ) + + +@pytest.mark.asyncio +async def test_send_uses_server_upload_limit_when_smaller_than_local_limit(tmp_path) -> None: + channel = MatrixChannel(_make_config(max_media_bytes=10), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.content_repository_config_response = SimpleNamespace(upload_size=3) + channel.client = client + + file_path = tmp_path / "tiny.txt" + file_path.write_text("hello", encoding="utf-8") + + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="", + media=[str(file_path)], + ) + ) + + assert client.upload_calls == [] + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "[attachment: tiny.txt - too large]" + + +@pytest.mark.asyncio +async def test_send_blocks_all_outbound_media_when_limit_is_zero(tmp_path) -> None: + channel = MatrixChannel(_make_config(max_media_bytes=0), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + file_path = tmp_path / "empty.txt" + file_path.write_bytes(b"") + + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="", + media=[str(file_path)], + ) + ) + + assert client.upload_calls == [] + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "[attachment: empty.txt - too large]" + + +@pytest.mark.asyncio +async def test_send_omits_ignore_unverified_devices_when_e2ee_disabled() -> None: + channel = MatrixChannel(_make_config(e2ee_enabled=False), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi") + ) + + assert len(client.room_send_calls) == 1 + assert "ignore_unverified_devices" not in client.room_send_calls[0] + + +@pytest.mark.asyncio +async def test_send_stops_typing_keepalive_task() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + channel._running = True + + await channel._start_typing_keepalive("!room:matrix.org") + assert "!room:matrix.org" in channel._typing_tasks + + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi") + ) + + assert "!room:matrix.org" not in channel._typing_tasks + assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS) + + +@pytest.mark.asyncio +async def test_send_progress_keeps_typing_keepalive_running() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + channel._running = True + + await channel._start_typing_keepalive("!room:matrix.org") + assert "!room:matrix.org" in channel._typing_tasks + + await channel.send( + OutboundMessage( + channel="matrix", + chat_id="!room:matrix.org", + content="working...", + metadata={"_progress": True, "_progress_kind": "reasoning"}, + ) + ) + + assert "!room:matrix.org" in channel._typing_tasks + assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS) + + +@pytest.mark.asyncio +async def test_send_clears_typing_when_send_fails() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.raise_on_send = True + channel.client = client + + with pytest.raises(RuntimeError, match="send failed"): + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi") + ) + + assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS) + + +@pytest.mark.asyncio +async def test_send_adds_formatted_body_for_markdown() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + markdown_text = "# Headline\n\n- [x] done\n\n| A | B |\n| - | - |\n| 1 | 2 |" + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text) + ) + + content = client.room_send_calls[0]["content"] + assert content["msgtype"] == "m.text" + assert content["body"] == markdown_text + assert content["m.mentions"] == {} + assert content["format"] == MATRIX_HTML_FORMAT + assert "

Headline

" in str(content["formatted_body"]) + assert "" in str(content["formatted_body"]) + assert "
  • [x] done
  • " in str(content["formatted_body"]) + + +@pytest.mark.asyncio +async def test_send_adds_formatted_body_for_inline_url_superscript_subscript() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + markdown_text = "Visit https://example.com and x^2^ plus H~2~O." + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text) + ) + + content = client.room_send_calls[0]["content"] + assert content["msgtype"] == "m.text" + assert content["body"] == markdown_text + assert content["m.mentions"] == {} + assert content["format"] == MATRIX_HTML_FORMAT + assert '' in str( + content["formatted_body"] + ) + assert "2" in str(content["formatted_body"]) + assert "2" in str(content["formatted_body"]) + + +@pytest.mark.asyncio +async def test_send_sanitizes_disallowed_link_scheme() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + markdown_text = "[click](javascript:alert(1))" + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text) + ) + + formatted_body = str(client.room_send_calls[0]["content"]["formatted_body"]) + assert "javascript:" not in formatted_body + assert "x' + cleaned_html = matrix_module.MATRIX_HTML_CLEANER.clean(dirty_html) + + assert " None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + markdown_text = "![ok](mxc://example.org/mediaid) ![no](https://example.com/a.png)" + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text) + ) + + formatted_body = str(client.room_send_calls[0]["content"]["formatted_body"]) + assert 'src="mxc://example.org/mediaid"' in formatted_body + assert 'src="https://example.com/a.png"' not in formatted_body + + +@pytest.mark.asyncio +async def test_send_falls_back_to_plaintext_when_markdown_render_fails(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + def _raise(text: str) -> str: + raise RuntimeError("boom") + + monkeypatch.setattr(matrix_module, "MATRIX_MARKDOWN", _raise) + markdown_text = "# Headline" + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text) + ) + + content = client.room_send_calls[0]["content"] + assert content == {"msgtype": "m.text", "body": markdown_text, "m.mentions": {}} + + +@pytest.mark.asyncio +async def test_send_keeps_plaintext_only_for_plain_text() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + text = "just a normal sentence without markdown markers" + await channel.send( + OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=text) + ) + + assert client.room_send_calls[0]["content"] == { + "msgtype": "m.text", + "body": text, + "m.mentions": {}, + } diff --git a/tests/test_message_tool.py b/tests/test_message_tool.py new file mode 100644 index 0000000..dc8e11d --- /dev/null +++ b/tests/test_message_tool.py @@ -0,0 +1,10 @@ +import pytest + +from nanobot.agent.tools.message import MessageTool + + +@pytest.mark.asyncio +async def test_message_tool_returns_error_when_no_target_context() -> None: + tool = MessageTool() + result = await tool.execute(content="test") + assert result == "Error: No target channel/chat specified" diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py new file mode 100644 index 0000000..26b8a16 --- /dev/null +++ b/tests/test_message_tool_suppress.py @@ -0,0 +1,103 @@ +"""Test message tool suppress logic for final replies.""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.agent.tools.message import MessageTool +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMResponse, ToolCallRequest + + +def _make_loop(tmp_path: Path) -> AgentLoop: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) + + +class TestMessageToolSuppressLogic: + """Final reply suppressed only when message tool sends to the same target.""" + + @pytest.mark.asyncio + async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call1", name="message", + arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"}, + ) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send") + result = await loop._process_message(msg) + + assert len(sent) == 1 + assert result is None # suppressed + + @pytest.mark.asyncio + async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call1", name="message", + arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"}, + ) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="I've sent the email.", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email") + result = await loop._process_message(msg) + + assert len(sent) == 1 + assert sent[0].channel == "email" + assert result is not None # not suppressed + assert result.channel == "feishu" + + @pytest.mark.asyncio + async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) + loop.tools.get_definitions = MagicMock(return_value=[]) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") + result = await loop._process_message(msg) + + assert result is not None + assert "Hello" in result.content + + +class TestMessageToolTurnTracking: + + def test_sent_in_turn_tracks_same_target(self) -> None: + tool = MessageTool() + tool.set_context("feishu", "chat1") + assert not tool._sent_in_turn + tool._sent_in_turn = True + assert tool._sent_in_turn + + def test_start_turn_resets(self) -> None: + tool = MessageTool() + tool._sent_in_turn = True + tool.start_turn() + assert not tool._sent_in_turn diff --git a/tests/test_task_cancel.py b/tests/test_task_cancel.py new file mode 100644 index 0000000..27a2d73 --- /dev/null +++ b/tests/test_task_cancel.py @@ -0,0 +1,167 @@ +"""Tests for /stop task cancellation.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _make_loop(): + """Create a minimal AgentLoop with mocked dependencies.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + workspace = MagicMock() + workspace.__truediv__ = MagicMock(return_value=MagicMock()) + + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as MockSubMgr: + MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0) + loop = AgentLoop(bus=bus, provider=provider, workspace=workspace) + return loop, bus + + +class TestHandleStop: + @pytest.mark.asyncio + async def test_stop_no_active_task(self): + from nanobot.bus.events import InboundMessage + + loop, bus = _make_loop() + msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop") + await loop._handle_stop(msg) + out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + assert "No active task" in out.content + + @pytest.mark.asyncio + async def test_stop_cancels_active_task(self): + from nanobot.bus.events import InboundMessage + + loop, bus = _make_loop() + cancelled = asyncio.Event() + + async def slow_task(): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancelled.set() + raise + + task = asyncio.create_task(slow_task()) + await asyncio.sleep(0) + loop._active_tasks["test:c1"] = [task] + + msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop") + await loop._handle_stop(msg) + + assert cancelled.is_set() + out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + assert "stopped" in out.content.lower() + + @pytest.mark.asyncio + async def test_stop_cancels_multiple_tasks(self): + from nanobot.bus.events import InboundMessage + + loop, bus = _make_loop() + events = [asyncio.Event(), asyncio.Event()] + + async def slow(idx): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + events[idx].set() + raise + + tasks = [asyncio.create_task(slow(i)) for i in range(2)] + await asyncio.sleep(0) + loop._active_tasks["test:c1"] = tasks + + msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop") + await loop._handle_stop(msg) + + assert all(e.is_set() for e in events) + out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + assert "2 task" in out.content + + +class TestDispatch: + @pytest.mark.asyncio + async def test_dispatch_processes_and_publishes(self): + from nanobot.bus.events import InboundMessage, OutboundMessage + + loop, bus = _make_loop() + msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello") + loop._process_message = AsyncMock( + return_value=OutboundMessage(channel="test", chat_id="c1", content="hi") + ) + await loop._dispatch(msg) + out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + assert out.content == "hi" + + @pytest.mark.asyncio + async def test_processing_lock_serializes(self): + from nanobot.bus.events import InboundMessage, OutboundMessage + + loop, bus = _make_loop() + order = [] + + async def mock_process(m, **kwargs): + order.append(f"start-{m.content}") + await asyncio.sleep(0.05) + order.append(f"end-{m.content}") + return OutboundMessage(channel="test", chat_id="c1", content=m.content) + + loop._process_message = mock_process + msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a") + msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b") + + t1 = asyncio.create_task(loop._dispatch(msg1)) + t2 = asyncio.create_task(loop._dispatch(msg2)) + await asyncio.gather(t1, t2) + assert order == ["start-a", "end-a", "start-b", "end-b"] + + +class TestSubagentCancellation: + @pytest.mark.asyncio + async def test_cancel_by_session(self): + from nanobot.agent.subagent import SubagentManager + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) + + cancelled = asyncio.Event() + + async def slow(): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancelled.set() + raise + + task = asyncio.create_task(slow()) + await asyncio.sleep(0) + mgr._running_tasks["sub-1"] = task + mgr._session_tasks["test:c1"] = {"sub-1"} + + count = await mgr.cancel_by_session("test:c1") + assert count == 1 + assert cancelled.is_set() + + @pytest.mark.asyncio + async def test_cancel_by_session_no_tasks(self): + from nanobot.agent.subagent import SubagentManager + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) + assert await mgr.cancel_by_session("nonexistent") == 0 diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py index f11c667..cb50fb0 100644 --- a/tests/test_tool_validation.py +++ b/tests/test_tool_validation.py @@ -2,6 +2,7 @@ from typing import Any from nanobot.agent.tools.base import Tool from nanobot.agent.tools.registry import ToolRegistry +from nanobot.agent.tools.shell import ExecTool class SampleTool(Tool): @@ -86,3 +87,22 @@ async def test_registry_returns_validation_error() -> None: reg.register(SampleTool()) result = await reg.execute("sample", {"query": "hi"}) assert "Invalid parameters" in result + + +def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None: + cmd = r"type C:\user\workspace\txt" + paths = ExecTool._extract_absolute_paths(cmd) + assert paths == [r"C:\user\workspace\txt"] + + +def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None: + cmd = ".venv/bin/python script.py" + paths = ExecTool._extract_absolute_paths(cmd) + assert "/bin/python" not in paths + + +def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None: + cmd = "cat /tmp/data.txt > /tmp/out.txt" + paths = ExecTool._extract_absolute_paths(cmd) + assert "/tmp/data.txt" in paths + assert "/tmp/out.txt" in paths