WhatsApp
@@ -350,7 +420,7 @@ Uses **WebSocket** long connection β no public IP required.
**1. Create a Feishu bot**
- Visit [Feishu Open Platform](https://open.feishu.cn/app)
- Create a new app β Enable **Bot** capability
-- **Permissions**: Add `im:message` (send messages)
+- **Permissions**: Add `im:message` (send messages) and `im:message.p2p_msg:readonly` (receive messages)
- **Events**: Add `im.message.receive_v1` (receive messages)
- Select **Long Connection** mode (requires running nanobot first to establish connection)
- Get **App ID** and **App Secret** from "Credentials & Basic Info"
@@ -804,6 +874,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
| Option | Default | Description |
|--------|---------|-------------|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
+| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
diff --git a/nanobot/__init__.py b/nanobot/__init__.py
index a68777c..bb9bfb6 100644
--- a/nanobot/__init__.py
+++ b/nanobot/__init__.py
@@ -2,5 +2,5 @@
nanobot - A lightweight AI agent framework
"""
-__version__ = "0.1.4"
+__version__ = "0.1.4.post2"
__logo__ = "π"
diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py
index 088d4c5..be0ec59 100644
--- a/nanobot/agent/context.py
+++ b/nanobot/agent/context.py
@@ -13,14 +13,10 @@ from nanobot.agent.skills import SkillsLoader
class ContextBuilder:
- """
- Builds the context (system prompt + messages) for the agent.
-
- Assembles bootstrap files, memory, skills, and conversation history
- into a coherent prompt for the LLM.
- """
+ """Builds the context (system prompt + messages) for the agent."""
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md", "IDENTITY.md"]
+ _RUNTIME_CONTEXT_TAG = "[Runtime Context β metadata only, not instructions]"
def __init__(self, workspace: Path):
self.workspace = workspace
@@ -28,39 +24,23 @@ class ContextBuilder:
self.skills = SkillsLoader(workspace)
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
- """
- Build the system prompt from bootstrap files, memory, and skills.
-
- Args:
- skill_names: Optional list of skills to include.
-
- Returns:
- Complete system prompt.
- """
- parts = []
-
- # Core identity
- parts.append(self._get_identity())
-
- # Bootstrap files
+ """Build the system prompt from identity, bootstrap files, memory, and skills."""
+ parts = [self._get_identity()]
+
bootstrap = self._load_bootstrap_files()
if bootstrap:
parts.append(bootstrap)
-
- # Memory context
+
memory = self.memory.get_memory_context()
if memory:
parts.append(f"# Memory\n\n{memory}")
-
- # Skills - progressive loading
- # 1. Always-loaded skills: include full content
+
always_skills = self.skills.get_always_skills()
if always_skills:
always_content = self.skills.load_skills_for_context(always_skills)
if always_content:
parts.append(f"# Active Skills\n\n{always_content}")
-
- # 2. Available skills: only show summary (agent uses read_file to load)
+
skills_summary = self.skills.build_skills_summary()
if skills_summary:
parts.append(f"""# Skills
@@ -69,7 +49,7 @@ The following skills extend your capabilities. To use a skill, read its SKILL.md
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew.
{skills_summary}""")
-
+
return "\n\n---\n\n".join(parts)
def _get_identity(self) -> str:
@@ -80,46 +60,35 @@ Skills with available="false" need dependencies installed first - you can try in
return f"""# nanobot π
-You are nanobot, a helpful AI assistant.
+You are nanobot, a helpful AI assistant.
## Runtime
{runtime}
## Workspace
Your workspace is at: {workspace_path}
-- Long-term memory: {workspace_path}/memory/MEMORY.md
-- History log: {workspace_path}/memory/HISTORY.md (grep-searchable)
+- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
+- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
-Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel.
-
-## Tool Call Guidelines
-- Before calling tools, you may briefly state your intent (e.g. "Let me check that"), but NEVER predict or describe the expected result before receiving it.
-- Before modifying a file, read it first to confirm its current content.
-- Do not assume a file or directory exists β use list_dir or read_file to verify.
+## nanobot Guidelines
+- State intent before tool calls, but NEVER predict or claim results before receiving them.
+- Before modifying a file, read it first. Do not assume files or directories exist.
- After writing or editing a file, re-read it if accuracy matters.
- If a tool call fails, analyze the error before retrying with a different approach.
+- Ask for clarification when the request is ambiguous.
-## Memory
-- Remember important facts: write to {workspace_path}/memory/MEMORY.md
-- Recall past events: grep {workspace_path}/memory/HISTORY.md"""
+Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel."""
@staticmethod
- def _inject_runtime_context(
- user_content: str | list[dict[str, Any]],
- channel: str | None,
- chat_id: str | None,
- ) -> str | list[dict[str, Any]]:
- """Append dynamic runtime context to the tail of the user message."""
+ def _build_runtime_context(channel: str | None, chat_id: str | None) -> str:
+ """Build untrusted runtime metadata block for injection before the user message."""
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
tz = time.strftime("%Z") or "UTC"
lines = [f"Current Time: {now} ({tz})"]
if channel and chat_id:
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
- block = "[Runtime Context]\n" + "\n".join(lines)
- if isinstance(user_content, str):
- return f"{user_content}\n\n{block}"
- return [*user_content, {"type": "text", "text": block}]
+ return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
def _load_bootstrap_files(self) -> str:
"""Load all bootstrap files from workspace."""
@@ -142,35 +111,13 @@ Reply directly with text for conversations. Only use the 'message' tool to send
channel: str | None = None,
chat_id: str | None = None,
) -> list[dict[str, Any]]:
- """
- Build the complete message list for an LLM call.
-
- Args:
- history: Previous conversation messages.
- current_message: The new user message.
- skill_names: Optional skills to include.
- media: Optional list of local file paths for images/media.
- channel: Current channel (telegram, feishu, etc.).
- chat_id: Current chat/user ID.
-
- Returns:
- List of messages including system prompt.
- """
- messages = []
-
- # System prompt
- system_prompt = self.build_system_prompt(skill_names)
- messages.append({"role": "system", "content": system_prompt})
-
- # History
- messages.extend(history)
-
- # Current message (with optional image attachments)
- user_content = self._build_user_content(current_message, media)
- user_content = self._inject_runtime_context(user_content, channel, chat_id)
- messages.append({"role": "user", "content": user_content})
-
- return messages
+ """Build the complete message list for an LLM call."""
+ return [
+ {"role": "system", "content": self.build_system_prompt(skill_names)},
+ *history,
+ {"role": "user", "content": self._build_runtime_context(channel, chat_id)},
+ {"role": "user", "content": self._build_user_content(current_message, media)},
+ ]
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
"""Build user message content with optional base64-encoded images."""
@@ -191,63 +138,24 @@ Reply directly with text for conversations. Only use the 'message' tool to send
return images + [{"type": "text", "text": text}]
def add_tool_result(
- self,
- messages: list[dict[str, Any]],
- tool_call_id: str,
- tool_name: str,
- result: str
+ self, messages: list[dict[str, Any]],
+ tool_call_id: str, tool_name: str, result: str,
) -> list[dict[str, Any]]:
- """
- Add a tool result to the message list.
-
- Args:
- messages: Current message list.
- tool_call_id: ID of the tool call.
- tool_name: Name of the tool.
- result: Tool execution result.
-
- Returns:
- Updated message list.
- """
- messages.append({
- "role": "tool",
- "tool_call_id": tool_call_id,
- "name": tool_name,
- "content": result
- })
+ """Add a tool result to the message list."""
+ messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
return messages
def add_assistant_message(
- self,
- messages: list[dict[str, Any]],
+ self, messages: list[dict[str, Any]],
content: str | None,
tool_calls: list[dict[str, Any]] | None = None,
reasoning_content: str | None = None,
) -> list[dict[str, Any]]:
- """
- Add an assistant message to the message list.
-
- Args:
- messages: Current message list.
- content: Message content.
- tool_calls: Optional tool calls.
- reasoning_content: Thinking output (Kimi, DeepSeek-R1, etc.).
-
- Returns:
- Updated message list.
- """
- msg: dict[str, Any] = {"role": "assistant"}
-
- # Always include content β some providers (e.g. StepFun) reject
- # assistant messages that omit the key entirely.
- msg["content"] = content
-
+ """Add an assistant message to the message list."""
+ msg: dict[str, Any] = {"role": "assistant", "content": content}
if tool_calls:
msg["tool_calls"] = tool_calls
-
- # Include reasoning content when provided (required by some thinking models)
if reasoning_content is not None:
msg["reasoning_content"] = reasoning_content
-
messages.append(msg)
return messages
diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py
index 8be8e51..b42c3ba 100644
--- a/nanobot/agent/loop.py
+++ b/nanobot/agent/loop.py
@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import json
import re
+import weakref
from contextlib import AsyncExitStack
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -43,6 +44,8 @@ class AgentLoop:
5. Sends responses back
"""
+ _TOOL_RESULT_MAX_CHARS = 500
+
def __init__(
self,
bus: MessageBus,
@@ -53,6 +56,7 @@ class AgentLoop:
temperature: float = 0.1,
max_tokens: int = 4096,
memory_window: int = 100,
+ reasoning_effort: str | None = None,
brave_api_key: str | None = None,
exec_config: ExecToolConfig | None = None,
cron_service: CronService | None = None,
@@ -71,6 +75,7 @@ class AgentLoop:
self.temperature = temperature
self.max_tokens = max_tokens
self.memory_window = memory_window
+ self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key
self.exec_config = exec_config or ExecToolConfig()
self.cron_service = cron_service
@@ -86,6 +91,7 @@ class AgentLoop:
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
+ reasoning_effort=reasoning_effort,
brave_api_key=brave_api_key,
exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace,
@@ -98,7 +104,9 @@ class AgentLoop:
self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
- self._consolidation_locks: dict[str, asyncio.Lock] = {}
+ self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
+ self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
+ self._processing_lock = asyncio.Lock()
self._register_default_tools()
def _register_default_tools(self) -> None:
@@ -110,6 +118,7 @@ class AgentLoop:
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
+ path_append=self.exec_config.path_append,
))
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
self.tools.register(WebFetchTool())
@@ -142,17 +151,10 @@ class AgentLoop:
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Update context for all tools that need routing info."""
- if message_tool := self.tools.get("message"):
- if isinstance(message_tool, MessageTool):
- message_tool.set_context(channel, chat_id, message_id)
-
- if spawn_tool := self.tools.get("spawn"):
- if isinstance(spawn_tool, SpawnTool):
- spawn_tool.set_context(channel, chat_id)
-
- if cron_tool := self.tools.get("cron"):
- if isinstance(cron_tool, CronTool):
- cron_tool.set_context(channel, chat_id)
+ for name in ("message", "spawn", "cron"):
+ if tool := self.tools.get(name):
+ if hasattr(tool, "set_context"):
+ tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
@staticmethod
def _strip_think(text: str | None) -> str | None:
@@ -165,7 +167,8 @@ class AgentLoop:
def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc):
- val = next(iter(tc.arguments.values()), None) if tc.arguments else None
+ args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
+ val = next(iter(args.values()), None) if isinstance(args, dict) else None
if not isinstance(val, str):
return tc.name
return f'{tc.name}("{val[:40]}β¦")' if len(val) > 40 else f'{tc.name}("{val}")'
@@ -191,6 +194,7 @@ class AgentLoop:
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
+ reasoning_effort=self.reasoning_effort,
)
if response.has_tool_calls:
@@ -225,7 +229,17 @@ class AgentLoop:
messages, tool_call.id, tool_call.name, result
)
else:
- final_content = self._strip_think(response.content)
+ clean = self._strip_think(response.content)
+ # Don't persist error responses to session history β they can
+ # poison the context and cause permanent 400 loops (#1303).
+ if response.finish_reason == "error":
+ logger.error("LLM returned error: {}", (clean or "")[:200])
+ final_content = clean or "Sorry, I encountered an error calling the AI model."
+ break
+ messages = self.context.add_assistant_message(
+ messages, clean, reasoning_content=response.reasoning_content,
+ )
+ final_content = clean
break
if final_content is None and iteration >= self.max_iterations:
@@ -238,35 +252,62 @@ class AgentLoop:
return final_content, tools_used, messages
async def run(self) -> None:
- """Run the agent loop, processing messages from the bus."""
+ """Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
self._running = True
await self._connect_mcp()
logger.info("Agent loop started")
while self._running:
try:
- msg = await asyncio.wait_for(
- self.bus.consume_inbound(),
- timeout=1.0
- )
- try:
- response = await self._process_message(msg)
- if response is not None:
- await self.bus.publish_outbound(response)
- elif msg.channel == "cli":
- await self.bus.publish_outbound(OutboundMessage(
- channel=msg.channel, chat_id=msg.chat_id, content="", metadata=msg.metadata or {},
- ))
- except Exception as e:
- logger.error("Error processing message: {}", e)
- await self.bus.publish_outbound(OutboundMessage(
- channel=msg.channel,
- chat_id=msg.chat_id,
- content=f"Sorry, I encountered an error: {str(e)}"
- ))
+ msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
except asyncio.TimeoutError:
continue
+ if msg.content.strip().lower() == "/stop":
+ await self._handle_stop(msg)
+ else:
+ task = asyncio.create_task(self._dispatch(msg))
+ self._active_tasks.setdefault(msg.session_key, []).append(task)
+ task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
+
+ async def _handle_stop(self, msg: InboundMessage) -> None:
+ """Cancel all active tasks and subagents for the session."""
+ tasks = self._active_tasks.pop(msg.session_key, [])
+ cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
+ for t in tasks:
+ try:
+ await t
+ except (asyncio.CancelledError, Exception):
+ pass
+ sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
+ total = cancelled + sub_cancelled
+ content = f"βΉ Stopped {total} task(s)." if total else "No active task to stop."
+ await self.bus.publish_outbound(OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id, content=content,
+ ))
+
+ async def _dispatch(self, msg: InboundMessage) -> None:
+ """Process a message under the global lock."""
+ async with self._processing_lock:
+ try:
+ response = await self._process_message(msg)
+ if response is not None:
+ await self.bus.publish_outbound(response)
+ elif msg.channel == "cli":
+ await self.bus.publish_outbound(OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id,
+ content="", metadata=msg.metadata or {},
+ ))
+ except asyncio.CancelledError:
+ logger.info("Task cancelled for session {}", msg.session_key)
+ raise
+ except Exception:
+ logger.exception("Error processing message for session {}", msg.session_key)
+ await self.bus.publish_outbound(OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id,
+ content="Sorry, I encountered an error.",
+ ))
+
async def close_mcp(self) -> None:
"""Close MCP connections."""
if self._mcp_stack:
@@ -281,18 +322,6 @@ class AgentLoop:
self._running = False
logger.info("Agent loop stopping")
- def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
- lock = self._consolidation_locks.get(session_key)
- if lock is None:
- lock = asyncio.Lock()
- self._consolidation_locks[session_key] = lock
- return lock
-
- def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
- """Drop lock entry if no longer in use."""
- if not lock.locked():
- self._consolidation_locks.pop(session_key, None)
-
async def _process_message(
self,
msg: InboundMessage,
@@ -328,7 +357,7 @@ class AgentLoop:
# Slash commands
cmd = msg.content.strip().lower()
if cmd == "/new":
- lock = self._get_consolidation_lock(session.key)
+ lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
self._consolidating.add(session.key)
try:
async with lock:
@@ -349,7 +378,6 @@ class AgentLoop:
)
finally:
self._consolidating.discard(session.key)
- self._prune_consolidation_lock(session.key, lock)
session.clear()
self.sessions.save(session)
@@ -358,12 +386,12 @@ class AgentLoop:
content="New session started.")
if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
- content="π nanobot commands:\n/new β Start a new conversation\n/help β Show available commands")
+ content="π nanobot commands:\n/new β Start a new conversation\n/stop β Stop the current task\n/help β Show available commands")
unconsolidated = len(session.messages) - session.last_consolidated
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
self._consolidating.add(session.key)
- lock = self._get_consolidation_lock(session.key)
+ lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
async def _consolidate_and_unlock():
try:
@@ -371,7 +399,6 @@ class AgentLoop:
await self._consolidate_memory(session)
finally:
self._consolidating.discard(session.key)
- self._prune_consolidation_lock(session.key, lock)
_task = asyncio.current_task()
if _task is not None:
self._consolidation_tasks.discard(_task)
@@ -407,32 +434,39 @@ class AgentLoop:
if final_content is None:
final_content = "I've completed processing but have no response to give."
- preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
- logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
-
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
- if message_tool := self.tools.get("message"):
- if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
- return None
+ if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
+ return None
+ preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
+ logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
metadata=msg.metadata or {},
)
- _TOOL_RESULT_MAX_CHARS = 500
-
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
for m in messages[skip:]:
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
- if entry.get("role") == "tool" and isinstance(entry.get("content"), str):
- content = entry["content"]
- if len(content) > self._TOOL_RESULT_MAX_CHARS:
- entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
+ role, content = entry.get("role"), entry.get("content")
+ if role == "assistant" and not content and not entry.get("tool_calls"):
+ continue # skip empty assistant messages β they poison session context
+ if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
+ entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
+ elif role == "user":
+ if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
+ continue
+ if isinstance(content, list):
+ entry["content"] = [
+ {"type": "text", "text": "[image]"} if (
+ c.get("type") == "image_url"
+ and c.get("image_url", {}).get("url", "").startswith("data:image/")
+ ) else c for c in content
+ ]
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
session.updated_at = datetime.now()
diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py
index d87c61a..a99ba4d 100644
--- a/nanobot/agent/subagent.py
+++ b/nanobot/agent/subagent.py
@@ -18,13 +18,7 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
class SubagentManager:
- """
- Manages background subagent execution.
-
- Subagents are lightweight agent instances that run in the background
- to handle specific tasks. They share the same LLM provider but have
- isolated context and a focused system prompt.
- """
+ """Manages background subagent execution."""
def __init__(
self,
@@ -34,6 +28,7 @@ class SubagentManager:
model: str | None = None,
temperature: float = 0.7,
max_tokens: int = 4096,
+ reasoning_effort: str | None = None,
brave_api_key: str | None = None,
exec_config: "ExecToolConfig | None" = None,
restrict_to_workspace: bool = False,
@@ -45,10 +40,12 @@ class SubagentManager:
self.model = model or provider.get_default_model()
self.temperature = temperature
self.max_tokens = max_tokens
+ self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key
self.exec_config = exec_config or ExecToolConfig()
self.restrict_to_workspace = restrict_to_workspace
self._running_tasks: dict[str, asyncio.Task[None]] = {}
+ self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
async def spawn(
self,
@@ -56,35 +53,28 @@ class SubagentManager:
label: str | None = None,
origin_channel: str = "cli",
origin_chat_id: str = "direct",
+ session_key: str | None = None,
) -> str:
- """
- Spawn a subagent to execute a task in the background.
-
- Args:
- task: The task description for the subagent.
- label: Optional human-readable label for the task.
- origin_channel: The channel to announce results to.
- origin_chat_id: The chat ID to announce results to.
-
- Returns:
- Status message indicating the subagent was started.
- """
+ """Spawn a subagent to execute a task in the background."""
task_id = str(uuid.uuid4())[:8]
display_label = label or task[:30] + ("..." if len(task) > 30 else "")
-
- origin = {
- "channel": origin_channel,
- "chat_id": origin_chat_id,
- }
-
- # Create background task
+ origin = {"channel": origin_channel, "chat_id": origin_chat_id}
+
bg_task = asyncio.create_task(
self._run_subagent(task_id, task, display_label, origin)
)
self._running_tasks[task_id] = bg_task
-
- # Cleanup when done
- bg_task.add_done_callback(lambda _: self._running_tasks.pop(task_id, None))
+ if session_key:
+ self._session_tasks.setdefault(session_key, set()).add(task_id)
+
+ def _cleanup(_: asyncio.Task) -> None:
+ self._running_tasks.pop(task_id, None)
+ if session_key and (ids := self._session_tasks.get(session_key)):
+ ids.discard(task_id)
+ if not ids:
+ del self._session_tasks[session_key]
+
+ bg_task.add_done_callback(_cleanup)
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
@@ -111,12 +101,12 @@ class SubagentManager:
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
+ path_append=self.exec_config.path_append,
))
tools.register(WebSearchTool(api_key=self.brave_api_key))
tools.register(WebFetchTool())
- # Build messages with subagent-specific prompt
- system_prompt = self._build_subagent_prompt(task)
+ system_prompt = self._build_subagent_prompt()
messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task},
@@ -136,6 +126,7 @@ class SubagentManager:
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
+ reasoning_effort=self.reasoning_effort,
)
if response.has_tool_calls:
@@ -215,43 +206,38 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
await self.bus.publish_inbound(msg)
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
- def _build_subagent_prompt(self, task: str) -> str:
+ def _build_subagent_prompt(self) -> str:
"""Build a focused system prompt for the subagent."""
- from datetime import datetime
- import time as _time
- now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
- tz = _time.strftime("%Z") or "UTC"
+ from nanobot.agent.context import ContextBuilder
+ from nanobot.agent.skills import SkillsLoader
- return f"""# Subagent
+ time_ctx = ContextBuilder._build_runtime_context(None, None)
+ parts = [f"""# Subagent
-## Current Time
-{now} ({tz})
+{time_ctx}
You are a subagent spawned by the main agent to complete a specific task.
-
-## Rules
-1. Stay focused - complete only the assigned task, nothing else
-2. Your final response will be reported back to the main agent
-3. Do not initiate conversations or take on side tasks
-4. Be concise but informative in your findings
-
-## What You Can Do
-- Read and write files in the workspace
-- Execute shell commands
-- Search the web and fetch web pages
-- Complete the task thoroughly
-
-## What You Cannot Do
-- Send messages directly to users (no message tool available)
-- Spawn other subagents
-- Access the main agent's conversation history
+Stay focused on the assigned task. Your final response will be reported back to the main agent.
## Workspace
-Your workspace is at: {self.workspace}
-Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
+{self.workspace}"""]
-When you have completed the task, provide a clear summary of your findings or actions."""
+ skills_summary = SkillsLoader(self.workspace).build_skills_summary()
+ if skills_summary:
+ parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
+
+ return "\n\n".join(parts)
+ async def cancel_by_session(self, session_key: str) -> int:
+ """Cancel all subagents for the given session. Returns count cancelled."""
+ tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
+ if tid in self._running_tasks and not self._running_tasks[tid].done()]
+ for t in tasks:
+ t.cancel()
+ if tasks:
+ await asyncio.gather(*tasks, return_exceptions=True)
+ return len(tasks)
+
def get_running_count(self) -> int:
"""Return the number of currently running subagents."""
return len(self._running_tasks)
diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py
index 40e76e3..35e519a 100644
--- a/nanobot/agent/tools/message.py
+++ b/nanobot/agent/tools/message.py
@@ -101,7 +101,8 @@ class MessageTool(Tool):
try:
await self._send_callback(msg)
- self._sent_in_turn = True
+ if channel == self._default_channel and chat_id == self._default_chat_id:
+ self._sent_in_turn = True
media_info = f" with {len(media)} attachments" if media else ""
return f"Message sent to {channel}:{chat_id}{media_info}"
except Exception as e:
diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py
index e3592a7..6b57874 100644
--- a/nanobot/agent/tools/shell.py
+++ b/nanobot/agent/tools/shell.py
@@ -19,6 +19,7 @@ class ExecTool(Tool):
deny_patterns: list[str] | None = None,
allow_patterns: list[str] | None = None,
restrict_to_workspace: bool = False,
+ path_append: str = "",
):
self.timeout = timeout
self.working_dir = working_dir
@@ -35,6 +36,7 @@ class ExecTool(Tool):
]
self.allow_patterns = allow_patterns or []
self.restrict_to_workspace = restrict_to_workspace
+ self.path_append = path_append
@property
def name(self) -> str:
@@ -67,12 +69,17 @@ class ExecTool(Tool):
if guard_error:
return guard_error
+ env = os.environ.copy()
+ if self.path_append:
+ env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
+
try:
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
+ env=env,
)
try:
@@ -134,13 +141,7 @@ class ExecTool(Tool):
cwd_path = Path(cwd).resolve()
- win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
- # Only match absolute paths β avoid false positives on relative
- # paths like ".venv/bin/python" where "/bin/python" would be
- # incorrectly extracted by the old pattern.
- posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", cmd)
-
- for raw in win_paths + posix_paths:
+ for raw in self._extract_absolute_paths(cmd):
try:
p = Path(raw.strip()).resolve()
except Exception:
@@ -149,3 +150,9 @@ class ExecTool(Tool):
return "Error: Command blocked by safety guard (path outside working dir)"
return None
+
+ @staticmethod
+ def _extract_absolute_paths(command: str) -> list[str]:
+ win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
+ posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only
+ return win_paths + posix_paths
diff --git a/nanobot/agent/tools/spawn.py b/nanobot/agent/tools/spawn.py
index 33cf8e7..fb816ca 100644
--- a/nanobot/agent/tools/spawn.py
+++ b/nanobot/agent/tools/spawn.py
@@ -15,11 +15,13 @@ class SpawnTool(Tool):
self._manager = manager
self._origin_channel = "cli"
self._origin_chat_id = "direct"
+ self._session_key = "cli:direct"
def set_context(self, channel: str, chat_id: str) -> None:
"""Set the origin context for subagent announcements."""
self._origin_channel = channel
self._origin_chat_id = chat_id
+ self._session_key = f"{channel}:{chat_id}"
@property
def name(self) -> str:
@@ -57,4 +59,5 @@ class SpawnTool(Tool):
label=label,
origin_channel=self._origin_channel,
origin_chat_id=self._origin_chat_id,
+ session_key=self._session_key,
)
diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py
index 56956c3..7860f12 100644
--- a/nanobot/agent/tools/web.py
+++ b/nanobot/agent/tools/web.py
@@ -80,7 +80,7 @@ class WebSearchTool(Tool):
r = await client.get(
"https://api.search.brave.com/res/v1/web/search",
params={"q": query, "count": n},
- headers={"Accept": "application/json", "X-Subscription-Token": api_key},
+ headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
timeout=10.0
)
r.raise_for_status()
diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py
index 09c7714..2797029 100644
--- a/nanobot/channels/dingtalk.py
+++ b/nanobot/channels/dingtalk.py
@@ -2,8 +2,12 @@
import asyncio
import json
+import mimetypes
+import os
import time
+from pathlib import Path
from typing import Any
+from urllib.parse import unquote, urlparse
from loguru import logger
import httpx
@@ -96,6 +100,9 @@ class DingTalkChannel(BaseChannel):
"""
name = "dingtalk"
+ _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
+ _AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
+ _VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
def __init__(self, config: DingTalkConfig, bus: MessageBus):
super().__init__(config, bus)
@@ -191,40 +198,224 @@ class DingTalkChannel(BaseChannel):
logger.error("Failed to get DingTalk access token: {}", e)
return None
+ @staticmethod
+ def _is_http_url(value: str) -> bool:
+ return urlparse(value).scheme in ("http", "https")
+
+ def _guess_upload_type(self, media_ref: str) -> str:
+ ext = Path(urlparse(media_ref).path).suffix.lower()
+ if ext in self._IMAGE_EXTS: return "image"
+ if ext in self._AUDIO_EXTS: return "voice"
+ if ext in self._VIDEO_EXTS: return "video"
+ return "file"
+
+ def _guess_filename(self, media_ref: str, upload_type: str) -> str:
+ name = os.path.basename(urlparse(media_ref).path)
+ return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
+
+ async def _read_media_bytes(
+ self,
+ media_ref: str,
+ ) -> tuple[bytes | None, str | None, str | None]:
+ if not media_ref:
+ return None, None, None
+
+ if self._is_http_url(media_ref):
+ if not self._http:
+ return None, None, None
+ try:
+ resp = await self._http.get(media_ref, follow_redirects=True)
+ if resp.status_code >= 400:
+ logger.warning(
+ "DingTalk media download failed status={} ref={}",
+ resp.status_code,
+ media_ref,
+ )
+ return None, None, None
+ content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
+ filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
+ return resp.content, filename, content_type or None
+ except Exception as e:
+ logger.error("DingTalk media download error ref={} err={}", media_ref, e)
+ return None, None, None
+
+ try:
+ if media_ref.startswith("file://"):
+ parsed = urlparse(media_ref)
+ local_path = Path(unquote(parsed.path))
+ else:
+ local_path = Path(os.path.expanduser(media_ref))
+ if not local_path.is_file():
+ logger.warning("DingTalk media file not found: {}", local_path)
+ return None, None, None
+ data = await asyncio.to_thread(local_path.read_bytes)
+ content_type = mimetypes.guess_type(local_path.name)[0]
+ return data, local_path.name, content_type
+ except Exception as e:
+ logger.error("DingTalk media read error ref={} err={}", media_ref, e)
+ return None, None, None
+
+ async def _upload_media(
+ self,
+ token: str,
+ data: bytes,
+ media_type: str,
+ filename: str,
+ content_type: str | None,
+ ) -> str | None:
+ if not self._http:
+ return None
+ url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
+ mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
+ files = {"media": (filename, data, mime)}
+
+ try:
+ resp = await self._http.post(url, files=files)
+ text = resp.text
+ result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
+ if resp.status_code >= 400:
+ logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
+ return None
+ errcode = result.get("errcode", 0)
+ if errcode != 0:
+ logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
+ return None
+ sub = result.get("result") or {}
+ media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
+ if not media_id:
+ logger.error("DingTalk media upload missing media_id body={}", text[:500])
+ return None
+ return str(media_id)
+ except Exception as e:
+ logger.error("DingTalk media upload error type={} err={}", media_type, e)
+ return None
+
+ async def _send_batch_message(
+ self,
+ token: str,
+ chat_id: str,
+ msg_key: str,
+ msg_param: dict[str, Any],
+ ) -> bool:
+ if not self._http:
+ logger.warning("DingTalk HTTP client not initialized, cannot send")
+ return False
+
+ url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
+ headers = {"x-acs-dingtalk-access-token": token}
+ payload = {
+ "robotCode": self.config.client_id,
+ "userIds": [chat_id],
+ "msgKey": msg_key,
+ "msgParam": json.dumps(msg_param, ensure_ascii=False),
+ }
+
+ try:
+ resp = await self._http.post(url, json=payload, headers=headers)
+ body = resp.text
+ if resp.status_code != 200:
+ logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
+ return False
+ try: result = resp.json()
+ except Exception: result = {}
+ errcode = result.get("errcode")
+ if errcode not in (None, 0):
+ logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
+ return False
+ logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
+ return True
+ except Exception as e:
+ logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
+ return False
+
+ async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
+ return await self._send_batch_message(
+ token,
+ chat_id,
+ "sampleMarkdown",
+ {"text": content, "title": "Nanobot Reply"},
+ )
+
+ async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
+ media_ref = (media_ref or "").strip()
+ if not media_ref:
+ return True
+
+ upload_type = self._guess_upload_type(media_ref)
+ if upload_type == "image" and self._is_http_url(media_ref):
+ ok = await self._send_batch_message(
+ token,
+ chat_id,
+ "sampleImageMsg",
+ {"photoURL": media_ref},
+ )
+ if ok:
+ return True
+ logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
+
+ data, filename, content_type = await self._read_media_bytes(media_ref)
+ if not data:
+ logger.error("DingTalk media read failed: {}", media_ref)
+ return False
+
+ filename = filename or self._guess_filename(media_ref, upload_type)
+ file_type = Path(filename).suffix.lower().lstrip(".")
+ if not file_type:
+ guessed = mimetypes.guess_extension(content_type or "")
+ file_type = (guessed or ".bin").lstrip(".")
+ if file_type == "jpeg":
+ file_type = "jpg"
+
+ media_id = await self._upload_media(
+ token=token,
+ data=data,
+ media_type=upload_type,
+ filename=filename,
+ content_type=content_type,
+ )
+ if not media_id:
+ return False
+
+ if upload_type == "image":
+ # Verified in production: sampleImageMsg accepts media_id in photoURL.
+ ok = await self._send_batch_message(
+ token,
+ chat_id,
+ "sampleImageMsg",
+ {"photoURL": media_id},
+ )
+ if ok:
+ return True
+ logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
+
+ return await self._send_batch_message(
+ token,
+ chat_id,
+ "sampleFile",
+ {"mediaId": media_id, "fileName": filename, "fileType": file_type},
+ )
+
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through DingTalk."""
token = await self._get_access_token()
if not token:
return
- # oToMessages/batchSend: sends to individual users (private chat)
- # https://open.dingtalk.com/document/orgapp/robot-batch-send-messages
- url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
+ if msg.content and msg.content.strip():
+ await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
- headers = {"x-acs-dingtalk-access-token": token}
-
- data = {
- "robotCode": self.config.client_id,
- "userIds": [msg.chat_id], # chat_id is the user's staffId
- "msgKey": "sampleMarkdown",
- "msgParam": json.dumps({
- "text": msg.content,
- "title": "Nanobot Reply",
- }, ensure_ascii=False),
- }
-
- if not self._http:
- logger.warning("DingTalk HTTP client not initialized, cannot send")
- return
-
- try:
- resp = await self._http.post(url, json=data, headers=headers)
- if resp.status_code != 200:
- logger.error("DingTalk send failed: {}", resp.text)
- else:
- logger.debug("DingTalk message sent to {}", msg.chat_id)
- except Exception as e:
- logger.error("Error sending DingTalk message: {}", e)
+ for media_ref in msg.media or []:
+ ok = await self._send_media_ref(token, msg.chat_id, media_ref)
+ if ok:
+ continue
+ logger.error("DingTalk media send failed for {}", media_ref)
+ # Send visible fallback so failures are observable by the user.
+ filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
+ await self._send_markdown_text(
+ token,
+ msg.chat_id,
+ f"[Attachment send failed: {filename}]",
+ )
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
"""Handle incoming message (called by NanobotDingTalkHandler).
diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py
index 480bf7b..c632fb7 100644
--- a/nanobot/channels/feishu.py
+++ b/nanobot/channels/feishu.py
@@ -89,8 +89,9 @@ def _extract_interactive_content(content: dict) -> list[str]:
elif isinstance(title, str):
parts.append(f"title: {title}")
- for element in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
- parts.extend(_extract_element_content(element))
+ for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
+ for element in elements:
+ parts.extend(_extract_element_content(element))
card = content.get("card", {})
if card:
@@ -325,13 +326,14 @@ class FeishuChannel(BaseChannel):
await asyncio.sleep(1)
async def stop(self) -> None:
- """Stop the Feishu bot."""
+ """
+ Stop the Feishu bot.
+
+ Notice: lark.ws.Client does not expose stop methodοΌ simply exiting the program will close the client.
+
+ Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
+ """
self._running = False
- if self._ws_client:
- try:
- self._ws_client.stop()
- except Exception as e:
- logger.warning("Error stopping WebSocket client: {}", e)
logger.info("Feishu bot stopped")
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
@@ -692,7 +694,7 @@ class FeishuChannel(BaseChannel):
msg_type = message.message_type
# Add reaction
- await self._add_reaction(message_id, "THUMBSUP")
+ await self._add_reaction(message_id, self.config.react_emoji)
# Parse content
content_parts = []
diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py
index 77b7294..c8df6b2 100644
--- a/nanobot/channels/manager.py
+++ b/nanobot/channels/manager.py
@@ -136,6 +136,18 @@ class ChannelManager:
logger.info("QQ channel enabled")
except ImportError as e:
logger.warning("QQ channel not available: {}", e)
+
+ # Matrix channel
+ if self.config.channels.matrix.enabled:
+ try:
+ from nanobot.channels.matrix import MatrixChannel
+ self.channels["matrix"] = MatrixChannel(
+ self.config.channels.matrix,
+ self.bus,
+ )
+ logger.info("Matrix channel enabled")
+ except ImportError as e:
+ logger.warning("Matrix channel not available: {}", e)
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
"""Start a channel and log any exceptions."""
diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py
new file mode 100644
index 0000000..21192e9
--- /dev/null
+++ b/nanobot/channels/matrix.py
@@ -0,0 +1,682 @@
+"""Matrix (Element) channel β inbound sync + outbound message/media delivery."""
+
+import asyncio
+import logging
+import mimetypes
+from pathlib import Path
+from typing import Any, TypeAlias
+
+from loguru import logger
+
+try:
+ import nh3
+ from mistune import create_markdown
+ from nio import (
+ AsyncClient, AsyncClientConfig, ContentRepositoryConfigError,
+ DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse,
+ RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText,
+ RoomSendError, RoomTypingError, SyncError, UploadError,
+ )
+ from nio.crypto.attachments import decrypt_attachment
+ from nio.exceptions import EncryptionError
+except ImportError as e:
+ raise ImportError(
+ "Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
+ ) from e
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.channels.base import BaseChannel
+from nanobot.config.loader import get_data_dir
+from nanobot.utils.helpers import safe_filename
+
+TYPING_NOTICE_TIMEOUT_MS = 30_000
+# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
+TYPING_KEEPALIVE_INTERVAL_MS = 20_000
+MATRIX_HTML_FORMAT = "org.matrix.custom.html"
+_ATTACH_MARKER = "[attachment: {}]"
+_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
+_ATTACH_FAILED = "[attachment: {} - download failed]"
+_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
+_DEFAULT_ATTACH_NAME = "attachment"
+_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
+
+MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
+MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
+
+MATRIX_MARKDOWN = create_markdown(
+ escape=True,
+ plugins=["table", "strikethrough", "url", "superscript", "subscript"],
+)
+
+MATRIX_ALLOWED_HTML_TAGS = {
+ "p", "a", "strong", "em", "del", "code", "pre", "blockquote",
+ "ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
+ "hr", "br", "table", "thead", "tbody", "tr", "th", "td",
+ "caption", "sup", "sub", "img",
+}
+MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
+ "a": {"href"}, "code": {"class"}, "ol": {"start"},
+ "img": {"src", "alt", "title", "width", "height"},
+}
+MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
+
+
+def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
+ """Filter attribute values to a safe Matrix-compatible subset."""
+ if tag == "a" and attr == "href":
+ return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
+ if tag == "img" and attr == "src":
+ return value if value.lower().startswith("mxc://") else None
+ if tag == "code" and attr == "class":
+ classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
+ return " ".join(classes) if classes else None
+ return value
+
+
+MATRIX_HTML_CLEANER = nh3.Cleaner(
+ tags=MATRIX_ALLOWED_HTML_TAGS,
+ attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
+ attribute_filter=_filter_matrix_html_attribute,
+ url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
+ strip_comments=True,
+ link_rel="noopener noreferrer",
+)
+
+
+def _render_markdown_html(text: str) -> str | None:
+ """Render markdown to sanitized HTML; returns None for plain text."""
+ try:
+ formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
+ except Exception:
+ return None
+ if not formatted:
+ return None
+ # Skip formatted_body for plain text
to keep payload minimal.
+ if formatted.startswith("") and formatted.endswith("
"):
+ inner = formatted[3:-4]
+ if "<" not in inner and ">" not in inner:
+ return None
+ return formatted
+
+
+def _build_matrix_text_content(text: str) -> dict[str, object]:
+ """Build Matrix m.text payload with optional HTML formatted_body."""
+ content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
+ if html := _render_markdown_html(text):
+ content["format"] = MATRIX_HTML_FORMAT
+ content["formatted_body"] = html
+ return content
+
+
+class _NioLoguruHandler(logging.Handler):
+ """Route matrix-nio stdlib logs into Loguru."""
+
+ def emit(self, record: logging.LogRecord) -> None:
+ try:
+ level = logger.level(record.levelname).name
+ except ValueError:
+ level = record.levelno
+ frame, depth = logging.currentframe(), 2
+ while frame and frame.f_code.co_filename == logging.__file__:
+ frame, depth = frame.f_back, depth + 1
+ logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
+
+
+def _configure_nio_logging_bridge() -> None:
+ """Bridge matrix-nio logs to Loguru (idempotent)."""
+ nio_logger = logging.getLogger("nio")
+ if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
+ nio_logger.handlers = [_NioLoguruHandler()]
+ nio_logger.propagate = False
+
+
+class MatrixChannel(BaseChannel):
+ """Matrix (Element) channel using long-polling sync."""
+
+ name = "matrix"
+
+ def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
+ workspace: Path | None = None):
+ super().__init__(config, bus)
+ self.client: AsyncClient | None = None
+ self._sync_task: asyncio.Task | None = None
+ self._typing_tasks: dict[str, asyncio.Task] = {}
+ self._restrict_to_workspace = restrict_to_workspace
+ self._workspace = workspace.expanduser().resolve() if workspace else None
+ self._server_upload_limit_bytes: int | None = None
+ self._server_upload_limit_checked = False
+
+ async def start(self) -> None:
+ """Start Matrix client and begin sync loop."""
+ self._running = True
+ _configure_nio_logging_bridge()
+
+ store_path = get_data_dir() / "matrix-store"
+ store_path.mkdir(parents=True, exist_ok=True)
+
+ self.client = AsyncClient(
+ homeserver=self.config.homeserver, user=self.config.user_id,
+ store_path=store_path,
+ config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
+ )
+ self.client.user_id = self.config.user_id
+ self.client.access_token = self.config.access_token
+ self.client.device_id = self.config.device_id
+
+ self._register_event_callbacks()
+ self._register_response_callbacks()
+
+ if not self.config.e2ee_enabled:
+ logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
+
+ if self.config.device_id:
+ try:
+ self.client.load_store()
+ except Exception:
+ logger.exception("Matrix store load failed; restart may replay recent messages.")
+ else:
+ logger.warning("Matrix device_id empty; restart may replay recent messages.")
+
+ self._sync_task = asyncio.create_task(self._sync_loop())
+
+ async def stop(self) -> None:
+ """Stop the Matrix channel with graceful sync shutdown."""
+ self._running = False
+ for room_id in list(self._typing_tasks):
+ await self._stop_typing_keepalive(room_id, clear_typing=False)
+ if self.client:
+ self.client.stop_sync_forever()
+ if self._sync_task:
+ try:
+ await asyncio.wait_for(asyncio.shield(self._sync_task),
+ timeout=self.config.sync_stop_grace_seconds)
+ except (asyncio.TimeoutError, asyncio.CancelledError):
+ self._sync_task.cancel()
+ try:
+ await self._sync_task
+ except asyncio.CancelledError:
+ pass
+ if self.client:
+ await self.client.close()
+
+ def _is_workspace_path_allowed(self, path: Path) -> bool:
+ """Check path is inside workspace (when restriction enabled)."""
+ if not self._restrict_to_workspace or not self._workspace:
+ return True
+ try:
+ path.resolve(strict=False).relative_to(self._workspace)
+ return True
+ except ValueError:
+ return False
+
+ def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
+ """Deduplicate and resolve outbound attachment paths."""
+ seen: set[str] = set()
+ candidates: list[Path] = []
+ for raw in media:
+ if not isinstance(raw, str) or not raw.strip():
+ continue
+ path = Path(raw.strip()).expanduser()
+ try:
+ key = str(path.resolve(strict=False))
+ except OSError:
+ key = str(path)
+ if key not in seen:
+ seen.add(key)
+ candidates.append(path)
+ return candidates
+
+ @staticmethod
+ def _build_outbound_attachment_content(
+ *, filename: str, mime: str, size_bytes: int,
+ mxc_url: str, encryption_info: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ """Build Matrix content payload for an uploaded file/image/audio/video."""
+ prefix = mime.split("/")[0]
+ msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
+ content: dict[str, Any] = {
+ "msgtype": msgtype, "body": filename, "filename": filename,
+ "info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
+ }
+ if encryption_info:
+ content["file"] = {**encryption_info, "url": mxc_url}
+ else:
+ content["url"] = mxc_url
+ return content
+
+ def _is_encrypted_room(self, room_id: str) -> bool:
+ if not self.client:
+ return False
+ room = getattr(self.client, "rooms", {}).get(room_id)
+ return bool(getattr(room, "encrypted", False))
+
+ async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
+ """Send m.room.message with E2EE options."""
+ if not self.client:
+ return
+ kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
+ if self.config.e2ee_enabled:
+ kwargs["ignore_unverified_devices"] = True
+ await self.client.room_send(**kwargs)
+
+ async def _resolve_server_upload_limit_bytes(self) -> int | None:
+ """Query homeserver upload limit once per channel lifecycle."""
+ if self._server_upload_limit_checked:
+ return self._server_upload_limit_bytes
+ self._server_upload_limit_checked = True
+ if not self.client:
+ return None
+ try:
+ response = await self.client.content_repository_config()
+ except Exception:
+ return None
+ upload_size = getattr(response, "upload_size", None)
+ if isinstance(upload_size, int) and upload_size > 0:
+ self._server_upload_limit_bytes = upload_size
+ return upload_size
+ return None
+
+ async def _effective_media_limit_bytes(self) -> int:
+ """min(local config, server advertised) β 0 blocks all uploads."""
+ local_limit = max(int(self.config.max_media_bytes), 0)
+ server_limit = await self._resolve_server_upload_limit_bytes()
+ if server_limit is None:
+ return local_limit
+ return min(local_limit, server_limit) if local_limit else 0
+
+ async def _upload_and_send_attachment(
+ self, room_id: str, path: Path, limit_bytes: int,
+ relates_to: dict[str, Any] | None = None,
+ ) -> str | None:
+ """Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
+ if not self.client:
+ return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
+
+ resolved = path.expanduser().resolve(strict=False)
+ filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
+ fail = _ATTACH_UPLOAD_FAILED.format(filename)
+
+ if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
+ return fail
+ try:
+ size_bytes = resolved.stat().st_size
+ except OSError:
+ return fail
+ if limit_bytes <= 0 or size_bytes > limit_bytes:
+ return _ATTACH_TOO_LARGE.format(filename)
+
+ mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
+ try:
+ with resolved.open("rb") as f:
+ upload_result = await self.client.upload(
+ f, content_type=mime, filename=filename,
+ encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
+ filesize=size_bytes,
+ )
+ except Exception:
+ return fail
+
+ upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
+ encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
+ if isinstance(upload_response, UploadError):
+ return fail
+ mxc_url = getattr(upload_response, "content_uri", None)
+ if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
+ return fail
+
+ content = self._build_outbound_attachment_content(
+ filename=filename, mime=mime, size_bytes=size_bytes,
+ mxc_url=mxc_url, encryption_info=encryption_info,
+ )
+ if relates_to:
+ content["m.relates_to"] = relates_to
+ try:
+ await self._send_room_content(room_id, content)
+ except Exception:
+ return fail
+ return None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send outbound content; clear typing for non-progress messages."""
+ if not self.client:
+ return
+ text = msg.content or ""
+ candidates = self._collect_outbound_media_candidates(msg.media)
+ relates_to = self._build_thread_relates_to(msg.metadata)
+ is_progress = bool((msg.metadata or {}).get("_progress"))
+ try:
+ failures: list[str] = []
+ if candidates:
+ limit_bytes = await self._effective_media_limit_bytes()
+ for path in candidates:
+ if fail := await self._upload_and_send_attachment(
+ msg.chat_id, path, limit_bytes, relates_to):
+ failures.append(fail)
+ if failures:
+ text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
+ if text or not candidates:
+ content = _build_matrix_text_content(text)
+ if relates_to:
+ content["m.relates_to"] = relates_to
+ await self._send_room_content(msg.chat_id, content)
+ finally:
+ if not is_progress:
+ await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
+
+ def _register_event_callbacks(self) -> None:
+ self.client.add_event_callback(self._on_message, RoomMessageText)
+ self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
+ self.client.add_event_callback(self._on_room_invite, InviteEvent)
+
+ def _register_response_callbacks(self) -> None:
+ self.client.add_response_callback(self._on_sync_error, SyncError)
+ self.client.add_response_callback(self._on_join_error, JoinError)
+ self.client.add_response_callback(self._on_send_error, RoomSendError)
+
+ def _log_response_error(self, label: str, response: Any) -> None:
+ """Log Matrix response errors β auth errors at ERROR level, rest at WARNING."""
+ code = getattr(response, "status_code", None)
+ is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
+ is_fatal = is_auth or getattr(response, "soft_logout", False)
+ (logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
+
+ async def _on_sync_error(self, response: SyncError) -> None:
+ self._log_response_error("sync", response)
+
+ async def _on_join_error(self, response: JoinError) -> None:
+ self._log_response_error("join", response)
+
+ async def _on_send_error(self, response: RoomSendError) -> None:
+ self._log_response_error("send", response)
+
+ async def _set_typing(self, room_id: str, typing: bool) -> None:
+ """Best-effort typing indicator update."""
+ if not self.client:
+ return
+ try:
+ response = await self.client.room_typing(room_id=room_id, typing_state=typing,
+ timeout=TYPING_NOTICE_TIMEOUT_MS)
+ if isinstance(response, RoomTypingError):
+ logger.debug("Matrix typing failed for {}: {}", room_id, response)
+ except Exception:
+ pass
+
+ async def _start_typing_keepalive(self, room_id: str) -> None:
+ """Start periodic typing refresh (spec-recommended keepalive)."""
+ await self._stop_typing_keepalive(room_id, clear_typing=False)
+ await self._set_typing(room_id, True)
+ if not self._running:
+ return
+
+ async def loop() -> None:
+ try:
+ while self._running:
+ await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
+ await self._set_typing(room_id, True)
+ except asyncio.CancelledError:
+ pass
+
+ self._typing_tasks[room_id] = asyncio.create_task(loop())
+
+ async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
+ if task := self._typing_tasks.pop(room_id, None):
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ if clear_typing:
+ await self._set_typing(room_id, False)
+
+ async def _sync_loop(self) -> None:
+ while self._running:
+ try:
+ await self.client.sync_forever(timeout=30000, full_state=True)
+ except asyncio.CancelledError:
+ break
+ except Exception:
+ await asyncio.sleep(2)
+
+ async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
+ allow_from = self.config.allow_from or []
+ if not allow_from or event.sender in allow_from:
+ await self.client.join(room.room_id)
+
+ def _is_direct_room(self, room: MatrixRoom) -> bool:
+ count = getattr(room, "member_count", None)
+ return isinstance(count, int) and count <= 2
+
+ def _is_bot_mentioned(self, event: RoomMessage) -> bool:
+ """Check m.mentions payload for bot mention."""
+ source = getattr(event, "source", None)
+ if not isinstance(source, dict):
+ return False
+ mentions = (source.get("content") or {}).get("m.mentions")
+ if not isinstance(mentions, dict):
+ return False
+ user_ids = mentions.get("user_ids")
+ if isinstance(user_ids, list) and self.config.user_id in user_ids:
+ return True
+ return bool(self.config.allow_room_mentions and mentions.get("room") is True)
+
+ def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
+ """Apply sender and room policy checks."""
+ if not self.is_allowed(event.sender):
+ return False
+ if self._is_direct_room(room):
+ return True
+ policy = self.config.group_policy
+ if policy == "open":
+ return True
+ if policy == "allowlist":
+ return room.room_id in (self.config.group_allow_from or [])
+ if policy == "mention":
+ return self._is_bot_mentioned(event)
+ return False
+
+ def _media_dir(self) -> Path:
+ d = get_data_dir() / "media" / "matrix"
+ d.mkdir(parents=True, exist_ok=True)
+ return d
+
+ @staticmethod
+ def _event_source_content(event: RoomMessage) -> dict[str, Any]:
+ source = getattr(event, "source", None)
+ if not isinstance(source, dict):
+ return {}
+ content = source.get("content")
+ return content if isinstance(content, dict) else {}
+
+ def _event_thread_root_id(self, event: RoomMessage) -> str | None:
+ relates_to = self._event_source_content(event).get("m.relates_to")
+ if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
+ return None
+ root_id = relates_to.get("event_id")
+ return root_id if isinstance(root_id, str) and root_id else None
+
+ def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
+ if not (root_id := self._event_thread_root_id(event)):
+ return None
+ meta: dict[str, str] = {"thread_root_event_id": root_id}
+ if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
+ meta["thread_reply_to_event_id"] = reply_to
+ return meta
+
+ @staticmethod
+ def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
+ if not metadata:
+ return None
+ root_id = metadata.get("thread_root_event_id")
+ if not isinstance(root_id, str) or not root_id:
+ return None
+ reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
+ if not isinstance(reply_to, str) or not reply_to:
+ return None
+ return {"rel_type": "m.thread", "event_id": root_id,
+ "m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
+
+ def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
+ msgtype = self._event_source_content(event).get("msgtype")
+ return _MSGTYPE_MAP.get(msgtype, "file")
+
+ @staticmethod
+ def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
+ return (isinstance(getattr(event, "key", None), dict)
+ and isinstance(getattr(event, "hashes", None), dict)
+ and isinstance(getattr(event, "iv", None), str))
+
+ def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
+ info = self._event_source_content(event).get("info")
+ size = info.get("size") if isinstance(info, dict) else None
+ return size if isinstance(size, int) and size >= 0 else None
+
+ def _event_mime(self, event: MatrixMediaEvent) -> str | None:
+ info = self._event_source_content(event).get("info")
+ if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
+ return m
+ m = getattr(event, "mimetype", None)
+ return m if isinstance(m, str) and m else None
+
+ def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
+ body = getattr(event, "body", None)
+ if isinstance(body, str) and body.strip():
+ if candidate := safe_filename(Path(body).name):
+ return candidate
+ return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
+
+ def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
+ filename: str, mime: str | None) -> Path:
+ safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
+ suffix = Path(safe_name).suffix
+ if not suffix and mime:
+ if guessed := mimetypes.guess_extension(mime, strict=False):
+ safe_name, suffix = f"{safe_name}{guessed}", guessed
+ stem = (Path(safe_name).stem or attachment_type)[:72]
+ suffix = suffix[:16]
+ event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
+ event_prefix = (event_id[:24] or "evt").strip("_")
+ return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
+
+ async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
+ if not self.client:
+ return None
+ response = await self.client.download(mxc=mxc_url)
+ if isinstance(response, DownloadError):
+ logger.warning("Matrix download failed for {}: {}", mxc_url, response)
+ return None
+ body = getattr(response, "body", None)
+ if isinstance(body, (bytes, bytearray)):
+ return bytes(body)
+ if isinstance(response, MemoryDownloadResponse):
+ return bytes(response.body)
+ if isinstance(body, (str, Path)):
+ path = Path(body)
+ if path.is_file():
+ try:
+ return path.read_bytes()
+ except OSError:
+ return None
+ return None
+
+ def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
+ key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
+ key = key_obj.get("k") if isinstance(key_obj, dict) else None
+ sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
+ if not all(isinstance(v, str) for v in (key, sha256, iv)):
+ return None
+ try:
+ return decrypt_attachment(ciphertext, key, sha256, iv)
+ except (EncryptionError, ValueError, TypeError):
+ logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
+ return None
+
+ async def _fetch_media_attachment(
+ self, room: MatrixRoom, event: MatrixMediaEvent,
+ ) -> tuple[dict[str, Any] | None, str]:
+ """Download, decrypt if needed, and persist a Matrix attachment."""
+ atype = self._event_attachment_type(event)
+ mime = self._event_mime(event)
+ filename = self._event_filename(event, atype)
+ mxc_url = getattr(event, "url", None)
+ fail = _ATTACH_FAILED.format(filename)
+
+ if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
+ return None, fail
+
+ limit_bytes = await self._effective_media_limit_bytes()
+ declared = self._event_declared_size_bytes(event)
+ if declared is not None and declared > limit_bytes:
+ return None, _ATTACH_TOO_LARGE.format(filename)
+
+ downloaded = await self._download_media_bytes(mxc_url)
+ if downloaded is None:
+ return None, fail
+
+ encrypted = self._is_encrypted_media_event(event)
+ data = downloaded
+ if encrypted:
+ if (data := self._decrypt_media_bytes(event, downloaded)) is None:
+ return None, fail
+
+ if len(data) > limit_bytes:
+ return None, _ATTACH_TOO_LARGE.format(filename)
+
+ path = self._build_attachment_path(event, atype, filename, mime)
+ try:
+ path.write_bytes(data)
+ except OSError:
+ return None, fail
+
+ attachment = {
+ "type": atype, "mime": mime, "filename": filename,
+ "event_id": str(getattr(event, "event_id", "") or ""),
+ "encrypted": encrypted, "size_bytes": len(data),
+ "path": str(path), "mxc_url": mxc_url,
+ }
+ return attachment, _ATTACH_MARKER.format(path)
+
+ def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
+ """Build common metadata for text and media handlers."""
+ meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
+ if isinstance(eid := getattr(event, "event_id", None), str) and eid:
+ meta["event_id"] = eid
+ if thread := self._thread_metadata(event):
+ meta.update(thread)
+ return meta
+
+ async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
+ if event.sender == self.config.user_id or not self._should_process_message(room, event):
+ return
+ await self._start_typing_keepalive(room.room_id)
+ try:
+ await self._handle_message(
+ sender_id=event.sender, chat_id=room.room_id,
+ content=event.body, metadata=self._base_metadata(room, event),
+ )
+ except Exception:
+ await self._stop_typing_keepalive(room.room_id, clear_typing=True)
+ raise
+
+ async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
+ if event.sender == self.config.user_id or not self._should_process_message(room, event):
+ return
+ attachment, marker = await self._fetch_media_attachment(room, event)
+ parts: list[str] = []
+ if isinstance(body := getattr(event, "body", None), str) and body.strip():
+ parts.append(body.strip())
+ parts.append(marker)
+
+ await self._start_typing_keepalive(room.room_id)
+ try:
+ meta = self._base_metadata(room, event)
+ if attachment:
+ meta["attachments"] = [attachment]
+ await self._handle_message(
+ sender_id=event.sender, chat_id=room.room_id,
+ content="\n".join(parts),
+ media=[attachment["path"]] if attachment else [],
+ metadata=meta,
+ )
+ except Exception:
+ await self._stop_typing_keepalive(room.room_id, clear_typing=True)
+ raise
diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py
index 5352a30..7b171bc 100644
--- a/nanobot/channels/qq.py
+++ b/nanobot/channels/qq.py
@@ -31,7 +31,8 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
class _Bot(botpy.Client):
def __init__(self):
- super().__init__(intents=intents)
+ # Disable botpy's file log β nanobot uses loguru; default "botpy.log" fails on read-only fs
+ super().__init__(intents=intents, ext_handlers=False)
async def on_ready(self):
logger.info("QQ bot ready: {}", self.robot.name)
@@ -100,10 +101,12 @@ class QQChannel(BaseChannel):
logger.warning("QQ client not initialized")
return
try:
+ msg_id = msg.metadata.get("message_id")
await self._client.api.post_c2c_message(
openid=msg.chat_id,
msg_type=0,
content=msg.content,
+ msg_id=msg_id,
)
except Exception as e:
logger.error("Error sending QQ message: {}", e)
diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py
index 906593b..57bfbcb 100644
--- a/nanobot/channels/slack.py
+++ b/nanobot/channels/slack.py
@@ -229,6 +229,11 @@ class SlackChannel(BaseChannel):
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
+ _CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
+ _INLINE_CODE_RE = re.compile(r"`[^`]+`")
+ _LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
+ _LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
+ _BARE_URL_RE = re.compile(r"(? str:
@@ -236,7 +241,26 @@ class SlackChannel(BaseChannel):
if not text:
return ""
text = cls._TABLE_RE.sub(cls._convert_table, text)
- return slackify_markdown(text)
+ return cls._fixup_mrkdwn(slackify_markdown(text))
+
+ @classmethod
+ def _fixup_mrkdwn(cls, text: str) -> str:
+ """Fix markdown artifacts that slackify_markdown misses."""
+ code_blocks: list[str] = []
+
+ def _save_code(m: re.Match) -> str:
+ code_blocks.append(m.group(0))
+ return f"\x00CB{len(code_blocks) - 1}\x00"
+
+ text = cls._CODE_FENCE_RE.sub(_save_code, text)
+ text = cls._INLINE_CODE_RE.sub(_save_code, text)
+ text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
+ text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
+ text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
+
+ for i, block in enumerate(code_blocks):
+ text = text.replace(f"\x00CB{i}\x00", block)
+ return text
@staticmethod
def _convert_table(match: re.Match) -> str:
diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py
index 6cd98e7..969d853 100644
--- a/nanobot/channels/telegram.py
+++ b/nanobot/channels/telegram.py
@@ -111,6 +111,7 @@ class TelegramChannel(BaseChannel):
BOT_COMMANDS = [
BotCommand("start", "Start the bot"),
BotCommand("new", "Start a new conversation"),
+ BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"),
]
@@ -126,6 +127,8 @@ class TelegramChannel(BaseChannel):
self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
+ self._media_group_buffers: dict[str, dict] = {}
+ self._media_group_tasks: dict[str, asyncio.Task] = {}
async def start(self) -> None:
"""Start the Telegram bot with long polling."""
@@ -190,6 +193,11 @@ class TelegramChannel(BaseChannel):
# Cancel all typing indicators
for chat_id in list(self._typing_tasks):
self._stop_typing(chat_id)
+
+ for task in self._media_group_tasks.values():
+ task.cancel()
+ self._media_group_tasks.clear()
+ self._media_group_buffers.clear()
if self._app:
logger.info("Stopping Telegram bot...")
@@ -299,6 +307,7 @@ class TelegramChannel(BaseChannel):
await update.message.reply_text(
"π nanobot commands:\n"
"/new β Start a new conversation\n"
+ "/stop β Stop the current task\n"
"/help β Show available commands"
)
@@ -397,6 +406,28 @@ class TelegramChannel(BaseChannel):
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
str_chat_id = str(chat_id)
+
+ # Telegram media groups: buffer briefly, forward as one aggregated turn.
+ if media_group_id := getattr(message, "media_group_id", None):
+ key = f"{str_chat_id}:{media_group_id}"
+ if key not in self._media_group_buffers:
+ self._media_group_buffers[key] = {
+ "sender_id": sender_id, "chat_id": str_chat_id,
+ "contents": [], "media": [],
+ "metadata": {
+ "message_id": message.message_id, "user_id": user.id,
+ "username": user.username, "first_name": user.first_name,
+ "is_group": message.chat.type != "private",
+ },
+ }
+ self._start_typing(str_chat_id)
+ buf = self._media_group_buffers[key]
+ if content and content != "[empty message]":
+ buf["contents"].append(content)
+ buf["media"].extend(media_paths)
+ if key not in self._media_group_tasks:
+ self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
+ return
# Start typing indicator before processing
self._start_typing(str_chat_id)
@@ -416,6 +447,21 @@ class TelegramChannel(BaseChannel):
}
)
+ async def _flush_media_group(self, key: str) -> None:
+ """Wait briefly, then forward buffered media-group as one turn."""
+ try:
+ await asyncio.sleep(0.6)
+ if not (buf := self._media_group_buffers.pop(key, None)):
+ return
+ content = "\n".join(buf["contents"]) or "[empty message]"
+ await self._handle_message(
+ sender_id=buf["sender_id"], chat_id=buf["chat_id"],
+ content=content, media=list(dict.fromkeys(buf["media"])),
+ metadata=buf["metadata"],
+ )
+ finally:
+ self._media_group_tasks.pop(key, None)
+
def _start_typing(self, chat_id: str) -> None:
"""Start sending 'typing...' indicator for a chat."""
# Cancel any existing typing task for this chat
diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py
index f5fb521..49d2390 100644
--- a/nanobot/channels/whatsapp.py
+++ b/nanobot/channels/whatsapp.py
@@ -2,6 +2,7 @@
import asyncio
import json
+from collections import OrderedDict
from typing import Any
from loguru import logger
@@ -15,18 +16,19 @@ from nanobot.config.schema import WhatsAppConfig
class WhatsAppChannel(BaseChannel):
"""
WhatsApp channel that connects to a Node.js bridge.
-
+
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
Communication between Python and Node.js is via WebSocket.
"""
-
+
name = "whatsapp"
-
+
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: WhatsAppConfig = config
self._ws = None
self._connected = False
+ self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
async def start(self) -> None:
"""Start the WhatsApp channel by connecting to the bridge."""
@@ -105,26 +107,34 @@ class WhatsAppChannel(BaseChannel):
# Incoming message from WhatsApp
# Deprecated by whatsapp: old phone number style typically: @s.whatspp.net
pn = data.get("pn", "")
- # New LID sytle typically:
+ # New LID sytle typically:
sender = data.get("sender", "")
content = data.get("content", "")
-
+ message_id = data.get("id", "")
+
+ if message_id:
+ if message_id in self._processed_message_ids:
+ return
+ self._processed_message_ids[message_id] = None
+ while len(self._processed_message_ids) > 1000:
+ self._processed_message_ids.popitem(last=False)
+
# Extract just the phone number or lid as chat_id
user_id = pn if pn else sender
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
logger.info("Sender {}", sender)
-
+
# Handle voice transcription if it's a voice message
if content == "[Voice Message]":
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]"
-
+
await self._handle_message(
sender_id=sender_id,
chat_id=sender, # Use full LID for replies
content=content,
metadata={
- "message_id": data.get("id"),
+ "message_id": message_id,
"timestamp": data.get("timestamp"),
"is_group": data.get("isGroup", False)
}
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py
index 1c20b50..2e417d6 100644
--- a/nanobot/cli/commands.py
+++ b/nanobot/cli/commands.py
@@ -20,6 +20,7 @@ from prompt_toolkit.patch_stdout import patch_stdout
from nanobot import __version__, __logo__
from nanobot.config.schema import Config
+from nanobot.utils.helpers import sync_workspace_templates
app = typer.Typer(
name="nanobot",
@@ -185,8 +186,7 @@ def onboard():
workspace.mkdir(parents=True, exist_ok=True)
console.print(f"[green]β[/green] Created workspace at {workspace}")
- # Create default bootstrap files
- _create_workspace_templates(workspace)
+ sync_workspace_templates(workspace)
console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:")
@@ -198,36 +198,6 @@ def onboard():
-def _create_workspace_templates(workspace: Path):
- """Create default workspace template files from bundled templates."""
- from importlib.resources import files as pkg_files
-
- templates_dir = pkg_files("nanobot") / "templates"
-
- for item in templates_dir.iterdir():
- if not item.name.endswith(".md"):
- continue
- dest = workspace / item.name
- if not dest.exists():
- dest.write_text(item.read_text(encoding="utf-8"), encoding="utf-8")
- console.print(f" [dim]Created {item.name}[/dim]")
-
- memory_dir = workspace / "memory"
- memory_dir.mkdir(exist_ok=True)
-
- memory_template = templates_dir / "memory" / "MEMORY.md"
- memory_file = memory_dir / "MEMORY.md"
- if not memory_file.exists():
- memory_file.write_text(memory_template.read_text(encoding="utf-8"), encoding="utf-8")
- console.print(" [dim]Created memory/MEMORY.md[/dim]")
-
- history_file = memory_dir / "HISTORY.md"
- if not history_file.exists():
- history_file.write_text("", encoding="utf-8")
- console.print(" [dim]Created memory/HISTORY.md[/dim]")
-
- (workspace / "skills").mkdir(exist_ok=True)
-
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
@@ -294,6 +264,7 @@ def gateway(
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
config = load_config()
+ sync_workspace_templates(config.workspace_path)
bus = MessageBus()
provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path)
@@ -312,6 +283,7 @@ def gateway(
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window,
+ reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None,
exec_config=config.tools.exec,
cron_service=cron,
@@ -447,6 +419,7 @@ def agent(
from loguru import logger
config = load_config()
+ sync_workspace_templates(config.workspace_path)
bus = MessageBus()
provider = _make_provider(config)
@@ -469,6 +442,7 @@ def agent(
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window,
+ reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None,
exec_config=config.tools.exec,
cron_service=cron,
@@ -960,6 +934,7 @@ def cron_run(
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window,
+ reasoning_effort=config.agents.defaults.reasoning_effort,
brave_api_key=config.tools.web.search.api_key or None,
exec_config=config.tools.exec,
restrict_to_workspace=config.tools.restrict_to_workspace,
diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py
index 215f38d..4f06ebe 100644
--- a/nanobot/config/schema.py
+++ b/nanobot/config/schema.py
@@ -1,6 +1,8 @@
"""Configuration schema using Pydantic."""
from pathlib import Path
+from typing import Literal
+
from pydantic import BaseModel, Field, ConfigDict
from pydantic.alias_generators import to_camel
from pydantic_settings import BaseSettings
@@ -40,6 +42,7 @@ class FeishuConfig(Base):
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
verification_token: str = "" # Verification Token for event subscription (optional)
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
+ react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
class DingTalkConfig(Base):
@@ -61,6 +64,23 @@ class DiscordConfig(Base):
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
+class MatrixConfig(Base):
+ """Matrix (Element) channel configuration."""
+
+ enabled: bool = False
+ homeserver: str = "https://matrix.org"
+ access_token: str = ""
+ user_id: str = "" # @bot:matrix.org
+ device_id: str = ""
+ e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
+ sync_stop_grace_seconds: int = 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
+ max_media_bytes: int = 20 * 1024 * 1024 # Max attachment size accepted for Matrix media handling (inbound + outbound).
+ allow_from: list[str] = Field(default_factory=list)
+ group_policy: Literal["open", "mention", "allowlist"] = "open"
+ group_allow_from: list[str] = Field(default_factory=list)
+ allow_room_mentions: bool = False
+
+
class EmailConfig(Base):
"""Email channel configuration (IMAP inbound + SMTP outbound)."""
@@ -164,6 +184,20 @@ class QQConfig(Base):
secret: str = "" # ζΊε¨δΊΊε―ι₯ (AppSecret) from q.qq.com
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
+class MatrixConfig(Base):
+ """Matrix (Element) channel configuration."""
+ enabled: bool = False
+ homeserver: str = "https://matrix.org"
+ access_token: str = ""
+ user_id: str = "" # e.g. @bot:matrix.org
+ device_id: str = ""
+ e2ee_enabled: bool = True # end-to-end encryption support
+ sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
+ max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
+ allow_from: list[str] = Field(default_factory=list)
+ group_policy: Literal["open", "mention", "allowlist"] = "open"
+ group_allow_from: list[str] = Field(default_factory=list)
+ allow_room_mentions: bool = False
class ChannelsConfig(Base):
"""Configuration for chat channels."""
@@ -179,6 +213,7 @@ class ChannelsConfig(Base):
email: EmailConfig = Field(default_factory=EmailConfig)
slack: SlackConfig = Field(default_factory=SlackConfig)
qq: QQConfig = Field(default_factory=QQConfig)
+ matrix: MatrixConfig = Field(default_factory=MatrixConfig)
class AgentDefaults(Base):
@@ -186,10 +221,12 @@ class AgentDefaults(Base):
workspace: str = "~/.nanobot/workspace"
model: str = "anthropic/claude-opus-4-5"
+ provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
max_tokens: int = 8192
temperature: float = 0.1
max_tool_iterations: int = 40
memory_window: int = 100
+ reasoning_effort: str | None = None # low / medium / high β enables LLM thinking mode
class AgentsConfig(Base):
@@ -260,6 +297,7 @@ class ExecToolConfig(Base):
"""Shell exec tool configuration."""
timeout: int = 60
+ path_append: str = ""
class MCPServerConfig(Base):
@@ -300,6 +338,11 @@ class Config(BaseSettings):
"""Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS
+ forced = self.agents.defaults.provider
+ if forced != "auto":
+ p = getattr(self.providers, forced, None)
+ return (p, forced) if p else (None, None)
+
model_lower = (model or self.agents.defaults.model).lower()
model_normalized = model_lower.replace("-", "_")
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py
index eb1599a..36e9938 100644
--- a/nanobot/providers/base.py
+++ b/nanobot/providers/base.py
@@ -88,6 +88,7 @@ class LLMProvider(ABC):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
+ reasoning_effort: str | None = None,
) -> LLMResponse:
"""
Send a chat completion request.
diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py
index a578d14..56e6270 100644
--- a/nanobot/providers/custom_provider.py
+++ b/nanobot/providers/custom_provider.py
@@ -18,13 +18,16 @@ class CustomProvider(LLMProvider):
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
- model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
+ model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
+ reasoning_effort: str | None = None) -> LLMResponse:
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"messages": self._sanitize_empty_content(messages),
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
+ if reasoning_effort:
+ kwargs["reasoning_effort"] = reasoning_effort
if tools:
kwargs.update(tools=tools, tool_choice="auto")
try:
diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py
index 03a6c4d..0067ae8 100644
--- a/nanobot/providers/litellm_provider.py
+++ b/nanobot/providers/litellm_provider.py
@@ -3,6 +3,8 @@
import json
import json_repair
import os
+import secrets
+import string
from typing import Any
import litellm
@@ -12,8 +14,14 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway
-# Standard OpenAI chat-completion message keys; extras (e.g. reasoning_content) are stripped for strict providers.
+# Standard OpenAI chat-completion message keys plus reasoning_content for
+# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
+_ALNUM = string.ascii_letters + string.digits
+
+def _short_tool_id() -> str:
+ """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
+ return "".join(secrets.choice(_ALNUM) for _ in range(9))
class LiteLLMProvider(LLMProvider):
@@ -170,6 +178,7 @@ class LiteLLMProvider(LLMProvider):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
+ reasoning_effort: str | None = None,
) -> LLMResponse:
"""
Send a chat completion request via LiteLLM.
@@ -216,6 +225,10 @@ class LiteLLMProvider(LLMProvider):
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
+ if reasoning_effort:
+ kwargs["reasoning_effort"] = reasoning_effort
+ kwargs["drop_params"] = True
+
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
@@ -244,7 +257,7 @@ class LiteLLMProvider(LLMProvider):
args = json_repair.loads(args)
tool_calls.append(ToolCallRequest(
- id=tc.id,
+ id=_short_tool_id(),
name=tc.function.name,
arguments=args,
))
diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py
index fa28593..9039202 100644
--- a/nanobot/providers/openai_codex_provider.py
+++ b/nanobot/providers/openai_codex_provider.py
@@ -31,6 +31,7 @@ class OpenAICodexProvider(LLMProvider):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
+ reasoning_effort: str | None = None,
) -> LLMResponse:
model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py
index 2766929..df915b7 100644
--- a/nanobot/providers/registry.py
+++ b/nanobot/providers/registry.py
@@ -201,7 +201,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# OpenAI Codex: uses OAuth, not API key.
ProviderSpec(
name="openai_codex",
- keywords=("openai-codex", "codex"),
+ keywords=("openai-codex",),
env_key="", # OAuth-based, no API key
display_name="OpenAI Codex",
litellm_prefix="", # Not routed through LiteLLM
diff --git a/nanobot/skills/memory/SKILL.md b/nanobot/skills/memory/SKILL.md
index 39adbde..529a02d 100644
--- a/nanobot/skills/memory/SKILL.md
+++ b/nanobot/skills/memory/SKILL.md
@@ -9,7 +9,7 @@ always: true
## Structure
- `memory/MEMORY.md` β Long-term facts (preferences, project context, relationships). Always loaded into your context.
-- `memory/HISTORY.md` β Append-only event log. NOT loaded into context. Search it with grep.
+- `memory/HISTORY.md` β Append-only event log. NOT loaded into context. Search it with grep. Each entry starts with [YYYY-MM-DD HH:MM].
## Search Past Events
diff --git a/nanobot/templates/AGENTS.md b/nanobot/templates/AGENTS.md
index 84ba657..4c3e5b1 100644
--- a/nanobot/templates/AGENTS.md
+++ b/nanobot/templates/AGENTS.md
@@ -2,14 +2,6 @@
You are a helpful AI assistant. Be concise, accurate, and friendly.
-## Guidelines
-
-- Before calling tools, briefly state your intent β but NEVER predict results before receiving them
-- Use precise tense: "I will run X" before the call, "X returned Y" after
-- NEVER claim success before a tool result confirms it
-- Ask for clarification when the request is ambiguous
-- Remember important information in `memory/MEMORY.md`; past events are logged in `memory/HISTORY.md`
-
## Scheduled Reminders
When user asks for a reminder at a specific time, use `exec` to run:
diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py
index 62f80ac..8322bc8 100644
--- a/nanobot/utils/helpers.py
+++ b/nanobot/utils/helpers.py
@@ -1,80 +1,67 @@
"""Utility functions for nanobot."""
+import re
from pathlib import Path
from datetime import datetime
def ensure_dir(path: Path) -> Path:
- """Ensure a directory exists, creating it if necessary."""
+ """Ensure directory exists, return it."""
path.mkdir(parents=True, exist_ok=True)
return path
def get_data_path() -> Path:
- """Get the nanobot data directory (~/.nanobot)."""
+ """~/.nanobot data directory."""
return ensure_dir(Path.home() / ".nanobot")
def get_workspace_path(workspace: str | None = None) -> Path:
- """
- Get the workspace path.
-
- Args:
- workspace: Optional workspace path. Defaults to ~/.nanobot/workspace.
-
- Returns:
- Expanded and ensured workspace path.
- """
- if workspace:
- path = Path(workspace).expanduser()
- else:
- path = Path.home() / ".nanobot" / "workspace"
+ """Resolve and ensure workspace path. Defaults to ~/.nanobot/workspace."""
+ path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
return ensure_dir(path)
-def get_sessions_path() -> Path:
- """Get the sessions storage directory."""
- return ensure_dir(get_data_path() / "sessions")
-
-
-def get_skills_path(workspace: Path | None = None) -> Path:
- """Get the skills directory within the workspace."""
- ws = workspace or get_workspace_path()
- return ensure_dir(ws / "skills")
-
-
def timestamp() -> str:
- """Get current timestamp in ISO format."""
+ """Current ISO timestamp."""
return datetime.now().isoformat()
-def truncate_string(s: str, max_len: int = 100, suffix: str = "...") -> str:
- """Truncate a string to max length, adding suffix if truncated."""
- if len(s) <= max_len:
- return s
- return s[: max_len - len(suffix)] + suffix
-
+_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
def safe_filename(name: str) -> str:
- """Convert a string to a safe filename."""
- # Replace unsafe characters
- unsafe = '<>:"/\\|?*'
- for char in unsafe:
- name = name.replace(char, "_")
- return name.strip()
+ """Replace unsafe path characters with underscores."""
+ return _UNSAFE_CHARS.sub("_", name).strip()
-def parse_session_key(key: str) -> tuple[str, str]:
- """
- Parse a session key into channel and chat_id.
-
- Args:
- key: Session key in format "channel:chat_id"
-
- Returns:
- Tuple of (channel, chat_id)
- """
- parts = key.split(":", 1)
- if len(parts) != 2:
- raise ValueError(f"Invalid session key: {key}")
- return parts[0], parts[1]
+def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
+ """Sync bundled templates to workspace. Only creates missing files."""
+ from importlib.resources import files as pkg_files
+ try:
+ tpl = pkg_files("nanobot") / "templates"
+ except Exception:
+ return []
+ if not tpl.is_dir():
+ return []
+
+ added: list[str] = []
+
+ def _write(src, dest: Path):
+ if dest.exists():
+ return
+ dest.parent.mkdir(parents=True, exist_ok=True)
+ dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
+ added.append(str(dest.relative_to(workspace)))
+
+ for item in tpl.iterdir():
+ if item.name.endswith(".md"):
+ _write(item, workspace / item.name)
+ _write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
+ _write(None, workspace / "memory" / "HISTORY.md")
+ (workspace / "skills").mkdir(exist_ok=True)
+
+ if added and not silent:
+ from rich.console import Console
+ for name in added:
+ Console().print(f" [dim]Created {name}[/dim]")
+ return added
diff --git a/pyproject.toml b/pyproject.toml
index cb58ec5..20dcb1e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "nanobot-ai"
-version = "0.1.4.post1"
+version = "0.1.4.post2"
description = "A lightweight personal AI assistant framework"
requires-python = ">=3.11"
license = {text = "MIT"}
@@ -45,6 +45,11 @@ dependencies = [
]
[project.optional-dependencies]
+matrix = [
+ "matrix-nio[e2e]>=0.25.2",
+ "mistune>=3.0.0,<4.0.0",
+ "nh3>=0.2.17,<1.0.0",
+]
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py
index 323519e..a3213dd 100644
--- a/tests/test_consolidate_offset.py
+++ b/tests/test_consolidate_offset.py
@@ -786,10 +786,8 @@ class TestConsolidationDeduplicationGuard:
)
@pytest.mark.asyncio
- async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
- self, tmp_path: Path
- ) -> None:
- """/new should remove lock entry for fully invalidated session key."""
+ async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
+ """/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
@@ -801,7 +799,6 @@ class TestConsolidationDeduplicationGuard:
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
-
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
@@ -811,10 +808,6 @@ class TestConsolidationDeduplicationGuard:
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
- # Ensure lock exists before /new.
- _ = loop._get_consolidation_lock(session.key)
- assert session.key in loop._consolidation_locks
-
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
return True
@@ -825,4 +818,4 @@ class TestConsolidationDeduplicationGuard:
assert response is not None
assert "new session started" in response.content.lower()
- assert session.key not in loop._consolidation_locks
+ assert loop.sessions.get_or_create("cli:test").messages == []
diff --git a/tests/test_context_prompt_cache.py b/tests/test_context_prompt_cache.py
index 8e2333c..9afcc7d 100644
--- a/tests/test_context_prompt_cache.py
+++ b/tests/test_context_prompt_cache.py
@@ -39,8 +39,8 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
assert prompt1 == prompt2
-def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None:
- """Dynamic runtime details should be added at the tail user message, not system."""
+def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
+ """Runtime metadata should be a separate user message before the actual user message."""
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
@@ -54,10 +54,13 @@ def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None:
assert messages[0]["role"] == "system"
assert "## Current Session" not in messages[0]["content"]
+ assert messages[-2]["role"] == "user"
+ runtime_content = messages[-2]["content"]
+ assert isinstance(runtime_content, str)
+ assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
+ assert "Current Time:" in runtime_content
+ assert "Channel: cli" in runtime_content
+ assert "Chat ID: direct" in runtime_content
+
assert messages[-1]["role"] == "user"
- user_content = messages[-1]["content"]
- assert isinstance(user_content, str)
- assert "Return exactly: OK" in user_content
- assert "Current Time:" in user_content
- assert "Channel: cli" in user_content
- assert "Chat ID: direct" in user_content
+ assert messages[-1]["content"] == "Return exactly: OK"
diff --git a/tests/test_heartbeat_service.py b/tests/test_heartbeat_service.py
index ec91c6b..c5478af 100644
--- a/tests/test_heartbeat_service.py
+++ b/tests/test_heartbeat_service.py
@@ -2,34 +2,28 @@ import asyncio
import pytest
-from nanobot.heartbeat.service import (
- HEARTBEAT_OK_TOKEN,
- HeartbeatService,
-)
+from nanobot.heartbeat.service import HeartbeatService
+from nanobot.providers.base import LLMResponse, ToolCallRequest
-def test_heartbeat_ok_detection() -> None:
- def is_ok(response: str) -> bool:
- return HEARTBEAT_OK_TOKEN in response.upper()
+class DummyProvider:
+ def __init__(self, responses: list[LLMResponse]):
+ self._responses = list(responses)
- assert is_ok("HEARTBEAT_OK")
- assert is_ok("`HEARTBEAT_OK`")
- assert is_ok("**HEARTBEAT_OK**")
- assert is_ok("heartbeat_ok")
- assert is_ok("HEARTBEAT_OK.")
-
- assert not is_ok("HEARTBEAT_NOT_OK")
- assert not is_ok("all good")
+ async def chat(self, *args, **kwargs) -> LLMResponse:
+ if self._responses:
+ return self._responses.pop(0)
+ return LLMResponse(content="", tool_calls=[])
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None:
- async def _on_heartbeat(_: str) -> str:
- return "HEARTBEAT_OK"
+ provider = DummyProvider([])
service = HeartbeatService(
workspace=tmp_path,
- on_heartbeat=_on_heartbeat,
+ provider=provider,
+ model="openai/gpt-4o-mini",
interval_s=9999,
enabled=True,
)
@@ -42,3 +36,82 @@ async def test_start_is_idempotent(tmp_path) -> None:
service.stop()
await asyncio.sleep(0)
+
+
+@pytest.mark.asyncio
+async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
+ provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ )
+
+ action, tasks = await service._decide("heartbeat content")
+ assert action == "skip"
+ assert tasks == ""
+
+
+@pytest.mark.asyncio
+async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
+ (tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
+
+ provider = DummyProvider([
+ LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1",
+ name="heartbeat",
+ arguments={"action": "run", "tasks": "check open tasks"},
+ )
+ ],
+ )
+ ])
+
+ called_with: list[str] = []
+
+ async def _on_execute(tasks: str) -> str:
+ called_with.append(tasks)
+ return "done"
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ on_execute=_on_execute,
+ )
+
+ result = await service.trigger_now()
+ assert result == "done"
+ assert called_with == ["check open tasks"]
+
+
+@pytest.mark.asyncio
+async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
+ (tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
+
+ provider = DummyProvider([
+ LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1",
+ name="heartbeat",
+ arguments={"action": "skip"},
+ )
+ ],
+ )
+ ])
+
+ async def _on_execute(tasks: str) -> str:
+ return tasks
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ on_execute=_on_execute,
+ )
+
+ assert await service.trigger_now() is None
diff --git a/tests/test_matrix_channel.py b/tests/test_matrix_channel.py
new file mode 100644
index 0000000..c6714c2
--- /dev/null
+++ b/tests/test_matrix_channel.py
@@ -0,0 +1,1302 @@
+import asyncio
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+
+import nanobot.channels.matrix as matrix_module
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.matrix import (
+ MATRIX_HTML_FORMAT,
+ TYPING_NOTICE_TIMEOUT_MS,
+ MatrixChannel,
+)
+from nanobot.config.schema import MatrixConfig
+
+_ROOM_SEND_UNSET = object()
+
+
+class _DummyTask:
+ def __init__(self) -> None:
+ self.cancelled = False
+
+ def cancel(self) -> None:
+ self.cancelled = True
+
+ def __await__(self):
+ async def _done():
+ return None
+
+ return _done().__await__()
+
+
+class _FakeAsyncClient:
+ def __init__(self, homeserver, user, store_path, config) -> None:
+ self.homeserver = homeserver
+ self.user = user
+ self.store_path = store_path
+ self.config = config
+ self.user_id: str | None = None
+ self.access_token: str | None = None
+ self.device_id: str | None = None
+ self.load_store_called = False
+ self.stop_sync_forever_called = False
+ self.join_calls: list[str] = []
+ self.callbacks: list[tuple[object, object]] = []
+ self.response_callbacks: list[tuple[object, object]] = []
+ self.rooms: dict[str, object] = {}
+ self.room_send_calls: list[dict[str, object]] = []
+ self.typing_calls: list[tuple[str, bool, int]] = []
+ self.download_calls: list[dict[str, object]] = []
+ self.upload_calls: list[dict[str, object]] = []
+ self.download_response: object | None = None
+ self.download_bytes: bytes = b"media"
+ self.download_content_type: str = "application/octet-stream"
+ self.download_filename: str | None = None
+ self.upload_response: object | None = None
+ self.content_repository_config_response: object = SimpleNamespace(upload_size=None)
+ self.raise_on_send = False
+ self.raise_on_typing = False
+ self.raise_on_upload = False
+
+ def add_event_callback(self, callback, event_type) -> None:
+ self.callbacks.append((callback, event_type))
+
+ def add_response_callback(self, callback, response_type) -> None:
+ self.response_callbacks.append((callback, response_type))
+
+ def load_store(self) -> None:
+ self.load_store_called = True
+
+ def stop_sync_forever(self) -> None:
+ self.stop_sync_forever_called = True
+
+ async def join(self, room_id: str) -> None:
+ self.join_calls.append(room_id)
+
+ async def room_send(
+ self,
+ room_id: str,
+ message_type: str,
+ content: dict[str, object],
+ ignore_unverified_devices: object = _ROOM_SEND_UNSET,
+ ) -> None:
+ call: dict[str, object] = {
+ "room_id": room_id,
+ "message_type": message_type,
+ "content": content,
+ }
+ if ignore_unverified_devices is not _ROOM_SEND_UNSET:
+ call["ignore_unverified_devices"] = ignore_unverified_devices
+ self.room_send_calls.append(call)
+ if self.raise_on_send:
+ raise RuntimeError("send failed")
+
+ async def room_typing(
+ self,
+ room_id: str,
+ typing_state: bool = True,
+ timeout: int = 30_000,
+ ) -> None:
+ self.typing_calls.append((room_id, typing_state, timeout))
+ if self.raise_on_typing:
+ raise RuntimeError("typing failed")
+
+ async def download(self, **kwargs):
+ self.download_calls.append(kwargs)
+ if self.download_response is not None:
+ return self.download_response
+ return matrix_module.MemoryDownloadResponse(
+ body=self.download_bytes,
+ content_type=self.download_content_type,
+ filename=self.download_filename,
+ )
+
+ async def upload(
+ self,
+ data_provider,
+ content_type: str | None = None,
+ filename: str | None = None,
+ filesize: int | None = None,
+ encrypt: bool = False,
+ ):
+ if self.raise_on_upload:
+ raise RuntimeError("upload failed")
+ if isinstance(data_provider, (bytes, bytearray)):
+ raise TypeError(
+ f"data_provider type {type(data_provider)!r} is not of a usable type "
+ "(Callable, IOBase)"
+ )
+ self.upload_calls.append(
+ {
+ "data_provider": data_provider,
+ "content_type": content_type,
+ "filename": filename,
+ "filesize": filesize,
+ "encrypt": encrypt,
+ }
+ )
+ if self.upload_response is not None:
+ return self.upload_response
+ if encrypt:
+ return (
+ SimpleNamespace(content_uri="mxc://example.org/uploaded"),
+ {
+ "v": "v2",
+ "iv": "iv",
+ "hashes": {"sha256": "hash"},
+ "key": {"alg": "A256CTR", "k": "key"},
+ },
+ )
+ return SimpleNamespace(content_uri="mxc://example.org/uploaded"), None
+
+ async def content_repository_config(self):
+ return self.content_repository_config_response
+
+ async def close(self) -> None:
+ return None
+
+
+def _make_config(**kwargs) -> MatrixConfig:
+ return MatrixConfig(
+ enabled=True,
+ homeserver="https://matrix.org",
+ access_token="token",
+ user_id="@bot:matrix.org",
+ **kwargs,
+ )
+
+
+@pytest.mark.asyncio
+async def test_start_skips_load_store_when_device_id_missing(
+ monkeypatch, tmp_path
+) -> None:
+ clients: list[_FakeAsyncClient] = []
+
+ def _fake_client(*args, **kwargs) -> _FakeAsyncClient:
+ client = _FakeAsyncClient(*args, **kwargs)
+ clients.append(client)
+ return client
+
+ def _fake_create_task(coro):
+ coro.close()
+ return _DummyTask()
+
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.AsyncClientConfig",
+ lambda **kwargs: SimpleNamespace(**kwargs),
+ )
+ monkeypatch.setattr("nanobot.channels.matrix.AsyncClient", _fake_client)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.asyncio.create_task", _fake_create_task
+ )
+
+ channel = MatrixChannel(_make_config(device_id=""), MessageBus())
+ await channel.start()
+
+ assert len(clients) == 1
+ assert clients[0].config.encryption_enabled is True
+ assert clients[0].load_store_called is False
+ assert len(clients[0].callbacks) == 3
+ assert len(clients[0].response_callbacks) == 3
+
+ await channel.stop()
+
+
+@pytest.mark.asyncio
+async def test_register_event_callbacks_uses_media_base_filter() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ channel._register_event_callbacks()
+
+ assert len(client.callbacks) == 3
+ assert client.callbacks[1][0] == channel._on_media_message
+ assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
+
+
+def test_media_event_filter_does_not_match_text_events() -> None:
+ assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
+
+
+@pytest.mark.asyncio
+async def test_start_disables_e2ee_when_configured(
+ monkeypatch, tmp_path
+) -> None:
+ clients: list[_FakeAsyncClient] = []
+
+ def _fake_client(*args, **kwargs) -> _FakeAsyncClient:
+ client = _FakeAsyncClient(*args, **kwargs)
+ clients.append(client)
+ return client
+
+ def _fake_create_task(coro):
+ coro.close()
+ return _DummyTask()
+
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.AsyncClientConfig",
+ lambda **kwargs: SimpleNamespace(**kwargs),
+ )
+ monkeypatch.setattr("nanobot.channels.matrix.AsyncClient", _fake_client)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.asyncio.create_task", _fake_create_task
+ )
+
+ channel = MatrixChannel(_make_config(device_id="", e2ee_enabled=False), MessageBus())
+ await channel.start()
+
+ assert len(clients) == 1
+ assert clients[0].config.encryption_enabled is False
+
+ await channel.stop()
+
+
+@pytest.mark.asyncio
+async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None:
+ channel = MatrixChannel(_make_config(device_id="DEVICE"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ task = _DummyTask()
+
+ channel.client = client
+ channel._sync_task = task
+ channel._running = True
+
+ await channel.stop()
+
+ assert channel._running is False
+ assert client.stop_sync_forever_called is True
+ assert task.cancelled is False
+
+
+@pytest.mark.asyncio
+async def test_room_invite_joins_when_allow_list_is_empty() -> None:
+ channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org")
+ event = SimpleNamespace(sender="@alice:matrix.org")
+
+ await channel._on_room_invite(room, event)
+
+ assert client.join_calls == ["!room:matrix.org"]
+
+
+@pytest.mark.asyncio
+async def test_room_invite_respects_allow_list_when_configured() -> None:
+ channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org")
+ event = SimpleNamespace(sender="@alice:matrix.org")
+
+ await channel._on_room_invite(room, event)
+
+ assert client.join_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_sets_typing_for_allowed_sender() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={})
+
+ await channel._on_message(room, event)
+
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [
+ ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_typing_keepalive_refreshes_periodically(monkeypatch) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._running = True
+
+ monkeypatch.setattr(matrix_module, "TYPING_KEEPALIVE_INTERVAL_MS", 10)
+
+ await channel._start_typing_keepalive("!room:matrix.org")
+ await asyncio.sleep(0.03)
+ await channel._stop_typing_keepalive("!room:matrix.org", clear_typing=True)
+
+ true_updates = [call for call in client.typing_calls if call[1] is True]
+ assert len(true_updates) >= 2
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_on_message_skips_typing_for_self_message() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
+ event = SimpleNamespace(sender="@bot:matrix.org", body="Hello", source={})
+
+ await channel._on_message(room, event)
+
+ assert client.typing_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_skips_typing_for_denied_sender() -> None:
+ channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={})
+
+ await channel._on_message(room, event)
+
+ assert handled == []
+ assert client.typing_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_mention_policy_requires_mx_mentions() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}})
+
+ await channel._on_message(room, event)
+
+ assert handled == []
+ assert client.typing_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_mention_policy_accepts_bot_user_mentions() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="Hello",
+ source={"content": {"m.mentions": {"user_ids": ["@bot:matrix.org"]}}},
+ )
+
+ await channel._on_message(room, event)
+
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_mention_policy_allows_direct_room_without_mentions() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!dm:matrix.org", display_name="DM", member_count=2)
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}})
+
+ await channel._on_message(room, event)
+
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [("!dm:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_allowlist_policy_requires_room_id() -> None:
+ channel = MatrixChannel(
+ _make_config(group_policy="allowlist", group_allow_from=["!allowed:matrix.org"]),
+ MessageBus(),
+ )
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["chat_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ denied_room = SimpleNamespace(room_id="!denied:matrix.org", display_name="Denied", member_count=3)
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}})
+ await channel._on_message(denied_room, event)
+
+ allowed_room = SimpleNamespace(
+ room_id="!allowed:matrix.org",
+ display_name="Allowed",
+ member_count=3,
+ )
+ await channel._on_message(allowed_room, event)
+
+ assert handled == ["!allowed:matrix.org"]
+ assert client.typing_calls == [("!allowed:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_room_mention_requires_opt_in() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ room_mention_event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="Hello everyone",
+ source={"content": {"m.mentions": {"room": True}}},
+ )
+
+ await channel._on_message(room, room_mention_event)
+ assert handled == []
+ assert client.typing_calls == []
+
+ channel.config.allow_room_mentions = True
+ await channel._on_message(room, room_mention_event)
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_sets_thread_metadata_when_threaded_event() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="Hello",
+ event_id="$reply1",
+ source={
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ }
+ }
+ },
+ )
+
+ await channel._on_message(room, event)
+
+ assert len(handled) == 1
+ metadata = handled[0]["metadata"]
+ assert metadata["thread_root_event_id"] == "$root1"
+ assert metadata["thread_reply_to_event_id"] == "$reply1"
+ assert metadata["event_id"] == "$reply1"
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_downloads_attachment_and_sets_metadata(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"image"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="photo.png",
+ url="mxc://example.org/mediaid",
+ event_id="$event1",
+ source={
+ "content": {
+ "msgtype": "m.image",
+ "info": {"mimetype": "image/png", "size": 5},
+ }
+ },
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(client.download_calls) == 1
+ assert len(handled) == 1
+ assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+ media_paths = handled[0]["media"]
+ assert isinstance(media_paths, list) and len(media_paths) == 1
+ media_path = Path(media_paths[0])
+ assert media_path.is_file()
+ assert media_path.read_bytes() == b"image"
+
+ metadata = handled[0]["metadata"]
+ attachments = metadata["attachments"]
+ assert isinstance(attachments, list) and len(attachments) == 1
+ assert attachments[0]["type"] == "image"
+ assert attachments[0]["mxc_url"] == "mxc://example.org/mediaid"
+ assert attachments[0]["path"] == str(media_path)
+ assert "[attachment: " in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_sets_thread_metadata_when_threaded_event(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"image"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="photo.png",
+ url="mxc://example.org/mediaid",
+ event_id="$event1",
+ source={
+ "content": {
+ "msgtype": "m.image",
+ "info": {"mimetype": "image/png", "size": 5},
+ "m.relates_to": {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ },
+ }
+ },
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(handled) == 1
+ metadata = handled[0]["metadata"]
+ assert metadata["thread_root_event_id"] == "$root1"
+ assert metadata["thread_reply_to_event_id"] == "$event1"
+ assert metadata["event_id"] == "$event1"
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_respects_declared_size_limit(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(max_media_bytes=3), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="large.bin",
+ url="mxc://example.org/large",
+ event_id="$event2",
+ source={"content": {"msgtype": "m.file", "info": {"size": 10}}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert client.download_calls == []
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: large.bin - too large]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_uses_server_limit_when_smaller_than_local_limit(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(max_media_bytes=10), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.content_repository_config_response = SimpleNamespace(upload_size=3)
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="large.bin",
+ url="mxc://example.org/large",
+ event_id="$event2_server",
+ source={"content": {"msgtype": "m.file", "info": {"size": 5}}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert client.download_calls == []
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: large.bin - too large]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_response = matrix_module.DownloadError("download failed")
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="photo.png",
+ url="mxc://example.org/mediaid",
+ event_id="$event3",
+ source={"content": {"msgtype": "m.image"}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(client.download_calls) == 1
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: photo.png - download failed]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_decrypts_encrypted_media(monkeypatch, tmp_path) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+ monkeypatch.setattr(
+ matrix_module,
+ "decrypt_attachment",
+ lambda ciphertext, key, sha256, iv: b"plain",
+ )
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"cipher"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="secret.txt",
+ url="mxc://example.org/encrypted",
+ event_id="$event4",
+ key={"k": "key"},
+ hashes={"sha256": "hash"},
+ iv="iv",
+ source={"content": {"msgtype": "m.file", "info": {"size": 6}}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(handled) == 1
+ media_path = Path(handled[0]["media"][0])
+ assert media_path.read_bytes() == b"plain"
+ attachment = handled[0]["metadata"]["attachments"][0]
+ assert attachment["encrypted"] is True
+ assert attachment["size_bytes"] == 5
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ def _raise(*args, **kwargs):
+ raise matrix_module.EncryptionError("boom")
+
+ monkeypatch.setattr(matrix_module, "decrypt_attachment", _raise)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"cipher"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="secret.txt",
+ url="mxc://example.org/encrypted",
+ event_id="$event5",
+ key={"k": "key"},
+ hashes={"sha256": "hash"},
+ iv="iv",
+ source={"content": {"msgtype": "m.file"}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: secret.txt - download failed]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_send_clears_typing_after_send() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"] == {
+ "msgtype": "m.text",
+ "body": "Hi",
+ "m.mentions": {},
+ }
+ assert client.room_send_calls[0]["ignore_unverified_devices"] is True
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_send_uploads_media_and_sends_file_event(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ file_path = tmp_path / "test.txt"
+ file_path.write_text("hello", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Please review.",
+ media=[str(file_path)],
+ )
+ )
+
+ assert len(client.upload_calls) == 1
+ assert not isinstance(client.upload_calls[0]["data_provider"], (bytes, bytearray))
+ assert hasattr(client.upload_calls[0]["data_provider"], "read")
+ assert client.upload_calls[0]["filename"] == "test.txt"
+ assert client.upload_calls[0]["filesize"] == 5
+ assert len(client.room_send_calls) == 2
+ assert client.room_send_calls[0]["content"]["msgtype"] == "m.file"
+ assert client.room_send_calls[0]["content"]["url"] == "mxc://example.org/uploaded"
+ assert client.room_send_calls[1]["content"]["body"] == "Please review."
+
+
+@pytest.mark.asyncio
+async def test_send_adds_thread_relates_to_for_thread_metadata() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ metadata = {
+ "thread_root_event_id": "$root1",
+ "thread_reply_to_event_id": "$reply1",
+ }
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Hi",
+ metadata=metadata,
+ )
+ )
+
+ content = client.room_send_calls[0]["content"]
+ assert content["m.relates_to"] == {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ "m.in_reply_to": {"event_id": "$reply1"},
+ "is_falling_back": True,
+ }
+
+
+@pytest.mark.asyncio
+async def test_send_uses_encrypted_media_payload_in_encrypted_room(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(e2ee_enabled=True), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.rooms["!encrypted:matrix.org"] = SimpleNamespace(encrypted=True)
+ channel.client = client
+
+ file_path = tmp_path / "secret.txt"
+ file_path.write_text("topsecret", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!encrypted:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert len(client.upload_calls) == 1
+ assert client.upload_calls[0]["encrypt"] is True
+ assert len(client.room_send_calls) == 1
+ content = client.room_send_calls[0]["content"]
+ assert content["msgtype"] == "m.file"
+ assert "file" in content
+ assert "url" not in content
+ assert content["file"]["url"] == "mxc://example.org/uploaded"
+ assert content["file"]["hashes"]["sha256"] == "hash"
+
+
+@pytest.mark.asyncio
+async def test_send_does_not_parse_attachment_marker_without_media(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ missing_path = tmp_path / "missing.txt"
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content=f"[attachment: {missing_path}]",
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == f"[attachment: {missing_path}]"
+
+
+@pytest.mark.asyncio
+async def test_send_passes_thread_relates_to_to_attachment_upload(monkeypatch) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._server_upload_limit_checked = True
+ channel._server_upload_limit_bytes = None
+
+ captured: dict[str, object] = {}
+
+ async def _fake_upload_and_send_attachment(
+ *,
+ room_id: str,
+ path: Path,
+ limit_bytes: int,
+ relates_to: dict[str, object] | None = None,
+ ) -> str | None:
+ captured["relates_to"] = relates_to
+ return None
+
+ monkeypatch.setattr(channel, "_upload_and_send_attachment", _fake_upload_and_send_attachment)
+
+ metadata = {
+ "thread_root_event_id": "$root1",
+ "thread_reply_to_event_id": "$reply1",
+ }
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Hi",
+ media=["/tmp/fake.txt"],
+ metadata=metadata,
+ )
+ )
+
+ assert captured["relates_to"] == {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ "m.in_reply_to": {"event_id": "$reply1"},
+ "is_falling_back": True,
+ }
+
+
+@pytest.mark.asyncio
+async def test_send_workspace_restriction_blocks_external_attachment(tmp_path) -> None:
+ workspace = tmp_path / "workspace"
+ workspace.mkdir()
+ file_path = tmp_path / "external.txt"
+ file_path.write_text("outside", encoding="utf-8")
+
+ channel = MatrixChannel(
+ _make_config(),
+ MessageBus(),
+ restrict_to_workspace=True,
+ workspace=workspace,
+ )
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == "[attachment: external.txt - upload failed]"
+
+
+@pytest.mark.asyncio
+async def test_send_handles_upload_exception_and_reports_failure(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.raise_on_upload = True
+ channel.client = client
+
+ file_path = tmp_path / "broken.txt"
+ file_path.write_text("hello", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Please review.",
+ media=[str(file_path)],
+ )
+ )
+
+ assert len(client.upload_calls) == 0
+ assert len(client.room_send_calls) == 1
+ assert (
+ client.room_send_calls[0]["content"]["body"]
+ == "Please review.\n[attachment: broken.txt - upload failed]"
+ )
+
+
+@pytest.mark.asyncio
+async def test_send_uses_server_upload_limit_when_smaller_than_local_limit(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(max_media_bytes=10), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.content_repository_config_response = SimpleNamespace(upload_size=3)
+ channel.client = client
+
+ file_path = tmp_path / "tiny.txt"
+ file_path.write_text("hello", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == "[attachment: tiny.txt - too large]"
+
+
+@pytest.mark.asyncio
+async def test_send_blocks_all_outbound_media_when_limit_is_zero(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(max_media_bytes=0), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ file_path = tmp_path / "empty.txt"
+ file_path.write_bytes(b"")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == "[attachment: empty.txt - too large]"
+
+
+@pytest.mark.asyncio
+async def test_send_omits_ignore_unverified_devices_when_e2ee_disabled() -> None:
+ channel = MatrixChannel(_make_config(e2ee_enabled=False), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert len(client.room_send_calls) == 1
+ assert "ignore_unverified_devices" not in client.room_send_calls[0]
+
+
+@pytest.mark.asyncio
+async def test_send_stops_typing_keepalive_task() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._running = True
+
+ await channel._start_typing_keepalive("!room:matrix.org")
+ assert "!room:matrix.org" in channel._typing_tasks
+
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert "!room:matrix.org" not in channel._typing_tasks
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_send_progress_keeps_typing_keepalive_running() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._running = True
+
+ await channel._start_typing_keepalive("!room:matrix.org")
+ assert "!room:matrix.org" in channel._typing_tasks
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="working...",
+ metadata={"_progress": True, "_progress_kind": "reasoning"},
+ )
+ )
+
+ assert "!room:matrix.org" in channel._typing_tasks
+ assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_send_clears_typing_when_send_fails() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.raise_on_send = True
+ channel.client = client
+
+ with pytest.raises(RuntimeError, match="send failed"):
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_send_adds_formatted_body_for_markdown() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ markdown_text = "# Headline\n\n- [x] done\n\n| A | B |\n| - | - |\n| 1 | 2 |"
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text)
+ )
+
+ content = client.room_send_calls[0]["content"]
+ assert content["msgtype"] == "m.text"
+ assert content["body"] == markdown_text
+ assert content["m.mentions"] == {}
+ assert content["format"] == MATRIX_HTML_FORMAT
+ assert "Headline
" in str(content["formatted_body"])
+ assert "" in str(content["formatted_body"])
+ assert "[x] done" in str(content["formatted_body"])
+
+
+@pytest.mark.asyncio
+async def test_send_adds_formatted_body_for_inline_url_superscript_subscript() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ markdown_text = "Visit https://example.com and x^2^ plus H~2~O."
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text)
+ )
+
+ content = client.room_send_calls[0]["content"]
+ assert content["msgtype"] == "m.text"
+ assert content["body"] == markdown_text
+ assert content["m.mentions"] == {}
+ assert content["format"] == MATRIX_HTML_FORMAT
+ assert '' in str(
+ content["formatted_body"]
+ )
+ assert "2" in str(content["formatted_body"])
+ assert "2" in str(content["formatted_body"])
+
+
+@pytest.mark.asyncio
+async def test_send_sanitizes_disallowed_link_scheme() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ markdown_text = "[click](javascript:alert(1))"
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text)
+ )
+
+ formatted_body = str(client.room_send_calls[0]["content"]["formatted_body"])
+ assert "javascript:" not in formatted_body
+ assert " None:
+ dirty_html = 'x'
+ cleaned_html = matrix_module.MATRIX_HTML_CLEANER.clean(dirty_html)
+
+ assert "