diff --git a/.gitignore b/.gitignore
index 374875a..c50cab8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,4 +20,5 @@ __pycache__/
poetry.lock
.pytest_cache/
botpy.log
+nano.*.save
diff --git a/README.md b/README.md
index f169bd7..629f59f 100644
--- a/README.md
+++ b/README.md
@@ -64,7 +64,7 @@
## Key Features of nanobot:
-๐ชถ **Ultra-Lightweight**: Just ~4,000 lines of core agent code โ 99% smaller than Clawdbot.
+๐ชถ **Ultra-Lightweight**: A super lightweight implementation of OpenClaw โ 99% smaller, significantly faster.
๐ฌ **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research.
@@ -78,6 +78,25 @@
+## Table of Contents
+
+- [News](#-news)
+- [Key Features](#key-features-of-nanobot)
+- [Architecture](#๏ธ-architecture)
+- [Features](#-features)
+- [Install](#-install)
+- [Quick Start](#-quick-start)
+- [Chat Apps](#-chat-apps)
+- [Agent Social Network](#-agent-social-network)
+- [Configuration](#๏ธ-configuration)
+- [Multiple Instances](#-multiple-instances)
+- [CLI Reference](#-cli-reference)
+- [Docker](#-docker)
+- [Linux Service](#-linux-service)
+- [Project Structure](#-project-structure)
+- [Contribute & Roadmap](#-contribute--roadmap)
+- [Star History](#-star-history)
+
## โจ Features
@@ -208,6 +227,7 @@ Connect nanobot to your favorite chat platform.
| **Slack** | Bot token + App-Level token |
| **Email** | IMAP/SMTP credentials |
| **QQ** | App ID + App Secret |
+| **Wecom** | Bot ID + Bot Secret |
Telegram (Recommended)
@@ -482,7 +502,8 @@ Uses **WebSocket** long connection โ no public IP required.
"appSecret": "xxx",
"encryptKey": "",
"verificationToken": "",
- "allowFrom": ["ou_YOUR_OPEN_ID"]
+ "allowFrom": ["ou_YOUR_OPEN_ID"],
+ "groupPolicy": "mention"
}
}
}
@@ -490,6 +511,7 @@ Uses **WebSocket** long connection โ no public IP required.
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
+> `groupPolicy`: `"mention"` (default โ respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
**3. Run**
@@ -677,6 +699,46 @@ nanobot gateway
+
+Wecom (ไผไธๅพฎไฟก)
+
+> Here we use [wecom-aibot-sdk-python](https://github.com/chengyongru/wecom_aibot_sdk) (community Python version of the official [@wecom/aibot-node-sdk](https://www.npmjs.com/package/@wecom/aibot-node-sdk)).
+>
+> Uses **WebSocket** long connection โ no public IP required.
+
+**1. Install the optional dependency**
+
+```bash
+pip install nanobot-ai[wecom]
+```
+
+**2. Create a WeCom AI Bot**
+
+Go to the WeCom admin console โ Intelligent Robot โ Create Robot โ select **API mode** with **long connection**. Copy the Bot ID and Secret.
+
+**3. Configure**
+
+```json
+{
+ "channels": {
+ "wecom": {
+ "enabled": true,
+ "botId": "your_bot_id",
+ "secret": "your_bot_secret",
+ "allowFrom": ["your_id"]
+ }
+ }
+}
+```
+
+**4. Run**
+
+```bash
+nanobot gateway
+```
+
+
+
## ๐ Agent Social Network
๐ nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
@@ -696,15 +758,17 @@ Config file: `~/.nanobot/config.json`
> [!TIP]
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
+> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
-> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
| Provider | Purpose | Get API Key |
|----------|---------|-------------|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | โ |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
+| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) ยท [volcengine.com](https://www.volcengine.com) |
+| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) ยท [byteplus.com](https://www.byteplus.com) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
@@ -714,10 +778,10 @@ Config file: `~/.nanobot/config.json`
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
| `siliconflow` | LLM (SiliconFlow/็ก
ๅบๆตๅจ) | [siliconflow.cn](https://siliconflow.cn) |
-| `volcengine` | LLM (VolcEngine/็ซๅฑฑๅผๆ) | [volcengine.com](https://www.volcengine.com) |
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
+| `ollama` | LLM (local, Ollama) | โ |
| `vllm` | LLM (local, any OpenAI-compatible server) | โ |
| `openai_codex` | LLM (Codex, OAuth) | `nanobot provider login openai-codex` |
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
@@ -783,6 +847,37 @@ Connects directly to any OpenAI-compatible endpoint โ LM Studio, llama.cpp, To
+
+Ollama (local)
+
+Run a local model with Ollama, then add to config:
+
+**1. Start Ollama** (example):
+```bash
+ollama run llama3.2
+```
+
+**2. Add to config** (partial โ merge into `~/.nanobot/config.json`):
+```json
+{
+ "providers": {
+ "ollama": {
+ "apiBase": "http://localhost:11434"
+ }
+ },
+ "agents": {
+ "defaults": {
+ "provider": "ollama",
+ "model": "llama3.2"
+ }
+ }
+}
+```
+
+> `provider: "auto"` also works when `providers.ollama.apiBase` is configured, but setting `"provider": "ollama"` is the clearest option.
+
+
+
vLLM (local / OpenAI-compatible)
diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py
index 8605a09..5fe0ee0 100644
--- a/nanobot/agent/loop.py
+++ b/nanobot/agent/loop.py
@@ -4,7 +4,9 @@ from __future__ import annotations
import asyncio
import json
+import os
import re
+import sys
from contextlib import AsyncExitStack
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -43,7 +45,7 @@ class AgentLoop:
5. Sends responses back
"""
- _TOOL_RESULT_MAX_CHARS = 500
+ _TOOL_RESULT_MAX_CHARS = 16_000
def __init__(
self,
@@ -52,9 +54,6 @@ class AgentLoop:
workspace: Path,
model: str | None = None,
max_iterations: int = 40,
- temperature: float = 0.1,
- max_tokens: int = 4096,
- reasoning_effort: str | None = None,
context_window_tokens: int = 65_536,
brave_api_key: str | None = None,
web_proxy: str | None = None,
@@ -72,9 +71,6 @@ class AgentLoop:
self.workspace = workspace
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
- self.temperature = temperature
- self.max_tokens = max_tokens
- self.reasoning_effort = reasoning_effort
self.context_window_tokens = context_window_tokens
self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
@@ -90,9 +86,6 @@ class AgentLoop:
workspace=workspace,
bus=bus,
model=self.model,
- temperature=self.temperature,
- max_tokens=self.max_tokens,
- reasoning_effort=reasoning_effort,
brave_api_key=brave_api_key,
web_proxy=web_proxy,
exec_config=self.exec_config,
@@ -202,9 +195,6 @@ class AgentLoop:
messages=messages,
tools=tool_defs,
model=self.model,
- temperature=self.temperature,
- max_tokens=self.max_tokens,
- reasoning_effort=self.reasoning_effort,
)
if response.has_tool_calls:
@@ -215,14 +205,7 @@ class AgentLoop:
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [
- {
- "id": tc.id,
- "type": "function",
- "function": {
- "name": tc.name,
- "arguments": json.dumps(tc.arguments, ensure_ascii=False)
- }
- }
+ tc.to_openai_tool_call()
for tc in response.tool_calls
]
messages = self.context.add_assistant_message(
@@ -275,8 +258,11 @@ class AgentLoop:
except asyncio.TimeoutError:
continue
- if msg.content.strip().lower() == "/stop":
+ cmd = msg.content.strip().lower()
+ if cmd == "/stop":
await self._handle_stop(msg)
+ elif cmd == "/restart":
+ await self._handle_restart(msg)
else:
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task)
@@ -293,11 +279,23 @@ class AgentLoop:
pass
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
total = cancelled + sub_cancelled
- content = f"โน Stopped {total} task(s)." if total else "No active task to stop."
+ content = f"Stopped {total} task(s)." if total else "No active task to stop."
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content,
))
+ async def _handle_restart(self, msg: InboundMessage) -> None:
+ """Restart the process in-place via os.execv."""
+ await self.bus.publish_outbound(OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
+ ))
+
+ async def _do_restart():
+ await asyncio.sleep(1)
+ os.execv(sys.executable, [sys.executable] + sys.argv)
+
+ asyncio.create_task(_do_restart())
+
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message under the global lock."""
async with self._processing_lock:
@@ -392,9 +390,16 @@ class AgentLoop:
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started.")
if cmd == "/help":
- return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
- content="๐ nanobot commands:\n/new โ Start a new conversation\n/stop โ Stop the current task\n/help โ Show available commands")
-
+ lines = [
+ "๐ nanobot commands:",
+ "/new โ Start a new conversation",
+ "/stop โ Stop the current task",
+ "/restart โ Restart the bot",
+ "/help โ Show available commands",
+ ]
+ return OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
+ )
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py
index cd5f54f..802dd04 100644
--- a/nanobot/agent/memory.py
+++ b/nanobot/agent/memory.py
@@ -57,7 +57,6 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None
-
class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
@@ -121,6 +120,7 @@ class MemoryStore:
],
tools=_SAVE_MEMORY_TOOL,
model=model,
+ tool_choice="required",
)
if not response.has_tool_calls:
diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py
index eff0b4f..eb3b3b0 100644
--- a/nanobot/agent/subagent.py
+++ b/nanobot/agent/subagent.py
@@ -28,9 +28,6 @@ class SubagentManager:
workspace: Path,
bus: MessageBus,
model: str | None = None,
- temperature: float = 0.7,
- max_tokens: int = 4096,
- reasoning_effort: str | None = None,
brave_api_key: str | None = None,
web_proxy: str | None = None,
exec_config: "ExecToolConfig | None" = None,
@@ -41,9 +38,6 @@ class SubagentManager:
self.workspace = workspace
self.bus = bus
self.model = model or provider.get_default_model()
- self.temperature = temperature
- self.max_tokens = max_tokens
- self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@@ -128,21 +122,11 @@ class SubagentManager:
messages=messages,
tools=tools.get_definitions(),
model=self.model,
- temperature=self.temperature,
- max_tokens=self.max_tokens,
- reasoning_effort=self.reasoning_effort,
)
if response.has_tool_calls:
tool_call_dicts = [
- {
- "id": tc.id,
- "type": "function",
- "function": {
- "name": tc.name,
- "arguments": json.dumps(tc.arguments, ensure_ascii=False),
- },
- }
+ tc.to_openai_tool_call()
for tc in response.tool_calls
]
messages.append(build_assistant_message(
@@ -231,7 +215,7 @@ Stay focused on the assigned task. Your final response will be reported back to
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
return "\n\n".join(parts)
-
+
async def cancel_by_session(self, session_key: str) -> int:
"""Cancel all subagents for the given session. Returns count cancelled."""
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py
index 7b0b867..02c8331 100644
--- a/nanobot/agent/tools/filesystem.py
+++ b/nanobot/agent/tools/filesystem.py
@@ -1,4 +1,4 @@
-"""File system tools: read, write, edit."""
+"""File system tools: read, write, edit, list."""
import difflib
from pathlib import Path
@@ -23,62 +23,108 @@ def _resolve_path(
return resolved
-class ReadFileTool(Tool):
- """Tool to read file contents."""
-
- _MAX_CHARS = 128_000 # ~128 KB โ prevents OOM from reading huge files into LLM context
+class _FsTool(Tool):
+ """Shared base for filesystem tools โ common init and path resolution."""
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace
self._allowed_dir = allowed_dir
+ def _resolve(self, path: str) -> Path:
+ return _resolve_path(path, self._workspace, self._allowed_dir)
+
+
+# ---------------------------------------------------------------------------
+# read_file
+# ---------------------------------------------------------------------------
+
+class ReadFileTool(_FsTool):
+ """Read file contents with optional line-based pagination."""
+
+ _MAX_CHARS = 128_000
+ _DEFAULT_LIMIT = 2000
+
@property
def name(self) -> str:
return "read_file"
@property
def description(self) -> str:
- return "Read the contents of a file at the given path."
+ return (
+ "Read the contents of a file. Returns numbered lines. "
+ "Use offset and limit to paginate through large files."
+ )
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
- "properties": {"path": {"type": "string", "description": "The file path to read"}},
+ "properties": {
+ "path": {"type": "string", "description": "The file path to read"},
+ "offset": {
+ "type": "integer",
+ "description": "Line number to start reading from (1-indexed, default 1)",
+ "minimum": 1,
+ },
+ "limit": {
+ "type": "integer",
+ "description": "Maximum number of lines to read (default 2000)",
+ "minimum": 1,
+ },
+ },
"required": ["path"],
}
- async def execute(self, path: str, **kwargs: Any) -> str:
+ async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> str:
try:
- file_path = _resolve_path(path, self._workspace, self._allowed_dir)
- if not file_path.exists():
+ fp = self._resolve(path)
+ if not fp.exists():
return f"Error: File not found: {path}"
- if not file_path.is_file():
+ if not fp.is_file():
return f"Error: Not a file: {path}"
- size = file_path.stat().st_size
- if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars โค 4 bytes)
- return (
- f"Error: File too large ({size:,} bytes). "
- f"Use exec tool with head/tail/grep to read portions."
- )
+ all_lines = fp.read_text(encoding="utf-8").splitlines()
+ total = len(all_lines)
- content = file_path.read_text(encoding="utf-8")
- if len(content) > self._MAX_CHARS:
- return content[: self._MAX_CHARS] + f"\n\n... (truncated โ file is {len(content):,} chars, limit {self._MAX_CHARS:,})"
- return content
+ if offset < 1:
+ offset = 1
+ if total == 0:
+ return f"(Empty file: {path})"
+ if offset > total:
+ return f"Error: offset {offset} is beyond end of file ({total} lines)"
+
+ start = offset - 1
+ end = min(start + (limit or self._DEFAULT_LIMIT), total)
+ numbered = [f"{start + i + 1}| {line}" for i, line in enumerate(all_lines[start:end])]
+ result = "\n".join(numbered)
+
+ if len(result) > self._MAX_CHARS:
+ trimmed, chars = [], 0
+ for line in numbered:
+ chars += len(line) + 1
+ if chars > self._MAX_CHARS:
+ break
+ trimmed.append(line)
+ end = start + len(trimmed)
+ result = "\n".join(trimmed)
+
+ if end < total:
+ result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
+ else:
+ result += f"\n\n(End of file โ {total} lines total)"
+ return result
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
- return f"Error reading file: {str(e)}"
+ return f"Error reading file: {e}"
-class WriteFileTool(Tool):
- """Tool to write content to a file."""
+# ---------------------------------------------------------------------------
+# write_file
+# ---------------------------------------------------------------------------
- def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
- self._workspace = workspace
- self._allowed_dir = allowed_dir
+class WriteFileTool(_FsTool):
+ """Write content to a file."""
@property
def name(self) -> str:
@@ -101,22 +147,48 @@ class WriteFileTool(Tool):
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
try:
- file_path = _resolve_path(path, self._workspace, self._allowed_dir)
- file_path.parent.mkdir(parents=True, exist_ok=True)
- file_path.write_text(content, encoding="utf-8")
- return f"Successfully wrote {len(content)} bytes to {file_path}"
+ fp = self._resolve(path)
+ fp.parent.mkdir(parents=True, exist_ok=True)
+ fp.write_text(content, encoding="utf-8")
+ return f"Successfully wrote {len(content)} bytes to {fp}"
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
- return f"Error writing file: {str(e)}"
+ return f"Error writing file: {e}"
-class EditFileTool(Tool):
- """Tool to edit a file by replacing text."""
+# ---------------------------------------------------------------------------
+# edit_file
+# ---------------------------------------------------------------------------
- def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
- self._workspace = workspace
- self._allowed_dir = allowed_dir
+def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
+ """Locate old_text in content: exact first, then line-trimmed sliding window.
+
+ Both inputs should use LF line endings (caller normalises CRLF).
+ Returns (matched_fragment, count) or (None, 0).
+ """
+ if old_text in content:
+ return old_text, content.count(old_text)
+
+ old_lines = old_text.splitlines()
+ if not old_lines:
+ return None, 0
+ stripped_old = [l.strip() for l in old_lines]
+ content_lines = content.splitlines()
+
+ candidates = []
+ for i in range(len(content_lines) - len(stripped_old) + 1):
+ window = content_lines[i : i + len(stripped_old)]
+ if [l.strip() for l in window] == stripped_old:
+ candidates.append("\n".join(window))
+
+ if candidates:
+ return candidates[0], len(candidates)
+ return None, 0
+
+
+class EditFileTool(_FsTool):
+ """Edit a file by replacing text with fallback matching."""
@property
def name(self) -> str:
@@ -124,7 +196,11 @@ class EditFileTool(Tool):
@property
def description(self) -> str:
- return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
+ return (
+ "Edit a file by replacing old_text with new_text. "
+ "Supports minor whitespace/line-ending differences. "
+ "Set replace_all=true to replace every occurrence."
+ )
@property
def parameters(self) -> dict[str, Any]:
@@ -132,40 +208,52 @@ class EditFileTool(Tool):
"type": "object",
"properties": {
"path": {"type": "string", "description": "The file path to edit"},
- "old_text": {"type": "string", "description": "The exact text to find and replace"},
+ "old_text": {"type": "string", "description": "The text to find and replace"},
"new_text": {"type": "string", "description": "The text to replace with"},
+ "replace_all": {
+ "type": "boolean",
+ "description": "Replace all occurrences (default false)",
+ },
},
"required": ["path", "old_text", "new_text"],
}
- async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
+ async def execute(
+ self, path: str, old_text: str, new_text: str,
+ replace_all: bool = False, **kwargs: Any,
+ ) -> str:
try:
- file_path = _resolve_path(path, self._workspace, self._allowed_dir)
- if not file_path.exists():
+ fp = self._resolve(path)
+ if not fp.exists():
return f"Error: File not found: {path}"
- content = file_path.read_text(encoding="utf-8")
+ raw = fp.read_bytes()
+ uses_crlf = b"\r\n" in raw
+ content = raw.decode("utf-8").replace("\r\n", "\n")
+ match, count = _find_match(content, old_text.replace("\r\n", "\n"))
- if old_text not in content:
- return self._not_found_message(old_text, content, path)
+ if match is None:
+ return self._not_found_msg(old_text, content, path)
+ if count > 1 and not replace_all:
+ return (
+ f"Warning: old_text appears {count} times. "
+ "Provide more context to make it unique, or set replace_all=true."
+ )
- # Count occurrences
- count = content.count(old_text)
- if count > 1:
- return f"Warning: old_text appears {count} times. Please provide more context to make it unique."
+ norm_new = new_text.replace("\r\n", "\n")
+ new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
+ if uses_crlf:
+ new_content = new_content.replace("\n", "\r\n")
- new_content = content.replace(old_text, new_text, 1)
- file_path.write_text(new_content, encoding="utf-8")
-
- return f"Successfully edited {file_path}"
+ fp.write_bytes(new_content.encode("utf-8"))
+ return f"Successfully edited {fp}"
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
- return f"Error editing file: {str(e)}"
+ return f"Error editing file: {e}"
@staticmethod
- def _not_found_message(old_text: str, content: str, path: str) -> str:
- """Build a helpful error when old_text is not found."""
+ def _not_found_msg(old_text: str, content: str, path: str) -> str:
lines = content.splitlines(keepends=True)
old_lines = old_text.splitlines(keepends=True)
window = len(old_lines)
@@ -177,27 +265,29 @@ class EditFileTool(Tool):
best_ratio, best_start = ratio, i
if best_ratio > 0.5:
- diff = "\n".join(
- difflib.unified_diff(
- old_lines,
- lines[best_start : best_start + window],
- fromfile="old_text (provided)",
- tofile=f"{path} (actual, line {best_start + 1})",
- lineterm="",
- )
- )
+ diff = "\n".join(difflib.unified_diff(
+ old_lines, lines[best_start : best_start + window],
+ fromfile="old_text (provided)",
+ tofile=f"{path} (actual, line {best_start + 1})",
+ lineterm="",
+ ))
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
- return (
- f"Error: old_text not found in {path}. No similar text found. Verify the file content."
- )
+ return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
-class ListDirTool(Tool):
- """Tool to list directory contents."""
+# ---------------------------------------------------------------------------
+# list_dir
+# ---------------------------------------------------------------------------
- def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
- self._workspace = workspace
- self._allowed_dir = allowed_dir
+class ListDirTool(_FsTool):
+ """List directory contents with optional recursion."""
+
+ _DEFAULT_MAX = 200
+ _IGNORE_DIRS = {
+ ".git", "node_modules", "__pycache__", ".venv", "venv",
+ "dist", "build", ".tox", ".mypy_cache", ".pytest_cache",
+ ".ruff_cache", ".coverage", "htmlcov",
+ }
@property
def name(self) -> str:
@@ -205,34 +295,71 @@ class ListDirTool(Tool):
@property
def description(self) -> str:
- return "List the contents of a directory."
+ return (
+ "List the contents of a directory. "
+ "Set recursive=true to explore nested structure. "
+ "Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
+ )
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
- "properties": {"path": {"type": "string", "description": "The directory path to list"}},
+ "properties": {
+ "path": {"type": "string", "description": "The directory path to list"},
+ "recursive": {
+ "type": "boolean",
+ "description": "Recursively list all files (default false)",
+ },
+ "max_entries": {
+ "type": "integer",
+ "description": "Maximum entries to return (default 200)",
+ "minimum": 1,
+ },
+ },
"required": ["path"],
}
- async def execute(self, path: str, **kwargs: Any) -> str:
+ async def execute(
+ self, path: str, recursive: bool = False,
+ max_entries: int | None = None, **kwargs: Any,
+ ) -> str:
try:
- dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
- if not dir_path.exists():
+ dp = self._resolve(path)
+ if not dp.exists():
return f"Error: Directory not found: {path}"
- if not dir_path.is_dir():
+ if not dp.is_dir():
return f"Error: Not a directory: {path}"
- items = []
- for item in sorted(dir_path.iterdir()):
- prefix = "๐ " if item.is_dir() else "๐ "
- items.append(f"{prefix}{item.name}")
+ cap = max_entries or self._DEFAULT_MAX
+ items: list[str] = []
+ total = 0
- if not items:
+ if recursive:
+ for item in sorted(dp.rglob("*")):
+ if any(p in self._IGNORE_DIRS for p in item.parts):
+ continue
+ total += 1
+ if len(items) < cap:
+ rel = item.relative_to(dp)
+ items.append(f"{rel}/" if item.is_dir() else str(rel))
+ else:
+ for item in sorted(dp.iterdir()):
+ if item.name in self._IGNORE_DIRS:
+ continue
+ total += 1
+ if len(items) < cap:
+ pfx = "๐ " if item.is_dir() else "๐ "
+ items.append(f"{pfx}{item.name}")
+
+ if not items and total == 0:
return f"Directory {path} is empty"
- return "\n".join(items)
+ result = "\n".join(items)
+ if total > cap:
+ result += f"\n\n(truncated, showing first {cap} of {total} entries)"
+ return result
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
- return f"Error listing directory: {str(e)}"
+ return f"Error listing directory: {e}"
diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py
index ce19920..bf1b082 100644
--- a/nanobot/agent/tools/shell.py
+++ b/nanobot/agent/tools/shell.py
@@ -42,6 +42,9 @@ class ExecTool(Tool):
def name(self) -> str:
return "exec"
+ _MAX_TIMEOUT = 600
+ _MAX_OUTPUT = 10_000
+
@property
def description(self) -> str:
return "Execute a shell command and return its output. Use with caution."
@@ -53,22 +56,36 @@ class ExecTool(Tool):
"properties": {
"command": {
"type": "string",
- "description": "The shell command to execute"
+ "description": "The shell command to execute",
},
"working_dir": {
"type": "string",
- "description": "Optional working directory for the command"
- }
+ "description": "Optional working directory for the command",
+ },
+ "timeout": {
+ "type": "integer",
+ "description": (
+ "Timeout in seconds. Increase for long-running commands "
+ "like compilation or installation (default 60, max 600)."
+ ),
+ "minimum": 1,
+ "maximum": 600,
+ },
},
- "required": ["command"]
+ "required": ["command"],
}
-
- async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str:
+
+ async def execute(
+ self, command: str, working_dir: str | None = None,
+ timeout: int | None = None, **kwargs: Any,
+ ) -> str:
cwd = working_dir or self.working_dir or os.getcwd()
guard_error = self._guard_command(command, cwd)
if guard_error:
return guard_error
-
+
+ effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT)
+
env = os.environ.copy()
if self.path_append:
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
@@ -81,44 +98,46 @@ class ExecTool(Tool):
cwd=cwd,
env=env,
)
-
+
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
- timeout=self.timeout
+ timeout=effective_timeout,
)
except asyncio.TimeoutError:
process.kill()
- # Wait for the process to fully terminate so pipes are
- # drained and file descriptors are released.
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
pass
- return f"Error: Command timed out after {self.timeout} seconds"
-
+ return f"Error: Command timed out after {effective_timeout} seconds"
+
output_parts = []
-
+
if stdout:
output_parts.append(stdout.decode("utf-8", errors="replace"))
-
+
if stderr:
stderr_text = stderr.decode("utf-8", errors="replace")
if stderr_text.strip():
output_parts.append(f"STDERR:\n{stderr_text}")
-
- if process.returncode != 0:
- output_parts.append(f"\nExit code: {process.returncode}")
-
+
+ output_parts.append(f"\nExit code: {process.returncode}")
+
result = "\n".join(output_parts) if output_parts else "(no output)"
-
- # Truncate very long output
- max_len = 10000
+
+ # Head + tail truncation to preserve both start and end of output
+ max_len = self._MAX_OUTPUT
if len(result) > max_len:
- result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)"
-
+ half = max_len // 2
+ result = (
+ result[:half]
+ + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
+ + result[-half:]
+ )
+
return result
-
+
except Exception as e:
return f"Error executing command: {str(e)}"
@@ -143,7 +162,8 @@ class ExecTool(Tool):
for raw in self._extract_absolute_paths(cmd):
try:
- p = Path(raw.strip()).resolve()
+ expanded = os.path.expandvars(raw.strip())
+ p = Path(expanded).expanduser().resolve()
except Exception:
continue
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
@@ -154,5 +174,6 @@ class ExecTool(Tool):
@staticmethod
def _extract_absolute_paths(command: str) -> list[str]:
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
- posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only
- return win_paths + posix_paths
+ posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
+ home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
+ return win_paths + posix_paths + home_paths
diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py
index dc53ba4..74c540a 100644
--- a/nanobot/channels/base.py
+++ b/nanobot/channels/base.py
@@ -1,6 +1,9 @@
"""Base channel interface for chat platforms."""
+from __future__ import annotations
+
from abc import ABC, abstractmethod
+from pathlib import Path
from typing import Any
from loguru import logger
@@ -18,6 +21,8 @@ class BaseChannel(ABC):
"""
name: str = "base"
+ display_name: str = "Base"
+ transcription_api_key: str = ""
def __init__(self, config: Any, bus: MessageBus):
"""
@@ -31,6 +36,19 @@ class BaseChannel(ABC):
self.bus = bus
self._running = False
+ async def transcribe_audio(self, file_path: str | Path) -> str:
+ """Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
+ if not self.transcription_api_key:
+ return ""
+ try:
+ from nanobot.providers.transcription import GroqTranscriptionProvider
+
+ provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
+ return await provider.transcribe(file_path)
+ except Exception as e:
+ logger.warning("{}: audio transcription failed: {}", self.name, e)
+ return ""
+
@abstractmethod
async def start(self) -> None:
"""
diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py
index cdcba57..4626d95 100644
--- a/nanobot/channels/dingtalk.py
+++ b/nanobot/channels/dingtalk.py
@@ -114,6 +114,7 @@ class DingTalkChannel(BaseChannel):
"""
name = "dingtalk"
+ display_name = "DingTalk"
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py
index 2ee4f77..afa20c9 100644
--- a/nanobot/channels/discord.py
+++ b/nanobot/channels/discord.py
@@ -25,6 +25,7 @@ class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket."""
name = "discord"
+ display_name = "Discord"
def __init__(self, config: DiscordConfig, bus: MessageBus):
super().__init__(config, bus)
diff --git a/nanobot/channels/email.py b/nanobot/channels/email.py
index 16771fb..46c2103 100644
--- a/nanobot/channels/email.py
+++ b/nanobot/channels/email.py
@@ -35,6 +35,7 @@ class EmailChannel(BaseChannel):
"""
name = "email"
+ display_name = "Email"
_IMAP_MONTHS = (
"Jan",
"Feb",
diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py
index 0409c32..2eb6a6a 100644
--- a/nanobot/channels/feishu.py
+++ b/nanobot/channels/feishu.py
@@ -244,11 +244,11 @@ class FeishuChannel(BaseChannel):
"""
name = "feishu"
+ display_name = "Feishu"
- def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""):
+ def __init__(self, config: FeishuConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: FeishuConfig = config
- self.groq_api_key = groq_api_key
self._client: Any = None
self._ws_client: Any = None
self._ws_thread: threading.Thread | None = None
@@ -352,6 +352,27 @@ class FeishuChannel(BaseChannel):
self._running = False
logger.info("Feishu bot stopped")
+ def _is_bot_mentioned(self, message: Any) -> bool:
+ """Check if the bot is @mentioned in the message."""
+ raw_content = message.content or ""
+ if "@_all" in raw_content:
+ return True
+
+ for mention in getattr(message, "mentions", None) or []:
+ mid = getattr(mention, "id", None)
+ if not mid:
+ continue
+ # Bot mentions have no user_id (None or "") but a valid open_id
+ if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"):
+ return True
+ return False
+
+ def _is_group_message_for_bot(self, message: Any) -> bool:
+ """Allow group messages when policy is open or bot is @mentioned."""
+ if self.config.group_policy == "open":
+ return True
+ return self._is_bot_mentioned(message)
+
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
@@ -893,6 +914,10 @@ class FeishuChannel(BaseChannel):
chat_type = message.chat_type
msg_type = message.message_type
+ if chat_type == "group" and not self._is_group_message_for_bot(message):
+ logger.debug("Feishu: skipping group message (not mentioned)")
+ return
+
# Add reaction
await self._add_reaction(message_id, self.config.react_emoji)
@@ -928,16 +953,10 @@ class FeishuChannel(BaseChannel):
if file_path:
media_paths.append(file_path)
- # Transcribe audio using Groq Whisper
- if msg_type == "audio" and file_path and self.groq_api_key:
- try:
- from nanobot.providers.transcription import GroqTranscriptionProvider
- transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
- transcription = await transcriber.transcribe(file_path)
- if transcription:
- content_text = f"[transcription: {transcription}]"
- except Exception as e:
- logger.warning("Failed to transcribe audio: {}", e)
+ if msg_type == "audio" and file_path:
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ content_text = f"[transcription: {transcription}]"
content_parts.append(content_text)
diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py
index 51539dd..8288ad0 100644
--- a/nanobot/channels/manager.py
+++ b/nanobot/channels/manager.py
@@ -7,7 +7,6 @@ from typing import Any
from loguru import logger
-from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Config
@@ -32,123 +31,23 @@ class ChannelManager:
self._init_channels()
def _init_channels(self) -> None:
- """Initialize channels based on config."""
+ """Initialize channels discovered via pkgutil scan."""
+ from nanobot.channels.registry import discover_channel_names, load_channel_class
- # Telegram channel
- if self.config.channels.telegram.enabled:
+ groq_key = self.config.providers.groq.api_key
+
+ for modname in discover_channel_names():
+ section = getattr(self.config.channels, modname, None)
+ if not section or not getattr(section, "enabled", False):
+ continue
try:
- from nanobot.channels.telegram import TelegramChannel
- self.channels["telegram"] = TelegramChannel(
- self.config.channels.telegram,
- self.bus,
- groq_api_key=self.config.providers.groq.api_key,
- )
- logger.info("Telegram channel enabled")
+ cls = load_channel_class(modname)
+ channel = cls(section, self.bus)
+ channel.transcription_api_key = groq_key
+ self.channels[modname] = channel
+ logger.info("{} channel enabled", cls.display_name)
except ImportError as e:
- logger.warning("Telegram channel not available: {}", e)
-
- # WhatsApp channel
- if self.config.channels.whatsapp.enabled:
- try:
- from nanobot.channels.whatsapp import WhatsAppChannel
- self.channels["whatsapp"] = WhatsAppChannel(
- self.config.channels.whatsapp, self.bus
- )
- logger.info("WhatsApp channel enabled")
- except ImportError as e:
- logger.warning("WhatsApp channel not available: {}", e)
-
- # Discord channel
- if self.config.channels.discord.enabled:
- try:
- from nanobot.channels.discord import DiscordChannel
- self.channels["discord"] = DiscordChannel(
- self.config.channels.discord, self.bus
- )
- logger.info("Discord channel enabled")
- except ImportError as e:
- logger.warning("Discord channel not available: {}", e)
-
- # Feishu channel
- if self.config.channels.feishu.enabled:
- try:
- from nanobot.channels.feishu import FeishuChannel
- self.channels["feishu"] = FeishuChannel(
- self.config.channels.feishu, self.bus,
- groq_api_key=self.config.providers.groq.api_key,
- )
- logger.info("Feishu channel enabled")
- except ImportError as e:
- logger.warning("Feishu channel not available: {}", e)
-
- # Mochat channel
- if self.config.channels.mochat.enabled:
- try:
- from nanobot.channels.mochat import MochatChannel
-
- self.channels["mochat"] = MochatChannel(
- self.config.channels.mochat, self.bus
- )
- logger.info("Mochat channel enabled")
- except ImportError as e:
- logger.warning("Mochat channel not available: {}", e)
-
- # DingTalk channel
- if self.config.channels.dingtalk.enabled:
- try:
- from nanobot.channels.dingtalk import DingTalkChannel
- self.channels["dingtalk"] = DingTalkChannel(
- self.config.channels.dingtalk, self.bus
- )
- logger.info("DingTalk channel enabled")
- except ImportError as e:
- logger.warning("DingTalk channel not available: {}", e)
-
- # Email channel
- if self.config.channels.email.enabled:
- try:
- from nanobot.channels.email import EmailChannel
- self.channels["email"] = EmailChannel(
- self.config.channels.email, self.bus
- )
- logger.info("Email channel enabled")
- except ImportError as e:
- logger.warning("Email channel not available: {}", e)
-
- # Slack channel
- if self.config.channels.slack.enabled:
- try:
- from nanobot.channels.slack import SlackChannel
- self.channels["slack"] = SlackChannel(
- self.config.channels.slack, self.bus
- )
- logger.info("Slack channel enabled")
- except ImportError as e:
- logger.warning("Slack channel not available: {}", e)
-
- # QQ channel
- if self.config.channels.qq.enabled:
- try:
- from nanobot.channels.qq import QQChannel
- self.channels["qq"] = QQChannel(
- self.config.channels.qq,
- self.bus,
- )
- logger.info("QQ channel enabled")
- except ImportError as e:
- logger.warning("QQ channel not available: {}", e)
-
- # Matrix channel
- if self.config.channels.matrix.enabled:
- try:
- from nanobot.channels.matrix import MatrixChannel
- self.channels["matrix"] = MatrixChannel(
- self.config.channels.matrix,
- self.bus,
- )
- logger.info("Matrix channel enabled")
- except ImportError as e:
- logger.warning("Matrix channel not available: {}", e)
+ logger.warning("{} channel not available: {}", modname, e)
self._validate_allow_from()
diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py
index 63cb0ca..0d7a908 100644
--- a/nanobot/channels/matrix.py
+++ b/nanobot/channels/matrix.py
@@ -37,6 +37,7 @@ except ImportError as e:
) from e
from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_data_dir, get_media_dir
from nanobot.utils.helpers import safe_filename
@@ -146,15 +147,15 @@ class MatrixChannel(BaseChannel):
"""Matrix (Element) channel using long-polling sync."""
name = "matrix"
+ display_name = "Matrix"
- def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
- workspace: Path | None = None):
+ def __init__(self, config: Any, bus: MessageBus):
super().__init__(config, bus)
self.client: AsyncClient | None = None
self._sync_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
- self._restrict_to_workspace = restrict_to_workspace
- self._workspace = workspace.expanduser().resolve() if workspace else None
+ self._restrict_to_workspace = False
+ self._workspace: Path | None = None
self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False
@@ -677,7 +678,14 @@ class MatrixChannel(BaseChannel):
parts: list[str] = []
if isinstance(body := getattr(event, "body", None), str) and body.strip():
parts.append(body.strip())
- if marker:
+
+ if attachment and attachment.get("type") == "audio":
+ transcription = await self.transcribe_audio(attachment["path"])
+ if transcription:
+ parts.append(f"[transcription: {transcription}]")
+ else:
+ parts.append(marker)
+ elif marker:
parts.append(marker)
await self._start_typing_keepalive(room.room_id)
diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py
index 09e31c3..52e246f 100644
--- a/nanobot/channels/mochat.py
+++ b/nanobot/channels/mochat.py
@@ -216,6 +216,7 @@ class MochatChannel(BaseChannel):
"""Mochat channel using socket.io with fallback polling workers."""
name = "mochat"
+ display_name = "Mochat"
def __init__(self, config: MochatConfig, bus: MessageBus):
super().__init__(config, bus)
diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py
index 5ac06e3..792cc12 100644
--- a/nanobot/channels/qq.py
+++ b/nanobot/channels/qq.py
@@ -54,6 +54,7 @@ class QQChannel(BaseChannel):
"""QQ channel using botpy SDK with WebSocket connection."""
name = "qq"
+ display_name = "QQ"
def __init__(self, config: QQConfig, bus: MessageBus):
super().__init__(config, bus)
diff --git a/nanobot/channels/registry.py b/nanobot/channels/registry.py
new file mode 100644
index 0000000..eb30ff7
--- /dev/null
+++ b/nanobot/channels/registry.py
@@ -0,0 +1,35 @@
+"""Auto-discovery for channel modules โ no hardcoded registry."""
+
+from __future__ import annotations
+
+import importlib
+import pkgutil
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from nanobot.channels.base import BaseChannel
+
+_INTERNAL = frozenset({"base", "manager", "registry"})
+
+
+def discover_channel_names() -> list[str]:
+ """Return all channel module names by scanning the package (zero imports)."""
+ import nanobot.channels as pkg
+
+ return [
+ name
+ for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
+ if name not in _INTERNAL and not ispkg
+ ]
+
+
+def load_channel_class(module_name: str) -> type[BaseChannel]:
+ """Import *module_name* and return the first BaseChannel subclass found."""
+ from nanobot.channels.base import BaseChannel as _Base
+
+ mod = importlib.import_module(f"nanobot.channels.{module_name}")
+ for attr in dir(mod):
+ obj = getattr(mod, attr)
+ if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
+ return obj
+ raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py
index 0384d8d..5819212 100644
--- a/nanobot/channels/slack.py
+++ b/nanobot/channels/slack.py
@@ -21,6 +21,7 @@ class SlackChannel(BaseChannel):
"""Slack channel using Socket Mode."""
name = "slack"
+ display_name = "Slack"
def __init__(self, config: SlackConfig, bus: MessageBus):
super().__init__(config, bus)
diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py
index 5b294cc..916685b 100644
--- a/nanobot/channels/telegram.py
+++ b/nanobot/channels/telegram.py
@@ -20,6 +20,7 @@ from nanobot.config.schema import TelegramConfig
from nanobot.utils.helpers import split_message
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
+TELEGRAM_REPLY_CONTEXT_MAX_LEN = TELEGRAM_MAX_MESSAGE_LEN # Max length for reply context in user message
def _strip_md(s: str) -> str:
@@ -155,6 +156,7 @@ class TelegramChannel(BaseChannel):
"""
name = "telegram"
+ display_name = "Telegram"
# Commands registered with Telegram's command menu
BOT_COMMANDS = [
@@ -162,17 +164,12 @@ class TelegramChannel(BaseChannel):
BotCommand("new", "Start a new conversation"),
BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"),
+ BotCommand("restart", "Restart the bot"),
]
- def __init__(
- self,
- config: TelegramConfig,
- bus: MessageBus,
- groq_api_key: str = "",
- ):
+ def __init__(self, config: TelegramConfig, bus: MessageBus):
super().__init__(config, bus)
self.config: TelegramConfig = config
- self.groq_api_key = groq_api_key
self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
@@ -225,6 +222,7 @@ class TelegramChannel(BaseChannel):
self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("stop", self._forward_command))
+ self._app.add_handler(CommandHandler("restart", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents
@@ -456,6 +454,7 @@ class TelegramChannel(BaseChannel):
@staticmethod
def _build_message_metadata(message, user) -> dict:
"""Build common Telegram inbound metadata payload."""
+ reply_to = getattr(message, "reply_to_message", None)
return {
"message_id": message.message_id,
"user_id": user.id,
@@ -464,8 +463,73 @@ class TelegramChannel(BaseChannel):
"is_group": message.chat.type != "private",
"message_thread_id": getattr(message, "message_thread_id", None),
"is_forum": bool(getattr(message.chat, "is_forum", False)),
+ "reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
}
+ @staticmethod
+ def _extract_reply_context(message) -> str | None:
+ """Extract text from the message being replied to, if any."""
+ reply = getattr(message, "reply_to_message", None)
+ if not reply:
+ return None
+ text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
+ if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
+ text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
+ return f"[Reply to: {text}]" if text else None
+
+ async def _download_message_media(
+ self, msg, *, add_failure_content: bool = False
+ ) -> tuple[list[str], list[str]]:
+ """Download media from a message (current or reply). Returns (media_paths, content_parts)."""
+ media_file = None
+ media_type = None
+ if getattr(msg, "photo", None):
+ media_file = msg.photo[-1]
+ media_type = "image"
+ elif getattr(msg, "voice", None):
+ media_file = msg.voice
+ media_type = "voice"
+ elif getattr(msg, "audio", None):
+ media_file = msg.audio
+ media_type = "audio"
+ elif getattr(msg, "document", None):
+ media_file = msg.document
+ media_type = "file"
+ elif getattr(msg, "video", None):
+ media_file = msg.video
+ media_type = "video"
+ elif getattr(msg, "video_note", None):
+ media_file = msg.video_note
+ media_type = "video"
+ elif getattr(msg, "animation", None):
+ media_file = msg.animation
+ media_type = "animation"
+ if not media_file or not self._app:
+ return [], []
+ try:
+ file = await self._app.bot.get_file(media_file.file_id)
+ ext = self._get_extension(
+ media_type,
+ getattr(media_file, "mime_type", None),
+ getattr(media_file, "file_name", None),
+ )
+ media_dir = get_media_dir("telegram")
+ file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
+ await file.download_to_drive(str(file_path))
+ path_str = str(file_path)
+ if media_type in ("voice", "audio"):
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ logger.info("Transcribed {}: {}...", media_type, transcription[:50])
+ return [path_str], [f"[transcription: {transcription}]"]
+ return [path_str], [f"[{media_type}: {path_str}]"]
+ return [path_str], [f"[{media_type}: {path_str}]"]
+ except Exception as e:
+ logger.warning("Failed to download message media: {}", e)
+ if add_failure_content:
+ return [], [f"[{media_type}: download failed]"]
+ return [], []
+
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
"""Load bot identity once and reuse it for mention/reply checks."""
if self._bot_user_id is not None or self._bot_username is not None:
@@ -550,7 +614,7 @@ class TelegramChannel(BaseChannel):
await self._handle_message(
sender_id=self._sender_id(user),
chat_id=str(message.chat_id),
- content=message.text,
+ content=message.text or "",
metadata=self._build_message_metadata(message, user),
session_key=self._derive_topic_session_key(message),
)
@@ -582,57 +646,26 @@ class TelegramChannel(BaseChannel):
if message.caption:
content_parts.append(message.caption)
- # Handle media files
- media_file = None
- media_type = None
-
- if message.photo:
- media_file = message.photo[-1] # Largest photo
- media_type = "image"
- elif message.voice:
- media_file = message.voice
- media_type = "voice"
- elif message.audio:
- media_file = message.audio
- media_type = "audio"
- elif message.document:
- media_file = message.document
- media_type = "file"
-
- # Download media if present
- if media_file and self._app:
- try:
- file = await self._app.bot.get_file(media_file.file_id)
- ext = self._get_extension(
- media_type,
- getattr(media_file, 'mime_type', None),
- getattr(media_file, 'file_name', None),
- )
- media_dir = get_media_dir("telegram")
-
- file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
- await file.download_to_drive(str(file_path))
-
- media_paths.append(str(file_path))
-
- # Handle voice transcription
- if media_type == "voice" or media_type == "audio":
- from nanobot.providers.transcription import GroqTranscriptionProvider
- transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
- transcription = await transcriber.transcribe(file_path)
- if transcription:
- logger.info("Transcribed {}: {}...", media_type, transcription[:50])
- content_parts.append(f"[transcription: {transcription}]")
- else:
- content_parts.append(f"[{media_type}: {file_path}]")
- else:
- content_parts.append(f"[{media_type}: {file_path}]")
-
- logger.debug("Downloaded {} to {}", media_type, file_path)
- except Exception as e:
- logger.error("Failed to download media: {}", e)
- content_parts.append(f"[{media_type}: download failed]")
+ # Download current message media
+ current_media_paths, current_media_parts = await self._download_message_media(
+ message, add_failure_content=True
+ )
+ media_paths.extend(current_media_paths)
+ content_parts.extend(current_media_parts)
+ if current_media_paths:
+ logger.debug("Downloaded message media to {}", current_media_paths[0])
+ # Reply context: text and/or media from the replied-to message
+ reply = getattr(message, "reply_to_message", None)
+ if reply is not None:
+ reply_ctx = self._extract_reply_context(message)
+ reply_media, reply_media_parts = await self._download_message_media(reply)
+ if reply_media:
+ media_paths = reply_media + media_paths
+ logger.debug("Attached replied-to media: {}", reply_media[0])
+ tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
+ if tag:
+ content_parts.insert(0, tag)
content = "\n".join(content_parts) if content_parts else "[empty message]"
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py
new file mode 100644
index 0000000..e0f4ae0
--- /dev/null
+++ b/nanobot/channels/wecom.py
@@ -0,0 +1,353 @@
+"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
+
+import asyncio
+import importlib.util
+import os
+from collections import OrderedDict
+from typing import Any
+
+from loguru import logger
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_media_dir
+from nanobot.config.schema import WecomConfig
+
+WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
+
+# Message type display mapping
+MSG_TYPE_MAP = {
+ "image": "[image]",
+ "voice": "[voice]",
+ "file": "[file]",
+ "mixed": "[mixed content]",
+}
+
+
+class WecomChannel(BaseChannel):
+ """
+ WeCom (Enterprise WeChat) channel using WebSocket long connection.
+
+ Uses WebSocket to receive events - no public IP or webhook required.
+
+ Requires:
+ - Bot ID and Secret from WeCom AI Bot platform
+ """
+
+ name = "wecom"
+ display_name = "WeCom"
+
+ def __init__(self, config: WecomConfig, bus: MessageBus):
+ super().__init__(config, bus)
+ self.config: WecomConfig = config
+ self._client: Any = None
+ self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
+ self._loop: asyncio.AbstractEventLoop | None = None
+ self._generate_req_id = None
+ # Store frame headers for each chat to enable replies
+ self._chat_frames: dict[str, Any] = {}
+
+ async def start(self) -> None:
+ """Start the WeCom bot with WebSocket long connection."""
+ if not WECOM_AVAILABLE:
+ logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
+ return
+
+ if not self.config.bot_id or not self.config.secret:
+ logger.error("WeCom bot_id and secret not configured")
+ return
+
+ from wecom_aibot_sdk import WSClient, generate_req_id
+
+ self._running = True
+ self._loop = asyncio.get_running_loop()
+ self._generate_req_id = generate_req_id
+
+ # Create WebSocket client
+ self._client = WSClient({
+ "bot_id": self.config.bot_id,
+ "secret": self.config.secret,
+ "reconnect_interval": 1000,
+ "max_reconnect_attempts": -1, # Infinite reconnect
+ "heartbeat_interval": 30000,
+ })
+
+ # Register event handlers
+ self._client.on("connected", self._on_connected)
+ self._client.on("authenticated", self._on_authenticated)
+ self._client.on("disconnected", self._on_disconnected)
+ self._client.on("error", self._on_error)
+ self._client.on("message.text", self._on_text_message)
+ self._client.on("message.image", self._on_image_message)
+ self._client.on("message.voice", self._on_voice_message)
+ self._client.on("message.file", self._on_file_message)
+ self._client.on("message.mixed", self._on_mixed_message)
+ self._client.on("event.enter_chat", self._on_enter_chat)
+
+ logger.info("WeCom bot starting with WebSocket long connection")
+ logger.info("No public IP required - using WebSocket to receive events")
+
+ # Connect
+ await self._client.connect_async()
+
+ # Keep running until stopped
+ while self._running:
+ await asyncio.sleep(1)
+
+ async def stop(self) -> None:
+ """Stop the WeCom bot."""
+ self._running = False
+ if self._client:
+ await self._client.disconnect()
+ logger.info("WeCom bot stopped")
+
+ async def _on_connected(self, frame: Any) -> None:
+ """Handle WebSocket connected event."""
+ logger.info("WeCom WebSocket connected")
+
+ async def _on_authenticated(self, frame: Any) -> None:
+ """Handle authentication success event."""
+ logger.info("WeCom authenticated successfully")
+
+ async def _on_disconnected(self, frame: Any) -> None:
+ """Handle WebSocket disconnected event."""
+ reason = frame.body if hasattr(frame, 'body') else str(frame)
+ logger.warning("WeCom WebSocket disconnected: {}", reason)
+
+ async def _on_error(self, frame: Any) -> None:
+ """Handle error event."""
+ logger.error("WeCom error: {}", frame)
+
+ async def _on_text_message(self, frame: Any) -> None:
+ """Handle text message."""
+ await self._process_message(frame, "text")
+
+ async def _on_image_message(self, frame: Any) -> None:
+ """Handle image message."""
+ await self._process_message(frame, "image")
+
+ async def _on_voice_message(self, frame: Any) -> None:
+ """Handle voice message."""
+ await self._process_message(frame, "voice")
+
+ async def _on_file_message(self, frame: Any) -> None:
+ """Handle file message."""
+ await self._process_message(frame, "file")
+
+ async def _on_mixed_message(self, frame: Any) -> None:
+ """Handle mixed content message."""
+ await self._process_message(frame, "mixed")
+
+ async def _on_enter_chat(self, frame: Any) -> None:
+ """Handle enter_chat event (user opens chat with bot)."""
+ try:
+ # Extract body from WsFrame dataclass or dict
+ if hasattr(frame, 'body'):
+ body = frame.body or {}
+ elif isinstance(frame, dict):
+ body = frame.get("body", frame)
+ else:
+ body = {}
+
+ chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
+
+ if chat_id and self.config.welcome_message:
+ await self._client.reply_welcome(frame, {
+ "msgtype": "text",
+ "text": {"content": self.config.welcome_message},
+ })
+ except Exception as e:
+ logger.error("Error handling enter_chat: {}", e)
+
+ async def _process_message(self, frame: Any, msg_type: str) -> None:
+ """Process incoming message and forward to bus."""
+ try:
+ # Extract body from WsFrame dataclass or dict
+ if hasattr(frame, 'body'):
+ body = frame.body or {}
+ elif isinstance(frame, dict):
+ body = frame.get("body", frame)
+ else:
+ body = {}
+
+ # Ensure body is a dict
+ if not isinstance(body, dict):
+ logger.warning("Invalid body type: {}", type(body))
+ return
+
+ # Extract message info
+ msg_id = body.get("msgid", "")
+ if not msg_id:
+ msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
+
+ # Deduplication check
+ if msg_id in self._processed_message_ids:
+ return
+ self._processed_message_ids[msg_id] = None
+
+ # Trim cache
+ while len(self._processed_message_ids) > 1000:
+ self._processed_message_ids.popitem(last=False)
+
+ # Extract sender info from "from" field (SDK format)
+ from_info = body.get("from", {})
+ sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
+
+ # For single chat, chatid is the sender's userid
+ # For group chat, chatid is provided in body
+ chat_type = body.get("chattype", "single")
+ chat_id = body.get("chatid", sender_id)
+
+ content_parts = []
+
+ if msg_type == "text":
+ text = body.get("text", {}).get("content", "")
+ if text:
+ content_parts.append(text)
+
+ elif msg_type == "image":
+ image_info = body.get("image", {})
+ file_url = image_info.get("url", "")
+ aes_key = image_info.get("aeskey", "")
+
+ if file_url and aes_key:
+ file_path = await self._download_and_save_media(file_url, aes_key, "image")
+ if file_path:
+ filename = os.path.basename(file_path)
+ content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
+ else:
+ content_parts.append("[image: download failed]")
+ else:
+ content_parts.append("[image: download failed]")
+
+ elif msg_type == "voice":
+ voice_info = body.get("voice", {})
+ # Voice message already contains transcribed content from WeCom
+ voice_content = voice_info.get("content", "")
+ if voice_content:
+ content_parts.append(f"[voice] {voice_content}")
+ else:
+ content_parts.append("[voice]")
+
+ elif msg_type == "file":
+ file_info = body.get("file", {})
+ file_url = file_info.get("url", "")
+ aes_key = file_info.get("aeskey", "")
+ file_name = file_info.get("name", "unknown")
+
+ if file_url and aes_key:
+ file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
+ if file_path:
+ content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
+ else:
+ content_parts.append(f"[file: {file_name}: download failed]")
+ else:
+ content_parts.append(f"[file: {file_name}: download failed]")
+
+ elif msg_type == "mixed":
+ # Mixed content contains multiple message items
+ msg_items = body.get("mixed", {}).get("item", [])
+ for item in msg_items:
+ item_type = item.get("type", "")
+ if item_type == "text":
+ text = item.get("text", {}).get("content", "")
+ if text:
+ content_parts.append(text)
+ else:
+ content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
+
+ else:
+ content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
+
+ content = "\n".join(content_parts) if content_parts else ""
+
+ if not content:
+ return
+
+ # Store frame for this chat to enable replies
+ self._chat_frames[chat_id] = frame
+
+ # Forward to message bus
+ # Note: media paths are included in content for broader model compatibility
+ await self._handle_message(
+ sender_id=sender_id,
+ chat_id=chat_id,
+ content=content,
+ media=None,
+ metadata={
+ "message_id": msg_id,
+ "msg_type": msg_type,
+ "chat_type": chat_type,
+ }
+ )
+
+ except Exception as e:
+ logger.error("Error processing WeCom message: {}", e)
+
+ async def _download_and_save_media(
+ self,
+ file_url: str,
+ aes_key: str,
+ media_type: str,
+ filename: str | None = None,
+ ) -> str | None:
+ """
+ Download and decrypt media from WeCom.
+
+ Returns:
+ file_path or None if download failed
+ """
+ try:
+ data, fname = await self._client.download_file(file_url, aes_key)
+
+ if not data:
+ logger.warning("Failed to download media from WeCom")
+ return None
+
+ media_dir = get_media_dir("wecom")
+ if not filename:
+ filename = fname or f"{media_type}_{hash(file_url) % 100000}"
+ filename = os.path.basename(filename)
+
+ file_path = media_dir / filename
+ file_path.write_bytes(data)
+ logger.debug("Downloaded {} to {}", media_type, file_path)
+ return str(file_path)
+
+ except Exception as e:
+ logger.error("Error downloading media: {}", e)
+ return None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send a message through WeCom."""
+ if not self._client:
+ logger.warning("WeCom client not initialized")
+ return
+
+ try:
+ content = msg.content.strip()
+ if not content:
+ return
+
+ # Get the stored frame for this chat
+ frame = self._chat_frames.get(msg.chat_id)
+ if not frame:
+ logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
+ return
+
+ # Use streaming reply for better UX
+ stream_id = self._generate_req_id("stream")
+
+ # Send as streaming message with finish=True
+ await self._client.reply_stream(
+ frame,
+ stream_id,
+ content,
+ finish=True,
+ )
+
+ logger.debug("WeCom message sent to {}", msg.chat_id)
+
+ except Exception as e:
+ logger.error("Error sending WeCom message: {}", e)
diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py
index 1307716..7fffb80 100644
--- a/nanobot/channels/whatsapp.py
+++ b/nanobot/channels/whatsapp.py
@@ -22,6 +22,7 @@ class WhatsAppChannel(BaseChannel):
"""
name = "whatsapp"
+ display_name = "WhatsApp"
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
super().__init__(config, bus)
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py
index 332df74..91631ed 100644
--- a/nanobot/cli/commands.py
+++ b/nanobot/cli/commands.py
@@ -262,6 +262,7 @@ def onboard():
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
+ from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
@@ -271,46 +272,50 @@ def _make_provider(config: Config):
# OpenAI Codex (OAuth)
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
- return OpenAICodexProvider(default_model=model)
-
+ provider = OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
- from nanobot.providers.custom_provider import CustomProvider
- if provider_name == "custom":
- return CustomProvider(
+ elif provider_name == "custom":
+ from nanobot.providers.custom_provider import CustomProvider
+ provider = CustomProvider(
api_key=p.api_key if p else "no-key",
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
default_model=model,
)
-
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
- if provider_name == "azure_openai":
+ elif provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1)
-
- return AzureOpenAIProvider(
+ provider = AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
+ else:
+ from nanobot.providers.litellm_provider import LiteLLMProvider
+ from nanobot.providers.registry import find_by_name
+ spec = find_by_name(provider_name)
+ if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
+ console.print("[red]Error: No API key configured.[/red]")
+ console.print("Set one in ~/.nanobot/config.json under providers section")
+ raise typer.Exit(1)
+ provider = LiteLLMProvider(
+ api_key=p.api_key if p else None,
+ api_base=config.get_api_base(model),
+ default_model=model,
+ extra_headers=p.extra_headers if p else None,
+ provider_name=provider_name,
+ )
- from nanobot.providers.litellm_provider import LiteLLMProvider
- from nanobot.providers.registry import find_by_name
- spec = find_by_name(provider_name)
- if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
- console.print("[red]Error: No API key configured.[/red]")
- console.print("Set one in ~/.nanobot/config.json under providers section")
- raise typer.Exit(1)
-
- return LiteLLMProvider(
- api_key=p.api_key if p else None,
- api_base=config.get_api_base(model),
- default_model=model,
- extra_headers=p.extra_headers if p else None,
- provider_name=provider_name,
+ defaults = config.agents.defaults
+ provider.generation = GenerationSettings(
+ temperature=defaults.temperature,
+ max_tokens=defaults.max_tokens,
+ reasoning_effort=defaults.reasoning_effort,
)
+ return provider
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
@@ -388,10 +393,7 @@ def gateway(
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
- temperature=config.agents.defaults.temperature,
- max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
- reasoning_effort=config.agents.defaults.reasoning_effort,
context_window_tokens=config.agents.defaults.context_window_tokens,
brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
@@ -574,10 +576,7 @@ def agent(
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
- temperature=config.agents.defaults.temperature,
- max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
- reasoning_effort=config.agents.defaults.reasoning_effort,
context_window_tokens=config.agents.defaults.context_window_tokens,
brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
@@ -733,6 +732,7 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status")
def channels_status():
"""Show channel status."""
+ from nanobot.channels.registry import discover_channel_names, load_channel_class
from nanobot.config.loader import load_config
config = load_config()
@@ -740,85 +740,19 @@ def channels_status():
table = Table(title="Channel Status")
table.add_column("Channel", style="cyan")
table.add_column("Enabled", style="green")
- table.add_column("Configuration", style="yellow")
- # WhatsApp
- wa = config.channels.whatsapp
- table.add_row(
- "WhatsApp",
- "โ" if wa.enabled else "โ",
- wa.bridge_url
- )
-
- dc = config.channels.discord
- table.add_row(
- "Discord",
- "โ" if dc.enabled else "โ",
- dc.gateway_url
- )
-
- # Feishu
- fs = config.channels.feishu
- fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
- table.add_row(
- "Feishu",
- "โ" if fs.enabled else "โ",
- fs_config
- )
-
- # Mochat
- mc = config.channels.mochat
- mc_base = mc.base_url or "[dim]not configured[/dim]"
- table.add_row(
- "Mochat",
- "โ" if mc.enabled else "โ",
- mc_base
- )
-
- # Telegram
- tg = config.channels.telegram
- tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
- table.add_row(
- "Telegram",
- "โ" if tg.enabled else "โ",
- tg_config
- )
-
- # Slack
- slack = config.channels.slack
- slack_config = "socket" if slack.app_token and slack.bot_token else "[dim]not configured[/dim]"
- table.add_row(
- "Slack",
- "โ" if slack.enabled else "โ",
- slack_config
- )
-
- # DingTalk
- dt = config.channels.dingtalk
- dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
- table.add_row(
- "DingTalk",
- "โ" if dt.enabled else "โ",
- dt_config
- )
-
- # QQ
- qq = config.channels.qq
- qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
- table.add_row(
- "QQ",
- "โ" if qq.enabled else "โ",
- qq_config
- )
-
- # Email
- em = config.channels.email
- em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
- table.add_row(
- "Email",
- "โ" if em.enabled else "โ",
- em_config
- )
+ for modname in sorted(discover_channel_names()):
+ section = getattr(config.channels, modname, None)
+ enabled = section and getattr(section, "enabled", False)
+ try:
+ cls = load_channel_class(modname)
+ display = cls.display_name
+ except ImportError:
+ display = modname.title()
+ table.add_row(
+ display,
+ "[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
+ )
console.print(table)
diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py
index a2de239..4092eeb 100644
--- a/nanobot/config/schema.py
+++ b/nanobot/config/schema.py
@@ -48,6 +48,7 @@ class FeishuConfig(Base):
react_emoji: str = (
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
)
+ group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned, "open" responds to all
class DingTalkConfig(Base):
@@ -200,6 +201,14 @@ class QQConfig(Base):
) # Allowed user openids (empty = public access)
+class WecomConfig(Base):
+ """WeCom (Enterprise WeChat) AI Bot channel configuration."""
+
+ enabled: bool = False
+ bot_id: str = "" # Bot ID from WeCom AI Bot platform
+ secret: str = "" # Bot Secret from WeCom AI Bot platform
+ allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
+ welcome_message: str = "" # Welcome message for enter_chat event
class ChannelsConfig(Base):
@@ -217,6 +226,7 @@ class ChannelsConfig(Base):
slack: SlackConfig = Field(default_factory=SlackConfig)
qq: QQConfig = Field(default_factory=QQConfig)
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
+ wecom: WecomConfig = Field(default_factory=WecomConfig)
class AgentDefaults(Base):
@@ -266,14 +276,18 @@ class ProvidersConfig(Base):
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
groq: ProviderConfig = Field(default_factory=ProviderConfig)
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
- dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # ้ฟ้ไบ้ไนๅ้ฎ
+ dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
+ ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (็ก
ๅบๆตๅจ)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (็ซๅฑฑๅผๆ)
+ volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
+ byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
+ byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -375,16 +389,34 @@ class Config(BaseSettings):
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
if p and model_prefix and normalized_prefix == spec.name:
- if spec.is_oauth or p.api_key:
+ if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name
# Match by keyword (order follows PROVIDERS registry)
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
if p and any(_kw_matches(kw) for kw in spec.keywords):
- if spec.is_oauth or p.api_key:
+ if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name
+ # Fallback: configured local providers can route models without
+ # provider-specific keywords (for example plain "llama3.2" on Ollama).
+ # Prefer providers whose detect_by_base_keyword matches the configured api_base
+ # (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
+ local_fallback: tuple[ProviderConfig, str] | None = None
+ for spec in PROVIDERS:
+ if not spec.is_local:
+ continue
+ p = getattr(self.providers, spec.name, None)
+ if not (p and p.api_base):
+ continue
+ if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
+ return p, spec.name
+ if local_fallback is None:
+ local_fallback = (p, spec.name)
+ if local_fallback:
+ return local_fallback
+
# Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks โ they require explicit model selection
for spec in PROVIDERS:
@@ -411,7 +443,7 @@ class Config(BaseSettings):
return p.api_key if p else None
def get_api_base(self, model: str | None = None) -> str | None:
- """Get API base URL for the given model. Applies default URLs for known gateways."""
+ """Get API base URL for the given model. Applies default URLs for gateway/local providers."""
from nanobot.providers.registry import find_by_name
p, name = self._match_provider(model)
@@ -422,7 +454,7 @@ class Config(BaseSettings):
# to avoid polluting the global litellm.api_base.
if name:
spec = find_by_name(name)
- if spec and spec.is_gateway and spec.default_api_base:
+ if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
return spec.default_api_base
return None
diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py
index bd79b00..05fbac4 100644
--- a/nanobot/providers/azure_openai_provider.py
+++ b/nanobot/providers/azure_openai_provider.py
@@ -88,6 +88,7 @@ class AzureOpenAIProvider(LLMProvider):
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
@@ -106,7 +107,7 @@ class AzureOpenAIProvider(LLMProvider):
if tools:
payload["tools"] = tools
- payload["tool_choice"] = "auto"
+ payload["tool_choice"] = tool_choice or "auto"
return payload
@@ -118,6 +119,7 @@ class AzureOpenAIProvider(LLMProvider):
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request to Azure OpenAI.
@@ -137,7 +139,8 @@ class AzureOpenAIProvider(LLMProvider):
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
- deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
+ deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
+ tool_choice=tool_choice,
)
try:
diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py
index a3b6c47..114a948 100644
--- a/nanobot/providers/base.py
+++ b/nanobot/providers/base.py
@@ -1,6 +1,7 @@
"""Base LLM provider interface."""
import asyncio
+import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
@@ -14,6 +15,24 @@ class ToolCallRequest:
id: str
name: str
arguments: dict[str, Any]
+ provider_specific_fields: dict[str, Any] | None = None
+ function_provider_specific_fields: dict[str, Any] | None = None
+
+ def to_openai_tool_call(self) -> dict[str, Any]:
+ """Serialize to an OpenAI-style tool_call payload."""
+ tool_call = {
+ "id": self.id,
+ "type": "function",
+ "function": {
+ "name": self.name,
+ "arguments": json.dumps(self.arguments, ensure_ascii=False),
+ },
+ }
+ if self.provider_specific_fields:
+ tool_call["provider_specific_fields"] = self.provider_specific_fields
+ if self.function_provider_specific_fields:
+ tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
+ return tool_call
@dataclass
@@ -32,6 +51,21 @@ class LLMResponse:
return len(self.tool_calls) > 0
+@dataclass(frozen=True)
+class GenerationSettings:
+ """Default generation parameters for LLM calls.
+
+ Stored on the provider so every call site inherits the same defaults
+ without having to pass temperature / max_tokens / reasoning_effort
+ through every layer. Individual call sites can still override by
+ passing explicit keyword arguments to chat() / chat_with_retry().
+ """
+
+ temperature: float = 0.7
+ max_tokens: int = 4096
+ reasoning_effort: str | None = None
+
+
class LLMProvider(ABC):
"""
Abstract base class for LLM providers.
@@ -56,9 +90,12 @@ class LLMProvider(ABC):
"temporarily unavailable",
)
+ _SENTINEL = object()
+
def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key
self.api_base = api_base
+ self.generation: GenerationSettings = GenerationSettings()
@staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -129,6 +166,7 @@ class LLMProvider(ABC):
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request.
@@ -139,6 +177,7 @@ class LLMProvider(ABC):
model: Model identifier (provider-specific).
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
+ tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns:
LLMResponse with content and/or tool calls.
@@ -155,11 +194,24 @@ class LLMProvider(ABC):
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
- max_tokens: int = 4096,
- temperature: float = 0.7,
- reasoning_effort: str | None = None,
+ max_tokens: object = _SENTINEL,
+ temperature: object = _SENTINEL,
+ reasoning_effort: object = _SENTINEL,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
- """Call chat() with retry on transient provider failures."""
+ """Call chat() with retry on transient provider failures.
+
+ Parameters default to ``self.generation`` when not explicitly passed,
+ so callers no longer need to thread temperature / max_tokens /
+ reasoning_effort through every layer.
+ """
+ if max_tokens is self._SENTINEL:
+ max_tokens = self.generation.max_tokens
+ if temperature is self._SENTINEL:
+ temperature = self.generation.temperature
+ if reasoning_effort is self._SENTINEL:
+ reasoning_effort = self.generation.reasoning_effort
+
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
try:
response = await self.chat(
@@ -169,6 +221,7 @@ class LLMProvider(ABC):
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
+ tool_choice=tool_choice,
)
except asyncio.CancelledError:
raise
@@ -201,6 +254,7 @@ class LLMProvider(ABC):
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
+ tool_choice=tool_choice,
)
except asyncio.CancelledError:
raise
diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py
index 66df734..f16c69b 100644
--- a/nanobot/providers/custom_provider.py
+++ b/nanobot/providers/custom_provider.py
@@ -25,7 +25,8 @@ class CustomProvider(LLMProvider):
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
- reasoning_effort: str | None = None) -> LLMResponse:
+ reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"messages": self._sanitize_empty_content(messages),
@@ -35,7 +36,7 @@ class CustomProvider(LLMProvider):
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
if tools:
- kwargs.update(tools=tools, tool_choice="auto")
+ kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
try:
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py
index cb67635..b4508a4 100644
--- a/nanobot/providers/litellm_provider.py
+++ b/nanobot/providers/litellm_provider.py
@@ -214,6 +214,7 @@ class LiteLLMProvider(LLMProvider):
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request via LiteLLM.
@@ -267,7 +268,7 @@ class LiteLLMProvider(LLMProvider):
if tools:
kwargs["tools"] = tools
- kwargs["tool_choice"] = "auto"
+ kwargs["tool_choice"] = tool_choice or "auto"
try:
response = await acompletion(**kwargs)
@@ -309,10 +310,17 @@ class LiteLLMProvider(LLMProvider):
if isinstance(args, str):
args = json_repair.loads(args)
+ provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
+ function_provider_specific_fields = (
+ getattr(tc.function, "provider_specific_fields", None) or None
+ )
+
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
+ provider_specific_fields=provider_specific_fields,
+ function_provider_specific_fields=function_provider_specific_fields,
))
usage = {}
diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py
index d04e210..c8f2155 100644
--- a/nanobot/providers/openai_codex_provider.py
+++ b/nanobot/providers/openai_codex_provider.py
@@ -32,6 +32,7 @@ class OpenAICodexProvider(LLMProvider):
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
@@ -48,7 +49,7 @@ class OpenAICodexProvider(LLMProvider):
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"prompt_cache_key": _prompt_cache_key(messages),
- "tool_choice": "auto",
+ "tool_choice": tool_choice or "auto",
"parallel_tool_calls": True,
}
diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py
index 3ba1a0e..2c9c185 100644
--- a/nanobot/providers/registry.py
+++ b/nanobot/providers/registry.py
@@ -145,7 +145,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
- # VolcEngine (็ซๅฑฑๅผๆ): OpenAI-compatible gateway
+
+ # VolcEngine (็ซๅฑฑๅผๆ): OpenAI-compatible gateway, pay-per-use models
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
@@ -162,6 +163,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
+
+ # VolcEngine Coding Plan (็ซๅฑฑๅผๆ Coding Plan): same key as volcengine
+ ProviderSpec(
+ name="volcengine_coding_plan",
+ keywords=("volcengine-plan",),
+ env_key="OPENAI_API_KEY",
+ display_name="VolcEngine Coding Plan",
+ litellm_prefix="volcengine",
+ skip_prefixes=(),
+ env_extras=(),
+ is_gateway=True,
+ is_local=False,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="",
+ default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
+ strip_model_prefix=True,
+ model_overrides=(),
+ ),
+
+ # BytePlus: VolcEngine international, pay-per-use models
+ ProviderSpec(
+ name="byteplus",
+ keywords=("byteplus",),
+ env_key="OPENAI_API_KEY",
+ display_name="BytePlus",
+ litellm_prefix="volcengine",
+ skip_prefixes=(),
+ env_extras=(),
+ is_gateway=True,
+ is_local=False,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="bytepluses",
+ default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
+ strip_model_prefix=True,
+ model_overrides=(),
+ ),
+
+ # BytePlus Coding Plan: same key as byteplus
+ ProviderSpec(
+ name="byteplus_coding_plan",
+ keywords=("byteplus-plan",),
+ env_key="OPENAI_API_KEY",
+ display_name="BytePlus Coding Plan",
+ litellm_prefix="volcengine",
+ skip_prefixes=(),
+ env_extras=(),
+ is_gateway=True,
+ is_local=False,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="",
+ default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
+ strip_model_prefix=True,
+ model_overrides=(),
+ ),
+
+
# === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec(
@@ -360,6 +417,23 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
+ # === Ollama (local, OpenAI-compatible) ===================================
+ ProviderSpec(
+ name="ollama",
+ keywords=("ollama", "nemotron"),
+ env_key="OLLAMA_API_KEY",
+ display_name="Ollama",
+ litellm_prefix="ollama_chat", # model โ ollama_chat/model
+ skip_prefixes=("ollama/", "ollama_chat/"),
+ env_extras=(),
+ is_gateway=False,
+ is_local=True,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="11434",
+ default_api_base="http://localhost:11434",
+ strip_model_prefix=False,
+ model_overrides=(),
+ ),
# === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last โ it rarely wins fallback.
diff --git a/pyproject.toml b/pyproject.toml
index 4984b03..58831c9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -48,6 +48,9 @@ dependencies = [
]
[project.optional-dependencies]
+wecom = [
+ "wecom-aibot-sdk-python>=0.1.2",
+]
matrix = [
"matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0",
@@ -69,6 +72,9 @@ nanobot = "nanobot.cli.commands:app"
requires = ["hatchling"]
build-backend = "hatchling.build"
+[tool.hatch.metadata]
+allow-direct-references = true
+
[tool.hatch.build.targets.wheel]
packages = ["nanobot"]
diff --git a/tests/test_commands.py b/tests/test_commands.py
index 1375a3a..5848bd8 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -114,6 +114,64 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex"
+def test_config_matches_explicit_ollama_prefix_without_api_key():
+ config = Config()
+ config.agents.defaults.model = "ollama/llama3.2"
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
+ config = Config()
+ config.agents.defaults.provider = "ollama"
+ config.agents.defaults.model = "llama3.2"
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_auto_detects_ollama_from_local_api_base():
+ config = Config.model_validate(
+ {
+ "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
+ "providers": {"ollama": {"apiBase": "http://localhost:11434"}},
+ }
+ )
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
+ config = Config.model_validate(
+ {
+ "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
+ "providers": {
+ "vllm": {"apiBase": "http://localhost:8000"},
+ "ollama": {"apiBase": "http://localhost:11434"},
+ },
+ }
+ )
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_falls_back_to_vllm_when_ollama_not_configured():
+ config = Config.model_validate(
+ {
+ "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
+ "providers": {
+ "vllm": {"apiBase": "http://localhost:8000"},
+ },
+ }
+ )
+
+ assert config.get_provider_name() == "vllm"
+ assert config.get_api_base() == "http://localhost:8000"
+
+
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex")
diff --git a/tests/test_filesystem_tools.py b/tests/test_filesystem_tools.py
new file mode 100644
index 0000000..db8f256
--- /dev/null
+++ b/tests/test_filesystem_tools.py
@@ -0,0 +1,251 @@
+"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
+
+import pytest
+
+from nanobot.agent.tools.filesystem import (
+ EditFileTool,
+ ListDirTool,
+ ReadFileTool,
+ _find_match,
+)
+
+
+# ---------------------------------------------------------------------------
+# ReadFileTool
+# ---------------------------------------------------------------------------
+
+class TestReadFileTool:
+
+ @pytest.fixture()
+ def tool(self, tmp_path):
+ return ReadFileTool(workspace=tmp_path)
+
+ @pytest.fixture()
+ def sample_file(self, tmp_path):
+ f = tmp_path / "sample.txt"
+ f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
+ return f
+
+ @pytest.mark.asyncio
+ async def test_basic_read_has_line_numbers(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file))
+ assert "1| line 1" in result
+ assert "20| line 20" in result
+
+ @pytest.mark.asyncio
+ async def test_offset_and_limit(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file), offset=5, limit=3)
+ assert "5| line 5" in result
+ assert "7| line 7" in result
+ assert "8| line 8" not in result
+ assert "Use offset=8 to continue" in result
+
+ @pytest.mark.asyncio
+ async def test_offset_beyond_end(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file), offset=999)
+ assert "Error" in result
+ assert "beyond end" in result
+
+ @pytest.mark.asyncio
+ async def test_end_of_file_marker(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
+ assert "End of file" in result
+
+ @pytest.mark.asyncio
+ async def test_empty_file(self, tool, tmp_path):
+ f = tmp_path / "empty.txt"
+ f.write_text("", encoding="utf-8")
+ result = await tool.execute(path=str(f))
+ assert "Empty file" in result
+
+ @pytest.mark.asyncio
+ async def test_file_not_found(self, tool, tmp_path):
+ result = await tool.execute(path=str(tmp_path / "nope.txt"))
+ assert "Error" in result
+ assert "not found" in result
+
+ @pytest.mark.asyncio
+ async def test_char_budget_trims(self, tool, tmp_path):
+ """When the selected slice exceeds _MAX_CHARS the output is trimmed."""
+ f = tmp_path / "big.txt"
+ # Each line is ~110 chars, 2000 lines โ 220 KB > 128 KB limit
+ f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
+ result = await tool.execute(path=str(f))
+ assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
+ assert "Use offset=" in result
+
+
+# ---------------------------------------------------------------------------
+# _find_match (unit tests for the helper)
+# ---------------------------------------------------------------------------
+
+class TestFindMatch:
+
+ def test_exact_match(self):
+ match, count = _find_match("hello world", "world")
+ assert match == "world"
+ assert count == 1
+
+ def test_exact_no_match(self):
+ match, count = _find_match("hello world", "xyz")
+ assert match is None
+ assert count == 0
+
+ def test_crlf_normalisation(self):
+ # Caller normalises CRLF before calling _find_match, so test with
+ # pre-normalised content to verify exact match still works.
+ content = "line1\nline2\nline3"
+ old_text = "line1\nline2\nline3"
+ match, count = _find_match(content, old_text)
+ assert match is not None
+ assert count == 1
+
+ def test_line_trim_fallback(self):
+ content = " def foo():\n pass\n"
+ old_text = "def foo():\n pass"
+ match, count = _find_match(content, old_text)
+ assert match is not None
+ assert count == 1
+ # The returned match should be the *original* indented text
+ assert " def foo():" in match
+
+ def test_line_trim_multiple_candidates(self):
+ content = " a\n b\n a\n b\n"
+ old_text = "a\nb"
+ match, count = _find_match(content, old_text)
+ assert count == 2
+
+ def test_empty_old_text(self):
+ match, count = _find_match("hello", "")
+ # Empty string is always "in" any string via exact match
+ assert match == ""
+
+
+# ---------------------------------------------------------------------------
+# EditFileTool
+# ---------------------------------------------------------------------------
+
+class TestEditFileTool:
+
+ @pytest.fixture()
+ def tool(self, tmp_path):
+ return EditFileTool(workspace=tmp_path)
+
+ @pytest.mark.asyncio
+ async def test_exact_match(self, tool, tmp_path):
+ f = tmp_path / "a.py"
+ f.write_text("hello world", encoding="utf-8")
+ result = await tool.execute(path=str(f), old_text="world", new_text="earth")
+ assert "Successfully" in result
+ assert f.read_text() == "hello earth"
+
+ @pytest.mark.asyncio
+ async def test_crlf_normalisation(self, tool, tmp_path):
+ f = tmp_path / "crlf.py"
+ f.write_bytes(b"line1\r\nline2\r\nline3")
+ result = await tool.execute(
+ path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
+ )
+ assert "Successfully" in result
+ raw = f.read_bytes()
+ assert b"LINE1" in raw
+ # CRLF line endings should be preserved throughout the file
+ assert b"\r\n" in raw
+
+ @pytest.mark.asyncio
+ async def test_trim_fallback(self, tool, tmp_path):
+ f = tmp_path / "indent.py"
+ f.write_text(" def foo():\n pass\n", encoding="utf-8")
+ result = await tool.execute(
+ path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
+ )
+ assert "Successfully" in result
+ assert "bar" in f.read_text()
+
+ @pytest.mark.asyncio
+ async def test_ambiguous_match(self, tool, tmp_path):
+ f = tmp_path / "dup.py"
+ f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
+ result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
+ assert "appears" in result.lower() or "Warning" in result
+
+ @pytest.mark.asyncio
+ async def test_replace_all(self, tool, tmp_path):
+ f = tmp_path / "multi.py"
+ f.write_text("foo bar foo bar foo", encoding="utf-8")
+ result = await tool.execute(
+ path=str(f), old_text="foo", new_text="baz", replace_all=True,
+ )
+ assert "Successfully" in result
+ assert f.read_text() == "baz bar baz bar baz"
+
+ @pytest.mark.asyncio
+ async def test_not_found(self, tool, tmp_path):
+ f = tmp_path / "nf.py"
+ f.write_text("hello", encoding="utf-8")
+ result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
+ assert "Error" in result
+ assert "not found" in result
+
+
+# ---------------------------------------------------------------------------
+# ListDirTool
+# ---------------------------------------------------------------------------
+
+class TestListDirTool:
+
+ @pytest.fixture()
+ def tool(self, tmp_path):
+ return ListDirTool(workspace=tmp_path)
+
+ @pytest.fixture()
+ def populated_dir(self, tmp_path):
+ (tmp_path / "src").mkdir()
+ (tmp_path / "src" / "main.py").write_text("pass")
+ (tmp_path / "src" / "utils.py").write_text("pass")
+ (tmp_path / "README.md").write_text("hi")
+ (tmp_path / ".git").mkdir()
+ (tmp_path / ".git" / "config").write_text("x")
+ (tmp_path / "node_modules").mkdir()
+ (tmp_path / "node_modules" / "pkg").mkdir()
+ return tmp_path
+
+ @pytest.mark.asyncio
+ async def test_basic_list(self, tool, populated_dir):
+ result = await tool.execute(path=str(populated_dir))
+ assert "README.md" in result
+ assert "src" in result
+ # .git and node_modules should be ignored
+ assert ".git" not in result
+ assert "node_modules" not in result
+
+ @pytest.mark.asyncio
+ async def test_recursive(self, tool, populated_dir):
+ result = await tool.execute(path=str(populated_dir), recursive=True)
+ assert "src/main.py" in result
+ assert "src/utils.py" in result
+ assert "README.md" in result
+ # Ignored dirs should not appear
+ assert ".git" not in result
+ assert "node_modules" not in result
+
+ @pytest.mark.asyncio
+ async def test_max_entries_truncation(self, tool, tmp_path):
+ for i in range(10):
+ (tmp_path / f"file_{i}.txt").write_text("x")
+ result = await tool.execute(path=str(tmp_path), max_entries=3)
+ assert "truncated" in result
+ assert "3 of 10" in result
+
+ @pytest.mark.asyncio
+ async def test_empty_dir(self, tool, tmp_path):
+ d = tmp_path / "empty"
+ d.mkdir()
+ result = await tool.execute(path=str(d))
+ assert "empty" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_not_found(self, tool, tmp_path):
+ result = await tool.execute(path=str(tmp_path / "nope"))
+ assert "Error" in result
+ assert "not found" in result
diff --git a/tests/test_gemini_thought_signature.py b/tests/test_gemini_thought_signature.py
new file mode 100644
index 0000000..bc4132c
--- /dev/null
+++ b/tests/test_gemini_thought_signature.py
@@ -0,0 +1,53 @@
+from types import SimpleNamespace
+
+from nanobot.providers.base import ToolCallRequest
+from nanobot.providers.litellm_provider import LiteLLMProvider
+
+
+def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
+ provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
+
+ response = SimpleNamespace(
+ choices=[
+ SimpleNamespace(
+ finish_reason="tool_calls",
+ message=SimpleNamespace(
+ content=None,
+ tool_calls=[
+ SimpleNamespace(
+ id="call_123",
+ function=SimpleNamespace(
+ name="read_file",
+ arguments='{"path":"todo.md"}',
+ provider_specific_fields={"inner": "value"},
+ ),
+ provider_specific_fields={"thought_signature": "signed-token"},
+ )
+ ],
+ ),
+ )
+ ],
+ usage=None,
+ )
+
+ parsed = provider._parse_response(response)
+
+ assert len(parsed.tool_calls) == 1
+ assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
+ assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
+
+
+def test_tool_call_request_serializes_provider_fields() -> None:
+ tool_call = ToolCallRequest(
+ id="abc123xyz",
+ name="read_file",
+ arguments={"path": "todo.md"},
+ provider_specific_fields={"thought_signature": "signed-token"},
+ function_provider_specific_fields={"inner": "value"},
+ )
+
+ message = tool_call.to_openai_tool_call()
+
+ assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
+ assert message["function"]["provider_specific_fields"] == {"inner": "value"}
+ assert message["function"]["arguments"] == '{"path": "todo.md"}'
diff --git a/tests/test_loop_save_turn.py b/tests/test_loop_save_turn.py
index aec6d1a..25ba88b 100644
--- a/tests/test_loop_save_turn.py
+++ b/tests/test_loop_save_turn.py
@@ -5,7 +5,7 @@ from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
- loop._TOOL_RESULT_MAX_CHARS = 500
+ loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
return loop
@@ -39,3 +39,17 @@ def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
+
+
+def test_save_turn_keeps_tool_results_under_16k() -> None:
+ loop = _mk_loop()
+ session = Session(key="test:tool-result")
+ content = "x" * 12_000
+
+ loop._save_turn(
+ session,
+ [{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
+ skip=0,
+ )
+
+ assert session.messages[0]["content"] == content
diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py
index 0263f01..69be858 100644
--- a/tests/test_memory_consolidation_types.py
+++ b/tests/test_memory_consolidation_types.py
@@ -265,3 +265,26 @@ class TestMemoryConsolidationTypeHandling:
assert result is True
assert provider.calls == 2
assert delays == [1]
+
+ @pytest.mark.asyncio
+ async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
+ """Consolidation no longer passes generation params โ the provider owns them."""
+ store = MemoryStore(tmp_path)
+ provider = AsyncMock()
+ provider.chat_with_retry = AsyncMock(
+ return_value=_make_tool_response(
+ history_entry="[2026-01-01] User discussed testing.",
+ memory_update="# Memory\nUser likes testing.",
+ )
+ )
+ messages = _make_messages(message_count=60)
+
+ result = await store.consolidate(messages, provider, "test-model")
+
+ assert result is True
+ provider.chat_with_retry.assert_awaited_once()
+ _, kwargs = provider.chat_with_retry.await_args
+ assert kwargs["model"] == "test-model"
+ assert "temperature" not in kwargs
+ assert "max_tokens" not in kwargs
+ assert "reasoning_effort" not in kwargs
diff --git a/tests/test_provider_retry.py b/tests/test_provider_retry.py
index 751ecc3..2420399 100644
--- a/tests/test_provider_retry.py
+++ b/tests/test_provider_retry.py
@@ -2,7 +2,7 @@ import asyncio
import pytest
-from nanobot.providers.base import LLMProvider, LLMResponse
+from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider):
@@ -10,9 +10,11 @@ class ScriptedProvider(LLMProvider):
super().__init__()
self._responses = list(responses)
self.calls = 0
+ self.last_kwargs: dict = {}
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
+ self.last_kwargs = kwargs
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
@@ -90,3 +92,34 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
+
+
+@pytest.mark.asyncio
+async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
+ """When callers omit generation params, provider.generation defaults are used."""
+ provider = ScriptedProvider([LLMResponse(content="ok")])
+ provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
+
+ await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
+
+ assert provider.last_kwargs["temperature"] == 0.2
+ assert provider.last_kwargs["max_tokens"] == 321
+ assert provider.last_kwargs["reasoning_effort"] == "high"
+
+
+@pytest.mark.asyncio
+async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
+ """Explicit kwargs should override provider.generation defaults."""
+ provider = ScriptedProvider([LLMResponse(content="ok")])
+ provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
+
+ await provider.chat_with_retry(
+ messages=[{"role": "user", "content": "hello"}],
+ temperature=0.9,
+ max_tokens=9999,
+ reasoning_effort="low",
+ )
+
+ assert provider.last_kwargs["temperature"] == 0.9
+ assert provider.last_kwargs["max_tokens"] == 9999
+ assert provider.last_kwargs["reasoning_effort"] == "low"
diff --git a/tests/test_restart_command.py b/tests/test_restart_command.py
new file mode 100644
index 0000000..c495347
--- /dev/null
+++ b/tests/test_restart_command.py
@@ -0,0 +1,76 @@
+"""Tests for /restart slash command."""
+
+from __future__ import annotations
+
+import asyncio
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from nanobot.bus.events import InboundMessage
+
+
+def _make_loop():
+ """Create a minimal AgentLoop with mocked dependencies."""
+ from nanobot.agent.loop import AgentLoop
+ from nanobot.bus.queue import MessageBus
+
+ bus = MessageBus()
+ provider = MagicMock()
+ provider.get_default_model.return_value = "test-model"
+ workspace = MagicMock()
+ workspace.__truediv__ = MagicMock(return_value=MagicMock())
+
+ with patch("nanobot.agent.loop.ContextBuilder"), \
+ patch("nanobot.agent.loop.SessionManager"), \
+ patch("nanobot.agent.loop.SubagentManager"):
+ loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
+ return loop, bus
+
+
+class TestRestartCommand:
+
+ @pytest.mark.asyncio
+ async def test_restart_sends_message_and_calls_execv(self):
+ loop, bus = _make_loop()
+ msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
+
+ with patch("nanobot.agent.loop.os.execv") as mock_execv:
+ await loop._handle_restart(msg)
+ out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
+ assert "Restarting" in out.content
+
+ await asyncio.sleep(1.5)
+ mock_execv.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_restart_intercepted_in_run_loop(self):
+ """Verify /restart is handled at the run-loop level, not inside _dispatch."""
+ loop, bus = _make_loop()
+ msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
+
+ with patch.object(loop, "_handle_restart") as mock_handle:
+ mock_handle.return_value = None
+ await bus.publish_inbound(msg)
+
+ loop._running = True
+ run_task = asyncio.create_task(loop.run())
+ await asyncio.sleep(0.1)
+ loop._running = False
+ run_task.cancel()
+ try:
+ await run_task
+ except asyncio.CancelledError:
+ pass
+
+ mock_handle.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_help_includes_restart(self):
+ loop, bus = _make_loop()
+ msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
+
+ response = await loop._process_message(msg)
+
+ assert response is not None
+ assert "/restart" in response.content
diff --git a/tests/test_telegram_channel.py b/tests/test_telegram_channel.py
index 678512d..897f77d 100644
--- a/tests/test_telegram_channel.py
+++ b/tests/test_telegram_channel.py
@@ -1,10 +1,13 @@
+import asyncio
+from pathlib import Path
from types import SimpleNamespace
+from unittest.mock import AsyncMock
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
-from nanobot.channels.telegram import TelegramChannel
+from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel
from nanobot.config.schema import TelegramConfig
@@ -42,6 +45,12 @@ class _FakeBot:
async def send_chat_action(self, **kwargs) -> None:
pass
+ async def get_file(self, file_id: str):
+ """Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
+ async def _fake_download(path) -> None:
+ pass
+ return SimpleNamespace(download_to_drive=_fake_download)
+
class _FakeApp:
def __init__(self, on_start_polling) -> None:
@@ -336,3 +345,255 @@ async def test_group_policy_open_accepts_plain_group_message() -> None:
assert len(handled) == 1
assert channel._app.bot.get_me_calls == 0
+
+
+def test_extract_reply_context_no_reply() -> None:
+ """When there is no reply_to_message, _extract_reply_context returns None."""
+ message = SimpleNamespace(reply_to_message=None)
+ assert TelegramChannel._extract_reply_context(message) is None
+
+
+def test_extract_reply_context_with_text() -> None:
+ """When reply has text, return prefixed string."""
+ reply = SimpleNamespace(text="Hello world", caption=None)
+ message = SimpleNamespace(reply_to_message=reply)
+ assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
+
+
+def test_extract_reply_context_with_caption_only() -> None:
+ """When reply has only caption (no text), caption is used."""
+ reply = SimpleNamespace(text=None, caption="Photo caption")
+ message = SimpleNamespace(reply_to_message=reply)
+ assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
+
+
+def test_extract_reply_context_truncation() -> None:
+ """Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
+ long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
+ reply = SimpleNamespace(text=long_text, caption=None)
+ message = SimpleNamespace(reply_to_message=reply)
+ result = TelegramChannel._extract_reply_context(message)
+ assert result is not None
+ assert result.startswith("[Reply to: ")
+ assert result.endswith("...]")
+ assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
+
+
+def test_extract_reply_context_no_text_returns_none() -> None:
+ """When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
+ reply = SimpleNamespace(text=None, caption=None)
+ message = SimpleNamespace(reply_to_message=reply)
+ assert TelegramChannel._extract_reply_context(message) is None
+
+
+@pytest.mark.asyncio
+async def test_on_message_includes_reply_context() -> None:
+ """When user replies to a message, content passed to bus starts with reply context."""
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
+ MessageBus(),
+ )
+ channel._app = _FakeApp(lambda: None)
+ handled = []
+ async def capture_handle(**kwargs) -> None:
+ handled.append(kwargs)
+ channel._handle_message = capture_handle
+ channel._start_typing = lambda _chat_id: None
+
+ reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
+ update = _make_telegram_update(text="translate this", reply_to_message=reply)
+ await channel._on_message(update, None)
+
+ assert len(handled) == 1
+ assert handled[0]["content"].startswith("[Reply to: Hello]")
+ assert "translate this" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_download_message_media_returns_path_when_download_succeeds(
+ monkeypatch, tmp_path
+) -> None:
+ """_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
+ media_dir = tmp_path / "media" / "telegram"
+ media_dir.mkdir(parents=True)
+ monkeypatch.setattr(
+ "nanobot.channels.telegram.get_media_dir",
+ lambda channel=None: media_dir if channel else tmp_path / "media",
+ )
+
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
+ MessageBus(),
+ )
+ channel._app = _FakeApp(lambda: None)
+ channel._app.bot.get_file = AsyncMock(
+ return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
+ )
+
+ msg = SimpleNamespace(
+ photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
+ voice=None,
+ audio=None,
+ document=None,
+ video=None,
+ video_note=None,
+ animation=None,
+ )
+ paths, parts = await channel._download_message_media(msg)
+ assert len(paths) == 1
+ assert len(parts) == 1
+ assert "fid123" in paths[0]
+ assert "[image:" in parts[0]
+
+
+@pytest.mark.asyncio
+async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
+ """When user replies to a message with media, that media is downloaded and attached to the turn."""
+ media_dir = tmp_path / "media" / "telegram"
+ media_dir.mkdir(parents=True)
+ monkeypatch.setattr(
+ "nanobot.channels.telegram.get_media_dir",
+ lambda channel=None: media_dir if channel else tmp_path / "media",
+ )
+
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
+ MessageBus(),
+ )
+ app = _FakeApp(lambda: None)
+ app.bot.get_file = AsyncMock(
+ return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
+ )
+ channel._app = app
+ handled = []
+ async def capture_handle(**kwargs) -> None:
+ handled.append(kwargs)
+ channel._handle_message = capture_handle
+ channel._start_typing = lambda _chat_id: None
+
+ reply_with_photo = SimpleNamespace(
+ text=None,
+ caption=None,
+ photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
+ document=None,
+ voice=None,
+ audio=None,
+ video=None,
+ video_note=None,
+ animation=None,
+ )
+ update = _make_telegram_update(
+ text="what is the image?",
+ reply_to_message=reply_with_photo,
+ )
+ await channel._on_message(update, None)
+
+ assert len(handled) == 1
+ assert handled[0]["content"].startswith("[Reply to: [image:")
+ assert "what is the image?" in handled[0]["content"]
+ assert len(handled[0]["media"]) == 1
+ assert "reply_photo_fid" in handled[0]["media"][0]
+
+
+@pytest.mark.asyncio
+async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
+ """When reply has media but download fails, no media attached and no reply tag."""
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
+ MessageBus(),
+ )
+ channel._app = _FakeApp(lambda: None)
+ channel._app.bot.get_file = None
+ handled = []
+ async def capture_handle(**kwargs) -> None:
+ handled.append(kwargs)
+ channel._handle_message = capture_handle
+ channel._start_typing = lambda _chat_id: None
+
+ reply_with_photo = SimpleNamespace(
+ text=None,
+ caption=None,
+ photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
+ document=None,
+ voice=None,
+ audio=None,
+ video=None,
+ video_note=None,
+ animation=None,
+ )
+ update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
+ await channel._on_message(update, None)
+
+ assert len(handled) == 1
+ assert "what is this?" in handled[0]["content"]
+ assert handled[0]["media"] == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
+ """When replying to a message with caption + photo, both text context and media are included."""
+ media_dir = tmp_path / "media" / "telegram"
+ media_dir.mkdir(parents=True)
+ monkeypatch.setattr(
+ "nanobot.channels.telegram.get_media_dir",
+ lambda channel=None: media_dir if channel else tmp_path / "media",
+ )
+
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
+ MessageBus(),
+ )
+ app = _FakeApp(lambda: None)
+ app.bot.get_file = AsyncMock(
+ return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
+ )
+ channel._app = app
+ handled = []
+ async def capture_handle(**kwargs) -> None:
+ handled.append(kwargs)
+ channel._handle_message = capture_handle
+ channel._start_typing = lambda _chat_id: None
+
+ reply_with_caption_and_photo = SimpleNamespace(
+ text=None,
+ caption="A cute cat",
+ photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
+ document=None,
+ voice=None,
+ audio=None,
+ video=None,
+ video_note=None,
+ animation=None,
+ )
+ update = _make_telegram_update(
+ text="what breed is this?",
+ reply_to_message=reply_with_caption_and_photo,
+ )
+ await channel._on_message(update, None)
+
+ assert len(handled) == 1
+ assert "[Reply to: A cute cat]" in handled[0]["content"]
+ assert "what breed is this?" in handled[0]["content"]
+ assert len(handled[0]["media"]) == 1
+ assert "cat_fid" in handled[0]["media"][0]
+
+
+@pytest.mark.asyncio
+async def test_forward_command_does_not_inject_reply_context() -> None:
+ """Slash commands forwarded via _forward_command must not include reply context."""
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
+ MessageBus(),
+ )
+ channel._app = _FakeApp(lambda: None)
+ handled = []
+ async def capture_handle(**kwargs) -> None:
+ handled.append(kwargs)
+ channel._handle_message = capture_handle
+
+ reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
+ update = _make_telegram_update(text="/new", reply_to_message=reply)
+ await channel._forward_command(update, None)
+
+ assert len(handled) == 1
+ assert handled[0]["content"] == "/new"
diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py
index c2b4b6a..095c041 100644
--- a/tests/test_tool_validation.py
+++ b/tests/test_tool_validation.py
@@ -108,6 +108,32 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
assert "/tmp/out.txt" in paths
+def test_exec_extract_absolute_paths_captures_home_paths() -> None:
+ cmd = "cat ~/.nanobot/config.json > ~/out.txt"
+ paths = ExecTool._extract_absolute_paths(cmd)
+ assert "~/.nanobot/config.json" in paths
+ assert "~/out.txt" in paths
+
+
+def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
+ cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
+ paths = ExecTool._extract_absolute_paths(cmd)
+ assert "/tmp/data.txt" in paths
+ assert "~/.nanobot/config.json" in paths
+
+
+def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
+ tool = ExecTool(restrict_to_workspace=True)
+ error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
+ assert error == "Error: Command blocked by safety guard (path outside working dir)"
+
+
+def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
+ tool = ExecTool(restrict_to_workspace=True)
+ error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
+ assert error == "Error: Command blocked by safety guard (path outside working dir)"
+
+
# --- cast_params tests ---
@@ -337,3 +363,44 @@ def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]
+
+
+# --- ExecTool enhancement tests ---
+
+
+async def test_exec_always_returns_exit_code() -> None:
+ """Exit code should appear in output even on success (exit 0)."""
+ tool = ExecTool()
+ result = await tool.execute(command="echo hello")
+ assert "Exit code: 0" in result
+ assert "hello" in result
+
+
+async def test_exec_head_tail_truncation() -> None:
+ """Long output should preserve both head and tail."""
+ tool = ExecTool()
+ # Generate output that exceeds _MAX_OUTPUT
+ big = "A" * 6000 + "\n" + "B" * 6000
+ result = await tool.execute(command=f"echo '{big}'")
+ assert "chars truncated" in result
+ # Head portion should start with As
+ assert result.startswith("A")
+ # Tail portion should end with the exit code which comes after Bs
+ assert "Exit code:" in result
+
+
+async def test_exec_timeout_parameter() -> None:
+ """LLM-supplied timeout should override the constructor default."""
+ tool = ExecTool(timeout=60)
+ # A very short timeout should cause the command to be killed
+ result = await tool.execute(command="sleep 10", timeout=1)
+ assert "timed out" in result
+ assert "1 seconds" in result
+
+
+async def test_exec_timeout_capped_at_max() -> None:
+ """Timeout values above _MAX_TIMEOUT should be clamped."""
+ tool = ExecTool()
+ # Should not raise โ just clamp to 600
+ result = await tool.execute(command="echo ok", timeout=9999)
+ assert "Exit code: 0" in result