Merge remote-tracking branch 'origin/main' into pr-1919

This commit is contained in:
Xubin Ren
2026-03-13 03:02:17 +00:00
21 changed files with 725 additions and 89 deletions

3
.gitignore vendored
View File

@@ -1,5 +1,6 @@
.worktrees/ .worktrees/
.assets .assets
.docs
.env .env
*.pyc *.pyc
dist/ dist/
@@ -7,7 +8,7 @@ build/
docs/ docs/
*.egg-info/ *.egg-info/
*.egg *.egg
*.pyc *.pycs
*.pyo *.pyo
*.pyd *.pyd
*.pyw *.pyw

View File

@@ -64,7 +64,7 @@
## Key Features of nanobot: ## Key Features of nanobot:
🪶 **Ultra-Lightweight**: Just ~4,000 lines of core agent code — 99% smaller than Clawdbot. 🪶 **Ultra-Lightweight**: A super lightweight implementation of OpenClaw — 99% smaller, significantly faster.
🔬 **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research. 🔬 **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research.
@@ -502,7 +502,8 @@ Uses **WebSocket** long connection — no public IP required.
"appSecret": "xxx", "appSecret": "xxx",
"encryptKey": "", "encryptKey": "",
"verificationToken": "", "verificationToken": "",
"allowFrom": ["ou_YOUR_OPEN_ID"] "allowFrom": ["ou_YOUR_OPEN_ID"],
"groupPolicy": "mention"
} }
} }
} }
@@ -510,6 +511,7 @@ Uses **WebSocket** long connection — no public IP required.
> `encryptKey` and `verificationToken` are optional for Long Connection mode. > `encryptKey` and `verificationToken` are optional for Long Connection mode.
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users. > `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
**3. Run** **3. Run**
@@ -756,15 +758,17 @@ Config file: `~/.nanobot/config.json`
> [!TIP] > [!TIP]
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. > - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config. > - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config. > - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
| Provider | Purpose | Get API Key | | Provider | Purpose | Get API Key |
|----------|---------|-------------| |----------|---------|-------------|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — | | `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) | | `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | | `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
@@ -774,7 +778,6 @@ Config file: `~/.nanobot/config.json`
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) | | `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) | | `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
| `volcengine` | LLM (VolcEngine/火山引擎) | [volcengine.com](https://www.volcengine.com) |
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |

View File

@@ -4,7 +4,9 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import os
import re import re
import sys
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -43,7 +45,7 @@ class AgentLoop:
5. Sends responses back 5. Sends responses back
""" """
_TOOL_RESULT_MAX_CHARS = 500 _TOOL_RESULT_MAX_CHARS = 16_000
def __init__( def __init__(
self, self,
@@ -137,7 +139,7 @@ class AgentLoop:
await self._mcp_stack.__aenter__() await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True self._mcp_connected = True
except Exception as e: except BaseException as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e) logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack: if self._mcp_stack:
try: try:
@@ -256,8 +258,11 @@ class AgentLoop:
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
if msg.content.strip().lower() == "/stop": cmd = msg.content.strip().lower()
if cmd == "/stop":
await self._handle_stop(msg) await self._handle_stop(msg)
elif cmd == "/restart":
await self._handle_restart(msg)
else: else:
task = asyncio.create_task(self._dispatch(msg)) task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task) self._active_tasks.setdefault(msg.session_key, []).append(task)
@@ -274,11 +279,23 @@ class AgentLoop:
pass pass
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key) sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
total = cancelled + sub_cancelled total = cancelled + sub_cancelled
content = f"Stopped {total} task(s)." if total else "No active task to stop." content = f"Stopped {total} task(s)." if total else "No active task to stop."
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content, channel=msg.channel, chat_id=msg.chat_id, content=content,
)) ))
async def _handle_restart(self, msg: InboundMessage) -> None:
"""Restart the process in-place via os.execv."""
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
))
async def _do_restart():
await asyncio.sleep(1)
os.execv(sys.executable, [sys.executable] + sys.argv)
asyncio.create_task(_do_restart())
async def _dispatch(self, msg: InboundMessage) -> None: async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message under the global lock.""" """Process a message under the global lock."""
async with self._processing_lock: async with self._processing_lock:
@@ -373,9 +390,16 @@ class AgentLoop:
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started.") content="New session started.")
if cmd == "/help": if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, lines = [
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands") "🐈 nanobot commands:",
"/new — Start a new conversation",
"/stop — Stop the current task",
"/restart — Restart the bot",
"/help — Show available commands",
]
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
)
await self.memory_consolidator.maybe_consolidate_by_tokens(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))

View File

@@ -112,14 +112,17 @@ class MemoryStore:
## Conversation to Process ## Conversation to Process
{self._format_messages(messages)}""" {self._format_messages(messages)}"""
chat_messages = [
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},
]
try: try:
response = await provider.chat_with_retry( response = await provider.chat_with_retry(
messages=[ messages=chat_messages,
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},
],
tools=_SAVE_MEMORY_TOOL, tools=_SAVE_MEMORY_TOOL,
model=model, model=model,
tool_choice={"type": "function", "function": {"name": "save_memory"}},
) )
if not response.has_tool_calls: if not response.has_tool_calls:

View File

@@ -352,6 +352,27 @@ class FeishuChannel(BaseChannel):
self._running = False self._running = False
logger.info("Feishu bot stopped") logger.info("Feishu bot stopped")
def _is_bot_mentioned(self, message: Any) -> bool:
"""Check if the bot is @mentioned in the message."""
raw_content = message.content or ""
if "@_all" in raw_content:
return True
for mention in getattr(message, "mentions", None) or []:
mid = getattr(mention, "id", None)
if not mid:
continue
# Bot mentions have no user_id (None or "") but a valid open_id
if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"):
return True
return False
def _is_group_message_for_bot(self, message: Any) -> bool:
"""Allow group messages when policy is open or bot is @mentioned."""
if self.config.group_policy == "open":
return True
return self._is_bot_mentioned(message)
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None: def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
"""Sync helper for adding reaction (runs in thread pool).""" """Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
@@ -893,6 +914,10 @@ class FeishuChannel(BaseChannel):
chat_type = message.chat_type chat_type = message.chat_type
msg_type = message.message_type msg_type = message.message_type
if chat_type == "group" and not self._is_group_message_for_bot(message):
logger.debug("Feishu: skipping group message (not mentioned)")
return
# Add reaction # Add reaction
await self._add_reaction(message_id, self.config.react_emoji) await self._add_reaction(message_id, self.config.react_emoji)

View File

@@ -114,16 +114,16 @@ class QQChannel(BaseChannel):
if msg_type == "group": if msg_type == "group":
await self._client.api.post_group_message( await self._client.api.post_group_message(
group_openid=msg.chat_id, group_openid=msg.chat_id,
msg_type=2, msg_type=0,
markdown={"content": msg.content}, content=msg.content,
msg_id=msg_id, msg_id=msg_id,
msg_seq=self._msg_seq, msg_seq=self._msg_seq,
) )
else: else:
await self._client.api.post_c2c_message( await self._client.api.post_c2c_message(
openid=msg.chat_id, openid=msg.chat_id,
msg_type=2, msg_type=0,
markdown={"content": msg.content}, content=msg.content,
msg_id=msg_id, msg_id=msg_id,
msg_seq=self._msg_seq, msg_seq=self._msg_seq,
) )

View File

@@ -20,6 +20,7 @@ from nanobot.config.schema import TelegramConfig
from nanobot.utils.helpers import split_message from nanobot.utils.helpers import split_message
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
TELEGRAM_REPLY_CONTEXT_MAX_LEN = TELEGRAM_MAX_MESSAGE_LEN # Max length for reply context in user message
def _strip_md(s: str) -> str: def _strip_md(s: str) -> str:
@@ -163,6 +164,7 @@ class TelegramChannel(BaseChannel):
BotCommand("new", "Start a new conversation"), BotCommand("new", "Start a new conversation"),
BotCommand("stop", "Stop the current task"), BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"), BotCommand("help", "Show available commands"),
BotCommand("restart", "Restart the bot"),
] ]
def __init__(self, config: TelegramConfig, bus: MessageBus): def __init__(self, config: TelegramConfig, bus: MessageBus):
@@ -220,6 +222,7 @@ class TelegramChannel(BaseChannel):
self._app.add_handler(CommandHandler("start", self._on_start)) self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command)) self._app.add_handler(CommandHandler("new", self._forward_command))
self._app.add_handler(CommandHandler("stop", self._forward_command)) self._app.add_handler(CommandHandler("stop", self._forward_command))
self._app.add_handler(CommandHandler("restart", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help)) self._app.add_handler(CommandHandler("help", self._on_help))
# Add message handler for text, photos, voice, documents # Add message handler for text, photos, voice, documents
@@ -451,6 +454,7 @@ class TelegramChannel(BaseChannel):
@staticmethod @staticmethod
def _build_message_metadata(message, user) -> dict: def _build_message_metadata(message, user) -> dict:
"""Build common Telegram inbound metadata payload.""" """Build common Telegram inbound metadata payload."""
reply_to = getattr(message, "reply_to_message", None)
return { return {
"message_id": message.message_id, "message_id": message.message_id,
"user_id": user.id, "user_id": user.id,
@@ -459,8 +463,73 @@ class TelegramChannel(BaseChannel):
"is_group": message.chat.type != "private", "is_group": message.chat.type != "private",
"message_thread_id": getattr(message, "message_thread_id", None), "message_thread_id": getattr(message, "message_thread_id", None),
"is_forum": bool(getattr(message.chat, "is_forum", False)), "is_forum": bool(getattr(message.chat, "is_forum", False)),
"reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
} }
@staticmethod
def _extract_reply_context(message) -> str | None:
"""Extract text from the message being replied to, if any."""
reply = getattr(message, "reply_to_message", None)
if not reply:
return None
text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
return f"[Reply to: {text}]" if text else None
async def _download_message_media(
self, msg, *, add_failure_content: bool = False
) -> tuple[list[str], list[str]]:
"""Download media from a message (current or reply). Returns (media_paths, content_parts)."""
media_file = None
media_type = None
if getattr(msg, "photo", None):
media_file = msg.photo[-1]
media_type = "image"
elif getattr(msg, "voice", None):
media_file = msg.voice
media_type = "voice"
elif getattr(msg, "audio", None):
media_file = msg.audio
media_type = "audio"
elif getattr(msg, "document", None):
media_file = msg.document
media_type = "file"
elif getattr(msg, "video", None):
media_file = msg.video
media_type = "video"
elif getattr(msg, "video_note", None):
media_file = msg.video_note
media_type = "video"
elif getattr(msg, "animation", None):
media_file = msg.animation
media_type = "animation"
if not media_file or not self._app:
return [], []
try:
file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(
media_type,
getattr(media_file, "mime_type", None),
getattr(media_file, "file_name", None),
)
media_dir = get_media_dir("telegram")
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
await file.download_to_drive(str(file_path))
path_str = str(file_path)
if media_type in ("voice", "audio"):
transcription = await self.transcribe_audio(file_path)
if transcription:
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
return [path_str], [f"[transcription: {transcription}]"]
return [path_str], [f"[{media_type}: {path_str}]"]
return [path_str], [f"[{media_type}: {path_str}]"]
except Exception as e:
logger.warning("Failed to download message media: {}", e)
if add_failure_content:
return [], [f"[{media_type}: download failed]"]
return [], []
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]: async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
"""Load bot identity once and reuse it for mention/reply checks.""" """Load bot identity once and reuse it for mention/reply checks."""
if self._bot_user_id is not None or self._bot_username is not None: if self._bot_user_id is not None or self._bot_username is not None:
@@ -545,7 +614,7 @@ class TelegramChannel(BaseChannel):
await self._handle_message( await self._handle_message(
sender_id=self._sender_id(user), sender_id=self._sender_id(user),
chat_id=str(message.chat_id), chat_id=str(message.chat_id),
content=message.text, content=message.text or "",
metadata=self._build_message_metadata(message, user), metadata=self._build_message_metadata(message, user),
session_key=self._derive_topic_session_key(message), session_key=self._derive_topic_session_key(message),
) )
@@ -577,54 +646,26 @@ class TelegramChannel(BaseChannel):
if message.caption: if message.caption:
content_parts.append(message.caption) content_parts.append(message.caption)
# Handle media files # Download current message media
media_file = None current_media_paths, current_media_parts = await self._download_message_media(
media_type = None message, add_failure_content=True
)
if message.photo: media_paths.extend(current_media_paths)
media_file = message.photo[-1] # Largest photo content_parts.extend(current_media_parts)
media_type = "image" if current_media_paths:
elif message.voice: logger.debug("Downloaded message media to {}", current_media_paths[0])
media_file = message.voice
media_type = "voice"
elif message.audio:
media_file = message.audio
media_type = "audio"
elif message.document:
media_file = message.document
media_type = "file"
# Download media if present
if media_file and self._app:
try:
file = await self._app.bot.get_file(media_file.file_id)
ext = self._get_extension(
media_type,
getattr(media_file, 'mime_type', None),
getattr(media_file, 'file_name', None),
)
media_dir = get_media_dir("telegram")
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
await file.download_to_drive(str(file_path))
media_paths.append(str(file_path))
if media_type in ("voice", "audio"):
transcription = await self.transcribe_audio(file_path)
if transcription:
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
content_parts.append(f"[transcription: {transcription}]")
else:
content_parts.append(f"[{media_type}: {file_path}]")
else:
content_parts.append(f"[{media_type}: {file_path}]")
logger.debug("Downloaded {} to {}", media_type, file_path)
except Exception as e:
logger.error("Failed to download media: {}", e)
content_parts.append(f"[{media_type}: download failed]")
# Reply context: text and/or media from the replied-to message
reply = getattr(message, "reply_to_message", None)
if reply is not None:
reply_ctx = self._extract_reply_context(message)
reply_media, reply_media_parts = await self._download_message_media(reply)
if reply_media:
media_paths = reply_media + media_paths
logger.debug("Attached replied-to media: {}", reply_media[0])
tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
if tag:
content_parts.insert(0, tag)
content = "\n".join(content_parts) if content_parts else "[empty message]" content = "\n".join(content_parts) if content_parts else "[empty message]"
logger.debug("Telegram message from {}: {}...", sender_id, content[:50]) logger.debug("Telegram message from {}: {}...", sender_id, content[:50])

View File

@@ -19,10 +19,12 @@ if sys.platform == "win32":
pass pass
import typer import typer
from prompt_toolkit import print_formatted_text
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import HTML from prompt_toolkit.formatted_text import ANSI, HTML
from prompt_toolkit.history import FileHistory from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.application import run_in_terminal
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.table import Table from rich.table import Table
@@ -111,8 +113,25 @@ def _init_prompt_session() -> None:
) )
def _make_console() -> Console:
return Console(file=sys.stdout)
def _render_interactive_ansi(render_fn) -> str:
"""Render Rich output to ANSI so prompt_toolkit can print it safely."""
ansi_console = Console(
force_terminal=True,
color_system=console.color_system or "standard",
width=console.width,
)
with ansi_console.capture() as capture:
render_fn(ansi_console)
return capture.get()
def _print_agent_response(response: str, render_markdown: bool) -> None: def _print_agent_response(response: str, render_markdown: bool) -> None:
"""Render assistant response with consistent terminal styling.""" """Render assistant response with consistent terminal styling."""
console = _make_console()
content = response or "" content = response or ""
body = Markdown(content) if render_markdown else Text(content) body = Markdown(content) if render_markdown else Text(content)
console.print() console.print()
@@ -121,6 +140,34 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
console.print() console.print()
async def _print_interactive_line(text: str) -> None:
"""Print async interactive updates with prompt_toolkit-safe Rich styling."""
def _write() -> None:
ansi = _render_interactive_ansi(
lambda c: c.print(f" [dim]↳ {text}[/dim]")
)
print_formatted_text(ANSI(ansi), end="")
await run_in_terminal(_write)
async def _print_interactive_response(response: str, render_markdown: bool) -> None:
"""Print async interactive replies with prompt_toolkit-safe Rich styling."""
def _write() -> None:
content = response or ""
ansi = _render_interactive_ansi(
lambda c: (
c.print(),
c.print(f"[cyan]{__logo__} nanobot[/cyan]"),
c.print(Markdown(content) if render_markdown else Text(content)),
c.print(),
)
)
print_formatted_text(ANSI(ansi), end="")
await run_in_terminal(_write)
def _is_exit_command(command: str) -> bool: def _is_exit_command(command: str) -> bool:
"""Return True when input should end interactive chat.""" """Return True when input should end interactive chat."""
return command.lower() in EXIT_COMMANDS return command.lower() in EXIT_COMMANDS
@@ -610,14 +657,15 @@ def agent(
elif ch and not is_tool_hint and not ch.send_progress: elif ch and not is_tool_hint and not ch.send_progress:
pass pass
else: else:
console.print(f" [dim]↳ {msg.content}[/dim]") await _print_interactive_line(msg.content)
elif not turn_done.is_set(): elif not turn_done.is_set():
if msg.content: if msg.content:
turn_response.append(msg.content) turn_response.append(msg.content)
turn_done.set() turn_done.set()
elif msg.content: elif msg.content:
console.print() await _print_interactive_response(msg.content, render_markdown=markdown)
_print_agent_response(msg.content, render_markdown=markdown)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -48,6 +48,7 @@ class FeishuConfig(Base):
react_emoji: str = ( react_emoji: str = (
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE) "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
) )
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned, "open" responds to all
class DingTalkConfig(Base): class DingTalkConfig(Base):
@@ -275,15 +276,18 @@ class ProvidersConfig(Base):
deepseek: ProviderConfig = Field(default_factory=ProviderConfig) deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
groq: ProviderConfig = Field(default_factory=ProviderConfig) groq: ProviderConfig = Field(default_factory=ProviderConfig)
zhipu: ProviderConfig = Field(default_factory=ProviderConfig) zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问 dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
vllm: ProviderConfig = Field(default_factory=ProviderConfig) vllm: ProviderConfig = Field(default_factory=ProviderConfig)
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
gemini: ProviderConfig = Field(default_factory=ProviderConfig) gemini: ProviderConfig = Field(default_factory=ProviderConfig)
moonshot: ProviderConfig = Field(default_factory=ProviderConfig) moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig) minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth) openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -397,12 +401,21 @@ class Config(BaseSettings):
# Fallback: configured local providers can route models without # Fallback: configured local providers can route models without
# provider-specific keywords (for example plain "llama3.2" on Ollama). # provider-specific keywords (for example plain "llama3.2" on Ollama).
# Prefer providers whose detect_by_base_keyword matches the configured api_base
# (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
local_fallback: tuple[ProviderConfig, str] | None = None
for spec in PROVIDERS: for spec in PROVIDERS:
if not spec.is_local: if not spec.is_local:
continue continue
p = getattr(self.providers, spec.name, None) p = getattr(self.providers, spec.name, None)
if p and p.api_base: if not (p and p.api_base):
continue
if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
return p, spec.name return p, spec.name
if local_fallback is None:
local_fallback = (p, spec.name)
if local_fallback:
return local_fallback
# Fallback: gateways first, then others (follows registry order) # Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks — they require explicit model selection # OAuth providers are NOT valid fallbacks — they require explicit model selection

View File

@@ -88,6 +88,7 @@ class AzureOpenAIProvider(LLMProvider):
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" """Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = { payload: dict[str, Any] = {
@@ -106,7 +107,7 @@ class AzureOpenAIProvider(LLMProvider):
if tools: if tools:
payload["tools"] = tools payload["tools"] = tools
payload["tool_choice"] = "auto" payload["tool_choice"] = tool_choice or "auto"
return payload return payload
@@ -118,6 +119,7 @@ class AzureOpenAIProvider(LLMProvider):
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request to Azure OpenAI. Send a chat completion request to Azure OpenAI.
@@ -137,7 +139,8 @@ class AzureOpenAIProvider(LLMProvider):
url = self._build_chat_url(deployment_name) url = self._build_chat_url(deployment_name)
headers = self._build_headers() headers = self._build_headers()
payload = self._prepare_request_payload( payload = self._prepare_request_payload(
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
tool_choice=tool_choice,
) )
try: try:

View File

@@ -166,6 +166,7 @@ class LLMProvider(ABC):
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request. Send a chat completion request.
@@ -176,6 +177,7 @@ class LLMProvider(ABC):
model: Model identifier (provider-specific). model: Model identifier (provider-specific).
max_tokens: Maximum tokens in response. max_tokens: Maximum tokens in response.
temperature: Sampling temperature. temperature: Sampling temperature.
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns: Returns:
LLMResponse with content and/or tool calls. LLMResponse with content and/or tool calls.
@@ -195,6 +197,7 @@ class LLMProvider(ABC):
max_tokens: object = _SENTINEL, max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL, temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL, reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
"""Call chat() with retry on transient provider failures. """Call chat() with retry on transient provider failures.
@@ -218,6 +221,7 @@ class LLMProvider(ABC):
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
reasoning_effort=reasoning_effort, reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
@@ -250,6 +254,7 @@ class LLMProvider(ABC):
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
reasoning_effort=reasoning_effort, reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise

View File

@@ -25,7 +25,8 @@ class CustomProvider(LLMProvider):
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
reasoning_effort: str | None = None) -> LLMResponse: reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": model or self.default_model, "model": model or self.default_model,
"messages": self._sanitize_empty_content(messages), "messages": self._sanitize_empty_content(messages),
@@ -35,7 +36,7 @@ class CustomProvider(LLMProvider):
if reasoning_effort: if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort kwargs["reasoning_effort"] = reasoning_effort
if tools: if tools:
kwargs.update(tools=tools, tool_choice="auto") kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
try: try:
return self._parse(await self._client.chat.completions.create(**kwargs)) return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e: except Exception as e:

View File

@@ -214,6 +214,7 @@ class LiteLLMProvider(LLMProvider):
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Send a chat completion request via LiteLLM. Send a chat completion request via LiteLLM.
@@ -267,7 +268,7 @@ class LiteLLMProvider(LLMProvider):
if tools: if tools:
kwargs["tools"] = tools kwargs["tools"] = tools
kwargs["tool_choice"] = "auto" kwargs["tool_choice"] = tool_choice or "auto"
try: try:
response = await acompletion(**kwargs) response = await acompletion(**kwargs)

View File

@@ -32,6 +32,7 @@ class OpenAICodexProvider(LLMProvider):
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7, temperature: float = 0.7,
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse: ) -> LLMResponse:
model = model or self.default_model model = model or self.default_model
system_prompt, input_items = _convert_messages(messages) system_prompt, input_items = _convert_messages(messages)
@@ -48,7 +49,7 @@ class OpenAICodexProvider(LLMProvider):
"text": {"verbosity": "medium"}, "text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"], "include": ["reasoning.encrypted_content"],
"prompt_cache_key": _prompt_cache_key(messages), "prompt_cache_key": _prompt_cache_key(messages),
"tool_choice": "auto", "tool_choice": tool_choice or "auto",
"parallel_tool_calls": True, "parallel_tool_calls": True,
} }

View File

@@ -145,7 +145,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# VolcEngine (火山引擎): OpenAI-compatible gateway
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
ProviderSpec( ProviderSpec(
name="volcengine", name="volcengine",
keywords=("volcengine", "volces", "ark"), keywords=("volcengine", "volces", "ark"),
@@ -162,6 +163,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False, strip_model_prefix=False,
model_overrides=(), model_overrides=(),
), ),
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
ProviderSpec(
name="volcengine_coding_plan",
keywords=("volcengine-plan",),
env_key="OPENAI_API_KEY",
display_name="VolcEngine Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus: VolcEngine international, pay-per-use models
ProviderSpec(
name="byteplus",
keywords=("byteplus",),
env_key="OPENAI_API_KEY",
display_name="BytePlus",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="bytepluses",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
strip_model_prefix=True,
model_overrides=(),
),
# BytePlus Coding Plan: same key as byteplus
ProviderSpec(
name="byteplus_coding_plan",
keywords=("byteplus-plan",),
env_key="OPENAI_API_KEY",
display_name="BytePlus Coding Plan",
litellm_prefix="volcengine",
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
strip_model_prefix=True,
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) =============== # === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec( ProviderSpec(

View File

@@ -49,7 +49,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
wecom = [ wecom = [
"wecom-aibot-sdk-python @ git+https://github.com/chengyongru/wecom_aibot_sdk.git@v0.1.2", "wecom-aibot-sdk-python>=0.1.2",
] ]
matrix = [ matrix = [
"matrix-nio[e2e]>=0.25.2", "matrix-nio[e2e]>=0.25.2",

View File

@@ -143,6 +143,35 @@ def test_config_auto_detects_ollama_from_local_api_base():
assert config.get_api_base() == "http://localhost:11434" assert config.get_api_base() == "http://localhost:11434"
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
"ollama": {"apiBase": "http://localhost:11434"},
},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_falls_back_to_vllm_when_ollama_not_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
},
}
)
assert config.get_provider_name() == "vllm"
assert config.get_api_base() == "http://localhost:8000"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword(): def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex") spec = find_by_model("github-copilot/gpt-5.3-codex")

View File

@@ -5,7 +5,7 @@ from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop: def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop) loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = 500 loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
return loop return loop
@@ -39,3 +39,17 @@ def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
skip=0, skip=0,
) )
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}] assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
def test_save_turn_keeps_tool_results_under_16k() -> None:
loop = _mk_loop()
session = Session(key="test:tool-result")
content = "x" * 12_000
loop._save_turn(
session,
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
skip=0,
)
assert session.messages[0]["content"] == content

View File

@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None: async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient() channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group" channel._chat_type_cache["group123"] = "group"
@@ -60,7 +60,37 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
assert len(channel._client.api.group_calls) == 1 assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0] call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123" assert call == {
assert call["msg_id"] == "msg1" "group_openid": "group123",
assert call["msg_seq"] == 2 "msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.c2c_calls assert not channel._client.api.c2c_calls
@pytest.mark.asyncio
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.c2c_calls) == 1
call = channel._client.api.c2c_calls[0]
assert call == {
"openid": "user123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.group_calls

View File

@@ -0,0 +1,76 @@
"""Tests for /restart slash command."""
from __future__ import annotations
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from nanobot.bus.events import InboundMessage
def _make_loop():
"""Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
workspace = MagicMock()
workspace.__truediv__ = MagicMock(return_value=MagicMock())
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager"):
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
return loop, bus
class TestRestartCommand:
@pytest.mark.asyncio
async def test_restart_sends_message_and_calls_execv(self):
loop, bus = _make_loop()
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
with patch("nanobot.agent.loop.os.execv") as mock_execv:
await loop._handle_restart(msg)
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "Restarting" in out.content
await asyncio.sleep(1.5)
mock_execv.assert_called_once()
@pytest.mark.asyncio
async def test_restart_intercepted_in_run_loop(self):
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
with patch.object(loop, "_handle_restart") as mock_handle:
mock_handle.return_value = None
await bus.publish_inbound(msg)
loop._running = True
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.1)
loop._running = False
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
mock_handle.assert_called_once()
@pytest.mark.asyncio
async def test_help_includes_restart(self):
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
response = await loop._process_message(msg)
assert response is not None
assert "/restart" in response.content

View File

@@ -1,10 +1,13 @@
import asyncio
from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest import pytest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TelegramChannel from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel
from nanobot.config.schema import TelegramConfig from nanobot.config.schema import TelegramConfig
@@ -42,6 +45,12 @@ class _FakeBot:
async def send_chat_action(self, **kwargs) -> None: async def send_chat_action(self, **kwargs) -> None:
pass pass
async def get_file(self, file_id: str):
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
async def _fake_download(path) -> None:
pass
return SimpleNamespace(download_to_drive=_fake_download)
class _FakeApp: class _FakeApp:
def __init__(self, on_start_polling) -> None: def __init__(self, on_start_polling) -> None:
@@ -336,3 +345,255 @@ async def test_group_policy_open_accepts_plain_group_message() -> None:
assert len(handled) == 1 assert len(handled) == 1
assert channel._app.bot.get_me_calls == 0 assert channel._app.bot.get_me_calls == 0
def test_extract_reply_context_no_reply() -> None:
"""When there is no reply_to_message, _extract_reply_context returns None."""
message = SimpleNamespace(reply_to_message=None)
assert TelegramChannel._extract_reply_context(message) is None
def test_extract_reply_context_with_text() -> None:
"""When reply has text, return prefixed string."""
reply = SimpleNamespace(text="Hello world", caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
def test_extract_reply_context_with_caption_only() -> None:
"""When reply has only caption (no text), caption is used."""
reply = SimpleNamespace(text=None, caption="Photo caption")
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
def test_extract_reply_context_truncation() -> None:
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
reply = SimpleNamespace(text=long_text, caption=None)
message = SimpleNamespace(reply_to_message=reply)
result = TelegramChannel._extract_reply_context(message)
assert result is not None
assert result.startswith("[Reply to: ")
assert result.endswith("...]")
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
def test_extract_reply_context_no_text_returns_none() -> None:
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
reply = SimpleNamespace(text=None, caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) is None
@pytest.mark.asyncio
async def test_on_message_includes_reply_context() -> None:
"""When user replies to a message, content passed to bus starts with reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="translate this", reply_to_message=reply)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: Hello]")
assert "translate this" in handled[0]["content"]
@pytest.mark.asyncio
async def test_download_message_media_returns_path_when_download_succeeds(
monkeypatch, tmp_path
) -> None:
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
msg = SimpleNamespace(
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
voice=None,
audio=None,
document=None,
video=None,
video_note=None,
animation=None,
)
paths, parts = await channel._download_message_media(msg)
assert len(paths) == 1
assert len(parts) == 1
assert "fid123" in paths[0]
assert "[image:" in parts[0]
@pytest.mark.asyncio
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what is the image?",
reply_to_message=reply_with_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: [image:")
assert "what is the image?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "reply_photo_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
"""When reply has media but download fails, no media attached and no reply tag."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = None
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
await channel._on_message(update, None)
assert len(handled) == 1
assert "what is this?" in handled[0]["content"]
assert handled[0]["media"] == []
@pytest.mark.asyncio
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
"""When replying to a message with caption + photo, both text context and media are included."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_caption_and_photo = SimpleNamespace(
text=None,
caption="A cute cat",
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what breed is this?",
reply_to_message=reply_with_caption_and_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert "[Reply to: A cute cat]" in handled[0]["content"]
assert "what breed is this?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "cat_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_forward_command_does_not_inject_reply_context() -> None:
"""Slash commands forwarded via _forward_command must not include reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="/new", reply_to_message=reply)
await channel._forward_command(update, None)
assert len(handled) == 1
assert handled[0]["content"] == "/new"