tags)
+ if cls._MD_LINK_RE.search(stripped):
+ return "post"
+
+ # Short plain text → text format
+ if len(stripped) <= cls._TEXT_MAX_LEN:
+ return "text"
+
+ # Medium plain text without any formatting → post format
+ return "post"
+
+ @classmethod
+ def _markdown_to_post(cls, content: str) -> str:
+ """Convert markdown content to Feishu post message JSON.
+
+ Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
+ Each line becomes a paragraph (row) in the post body.
+ """
+ lines = content.strip().split("\n")
+ paragraphs: list[list[dict]] = []
+
+ for line in lines:
+ elements: list[dict] = []
+ last_end = 0
+
+ for m in cls._MD_LINK_RE.finditer(line):
+ # Text before this link
+ before = line[last_end:m.start()]
+ if before:
+ elements.append({"tag": "text", "text": before})
+ elements.append({
+ "tag": "a",
+ "text": m.group(1),
+ "href": m.group(2),
+ })
+ last_end = m.end()
+
+ # Remaining text after last link
+ remaining = line[last_end:]
+ if remaining:
+ elements.append({"tag": "text", "text": remaining})
+
+ # Empty line → empty paragraph for spacing
+ if not elements:
+ elements.append({"tag": "text", "text": ""})
+
+ paragraphs.append(elements)
+
+ post_body = {
+ "zh_cn": {
+ "content": paragraphs,
+ }
+ }
+ return json.dumps(post_body, ensure_ascii=False)
+
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
_AUDIO_EXTS = {".opus"}
+ _VIDEO_EXTS = {".mp4", ".mov", ".avi"}
_FILE_TYPE_MAP = {
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
@@ -451,6 +657,7 @@ class FeishuChannel(BaseChannel):
def _upload_image_sync(self, file_path: str) -> str | None:
"""Upload an image to Feishu and return the image_key."""
+ from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
try:
with open(file_path, "rb") as f:
request = CreateImageRequest.builder() \
@@ -474,6 +681,7 @@ class FeishuChannel(BaseChannel):
def _upload_file_sync(self, file_path: str) -> str | None:
"""Upload a file to Feishu and return the file_key."""
+ from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
ext = os.path.splitext(file_path)[1].lower()
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
file_name = os.path.basename(file_path)
@@ -501,6 +709,7 @@ class FeishuChannel(BaseChannel):
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
"""Download an image from Feishu message by message_id and image_key."""
+ from lark_oapi.api.im.v1 import GetMessageResourceRequest
try:
request = GetMessageResourceRequest.builder() \
.message_id(message_id) \
@@ -525,6 +734,13 @@ class FeishuChannel(BaseChannel):
self, message_id: str, file_key: str, resource_type: str = "file"
) -> tuple[bytes | None, str | None]:
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
+ from lark_oapi.api.im.v1 import GetMessageResourceRequest
+
+ # Feishu API only accepts 'image' or 'file' as type parameter
+ # Convert 'audio' to 'file' for API compatibility
+ if resource_type == "audio":
+ resource_type = "file"
+
try:
request = (
GetMessageResourceRequest.builder()
@@ -559,8 +775,7 @@ class FeishuChannel(BaseChannel):
(file_path, content_text) - file_path is None if download failed
"""
loop = asyncio.get_running_loop()
- media_dir = Path.home() / ".nanobot" / "media"
- media_dir.mkdir(parents=True, exist_ok=True)
+ media_dir = get_media_dir("feishu")
data, filename = None, None
@@ -580,8 +795,9 @@ class FeishuChannel(BaseChannel):
None, self._download_file_sync, message_id, file_key, msg_type
)
if not filename:
- ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
- filename = f"{file_key[:16]}{ext}"
+ filename = file_key[:16]
+ if msg_type == "audio" and not filename.endswith(".opus"):
+ filename = f"{filename}.opus"
if data and filename:
file_path = media_dir / filename
@@ -591,8 +807,80 @@ class FeishuChannel(BaseChannel):
return None, f"[{msg_type}: download failed]"
+ _REPLY_CONTEXT_MAX_LEN = 200
+
+ def _get_message_content_sync(self, message_id: str) -> str | None:
+ """Fetch the text content of a Feishu message by ID (synchronous).
+
+ Returns a "[Reply to: ...]" context string, or None on failure.
+ """
+ from lark_oapi.api.im.v1 import GetMessageRequest
+ try:
+ request = GetMessageRequest.builder().message_id(message_id).build()
+ response = self._client.im.v1.message.get(request)
+ if not response.success():
+ logger.debug(
+ "Feishu: could not fetch parent message {}: code={}, msg={}",
+ message_id, response.code, response.msg,
+ )
+ return None
+ items = getattr(response.data, "items", None)
+ if not items:
+ return None
+ msg_obj = items[0]
+ raw_content = getattr(msg_obj, "body", None)
+ raw_content = getattr(raw_content, "content", None) if raw_content else None
+ if not raw_content:
+ return None
+ try:
+ content_json = json.loads(raw_content)
+ except (json.JSONDecodeError, TypeError):
+ return None
+ msg_type = getattr(msg_obj, "msg_type", "")
+ if msg_type == "text":
+ text = content_json.get("text", "").strip()
+ elif msg_type == "post":
+ text, _ = _extract_post_content(content_json)
+ text = text.strip()
+ else:
+ text = ""
+ if not text:
+ return None
+ if len(text) > self._REPLY_CONTEXT_MAX_LEN:
+ text = text[: self._REPLY_CONTEXT_MAX_LEN] + "..."
+ return f"[Reply to: {text}]"
+ except Exception as e:
+ logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
+ return None
+
+ def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
+ """Reply to an existing Feishu message using the Reply API (synchronous)."""
+ from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
+ try:
+ request = ReplyMessageRequest.builder() \
+ .message_id(parent_message_id) \
+ .request_body(
+ ReplyMessageRequestBody.builder()
+ .msg_type(msg_type)
+ .content(content)
+ .build()
+ ).build()
+ response = self._client.im.v1.message.reply(request)
+ if not response.success():
+ logger.error(
+ "Failed to reply to Feishu message {}: code={}, msg={}, log_id={}",
+ parent_message_id, response.code, response.msg, response.get_log_id()
+ )
+ return False
+ logger.debug("Feishu reply sent to message {}", parent_message_id)
+ return True
+ except Exception as e:
+ logger.error("Error replying to Feishu message {}: {}", parent_message_id, e)
+ return False
+
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
"""Send a single message (text/image/file/interactive) synchronously."""
+ from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
try:
request = CreateMessageRequest.builder() \
.receive_id_type(receive_id_type) \
@@ -626,6 +914,38 @@ class FeishuChannel(BaseChannel):
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
loop = asyncio.get_running_loop()
+ # Handle tool hint messages as code blocks in interactive cards.
+ # These are progress-only messages and should bypass normal reply routing.
+ if msg.metadata.get("_tool_hint"):
+ if msg.content and msg.content.strip():
+ await self._send_tool_hint_card(
+ receive_id_type, msg.chat_id, msg.content.strip()
+ )
+ return
+
+ # Determine whether the first message should quote the user's message.
+ # Only the very first send (media or text) in this call uses reply; subsequent
+ # chunks/media fall back to plain create to avoid redundant quote bubbles.
+ reply_message_id: str | None = None
+ if (
+ self.config.reply_to_message
+ and not msg.metadata.get("_progress", False)
+ ):
+ reply_message_id = msg.metadata.get("message_id") or None
+
+ first_send = True # tracks whether the reply has already been used
+
+ def _do_send(m_type: str, content: str) -> None:
+ """Send via reply (first message) or create (subsequent)."""
+ nonlocal first_send
+ if reply_message_id and first_send:
+ first_send = False
+ ok = self._reply_message_sync(reply_message_id, m_type, content)
+ if ok:
+ return
+ # Fall back to regular send if reply fails
+ self._send_message_sync(receive_id_type, msg.chat_id, m_type, content)
+
for file_path in msg.media:
if not os.path.isfile(file_path):
logger.warning("Media file not found: {}", file_path)
@@ -635,37 +955,58 @@ class FeishuChannel(BaseChannel):
key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
if key:
await loop.run_in_executor(
- None, self._send_message_sync,
- receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}, ensure_ascii=False),
+ None, _do_send,
+ "image", json.dumps({"image_key": key}, ensure_ascii=False),
)
else:
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
if key:
- media_type = "audio" if ext in self._AUDIO_EXTS else "file"
+ # Use msg_type "media" for audio/video so users can play inline;
+ # "file" for everything else (documents, archives, etc.)
+ if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
+ media_type = "media"
+ else:
+ media_type = "file"
await loop.run_in_executor(
- None, self._send_message_sync,
- receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
+ None, _do_send,
+ media_type, json.dumps({"file_key": key}, ensure_ascii=False),
)
if msg.content and msg.content.strip():
- card = {"config": {"wide_screen_mode": True}, "elements": self._build_card_elements(msg.content)}
- await loop.run_in_executor(
- None, self._send_message_sync,
- receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
- )
+ fmt = self._detect_msg_format(msg.content)
+
+ if fmt == "text":
+ # Short plain text – send as simple text message
+ text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
+ await loop.run_in_executor(None, _do_send, "text", text_body)
+
+ elif fmt == "post":
+ # Medium content with links – send as rich-text post
+ post_body = self._markdown_to_post(msg.content)
+ await loop.run_in_executor(None, _do_send, "post", post_body)
+
+ else:
+ # Complex / long content – send as interactive card
+ elements = self._build_card_elements(msg.content)
+ for chunk in self._split_elements_by_table_limit(elements):
+ card = {"config": {"wide_screen_mode": True}, "elements": chunk}
+ await loop.run_in_executor(
+ None, _do_send,
+ "interactive", json.dumps(card, ensure_ascii=False),
+ )
except Exception as e:
logger.error("Error sending Feishu message: {}", e)
-
- def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
+
+ def _on_message_sync(self, data: Any) -> None:
"""
Sync handler for incoming messages (called from WebSocket thread).
Schedules async handling in the main event loop.
"""
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
-
- async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
+
+ async def _on_message(self, data: Any) -> None:
"""Handle incoming message from Feishu."""
try:
event = data.event
@@ -691,8 +1032,12 @@ class FeishuChannel(BaseChannel):
chat_type = message.chat_type
msg_type = message.message_type
+ if chat_type == "group" and not self._is_group_message_for_bot(message):
+ logger.debug("Feishu: skipping group message (not mentioned)")
+ return
+
# Add reaction
- await self._add_reaction(message_id, "THUMBSUP")
+ await self._add_reaction(message_id, self.config.react_emoji)
# Parse content
content_parts = []
@@ -725,6 +1070,12 @@ class FeishuChannel(BaseChannel):
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
if file_path:
media_paths.append(file_path)
+
+ if msg_type == "audio" and file_path:
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ content_text = f"[transcription: {transcription}]"
+
content_parts.append(content_text)
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
@@ -736,6 +1087,19 @@ class FeishuChannel(BaseChannel):
else:
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
+ # Extract reply context (parent/root message IDs)
+ parent_id = getattr(message, "parent_id", None) or None
+ root_id = getattr(message, "root_id", None) or None
+
+ # Prepend quoted message text when the user replied to another message
+ if parent_id and self._client:
+ loop = asyncio.get_running_loop()
+ reply_ctx = await loop.run_in_executor(
+ None, self._get_message_content_sync, parent_id
+ )
+ if reply_ctx:
+ content_parts.insert(0, reply_ctx)
+
content = "\n".join(content_parts) if content_parts else ""
if not content and not media_paths:
@@ -752,8 +1116,98 @@ class FeishuChannel(BaseChannel):
"message_id": message_id,
"chat_type": chat_type,
"msg_type": msg_type,
+ "parent_id": parent_id,
+ "root_id": root_id,
}
)
except Exception as e:
logger.error("Error processing Feishu message: {}", e)
+
+ def _on_reaction_created(self, data: Any) -> None:
+ """Ignore reaction events so they do not generate SDK noise."""
+ pass
+
+ def _on_message_read(self, data: Any) -> None:
+ """Ignore read events so they do not generate SDK noise."""
+ pass
+
+ def _on_bot_p2p_chat_entered(self, data: Any) -> None:
+ """Ignore p2p-enter events when a user opens a bot chat."""
+ logger.debug("Bot entered p2p chat (user opened chat window)")
+ pass
+
+ @staticmethod
+ def _format_tool_hint_lines(tool_hint: str) -> str:
+ """Split tool hints across lines on top-level call separators only."""
+ parts: list[str] = []
+ buf: list[str] = []
+ depth = 0
+ in_string = False
+ quote_char = ""
+ escaped = False
+
+ for i, ch in enumerate(tool_hint):
+ buf.append(ch)
+
+ if in_string:
+ if escaped:
+ escaped = False
+ elif ch == "\\":
+ escaped = True
+ elif ch == quote_char:
+ in_string = False
+ continue
+
+ if ch in {'"', "'"}:
+ in_string = True
+ quote_char = ch
+ continue
+
+ if ch == "(":
+ depth += 1
+ continue
+
+ if ch == ")" and depth > 0:
+ depth -= 1
+ continue
+
+ if ch == "," and depth == 0:
+ next_char = tool_hint[i + 1] if i + 1 < len(tool_hint) else ""
+ if next_char == " ":
+ parts.append("".join(buf).rstrip())
+ buf = []
+
+ if buf:
+ parts.append("".join(buf).strip())
+
+ return "\n".join(part for part in parts if part)
+
+ async def _send_tool_hint_card(self, receive_id_type: str, receive_id: str, tool_hint: str) -> None:
+ """Send tool hint as an interactive card with formatted code block.
+
+ Args:
+ receive_id_type: "chat_id" or "open_id"
+ receive_id: The target chat or user ID
+ tool_hint: Formatted tool hint string (e.g., 'web_search("q"), read_file("path")')
+ """
+ loop = asyncio.get_running_loop()
+
+ # Put each top-level tool call on its own line without altering commas inside arguments.
+ formatted_code = self._format_tool_hint_lines(tool_hint)
+
+ card = {
+ "config": {"wide_screen_mode": True},
+ "elements": [
+ {
+ "tag": "markdown",
+ "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"
+ }
+ ]
+ }
+
+ await loop.run_in_executor(
+ None, self._send_message_sync,
+ receive_id_type, receive_id, "interactive",
+ json.dumps(card, ensure_ascii=False),
+ )
diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py
index 77b7294..3820c10 100644
--- a/nanobot/channels/manager.py
+++ b/nanobot/channels/manager.py
@@ -7,7 +7,6 @@ from typing import Any
from loguru import logger
-from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Config
@@ -16,127 +15,56 @@ from nanobot.config.schema import Config
class ChannelManager:
"""
Manages chat channels and coordinates message routing.
-
+
Responsibilities:
- Initialize enabled channels (Telegram, WhatsApp, etc.)
- Start/stop channels
- Route outbound messages
"""
-
+
def __init__(self, config: Config, bus: MessageBus):
self.config = config
self.bus = bus
self.channels: dict[str, BaseChannel] = {}
self._dispatch_task: asyncio.Task | None = None
-
+
self._init_channels()
-
+
def _init_channels(self) -> None:
- """Initialize channels based on config."""
-
- # Telegram channel
- if self.config.channels.telegram.enabled:
- try:
- from nanobot.channels.telegram import TelegramChannel
- self.channels["telegram"] = TelegramChannel(
- self.config.channels.telegram,
- self.bus,
- groq_api_key=self.config.providers.groq.api_key,
- )
- logger.info("Telegram channel enabled")
- except ImportError as e:
- logger.warning("Telegram channel not available: {}", e)
-
- # WhatsApp channel
- if self.config.channels.whatsapp.enabled:
- try:
- from nanobot.channels.whatsapp import WhatsAppChannel
- self.channels["whatsapp"] = WhatsAppChannel(
- self.config.channels.whatsapp, self.bus
- )
- logger.info("WhatsApp channel enabled")
- except ImportError as e:
- logger.warning("WhatsApp channel not available: {}", e)
+ """Initialize channels discovered via pkgutil scan + entry_points plugins."""
+ from nanobot.channels.registry import discover_all
- # Discord channel
- if self.config.channels.discord.enabled:
- try:
- from nanobot.channels.discord import DiscordChannel
- self.channels["discord"] = DiscordChannel(
- self.config.channels.discord, self.bus
- )
- logger.info("Discord channel enabled")
- except ImportError as e:
- logger.warning("Discord channel not available: {}", e)
-
- # Feishu channel
- if self.config.channels.feishu.enabled:
- try:
- from nanobot.channels.feishu import FeishuChannel
- self.channels["feishu"] = FeishuChannel(
- self.config.channels.feishu, self.bus
- )
- logger.info("Feishu channel enabled")
- except ImportError as e:
- logger.warning("Feishu channel not available: {}", e)
+ groq_key = self.config.providers.groq.api_key
- # Mochat channel
- if self.config.channels.mochat.enabled:
+ for name, cls in discover_all().items():
+ section = getattr(self.config.channels, name, None)
+ if section is None:
+ continue
+ enabled = (
+ section.get("enabled", False)
+ if isinstance(section, dict)
+ else getattr(section, "enabled", False)
+ )
+ if not enabled:
+ continue
try:
- from nanobot.channels.mochat import MochatChannel
+ channel = cls(section, self.bus)
+ channel.transcription_api_key = groq_key
+ self.channels[name] = channel
+ logger.info("{} channel enabled", cls.display_name)
+ except Exception as e:
+ logger.warning("{} channel not available: {}", name, e)
- self.channels["mochat"] = MochatChannel(
- self.config.channels.mochat, self.bus
- )
- logger.info("Mochat channel enabled")
- except ImportError as e:
- logger.warning("Mochat channel not available: {}", e)
+ self._validate_allow_from()
- # DingTalk channel
- if self.config.channels.dingtalk.enabled:
- try:
- from nanobot.channels.dingtalk import DingTalkChannel
- self.channels["dingtalk"] = DingTalkChannel(
- self.config.channels.dingtalk, self.bus
+ def _validate_allow_from(self) -> None:
+ for name, ch in self.channels.items():
+ if getattr(ch.config, "allow_from", None) == []:
+ raise SystemExit(
+ f'Error: "{name}" has empty allowFrom (denies all). '
+ f'Set ["*"] to allow everyone, or add specific user IDs.'
)
- logger.info("DingTalk channel enabled")
- except ImportError as e:
- logger.warning("DingTalk channel not available: {}", e)
- # Email channel
- if self.config.channels.email.enabled:
- try:
- from nanobot.channels.email import EmailChannel
- self.channels["email"] = EmailChannel(
- self.config.channels.email, self.bus
- )
- logger.info("Email channel enabled")
- except ImportError as e:
- logger.warning("Email channel not available: {}", e)
-
- # Slack channel
- if self.config.channels.slack.enabled:
- try:
- from nanobot.channels.slack import SlackChannel
- self.channels["slack"] = SlackChannel(
- self.config.channels.slack, self.bus
- )
- logger.info("Slack channel enabled")
- except ImportError as e:
- logger.warning("Slack channel not available: {}", e)
-
- # QQ channel
- if self.config.channels.qq.enabled:
- try:
- from nanobot.channels.qq import QQChannel
- self.channels["qq"] = QQChannel(
- self.config.channels.qq,
- self.bus,
- )
- logger.info("QQ channel enabled")
- except ImportError as e:
- logger.warning("QQ channel not available: {}", e)
-
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
"""Start a channel and log any exceptions."""
try:
@@ -149,23 +77,23 @@ class ChannelManager:
if not self.channels:
logger.warning("No channels enabled")
return
-
+
# Start outbound dispatcher
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
-
+
# Start channels
tasks = []
for name, channel in self.channels.items():
logger.info("Starting {} channel...", name)
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
-
+
# Wait for all to complete (they should run forever)
await asyncio.gather(*tasks, return_exceptions=True)
-
+
async def stop_all(self) -> None:
"""Stop all channels and the dispatcher."""
logger.info("Stopping all channels...")
-
+
# Stop dispatcher
if self._dispatch_task:
self._dispatch_task.cancel()
@@ -173,7 +101,7 @@ class ChannelManager:
await self._dispatch_task
except asyncio.CancelledError:
pass
-
+
# Stop all channels
for name, channel in self.channels.items():
try:
@@ -181,24 +109,24 @@ class ChannelManager:
logger.info("Stopped {} channel", name)
except Exception as e:
logger.error("Error stopping {}: {}", name, e)
-
+
async def _dispatch_outbound(self) -> None:
"""Dispatch outbound messages to the appropriate channel."""
logger.info("Outbound dispatcher started")
-
+
while True:
try:
msg = await asyncio.wait_for(
self.bus.consume_outbound(),
timeout=1.0
)
-
+
if msg.metadata.get("_progress"):
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
continue
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
continue
-
+
channel = self.channels.get(msg.channel)
if channel:
try:
@@ -207,16 +135,16 @@ class ChannelManager:
logger.error("Error sending to {}: {}", msg.channel, e)
else:
logger.warning("Unknown channel: {}", msg.channel)
-
+
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
-
+
def get_channel(self, name: str) -> BaseChannel | None:
"""Get a channel by name."""
return self.channels.get(name)
-
+
def get_status(self) -> dict[str, Any]:
"""Get status of all channels."""
return {
@@ -226,7 +154,7 @@ class ChannelManager:
}
for name, channel in self.channels.items()
}
-
+
@property
def enabled_channels(self) -> list[str]:
"""Get list of enabled channel names."""
diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py
new file mode 100644
index 0000000..9892673
--- /dev/null
+++ b/nanobot/channels/matrix.py
@@ -0,0 +1,739 @@
+"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
+
+import asyncio
+import logging
+import mimetypes
+from pathlib import Path
+from typing import Any, Literal, TypeAlias
+
+from loguru import logger
+from pydantic import Field
+
+try:
+ import nh3
+ from mistune import create_markdown
+ from nio import (
+ AsyncClient,
+ AsyncClientConfig,
+ ContentRepositoryConfigError,
+ DownloadError,
+ InviteEvent,
+ JoinError,
+ MatrixRoom,
+ MemoryDownloadResponse,
+ RoomEncryptedMedia,
+ RoomMessage,
+ RoomMessageMedia,
+ RoomMessageText,
+ RoomSendError,
+ RoomTypingError,
+ SyncError,
+ UploadError,
+ )
+ from nio.crypto.attachments import decrypt_attachment
+ from nio.exceptions import EncryptionError
+except ImportError as e:
+ raise ImportError(
+ "Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
+ ) from e
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_data_dir, get_media_dir
+from nanobot.config.schema import Base
+from nanobot.utils.helpers import safe_filename
+
+TYPING_NOTICE_TIMEOUT_MS = 30_000
+# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
+TYPING_KEEPALIVE_INTERVAL_MS = 20_000
+MATRIX_HTML_FORMAT = "org.matrix.custom.html"
+_ATTACH_MARKER = "[attachment: {}]"
+_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
+_ATTACH_FAILED = "[attachment: {} - download failed]"
+_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
+_DEFAULT_ATTACH_NAME = "attachment"
+_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
+
+MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
+MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
+
+MATRIX_MARKDOWN = create_markdown(
+ escape=True,
+ plugins=["table", "strikethrough", "url", "superscript", "subscript"],
+)
+
+MATRIX_ALLOWED_HTML_TAGS = {
+ "p", "a", "strong", "em", "del", "code", "pre", "blockquote",
+ "ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
+ "hr", "br", "table", "thead", "tbody", "tr", "th", "td",
+ "caption", "sup", "sub", "img",
+}
+MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
+ "a": {"href"}, "code": {"class"}, "ol": {"start"},
+ "img": {"src", "alt", "title", "width", "height"},
+}
+MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
+
+
+def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
+ """Filter attribute values to a safe Matrix-compatible subset."""
+ if tag == "a" and attr == "href":
+ return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
+ if tag == "img" and attr == "src":
+ return value if value.lower().startswith("mxc://") else None
+ if tag == "code" and attr == "class":
+ classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
+ return " ".join(classes) if classes else None
+ return value
+
+
+MATRIX_HTML_CLEANER = nh3.Cleaner(
+ tags=MATRIX_ALLOWED_HTML_TAGS,
+ attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
+ attribute_filter=_filter_matrix_html_attribute,
+ url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
+ strip_comments=True,
+ link_rel="noopener noreferrer",
+)
+
+
+def _render_markdown_html(text: str) -> str | None:
+ """Render markdown to sanitized HTML; returns None for plain text."""
+ try:
+ formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
+ except Exception:
+ return None
+ if not formatted:
+ return None
+ # Skip formatted_body for plain text
to keep payload minimal.
+ if formatted.startswith("") and formatted.endswith("
"):
+ inner = formatted[3:-4]
+ if "<" not in inner and ">" not in inner:
+ return None
+ return formatted
+
+
+def _build_matrix_text_content(text: str) -> dict[str, object]:
+ """Build Matrix m.text payload with optional HTML formatted_body."""
+ content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
+ if html := _render_markdown_html(text):
+ content["format"] = MATRIX_HTML_FORMAT
+ content["formatted_body"] = html
+ return content
+
+
+class _NioLoguruHandler(logging.Handler):
+ """Route matrix-nio stdlib logs into Loguru."""
+
+ def emit(self, record: logging.LogRecord) -> None:
+ try:
+ level = logger.level(record.levelname).name
+ except ValueError:
+ level = record.levelno
+ frame, depth = logging.currentframe(), 2
+ while frame and frame.f_code.co_filename == logging.__file__:
+ frame, depth = frame.f_back, depth + 1
+ logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
+
+
+def _configure_nio_logging_bridge() -> None:
+ """Bridge matrix-nio logs to Loguru (idempotent)."""
+ nio_logger = logging.getLogger("nio")
+ if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
+ nio_logger.handlers = [_NioLoguruHandler()]
+ nio_logger.propagate = False
+
+
+class MatrixConfig(Base):
+ """Matrix (Element) channel configuration."""
+
+ enabled: bool = False
+ homeserver: str = "https://matrix.org"
+ access_token: str = ""
+ user_id: str = ""
+ device_id: str = ""
+ e2ee_enabled: bool = True
+ sync_stop_grace_seconds: int = 2
+ max_media_bytes: int = 20 * 1024 * 1024
+ allow_from: list[str] = Field(default_factory=list)
+ group_policy: Literal["open", "mention", "allowlist"] = "open"
+ group_allow_from: list[str] = Field(default_factory=list)
+ allow_room_mentions: bool = False
+
+
+class MatrixChannel(BaseChannel):
+ """Matrix (Element) channel using long-polling sync."""
+
+ name = "matrix"
+ display_name = "Matrix"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return MatrixConfig().model_dump(by_alias=True)
+
+ def __init__(
+ self,
+ config: Any,
+ bus: MessageBus,
+ *,
+ restrict_to_workspace: bool = False,
+ workspace: str | Path | None = None,
+ ):
+ if isinstance(config, dict):
+ config = MatrixConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.client: AsyncClient | None = None
+ self._sync_task: asyncio.Task | None = None
+ self._typing_tasks: dict[str, asyncio.Task] = {}
+ self._restrict_to_workspace = bool(restrict_to_workspace)
+ self._workspace = (
+ Path(workspace).expanduser().resolve(strict=False) if workspace is not None else None
+ )
+ self._server_upload_limit_bytes: int | None = None
+ self._server_upload_limit_checked = False
+
+ async def start(self) -> None:
+ """Start Matrix client and begin sync loop."""
+ self._running = True
+ _configure_nio_logging_bridge()
+
+ store_path = get_data_dir() / "matrix-store"
+ store_path.mkdir(parents=True, exist_ok=True)
+
+ self.client = AsyncClient(
+ homeserver=self.config.homeserver, user=self.config.user_id,
+ store_path=store_path,
+ config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
+ )
+ self.client.user_id = self.config.user_id
+ self.client.access_token = self.config.access_token
+ self.client.device_id = self.config.device_id
+
+ self._register_event_callbacks()
+ self._register_response_callbacks()
+
+ if not self.config.e2ee_enabled:
+ logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
+
+ if self.config.device_id:
+ try:
+ self.client.load_store()
+ except Exception:
+ logger.exception("Matrix store load failed; restart may replay recent messages.")
+ else:
+ logger.warning("Matrix device_id empty; restart may replay recent messages.")
+
+ self._sync_task = asyncio.create_task(self._sync_loop())
+
+ async def stop(self) -> None:
+ """Stop the Matrix channel with graceful sync shutdown."""
+ self._running = False
+ for room_id in list(self._typing_tasks):
+ await self._stop_typing_keepalive(room_id, clear_typing=False)
+ if self.client:
+ self.client.stop_sync_forever()
+ if self._sync_task:
+ try:
+ await asyncio.wait_for(asyncio.shield(self._sync_task),
+ timeout=self.config.sync_stop_grace_seconds)
+ except (asyncio.TimeoutError, asyncio.CancelledError):
+ self._sync_task.cancel()
+ try:
+ await self._sync_task
+ except asyncio.CancelledError:
+ pass
+ if self.client:
+ await self.client.close()
+
+ def _is_workspace_path_allowed(self, path: Path) -> bool:
+ """Check path is inside workspace (when restriction enabled)."""
+ if not self._restrict_to_workspace or not self._workspace:
+ return True
+ try:
+ path.resolve(strict=False).relative_to(self._workspace)
+ return True
+ except ValueError:
+ return False
+
+ def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
+ """Deduplicate and resolve outbound attachment paths."""
+ seen: set[str] = set()
+ candidates: list[Path] = []
+ for raw in media:
+ if not isinstance(raw, str) or not raw.strip():
+ continue
+ path = Path(raw.strip()).expanduser()
+ try:
+ key = str(path.resolve(strict=False))
+ except OSError:
+ key = str(path)
+ if key not in seen:
+ seen.add(key)
+ candidates.append(path)
+ return candidates
+
+ @staticmethod
+ def _build_outbound_attachment_content(
+ *, filename: str, mime: str, size_bytes: int,
+ mxc_url: str, encryption_info: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ """Build Matrix content payload for an uploaded file/image/audio/video."""
+ prefix = mime.split("/")[0]
+ msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
+ content: dict[str, Any] = {
+ "msgtype": msgtype, "body": filename, "filename": filename,
+ "info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
+ }
+ if encryption_info:
+ content["file"] = {**encryption_info, "url": mxc_url}
+ else:
+ content["url"] = mxc_url
+ return content
+
+ def _is_encrypted_room(self, room_id: str) -> bool:
+ if not self.client:
+ return False
+ room = getattr(self.client, "rooms", {}).get(room_id)
+ return bool(getattr(room, "encrypted", False))
+
+ async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
+ """Send m.room.message with E2EE options."""
+ if not self.client:
+ return
+ kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
+ if self.config.e2ee_enabled:
+ kwargs["ignore_unverified_devices"] = True
+ await self.client.room_send(**kwargs)
+
+ async def _resolve_server_upload_limit_bytes(self) -> int | None:
+ """Query homeserver upload limit once per channel lifecycle."""
+ if self._server_upload_limit_checked:
+ return self._server_upload_limit_bytes
+ self._server_upload_limit_checked = True
+ if not self.client:
+ return None
+ try:
+ response = await self.client.content_repository_config()
+ except Exception:
+ return None
+ upload_size = getattr(response, "upload_size", None)
+ if isinstance(upload_size, int) and upload_size > 0:
+ self._server_upload_limit_bytes = upload_size
+ return upload_size
+ return None
+
+ async def _effective_media_limit_bytes(self) -> int:
+ """min(local config, server advertised) — 0 blocks all uploads."""
+ local_limit = max(int(self.config.max_media_bytes), 0)
+ server_limit = await self._resolve_server_upload_limit_bytes()
+ if server_limit is None:
+ return local_limit
+ return min(local_limit, server_limit) if local_limit else 0
+
+ async def _upload_and_send_attachment(
+ self, room_id: str, path: Path, limit_bytes: int,
+ relates_to: dict[str, Any] | None = None,
+ ) -> str | None:
+ """Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
+ if not self.client:
+ return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
+
+ resolved = path.expanduser().resolve(strict=False)
+ filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
+ fail = _ATTACH_UPLOAD_FAILED.format(filename)
+
+ if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
+ return fail
+ try:
+ size_bytes = resolved.stat().st_size
+ except OSError:
+ return fail
+ if limit_bytes <= 0 or size_bytes > limit_bytes:
+ return _ATTACH_TOO_LARGE.format(filename)
+
+ mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
+ try:
+ with resolved.open("rb") as f:
+ upload_result = await self.client.upload(
+ f, content_type=mime, filename=filename,
+ encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
+ filesize=size_bytes,
+ )
+ except Exception:
+ return fail
+
+ upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
+ encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
+ if isinstance(upload_response, UploadError):
+ return fail
+ mxc_url = getattr(upload_response, "content_uri", None)
+ if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
+ return fail
+
+ content = self._build_outbound_attachment_content(
+ filename=filename, mime=mime, size_bytes=size_bytes,
+ mxc_url=mxc_url, encryption_info=encryption_info,
+ )
+ if relates_to:
+ content["m.relates_to"] = relates_to
+ try:
+ await self._send_room_content(room_id, content)
+ except Exception:
+ return fail
+ return None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send outbound content; clear typing for non-progress messages."""
+ if not self.client:
+ return
+ text = msg.content or ""
+ candidates = self._collect_outbound_media_candidates(msg.media)
+ relates_to = self._build_thread_relates_to(msg.metadata)
+ is_progress = bool((msg.metadata or {}).get("_progress"))
+ try:
+ failures: list[str] = []
+ if candidates:
+ limit_bytes = await self._effective_media_limit_bytes()
+ for path in candidates:
+ if fail := await self._upload_and_send_attachment(
+ room_id=msg.chat_id,
+ path=path,
+ limit_bytes=limit_bytes,
+ relates_to=relates_to,
+ ):
+ failures.append(fail)
+ if failures:
+ text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
+ if text or not candidates:
+ content = _build_matrix_text_content(text)
+ if relates_to:
+ content["m.relates_to"] = relates_to
+ await self._send_room_content(msg.chat_id, content)
+ finally:
+ if not is_progress:
+ await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
+
+ def _register_event_callbacks(self) -> None:
+ self.client.add_event_callback(self._on_message, RoomMessageText)
+ self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
+ self.client.add_event_callback(self._on_room_invite, InviteEvent)
+
+ def _register_response_callbacks(self) -> None:
+ self.client.add_response_callback(self._on_sync_error, SyncError)
+ self.client.add_response_callback(self._on_join_error, JoinError)
+ self.client.add_response_callback(self._on_send_error, RoomSendError)
+
+ def _log_response_error(self, label: str, response: Any) -> None:
+ """Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
+ code = getattr(response, "status_code", None)
+ is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
+ is_fatal = is_auth or getattr(response, "soft_logout", False)
+ (logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
+
+ async def _on_sync_error(self, response: SyncError) -> None:
+ self._log_response_error("sync", response)
+
+ async def _on_join_error(self, response: JoinError) -> None:
+ self._log_response_error("join", response)
+
+ async def _on_send_error(self, response: RoomSendError) -> None:
+ self._log_response_error("send", response)
+
+ async def _set_typing(self, room_id: str, typing: bool) -> None:
+ """Best-effort typing indicator update."""
+ if not self.client:
+ return
+ try:
+ response = await self.client.room_typing(room_id=room_id, typing_state=typing,
+ timeout=TYPING_NOTICE_TIMEOUT_MS)
+ if isinstance(response, RoomTypingError):
+ logger.debug("Matrix typing failed for {}: {}", room_id, response)
+ except Exception:
+ pass
+
+ async def _start_typing_keepalive(self, room_id: str) -> None:
+ """Start periodic typing refresh (spec-recommended keepalive)."""
+ await self._stop_typing_keepalive(room_id, clear_typing=False)
+ await self._set_typing(room_id, True)
+ if not self._running:
+ return
+
+ async def loop() -> None:
+ try:
+ while self._running:
+ await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
+ await self._set_typing(room_id, True)
+ except asyncio.CancelledError:
+ pass
+
+ self._typing_tasks[room_id] = asyncio.create_task(loop())
+
+ async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
+ if task := self._typing_tasks.pop(room_id, None):
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ if clear_typing:
+ await self._set_typing(room_id, False)
+
+ async def _sync_loop(self) -> None:
+ while self._running:
+ try:
+ await self.client.sync_forever(timeout=30000, full_state=True)
+ except asyncio.CancelledError:
+ break
+ except Exception:
+ await asyncio.sleep(2)
+
+ async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
+ if self.is_allowed(event.sender):
+ await self.client.join(room.room_id)
+
+ def _is_direct_room(self, room: MatrixRoom) -> bool:
+ count = getattr(room, "member_count", None)
+ return isinstance(count, int) and count <= 2
+
+ def _is_bot_mentioned(self, event: RoomMessage) -> bool:
+ """Check m.mentions payload for bot mention."""
+ source = getattr(event, "source", None)
+ if not isinstance(source, dict):
+ return False
+ mentions = (source.get("content") or {}).get("m.mentions")
+ if not isinstance(mentions, dict):
+ return False
+ user_ids = mentions.get("user_ids")
+ if isinstance(user_ids, list) and self.config.user_id in user_ids:
+ return True
+ return bool(self.config.allow_room_mentions and mentions.get("room") is True)
+
+ def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
+ """Apply sender and room policy checks."""
+ if not self.is_allowed(event.sender):
+ return False
+ if self._is_direct_room(room):
+ return True
+ policy = self.config.group_policy
+ if policy == "open":
+ return True
+ if policy == "allowlist":
+ return room.room_id in (self.config.group_allow_from or [])
+ if policy == "mention":
+ return self._is_bot_mentioned(event)
+ return False
+
+ def _media_dir(self) -> Path:
+ return get_media_dir("matrix")
+
+ @staticmethod
+ def _event_source_content(event: RoomMessage) -> dict[str, Any]:
+ source = getattr(event, "source", None)
+ if not isinstance(source, dict):
+ return {}
+ content = source.get("content")
+ return content if isinstance(content, dict) else {}
+
+ def _event_thread_root_id(self, event: RoomMessage) -> str | None:
+ relates_to = self._event_source_content(event).get("m.relates_to")
+ if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
+ return None
+ root_id = relates_to.get("event_id")
+ return root_id if isinstance(root_id, str) and root_id else None
+
+ def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
+ if not (root_id := self._event_thread_root_id(event)):
+ return None
+ meta: dict[str, str] = {"thread_root_event_id": root_id}
+ if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
+ meta["thread_reply_to_event_id"] = reply_to
+ return meta
+
+ @staticmethod
+ def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
+ if not metadata:
+ return None
+ root_id = metadata.get("thread_root_event_id")
+ if not isinstance(root_id, str) or not root_id:
+ return None
+ reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
+ if not isinstance(reply_to, str) or not reply_to:
+ return None
+ return {"rel_type": "m.thread", "event_id": root_id,
+ "m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
+
+ def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
+ msgtype = self._event_source_content(event).get("msgtype")
+ return _MSGTYPE_MAP.get(msgtype, "file")
+
+ @staticmethod
+ def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
+ return (isinstance(getattr(event, "key", None), dict)
+ and isinstance(getattr(event, "hashes", None), dict)
+ and isinstance(getattr(event, "iv", None), str))
+
+ def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
+ info = self._event_source_content(event).get("info")
+ size = info.get("size") if isinstance(info, dict) else None
+ return size if isinstance(size, int) and size >= 0 else None
+
+ def _event_mime(self, event: MatrixMediaEvent) -> str | None:
+ info = self._event_source_content(event).get("info")
+ if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
+ return m
+ m = getattr(event, "mimetype", None)
+ return m if isinstance(m, str) and m else None
+
+ def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
+ body = getattr(event, "body", None)
+ if isinstance(body, str) and body.strip():
+ if candidate := safe_filename(Path(body).name):
+ return candidate
+ return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
+
+ def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
+ filename: str, mime: str | None) -> Path:
+ safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
+ suffix = Path(safe_name).suffix
+ if not suffix and mime:
+ if guessed := mimetypes.guess_extension(mime, strict=False):
+ safe_name, suffix = f"{safe_name}{guessed}", guessed
+ stem = (Path(safe_name).stem or attachment_type)[:72]
+ suffix = suffix[:16]
+ event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
+ event_prefix = (event_id[:24] or "evt").strip("_")
+ return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
+
+ async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
+ if not self.client:
+ return None
+ response = await self.client.download(mxc=mxc_url)
+ if isinstance(response, DownloadError):
+ logger.warning("Matrix download failed for {}: {}", mxc_url, response)
+ return None
+ body = getattr(response, "body", None)
+ if isinstance(body, (bytes, bytearray)):
+ return bytes(body)
+ if isinstance(response, MemoryDownloadResponse):
+ return bytes(response.body)
+ if isinstance(body, (str, Path)):
+ path = Path(body)
+ if path.is_file():
+ try:
+ return path.read_bytes()
+ except OSError:
+ return None
+ return None
+
+ def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
+ key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
+ key = key_obj.get("k") if isinstance(key_obj, dict) else None
+ sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
+ if not all(isinstance(v, str) for v in (key, sha256, iv)):
+ return None
+ try:
+ return decrypt_attachment(ciphertext, key, sha256, iv)
+ except (EncryptionError, ValueError, TypeError):
+ logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
+ return None
+
+ async def _fetch_media_attachment(
+ self, room: MatrixRoom, event: MatrixMediaEvent,
+ ) -> tuple[dict[str, Any] | None, str]:
+ """Download, decrypt if needed, and persist a Matrix attachment."""
+ atype = self._event_attachment_type(event)
+ mime = self._event_mime(event)
+ filename = self._event_filename(event, atype)
+ mxc_url = getattr(event, "url", None)
+ fail = _ATTACH_FAILED.format(filename)
+
+ if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
+ return None, fail
+
+ limit_bytes = await self._effective_media_limit_bytes()
+ declared = self._event_declared_size_bytes(event)
+ if declared is not None and declared > limit_bytes:
+ return None, _ATTACH_TOO_LARGE.format(filename)
+
+ downloaded = await self._download_media_bytes(mxc_url)
+ if downloaded is None:
+ return None, fail
+
+ encrypted = self._is_encrypted_media_event(event)
+ data = downloaded
+ if encrypted:
+ if (data := self._decrypt_media_bytes(event, downloaded)) is None:
+ return None, fail
+
+ if len(data) > limit_bytes:
+ return None, _ATTACH_TOO_LARGE.format(filename)
+
+ path = self._build_attachment_path(event, atype, filename, mime)
+ try:
+ path.write_bytes(data)
+ except OSError:
+ return None, fail
+
+ attachment = {
+ "type": atype, "mime": mime, "filename": filename,
+ "event_id": str(getattr(event, "event_id", "") or ""),
+ "encrypted": encrypted, "size_bytes": len(data),
+ "path": str(path), "mxc_url": mxc_url,
+ }
+ return attachment, _ATTACH_MARKER.format(path)
+
+ def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
+ """Build common metadata for text and media handlers."""
+ meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
+ if isinstance(eid := getattr(event, "event_id", None), str) and eid:
+ meta["event_id"] = eid
+ if thread := self._thread_metadata(event):
+ meta.update(thread)
+ return meta
+
+ async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
+ if event.sender == self.config.user_id or not self._should_process_message(room, event):
+ return
+ await self._start_typing_keepalive(room.room_id)
+ try:
+ await self._handle_message(
+ sender_id=event.sender, chat_id=room.room_id,
+ content=event.body, metadata=self._base_metadata(room, event),
+ )
+ except Exception:
+ await self._stop_typing_keepalive(room.room_id, clear_typing=True)
+ raise
+
+ async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
+ if event.sender == self.config.user_id or not self._should_process_message(room, event):
+ return
+ attachment, marker = await self._fetch_media_attachment(room, event)
+ parts: list[str] = []
+ if isinstance(body := getattr(event, "body", None), str) and body.strip():
+ parts.append(body.strip())
+
+ if attachment and attachment.get("type") == "audio":
+ transcription = await self.transcribe_audio(attachment["path"])
+ if transcription:
+ parts.append(f"[transcription: {transcription}]")
+ else:
+ parts.append(marker)
+ elif marker:
+ parts.append(marker)
+
+ await self._start_typing_keepalive(room.room_id)
+ try:
+ meta = self._base_metadata(room, event)
+ meta["attachments"] = []
+ if attachment:
+ meta["attachments"] = [attachment]
+ await self._handle_message(
+ sender_id=event.sender, chat_id=room.room_id,
+ content="\n".join(parts),
+ media=[attachment["path"]] if attachment else [],
+ metadata=meta,
+ )
+ except Exception:
+ await self._stop_typing_keepalive(room.room_id, clear_typing=True)
+ raise
diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py
index e762dfd..629379f 100644
--- a/nanobot/channels/mochat.py
+++ b/nanobot/channels/mochat.py
@@ -15,8 +15,9 @@ from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
-from nanobot.config.schema import MochatConfig
-from nanobot.utils.helpers import get_data_path
+from nanobot.config.paths import get_runtime_subdir
+from nanobot.config.schema import Base
+from pydantic import Field
try:
import socketio
@@ -208,6 +209,49 @@ def parse_timestamp(value: Any) -> int | None:
return None
+# ---------------------------------------------------------------------------
+# Config classes
+# ---------------------------------------------------------------------------
+
+class MochatMentionConfig(Base):
+ """Mochat mention behavior configuration."""
+
+ require_in_groups: bool = False
+
+
+class MochatGroupRule(Base):
+ """Mochat per-group mention requirement."""
+
+ require_mention: bool = False
+
+
+class MochatConfig(Base):
+ """Mochat channel configuration."""
+
+ enabled: bool = False
+ base_url: str = "https://mochat.io"
+ socket_url: str = ""
+ socket_path: str = "/socket.io"
+ socket_disable_msgpack: bool = False
+ socket_reconnect_delay_ms: int = 1000
+ socket_max_reconnect_delay_ms: int = 10000
+ socket_connect_timeout_ms: int = 10000
+ refresh_interval_ms: int = 30000
+ watch_timeout_ms: int = 25000
+ watch_limit: int = 100
+ retry_delay_ms: int = 500
+ max_retry_attempts: int = 0
+ claw_token: str = ""
+ agent_user_id: str = ""
+ sessions: list[str] = Field(default_factory=list)
+ panels: list[str] = Field(default_factory=list)
+ allow_from: list[str] = Field(default_factory=list)
+ mention: MochatMentionConfig = Field(default_factory=MochatMentionConfig)
+ groups: dict[str, MochatGroupRule] = Field(default_factory=dict)
+ reply_delay_mode: str = "non-mention"
+ reply_delay_ms: int = 120000
+
+
# ---------------------------------------------------------------------------
# Channel
# ---------------------------------------------------------------------------
@@ -216,15 +260,22 @@ class MochatChannel(BaseChannel):
"""Mochat channel using socket.io with fallback polling workers."""
name = "mochat"
+ display_name = "Mochat"
- def __init__(self, config: MochatConfig, bus: MessageBus):
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return MochatConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = MochatConfig.model_validate(config)
super().__init__(config, bus)
self.config: MochatConfig = config
self._http: httpx.AsyncClient | None = None
self._socket: Any = None
self._ws_connected = self._ws_ready = False
- self._state_dir = get_data_path() / "mochat"
+ self._state_dir = get_runtime_subdir("mochat")
self._cursor_path = self._state_dir / "session_cursors.json"
self._session_cursor: dict[str, int] = {}
self._cursor_save_task: asyncio.Task | None = None
diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py
index 5352a30..e556c98 100644
--- a/nanobot/channels/qq.py
+++ b/nanobot/channels/qq.py
@@ -2,27 +2,29 @@
import asyncio
from collections import deque
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Literal
from loguru import logger
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
-from nanobot.config.schema import QQConfig
+from nanobot.config.schema import Base
+from pydantic import Field
try:
import botpy
- from botpy.message import C2CMessage
+ from botpy.message import C2CMessage, GroupMessage
QQ_AVAILABLE = True
except ImportError:
QQ_AVAILABLE = False
botpy = None
C2CMessage = None
+ GroupMessage = None
if TYPE_CHECKING:
- from botpy.message import C2CMessage
+ from botpy.message import C2CMessage, GroupMessage
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
@@ -31,30 +33,53 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
class _Bot(botpy.Client):
def __init__(self):
- super().__init__(intents=intents)
+ # Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
+ super().__init__(intents=intents, ext_handlers=False)
async def on_ready(self):
logger.info("QQ bot ready: {}", self.robot.name)
async def on_c2c_message_create(self, message: "C2CMessage"):
- await channel._on_message(message)
+ await channel._on_message(message, is_group=False)
+
+ async def on_group_at_message_create(self, message: "GroupMessage"):
+ await channel._on_message(message, is_group=True)
async def on_direct_message_create(self, message):
- await channel._on_message(message)
+ await channel._on_message(message, is_group=False)
return _Bot
+class QQConfig(Base):
+ """QQ channel configuration using botpy SDK."""
+
+ enabled: bool = False
+ app_id: str = ""
+ secret: str = ""
+ allow_from: list[str] = Field(default_factory=list)
+ msg_format: Literal["plain", "markdown"] = "plain"
+
+
class QQChannel(BaseChannel):
"""QQ channel using botpy SDK with WebSocket connection."""
name = "qq"
+ display_name = "QQ"
- def __init__(self, config: QQConfig, bus: MessageBus):
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return QQConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = QQConfig.model_validate(config)
super().__init__(config, bus)
self.config: QQConfig = config
self._client: "botpy.Client | None" = None
self._processed_ids: deque = deque(maxlen=1000)
+ self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
+ self._chat_type_cache: dict[str, str] = {}
async def start(self) -> None:
"""Start the QQ bot."""
@@ -69,8 +94,7 @@ class QQChannel(BaseChannel):
self._running = True
BotClass = _make_bot_class(self)
self._client = BotClass()
-
- logger.info("QQ bot started (C2C private message)")
+ logger.info("QQ bot started (C2C & Group supported)")
await self._run_bot()
async def _run_bot(self) -> None:
@@ -99,16 +123,36 @@ class QQChannel(BaseChannel):
if not self._client:
logger.warning("QQ client not initialized")
return
+
try:
- await self._client.api.post_c2c_message(
- openid=msg.chat_id,
- msg_type=0,
- content=msg.content,
- )
+ msg_id = msg.metadata.get("message_id")
+ self._msg_seq += 1
+ use_markdown = self.config.msg_format == "markdown"
+ payload: dict[str, Any] = {
+ "msg_type": 2 if use_markdown else 0,
+ "msg_id": msg_id,
+ "msg_seq": self._msg_seq,
+ }
+ if use_markdown:
+ payload["markdown"] = {"content": msg.content}
+ else:
+ payload["content"] = msg.content
+
+ chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
+ if chat_type == "group":
+ await self._client.api.post_group_message(
+ group_openid=msg.chat_id,
+ **payload,
+ )
+ else:
+ await self._client.api.post_c2c_message(
+ openid=msg.chat_id,
+ **payload,
+ )
except Exception as e:
logger.error("Error sending QQ message: {}", e)
- async def _on_message(self, data: "C2CMessage") -> None:
+ async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
"""Handle incoming message from QQ."""
try:
# Dedup by message ID
@@ -116,15 +160,22 @@ class QQChannel(BaseChannel):
return
self._processed_ids.append(data.id)
- author = data.author
- user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
content = (data.content or "").strip()
if not content:
return
+ if is_group:
+ chat_id = data.group_openid
+ user_id = data.author.member_openid
+ self._chat_type_cache[chat_id] = "group"
+ else:
+ chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
+ user_id = chat_id
+ self._chat_type_cache[chat_id] = "c2c"
+
await self._handle_message(
sender_id=user_id,
- chat_id=user_id,
+ chat_id=chat_id,
content=content,
metadata={"message_id": data.id},
)
diff --git a/nanobot/channels/registry.py b/nanobot/channels/registry.py
new file mode 100644
index 0000000..04effc7
--- /dev/null
+++ b/nanobot/channels/registry.py
@@ -0,0 +1,71 @@
+"""Auto-discovery for built-in channel modules and external plugins."""
+
+from __future__ import annotations
+
+import importlib
+import pkgutil
+from typing import TYPE_CHECKING
+
+from loguru import logger
+
+if TYPE_CHECKING:
+ from nanobot.channels.base import BaseChannel
+
+_INTERNAL = frozenset({"base", "manager", "registry"})
+
+
+def discover_channel_names() -> list[str]:
+ """Return all built-in channel module names by scanning the package (zero imports)."""
+ import nanobot.channels as pkg
+
+ return [
+ name
+ for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
+ if name not in _INTERNAL and not ispkg
+ ]
+
+
+def load_channel_class(module_name: str) -> type[BaseChannel]:
+ """Import *module_name* and return the first BaseChannel subclass found."""
+ from nanobot.channels.base import BaseChannel as _Base
+
+ mod = importlib.import_module(f"nanobot.channels.{module_name}")
+ for attr in dir(mod):
+ obj = getattr(mod, attr)
+ if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
+ return obj
+ raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
+
+
+def discover_plugins() -> dict[str, type[BaseChannel]]:
+ """Discover external channel plugins registered via entry_points."""
+ from importlib.metadata import entry_points
+
+ plugins: dict[str, type[BaseChannel]] = {}
+ for ep in entry_points(group="nanobot.channels"):
+ try:
+ cls = ep.load()
+ plugins[ep.name] = cls
+ except Exception as e:
+ logger.warning("Failed to load channel plugin '{}': {}", ep.name, e)
+ return plugins
+
+
+def discover_all() -> dict[str, type[BaseChannel]]:
+ """Return all channels: built-in (pkgutil) merged with external (entry_points).
+
+ Built-in channels take priority — an external plugin cannot shadow a built-in name.
+ """
+ builtin: dict[str, type[BaseChannel]] = {}
+ for modname in discover_channel_names():
+ try:
+ builtin[modname] = load_channel_class(modname)
+ except ImportError as e:
+ logger.debug("Skipping built-in channel '{}': {}", modname, e)
+
+ external = discover_plugins()
+ shadowed = set(external) & set(builtin)
+ if shadowed:
+ logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
+
+ return {**external, **builtin}
diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py
index 57bfbcb..c9f353d 100644
--- a/nanobot/channels/slack.py
+++ b/nanobot/channels/slack.py
@@ -5,25 +5,58 @@ import re
from typing import Any
from loguru import logger
-from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
+from slack_sdk.socket_mode.websockets import SocketModeClient
from slack_sdk.web.async_client import AsyncWebClient
-
from slackify_markdown import slackify_markdown
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
+from pydantic import Field
+
from nanobot.channels.base import BaseChannel
-from nanobot.config.schema import SlackConfig
+from nanobot.config.schema import Base
+
+
+class SlackDMConfig(Base):
+ """Slack DM policy configuration."""
+
+ enabled: bool = True
+ policy: str = "open"
+ allow_from: list[str] = Field(default_factory=list)
+
+
+class SlackConfig(Base):
+ """Slack channel configuration."""
+
+ enabled: bool = False
+ mode: str = "socket"
+ webhook_path: str = "/slack/events"
+ bot_token: str = ""
+ app_token: str = ""
+ user_token_read_only: bool = True
+ reply_in_thread: bool = True
+ react_emoji: str = "eyes"
+ allow_from: list[str] = Field(default_factory=list)
+ group_policy: str = "mention"
+ group_allow_from: list[str] = Field(default_factory=list)
+ dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
class SlackChannel(BaseChannel):
"""Slack channel using Socket Mode."""
name = "slack"
+ display_name = "Slack"
- def __init__(self, config: SlackConfig, bus: MessageBus):
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return SlackConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = SlackConfig.model_validate(config)
super().__init__(config, bus)
self.config: SlackConfig = config
self._web_client: AsyncWebClient | None = None
@@ -82,14 +115,15 @@ class SlackChannel(BaseChannel):
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
thread_ts = slack_meta.get("thread_ts")
channel_type = slack_meta.get("channel_type")
- # Only reply in thread for channel/group messages; DMs don't use threads
- use_thread = thread_ts and channel_type != "im"
- thread_ts_param = thread_ts if use_thread else None
+ # Slack DMs don't use threads; channel/group replies may keep thread_ts.
+ thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
- if msg.content:
+ # Slack rejects empty text payloads. Keep media-only messages media-only,
+ # but send a single blank message when the bot has no text or files to send.
+ if msg.content or not (msg.media or []):
await self._web_client.chat_postMessage(
channel=msg.chat_id,
- text=self._to_mrkdwn(msg.content),
+ text=self._to_mrkdwn(msg.content) if msg.content else " ",
thread_ts=thread_ts_param,
)
@@ -278,4 +312,3 @@ class SlackChannel(BaseChannel):
if parts:
rows.append(" · ".join(parts))
return "\n".join(rows)
-
diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py
index 6cd98e7..34c4a3b 100644
--- a/nanobot/channels/telegram.py
+++ b/nanobot/channels/telegram.py
@@ -4,15 +4,66 @@ from __future__ import annotations
import asyncio
import re
+import time
+import unicodedata
+from typing import Any, Literal
+
from loguru import logger
-from telegram import BotCommand, Update, ReplyParameters
-from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
+from pydantic import Field
+from telegram import BotCommand, ReplyParameters, Update
+from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
-from nanobot.config.schema import TelegramConfig
+from nanobot.config.paths import get_media_dir
+from nanobot.config.schema import Base
+from nanobot.utils.helpers import split_message
+
+TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
+TELEGRAM_REPLY_CONTEXT_MAX_LEN = TELEGRAM_MAX_MESSAGE_LEN # Max length for reply context in user message
+
+
+def _strip_md(s: str) -> str:
+ """Strip markdown inline formatting from text."""
+ s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
+ s = re.sub(r'__(.+?)__', r'\1', s)
+ s = re.sub(r'~~(.+?)~~', r'\1', s)
+ s = re.sub(r'`([^`]+)`', r'\1', s)
+ return s.strip()
+
+
+def _render_table_box(table_lines: list[str]) -> str:
+ """Convert markdown pipe-table to compact aligned text for display."""
+
+ def dw(s: str) -> int:
+ return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
+
+ rows: list[list[str]] = []
+ has_sep = False
+ for line in table_lines:
+ cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
+ if all(re.match(r'^:?-+:?$', c) for c in cells if c):
+ has_sep = True
+ continue
+ rows.append(cells)
+ if not rows or not has_sep:
+ return '\n'.join(table_lines)
+
+ ncols = max(len(r) for r in rows)
+ for r in rows:
+ r.extend([''] * (ncols - len(r)))
+ widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
+
+ def dr(cells: list[str]) -> str:
+ return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
+
+ out = [dr(rows[0])]
+ out.append(' '.join('─' * w for w in widths))
+ for row in rows[1:]:
+ out.append(dr(row))
+ return '\n'.join(out)
def _markdown_to_telegram_html(text: str) -> str:
@@ -21,183 +72,235 @@ def _markdown_to_telegram_html(text: str) -> str:
"""
if not text:
return ""
-
+
# 1. Extract and protect code blocks (preserve content from other processing)
code_blocks: list[str] = []
def save_code_block(m: re.Match) -> str:
code_blocks.append(m.group(1))
return f"\x00CB{len(code_blocks) - 1}\x00"
-
+
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
-
+
+ # 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
+ lines = text.split('\n')
+ rebuilt: list[str] = []
+ li = 0
+ while li < len(lines):
+ if re.match(r'^\s*\|.+\|', lines[li]):
+ tbl: list[str] = []
+ while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
+ tbl.append(lines[li])
+ li += 1
+ box = _render_table_box(tbl)
+ if box != '\n'.join(tbl):
+ code_blocks.append(box)
+ rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
+ else:
+ rebuilt.extend(tbl)
+ else:
+ rebuilt.append(lines[li])
+ li += 1
+ text = '\n'.join(rebuilt)
+
# 2. Extract and protect inline code
inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str:
inline_codes.append(m.group(1))
return f"\x00IC{len(inline_codes) - 1}\x00"
-
+
text = re.sub(r'`([^`]+)`', save_inline_code, text)
-
+
# 3. Headers # Title -> just the title text
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
-
+
# 4. Blockquotes > text -> just the text (before HTML escaping)
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
-
+
# 5. Escape HTML special characters
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
-
+
# 6. Links [text](url) - must be before bold/italic to handle nested cases
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1', text)
-
+
# 7. Bold **text** or __text__
text = re.sub(r'\*\*(.+?)\*\*', r'\1', text)
text = re.sub(r'__(.+?)__', r'\1', text)
-
+
# 8. Italic _text_ (avoid matching inside words like some_var_name)
text = re.sub(r'(?\1', text)
-
+
# 9. Strikethrough ~~text~~
text = re.sub(r'~~(.+?)~~', r'\1', text)
-
+
# 10. Bullet lists - item -> • item
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
-
+
# 11. Restore inline code with HTML tags
for i, code in enumerate(inline_codes):
# Escape HTML in code content
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
text = text.replace(f"\x00IC{i}\x00", f"{escaped}")
-
+
# 12. Restore code blocks with HTML tags
for i, code in enumerate(code_blocks):
# Escape HTML in code content
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
text = text.replace(f"\x00CB{i}\x00", f"{escaped}
")
-
+
return text
-def _split_message(content: str, max_len: int = 4000) -> list[str]:
- """Split content into chunks within max_len, preferring line breaks."""
- if len(content) <= max_len:
- return [content]
- chunks: list[str] = []
- while content:
- if len(content) <= max_len:
- chunks.append(content)
- break
- cut = content[:max_len]
- pos = cut.rfind('\n')
- if pos == -1:
- pos = cut.rfind(' ')
- if pos == -1:
- pos = max_len
- chunks.append(content[:pos])
- content = content[pos:].lstrip()
- return chunks
+class TelegramConfig(Base):
+ """Telegram channel configuration."""
+
+ enabled: bool = False
+ token: str = ""
+ allow_from: list[str] = Field(default_factory=list)
+ proxy: str | None = None
+ reply_to_message: bool = False
+ group_policy: Literal["open", "mention"] = "mention"
class TelegramChannel(BaseChannel):
"""
Telegram channel using long polling.
-
+
Simple and reliable - no webhook/public IP needed.
"""
-
+
name = "telegram"
-
+ display_name = "Telegram"
+
# Commands registered with Telegram's command menu
BOT_COMMANDS = [
BotCommand("start", "Start the bot"),
BotCommand("new", "Start a new conversation"),
+ BotCommand("stop", "Stop the current task"),
BotCommand("help", "Show available commands"),
+ BotCommand("restart", "Restart the bot"),
]
-
- def __init__(
- self,
- config: TelegramConfig,
- bus: MessageBus,
- groq_api_key: str = "",
- ):
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return TelegramConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = TelegramConfig.model_validate(config)
super().__init__(config, bus)
self.config: TelegramConfig = config
- self.groq_api_key = groq_api_key
self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
-
+ self._media_group_buffers: dict[str, dict] = {}
+ self._media_group_tasks: dict[str, asyncio.Task] = {}
+ self._message_threads: dict[tuple[str, int], int] = {}
+ self._bot_user_id: int | None = None
+ self._bot_username: str | None = None
+
+ def is_allowed(self, sender_id: str) -> bool:
+ """Preserve Telegram's legacy id|username allowlist matching."""
+ if super().is_allowed(sender_id):
+ return True
+
+ allow_list = getattr(self.config, "allow_from", [])
+ if not allow_list or "*" in allow_list:
+ return False
+
+ sender_str = str(sender_id)
+ if sender_str.count("|") != 1:
+ return False
+
+ sid, username = sender_str.split("|", 1)
+ if not sid.isdigit() or not username:
+ return False
+
+ return sid in allow_list or username in allow_list
+
async def start(self) -> None:
"""Start the Telegram bot with long polling."""
if not self.config.token:
logger.error("Telegram bot token not configured")
return
-
+
self._running = True
-
+
# Build the application with larger connection pool to avoid pool-timeout on long runs
- req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
+ req = HTTPXRequest(
+ connection_pool_size=16,
+ pool_timeout=5.0,
+ connect_timeout=30.0,
+ read_timeout=30.0,
+ proxy=self.config.proxy if self.config.proxy else None,
+ )
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
- if self.config.proxy:
- builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
self._app = builder.build()
self._app.add_error_handler(self._on_error)
-
+
# Add command handlers
self._app.add_handler(CommandHandler("start", self._on_start))
self._app.add_handler(CommandHandler("new", self._forward_command))
+ self._app.add_handler(CommandHandler("stop", self._forward_command))
+ self._app.add_handler(CommandHandler("restart", self._forward_command))
self._app.add_handler(CommandHandler("help", self._on_help))
-
+
# Add message handler for text, photos, voice, documents
self._app.add_handler(
MessageHandler(
- (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
- & ~filters.COMMAND,
+ (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
+ & ~filters.COMMAND,
self._on_message
)
)
-
+
logger.info("Starting Telegram bot (polling mode)...")
-
+
# Initialize and start polling
await self._app.initialize()
await self._app.start()
-
+
# Get bot info and register command menu
bot_info = await self._app.bot.get_me()
+ self._bot_user_id = getattr(bot_info, "id", None)
+ self._bot_username = getattr(bot_info, "username", None)
logger.info("Telegram bot @{} connected", bot_info.username)
-
+
try:
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
logger.debug("Telegram bot commands registered")
except Exception as e:
logger.warning("Failed to register bot commands: {}", e)
-
+
# Start polling (this runs until stopped)
await self._app.updater.start_polling(
allowed_updates=["message"],
drop_pending_updates=True # Ignore old messages on startup
)
-
+
# Keep running until stopped
while self._running:
await asyncio.sleep(1)
-
+
async def stop(self) -> None:
"""Stop the Telegram bot."""
self._running = False
-
+
# Cancel all typing indicators
for chat_id in list(self._typing_tasks):
self._stop_typing(chat_id)
-
+
+ for task in self._media_group_tasks.values():
+ task.cancel()
+ self._media_group_tasks.clear()
+ self._media_group_buffers.clear()
+
if self._app:
logger.info("Stopping Telegram bot...")
await self._app.updater.stop()
await self._app.stop()
await self._app.shutdown()
self._app = None
-
+
@staticmethod
def _get_media_type(path: str) -> str:
"""Guess media type from file extension."""
@@ -216,17 +319,25 @@ class TelegramChannel(BaseChannel):
logger.warning("Telegram bot not running")
return
- self._stop_typing(msg.chat_id)
+ # Only stop typing indicator for final responses
+ if not msg.metadata.get("_progress", False):
+ self._stop_typing(msg.chat_id)
try:
chat_id = int(msg.chat_id)
except ValueError:
logger.error("Invalid chat_id: {}", msg.chat_id)
return
+ reply_to_message_id = msg.metadata.get("message_id")
+ message_thread_id = msg.metadata.get("message_thread_id")
+ if message_thread_id is None and reply_to_message_id is not None:
+ message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
+ thread_kwargs = {}
+ if message_thread_id is not None:
+ thread_kwargs["message_thread_id"] = message_thread_id
reply_params = None
if self.config.reply_to_message:
- reply_to_message_id = msg.metadata.get("message_id")
if reply_to_message_id:
reply_params = ReplyParameters(
message_id=reply_to_message_id,
@@ -245,9 +356,10 @@ class TelegramChannel(BaseChannel):
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
with open(media_path, 'rb') as f:
await sender(
- chat_id=chat_id,
+ chat_id=chat_id,
**{param: f},
- reply_parameters=reply_params
+ reply_parameters=reply_params,
+ **thread_kwargs,
)
except Exception as e:
filename = media_path.rsplit("/", 1)[-1]
@@ -255,31 +367,72 @@ class TelegramChannel(BaseChannel):
await self._app.bot.send_message(
chat_id=chat_id,
text=f"[Failed to send: {filename}]",
- reply_parameters=reply_params
+ reply_parameters=reply_params,
+ **thread_kwargs,
)
# Send text content
if msg.content and msg.content != "[empty message]":
- for chunk in _split_message(msg.content):
- try:
- html = _markdown_to_telegram_html(chunk)
- await self._app.bot.send_message(
- chat_id=chat_id,
- text=html,
- parse_mode="HTML",
- reply_parameters=reply_params
- )
- except Exception as e:
- logger.warning("HTML parse failed, falling back to plain text: {}", e)
- try:
- await self._app.bot.send_message(
- chat_id=chat_id,
- text=chunk,
- reply_parameters=reply_params
- )
- except Exception as e2:
- logger.error("Error sending Telegram message: {}", e2)
-
+ is_progress = msg.metadata.get("_progress", False)
+
+ for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
+ # Final response: simulate streaming via draft, then persist
+ if not is_progress:
+ await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
+ else:
+ await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
+
+ async def _send_text(
+ self,
+ chat_id: int,
+ text: str,
+ reply_params=None,
+ thread_kwargs: dict | None = None,
+ ) -> None:
+ """Send a plain text message with HTML fallback."""
+ try:
+ html = _markdown_to_telegram_html(text)
+ await self._app.bot.send_message(
+ chat_id=chat_id, text=html, parse_mode="HTML",
+ reply_parameters=reply_params,
+ **(thread_kwargs or {}),
+ )
+ except Exception as e:
+ logger.warning("HTML parse failed, falling back to plain text: {}", e)
+ try:
+ await self._app.bot.send_message(
+ chat_id=chat_id,
+ text=text,
+ reply_parameters=reply_params,
+ **(thread_kwargs or {}),
+ )
+ except Exception as e2:
+ logger.error("Error sending Telegram message: {}", e2)
+
+ async def _send_with_streaming(
+ self,
+ chat_id: int,
+ text: str,
+ reply_params=None,
+ thread_kwargs: dict | None = None,
+ ) -> None:
+ """Simulate streaming via send_message_draft, then persist with send_message."""
+ draft_id = int(time.time() * 1000) % (2**31)
+ try:
+ step = max(len(text) // 8, 40)
+ for i in range(step, len(text), step):
+ await self._app.bot.send_message_draft(
+ chat_id=chat_id, draft_id=draft_id, text=text[:i],
+ )
+ await asyncio.sleep(0.04)
+ await self._app.bot.send_message_draft(
+ chat_id=chat_id, draft_id=draft_id, text=text,
+ )
+ await asyncio.sleep(0.15)
+ except Exception:
+ pass
+ await self._send_text(chat_id, text, reply_params, thread_kwargs)
+
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command."""
if not update.message or not update.effective_user:
@@ -299,6 +452,8 @@ class TelegramChannel(BaseChannel):
await update.message.reply_text(
"🐈 nanobot commands:\n"
"/new — Start a new conversation\n"
+ "/stop — Stop the current task\n"
+ "/restart — Restart the bot\n"
"/help — Show available commands"
)
@@ -308,126 +463,298 @@ class TelegramChannel(BaseChannel):
sid = str(user.id)
return f"{sid}|{user.username}" if user.username else sid
+ @staticmethod
+ def _derive_topic_session_key(message) -> str | None:
+ """Derive topic-scoped session key for non-private Telegram chats."""
+ message_thread_id = getattr(message, "message_thread_id", None)
+ if message.chat.type == "private" or message_thread_id is None:
+ return None
+ return f"telegram:{message.chat_id}:topic:{message_thread_id}"
+
+ @staticmethod
+ def _build_message_metadata(message, user) -> dict:
+ """Build common Telegram inbound metadata payload."""
+ reply_to = getattr(message, "reply_to_message", None)
+ return {
+ "message_id": message.message_id,
+ "user_id": user.id,
+ "username": user.username,
+ "first_name": user.first_name,
+ "is_group": message.chat.type != "private",
+ "message_thread_id": getattr(message, "message_thread_id", None),
+ "is_forum": bool(getattr(message.chat, "is_forum", False)),
+ "reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
+ }
+
+ @staticmethod
+ def _extract_reply_context(message) -> str | None:
+ """Extract text from the message being replied to, if any."""
+ reply = getattr(message, "reply_to_message", None)
+ if not reply:
+ return None
+ text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
+ if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
+ text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
+ return f"[Reply to: {text}]" if text else None
+
+ async def _download_message_media(
+ self, msg, *, add_failure_content: bool = False
+ ) -> tuple[list[str], list[str]]:
+ """Download media from a message (current or reply). Returns (media_paths, content_parts)."""
+ media_file = None
+ media_type = None
+ if getattr(msg, "photo", None):
+ media_file = msg.photo[-1]
+ media_type = "image"
+ elif getattr(msg, "voice", None):
+ media_file = msg.voice
+ media_type = "voice"
+ elif getattr(msg, "audio", None):
+ media_file = msg.audio
+ media_type = "audio"
+ elif getattr(msg, "document", None):
+ media_file = msg.document
+ media_type = "file"
+ elif getattr(msg, "video", None):
+ media_file = msg.video
+ media_type = "video"
+ elif getattr(msg, "video_note", None):
+ media_file = msg.video_note
+ media_type = "video"
+ elif getattr(msg, "animation", None):
+ media_file = msg.animation
+ media_type = "animation"
+ if not media_file or not self._app:
+ return [], []
+ try:
+ file = await self._app.bot.get_file(media_file.file_id)
+ ext = self._get_extension(
+ media_type,
+ getattr(media_file, "mime_type", None),
+ getattr(media_file, "file_name", None),
+ )
+ media_dir = get_media_dir("telegram")
+ unique_id = getattr(media_file, "file_unique_id", media_file.file_id)
+ file_path = media_dir / f"{unique_id}{ext}"
+ await file.download_to_drive(str(file_path))
+ path_str = str(file_path)
+ if media_type in ("voice", "audio"):
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ logger.info("Transcribed {}: {}...", media_type, transcription[:50])
+ return [path_str], [f"[transcription: {transcription}]"]
+ return [path_str], [f"[{media_type}: {path_str}]"]
+ return [path_str], [f"[{media_type}: {path_str}]"]
+ except Exception as e:
+ logger.warning("Failed to download message media: {}", e)
+ if add_failure_content:
+ return [], [f"[{media_type}: download failed]"]
+ return [], []
+
+ async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
+ """Load bot identity once and reuse it for mention/reply checks."""
+ if self._bot_user_id is not None or self._bot_username is not None:
+ return self._bot_user_id, self._bot_username
+ if not self._app:
+ return None, None
+ bot_info = await self._app.bot.get_me()
+ self._bot_user_id = getattr(bot_info, "id", None)
+ self._bot_username = getattr(bot_info, "username", None)
+ return self._bot_user_id, self._bot_username
+
+ @staticmethod
+ def _has_mention_entity(
+ text: str,
+ entities,
+ bot_username: str,
+ bot_id: int | None,
+ ) -> bool:
+ """Check Telegram mention entities against the bot username."""
+ handle = f"@{bot_username}".lower()
+ for entity in entities or []:
+ entity_type = getattr(entity, "type", None)
+ if entity_type == "text_mention":
+ user = getattr(entity, "user", None)
+ if user is not None and bot_id is not None and getattr(user, "id", None) == bot_id:
+ return True
+ continue
+ if entity_type != "mention":
+ continue
+ offset = getattr(entity, "offset", None)
+ length = getattr(entity, "length", None)
+ if offset is None or length is None:
+ continue
+ if text[offset : offset + length].lower() == handle:
+ return True
+ return handle in text.lower()
+
+ async def _is_group_message_for_bot(self, message) -> bool:
+ """Allow group messages when policy is open, @mentioned, or replying to the bot."""
+ if message.chat.type == "private" or self.config.group_policy == "open":
+ return True
+
+ bot_id, bot_username = await self._ensure_bot_identity()
+ if bot_username:
+ text = message.text or ""
+ caption = message.caption or ""
+ if self._has_mention_entity(
+ text,
+ getattr(message, "entities", None),
+ bot_username,
+ bot_id,
+ ):
+ return True
+ if self._has_mention_entity(
+ caption,
+ getattr(message, "caption_entities", None),
+ bot_username,
+ bot_id,
+ ):
+ return True
+
+ reply_user = getattr(getattr(message, "reply_to_message", None), "from_user", None)
+ return bool(bot_id and reply_user and reply_user.id == bot_id)
+
+ def _remember_thread_context(self, message) -> None:
+ """Cache topic thread id by chat/message id for follow-up replies."""
+ message_thread_id = getattr(message, "message_thread_id", None)
+ if message_thread_id is None:
+ return
+ key = (str(message.chat_id), message.message_id)
+ self._message_threads[key] = message_thread_id
+ if len(self._message_threads) > 1000:
+ self._message_threads.pop(next(iter(self._message_threads)))
+
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Forward slash commands to the bus for unified handling in AgentLoop."""
if not update.message or not update.effective_user:
return
+ message = update.message
+ user = update.effective_user
+ self._remember_thread_context(message)
await self._handle_message(
- sender_id=self._sender_id(update.effective_user),
- chat_id=str(update.message.chat_id),
- content=update.message.text,
+ sender_id=self._sender_id(user),
+ chat_id=str(message.chat_id),
+ content=message.text or "",
+ metadata=self._build_message_metadata(message, user),
+ session_key=self._derive_topic_session_key(message),
)
-
+
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming messages (text, photos, voice, documents)."""
if not update.message or not update.effective_user:
return
-
+
message = update.message
user = update.effective_user
chat_id = message.chat_id
sender_id = self._sender_id(user)
-
+ self._remember_thread_context(message)
+
# Store chat_id for replies
self._chat_ids[sender_id] = chat_id
-
+
+ if not await self._is_group_message_for_bot(message):
+ return
+
# Build content from text and/or media
content_parts = []
media_paths = []
-
+
# Text content
if message.text:
content_parts.append(message.text)
if message.caption:
content_parts.append(message.caption)
-
- # Handle media files
- media_file = None
- media_type = None
-
- if message.photo:
- media_file = message.photo[-1] # Largest photo
- media_type = "image"
- elif message.voice:
- media_file = message.voice
- media_type = "voice"
- elif message.audio:
- media_file = message.audio
- media_type = "audio"
- elif message.document:
- media_file = message.document
- media_type = "file"
-
- # Download media if present
- if media_file and self._app:
- try:
- file = await self._app.bot.get_file(media_file.file_id)
- ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
-
- # Save to workspace/media/
- from pathlib import Path
- media_dir = Path.home() / ".nanobot" / "media"
- media_dir.mkdir(parents=True, exist_ok=True)
-
- file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
- await file.download_to_drive(str(file_path))
-
- media_paths.append(str(file_path))
-
- # Handle voice transcription
- if media_type == "voice" or media_type == "audio":
- from nanobot.providers.transcription import GroqTranscriptionProvider
- transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
- transcription = await transcriber.transcribe(file_path)
- if transcription:
- logger.info("Transcribed {}: {}...", media_type, transcription[:50])
- content_parts.append(f"[transcription: {transcription}]")
- else:
- content_parts.append(f"[{media_type}: {file_path}]")
- else:
- content_parts.append(f"[{media_type}: {file_path}]")
-
- logger.debug("Downloaded {} to {}", media_type, file_path)
- except Exception as e:
- logger.error("Failed to download media: {}", e)
- content_parts.append(f"[{media_type}: download failed]")
-
+
+ # Download current message media
+ current_media_paths, current_media_parts = await self._download_message_media(
+ message, add_failure_content=True
+ )
+ media_paths.extend(current_media_paths)
+ content_parts.extend(current_media_parts)
+ if current_media_paths:
+ logger.debug("Downloaded message media to {}", current_media_paths[0])
+
+ # Reply context: text and/or media from the replied-to message
+ reply = getattr(message, "reply_to_message", None)
+ if reply is not None:
+ reply_ctx = self._extract_reply_context(message)
+ reply_media, reply_media_parts = await self._download_message_media(reply)
+ if reply_media:
+ media_paths = reply_media + media_paths
+ logger.debug("Attached replied-to media: {}", reply_media[0])
+ tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
+ if tag:
+ content_parts.insert(0, tag)
content = "\n".join(content_parts) if content_parts else "[empty message]"
-
+
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
-
+
str_chat_id = str(chat_id)
-
+ metadata = self._build_message_metadata(message, user)
+ session_key = self._derive_topic_session_key(message)
+
+ # Telegram media groups: buffer briefly, forward as one aggregated turn.
+ if media_group_id := getattr(message, "media_group_id", None):
+ key = f"{str_chat_id}:{media_group_id}"
+ if key not in self._media_group_buffers:
+ self._media_group_buffers[key] = {
+ "sender_id": sender_id, "chat_id": str_chat_id,
+ "contents": [], "media": [],
+ "metadata": metadata,
+ "session_key": session_key,
+ }
+ self._start_typing(str_chat_id)
+ buf = self._media_group_buffers[key]
+ if content and content != "[empty message]":
+ buf["contents"].append(content)
+ buf["media"].extend(media_paths)
+ if key not in self._media_group_tasks:
+ self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
+ return
+
# Start typing indicator before processing
self._start_typing(str_chat_id)
-
+
# Forward to the message bus
await self._handle_message(
sender_id=sender_id,
chat_id=str_chat_id,
content=content,
media=media_paths,
- metadata={
- "message_id": message.message_id,
- "user_id": user.id,
- "username": user.username,
- "first_name": user.first_name,
- "is_group": message.chat.type != "private"
- }
+ metadata=metadata,
+ session_key=session_key,
)
-
+
+ async def _flush_media_group(self, key: str) -> None:
+ """Wait briefly, then forward buffered media-group as one turn."""
+ try:
+ await asyncio.sleep(0.6)
+ if not (buf := self._media_group_buffers.pop(key, None)):
+ return
+ content = "\n".join(buf["contents"]) or "[empty message]"
+ await self._handle_message(
+ sender_id=buf["sender_id"], chat_id=buf["chat_id"],
+ content=content, media=list(dict.fromkeys(buf["media"])),
+ metadata=buf["metadata"],
+ session_key=buf.get("session_key"),
+ )
+ finally:
+ self._media_group_tasks.pop(key, None)
+
def _start_typing(self, chat_id: str) -> None:
"""Start sending 'typing...' indicator for a chat."""
# Cancel any existing typing task for this chat
self._stop_typing(chat_id)
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
-
+
def _stop_typing(self, chat_id: str) -> None:
"""Stop the typing indicator for a chat."""
task = self._typing_tasks.pop(chat_id, None)
if task and not task.done():
task.cancel()
-
+
async def _typing_loop(self, chat_id: str) -> None:
"""Repeatedly send 'typing' action until cancelled."""
try:
@@ -438,13 +765,18 @@ class TelegramChannel(BaseChannel):
pass
except Exception as e:
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
-
+
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Log polling / handler errors instead of silently swallowing them."""
logger.error("Telegram error: {}", context.error)
- def _get_extension(self, media_type: str, mime_type: str | None) -> str:
- """Get file extension based on media type."""
+ def _get_extension(
+ self,
+ media_type: str,
+ mime_type: str | None,
+ filename: str | None = None,
+ ) -> str:
+ """Get file extension based on media type or original filename."""
if mime_type:
ext_map = {
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
@@ -452,6 +784,14 @@ class TelegramChannel(BaseChannel):
}
if mime_type in ext_map:
return ext_map[mime_type]
-
+
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
- return type_map.get(media_type, "")
+ if ext := type_map.get(media_type, ""):
+ return ext
+
+ if filename:
+ from pathlib import Path
+
+ return "".join(Path(filename).suffixes)
+
+ return ""
diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py
new file mode 100644
index 0000000..2f24855
--- /dev/null
+++ b/nanobot/channels/wecom.py
@@ -0,0 +1,370 @@
+"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
+
+import asyncio
+import importlib.util
+import os
+from collections import OrderedDict
+from typing import Any
+
+from loguru import logger
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_media_dir
+from nanobot.config.schema import Base
+from pydantic import Field
+
+WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
+
+class WecomConfig(Base):
+ """WeCom (Enterprise WeChat) AI Bot channel configuration."""
+
+ enabled: bool = False
+ bot_id: str = ""
+ secret: str = ""
+ allow_from: list[str] = Field(default_factory=list)
+ welcome_message: str = ""
+
+
+# Message type display mapping
+MSG_TYPE_MAP = {
+ "image": "[image]",
+ "voice": "[voice]",
+ "file": "[file]",
+ "mixed": "[mixed content]",
+}
+
+
+class WecomChannel(BaseChannel):
+ """
+ WeCom (Enterprise WeChat) channel using WebSocket long connection.
+
+ Uses WebSocket to receive events - no public IP or webhook required.
+
+ Requires:
+ - Bot ID and Secret from WeCom AI Bot platform
+ """
+
+ name = "wecom"
+ display_name = "WeCom"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return WecomConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = WecomConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.config: WecomConfig = config
+ self._client: Any = None
+ self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
+ self._loop: asyncio.AbstractEventLoop | None = None
+ self._generate_req_id = None
+ # Store frame headers for each chat to enable replies
+ self._chat_frames: dict[str, Any] = {}
+
+ async def start(self) -> None:
+ """Start the WeCom bot with WebSocket long connection."""
+ if not WECOM_AVAILABLE:
+ logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
+ return
+
+ if not self.config.bot_id or not self.config.secret:
+ logger.error("WeCom bot_id and secret not configured")
+ return
+
+ from wecom_aibot_sdk import WSClient, generate_req_id
+
+ self._running = True
+ self._loop = asyncio.get_running_loop()
+ self._generate_req_id = generate_req_id
+
+ # Create WebSocket client
+ self._client = WSClient({
+ "bot_id": self.config.bot_id,
+ "secret": self.config.secret,
+ "reconnect_interval": 1000,
+ "max_reconnect_attempts": -1, # Infinite reconnect
+ "heartbeat_interval": 30000,
+ })
+
+ # Register event handlers
+ self._client.on("connected", self._on_connected)
+ self._client.on("authenticated", self._on_authenticated)
+ self._client.on("disconnected", self._on_disconnected)
+ self._client.on("error", self._on_error)
+ self._client.on("message.text", self._on_text_message)
+ self._client.on("message.image", self._on_image_message)
+ self._client.on("message.voice", self._on_voice_message)
+ self._client.on("message.file", self._on_file_message)
+ self._client.on("message.mixed", self._on_mixed_message)
+ self._client.on("event.enter_chat", self._on_enter_chat)
+
+ logger.info("WeCom bot starting with WebSocket long connection")
+ logger.info("No public IP required - using WebSocket to receive events")
+
+ # Connect
+ await self._client.connect_async()
+
+ # Keep running until stopped
+ while self._running:
+ await asyncio.sleep(1)
+
+ async def stop(self) -> None:
+ """Stop the WeCom bot."""
+ self._running = False
+ if self._client:
+ await self._client.disconnect()
+ logger.info("WeCom bot stopped")
+
+ async def _on_connected(self, frame: Any) -> None:
+ """Handle WebSocket connected event."""
+ logger.info("WeCom WebSocket connected")
+
+ async def _on_authenticated(self, frame: Any) -> None:
+ """Handle authentication success event."""
+ logger.info("WeCom authenticated successfully")
+
+ async def _on_disconnected(self, frame: Any) -> None:
+ """Handle WebSocket disconnected event."""
+ reason = frame.body if hasattr(frame, 'body') else str(frame)
+ logger.warning("WeCom WebSocket disconnected: {}", reason)
+
+ async def _on_error(self, frame: Any) -> None:
+ """Handle error event."""
+ logger.error("WeCom error: {}", frame)
+
+ async def _on_text_message(self, frame: Any) -> None:
+ """Handle text message."""
+ await self._process_message(frame, "text")
+
+ async def _on_image_message(self, frame: Any) -> None:
+ """Handle image message."""
+ await self._process_message(frame, "image")
+
+ async def _on_voice_message(self, frame: Any) -> None:
+ """Handle voice message."""
+ await self._process_message(frame, "voice")
+
+ async def _on_file_message(self, frame: Any) -> None:
+ """Handle file message."""
+ await self._process_message(frame, "file")
+
+ async def _on_mixed_message(self, frame: Any) -> None:
+ """Handle mixed content message."""
+ await self._process_message(frame, "mixed")
+
+ async def _on_enter_chat(self, frame: Any) -> None:
+ """Handle enter_chat event (user opens chat with bot)."""
+ try:
+ # Extract body from WsFrame dataclass or dict
+ if hasattr(frame, 'body'):
+ body = frame.body or {}
+ elif isinstance(frame, dict):
+ body = frame.get("body", frame)
+ else:
+ body = {}
+
+ chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
+
+ if chat_id and self.config.welcome_message:
+ await self._client.reply_welcome(frame, {
+ "msgtype": "text",
+ "text": {"content": self.config.welcome_message},
+ })
+ except Exception as e:
+ logger.error("Error handling enter_chat: {}", e)
+
+ async def _process_message(self, frame: Any, msg_type: str) -> None:
+ """Process incoming message and forward to bus."""
+ try:
+ # Extract body from WsFrame dataclass or dict
+ if hasattr(frame, 'body'):
+ body = frame.body or {}
+ elif isinstance(frame, dict):
+ body = frame.get("body", frame)
+ else:
+ body = {}
+
+ # Ensure body is a dict
+ if not isinstance(body, dict):
+ logger.warning("Invalid body type: {}", type(body))
+ return
+
+ # Extract message info
+ msg_id = body.get("msgid", "")
+ if not msg_id:
+ msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
+
+ # Deduplication check
+ if msg_id in self._processed_message_ids:
+ return
+ self._processed_message_ids[msg_id] = None
+
+ # Trim cache
+ while len(self._processed_message_ids) > 1000:
+ self._processed_message_ids.popitem(last=False)
+
+ # Extract sender info from "from" field (SDK format)
+ from_info = body.get("from", {})
+ sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
+
+ # For single chat, chatid is the sender's userid
+ # For group chat, chatid is provided in body
+ chat_type = body.get("chattype", "single")
+ chat_id = body.get("chatid", sender_id)
+
+ content_parts = []
+
+ if msg_type == "text":
+ text = body.get("text", {}).get("content", "")
+ if text:
+ content_parts.append(text)
+
+ elif msg_type == "image":
+ image_info = body.get("image", {})
+ file_url = image_info.get("url", "")
+ aes_key = image_info.get("aeskey", "")
+
+ if file_url and aes_key:
+ file_path = await self._download_and_save_media(file_url, aes_key, "image")
+ if file_path:
+ filename = os.path.basename(file_path)
+ content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
+ else:
+ content_parts.append("[image: download failed]")
+ else:
+ content_parts.append("[image: download failed]")
+
+ elif msg_type == "voice":
+ voice_info = body.get("voice", {})
+ # Voice message already contains transcribed content from WeCom
+ voice_content = voice_info.get("content", "")
+ if voice_content:
+ content_parts.append(f"[voice] {voice_content}")
+ else:
+ content_parts.append("[voice]")
+
+ elif msg_type == "file":
+ file_info = body.get("file", {})
+ file_url = file_info.get("url", "")
+ aes_key = file_info.get("aeskey", "")
+ file_name = file_info.get("name", "unknown")
+
+ if file_url and aes_key:
+ file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
+ if file_path:
+ content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
+ else:
+ content_parts.append(f"[file: {file_name}: download failed]")
+ else:
+ content_parts.append(f"[file: {file_name}: download failed]")
+
+ elif msg_type == "mixed":
+ # Mixed content contains multiple message items
+ msg_items = body.get("mixed", {}).get("item", [])
+ for item in msg_items:
+ item_type = item.get("type", "")
+ if item_type == "text":
+ text = item.get("text", {}).get("content", "")
+ if text:
+ content_parts.append(text)
+ else:
+ content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
+
+ else:
+ content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
+
+ content = "\n".join(content_parts) if content_parts else ""
+
+ if not content:
+ return
+
+ # Store frame for this chat to enable replies
+ self._chat_frames[chat_id] = frame
+
+ # Forward to message bus
+ # Note: media paths are included in content for broader model compatibility
+ await self._handle_message(
+ sender_id=sender_id,
+ chat_id=chat_id,
+ content=content,
+ media=None,
+ metadata={
+ "message_id": msg_id,
+ "msg_type": msg_type,
+ "chat_type": chat_type,
+ }
+ )
+
+ except Exception as e:
+ logger.error("Error processing WeCom message: {}", e)
+
+ async def _download_and_save_media(
+ self,
+ file_url: str,
+ aes_key: str,
+ media_type: str,
+ filename: str | None = None,
+ ) -> str | None:
+ """
+ Download and decrypt media from WeCom.
+
+ Returns:
+ file_path or None if download failed
+ """
+ try:
+ data, fname = await self._client.download_file(file_url, aes_key)
+
+ if not data:
+ logger.warning("Failed to download media from WeCom")
+ return None
+
+ media_dir = get_media_dir("wecom")
+ if not filename:
+ filename = fname or f"{media_type}_{hash(file_url) % 100000}"
+ filename = os.path.basename(filename)
+
+ file_path = media_dir / filename
+ file_path.write_bytes(data)
+ logger.debug("Downloaded {} to {}", media_type, file_path)
+ return str(file_path)
+
+ except Exception as e:
+ logger.error("Error downloading media: {}", e)
+ return None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send a message through WeCom."""
+ if not self._client:
+ logger.warning("WeCom client not initialized")
+ return
+
+ try:
+ content = msg.content.strip()
+ if not content:
+ return
+
+ # Get the stored frame for this chat
+ frame = self._chat_frames.get(msg.chat_id)
+ if not frame:
+ logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
+ return
+
+ # Use streaming reply for better UX
+ stream_id = self._generate_req_id("stream")
+
+ # Send as streaming message with finish=True
+ await self._client.reply_stream(
+ frame,
+ stream_id,
+ content,
+ finish=True,
+ )
+
+ logger.debug("WeCom message sent to {}", msg.chat_id)
+
+ except Exception as e:
+ logger.error("Error sending WeCom message: {}", e)
diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py
index f5fb521..b689e30 100644
--- a/nanobot/channels/whatsapp.py
+++ b/nanobot/channels/whatsapp.py
@@ -2,42 +2,62 @@
import asyncio
import json
+import mimetypes
+from collections import OrderedDict
from typing import Any
from loguru import logger
+from pydantic import Field
+
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
-from nanobot.config.schema import WhatsAppConfig
+from nanobot.config.schema import Base
+
+
+class WhatsAppConfig(Base):
+ """WhatsApp channel configuration."""
+
+ enabled: bool = False
+ bridge_url: str = "ws://localhost:3001"
+ bridge_token: str = ""
+ allow_from: list[str] = Field(default_factory=list)
class WhatsAppChannel(BaseChannel):
"""
WhatsApp channel that connects to a Node.js bridge.
-
+
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
Communication between Python and Node.js is via WebSocket.
"""
-
+
name = "whatsapp"
-
- def __init__(self, config: WhatsAppConfig, bus: MessageBus):
+ display_name = "WhatsApp"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return WhatsAppConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = WhatsAppConfig.model_validate(config)
super().__init__(config, bus)
- self.config: WhatsAppConfig = config
self._ws = None
self._connected = False
-
+ self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
+
async def start(self) -> None:
"""Start the WhatsApp channel by connecting to the bridge."""
import websockets
-
+
bridge_url = self.config.bridge_url
-
+
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
-
+
self._running = True
-
+
while self._running:
try:
async with websockets.connect(bridge_url) as ws:
@@ -47,40 +67,40 @@ class WhatsAppChannel(BaseChannel):
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
self._connected = True
logger.info("Connected to WhatsApp bridge")
-
+
# Listen for messages
async for message in ws:
try:
await self._handle_bridge_message(message)
except Exception as e:
logger.error("Error handling bridge message: {}", e)
-
+
except asyncio.CancelledError:
break
except Exception as e:
self._connected = False
self._ws = None
logger.warning("WhatsApp bridge connection error: {}", e)
-
+
if self._running:
logger.info("Reconnecting in 5 seconds...")
await asyncio.sleep(5)
-
+
async def stop(self) -> None:
"""Stop the WhatsApp channel."""
self._running = False
self._connected = False
-
+
if self._ws:
await self._ws.close()
self._ws = None
-
+
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WhatsApp."""
if not self._ws or not self._connected:
logger.warning("WhatsApp bridge not connected")
return
-
+
try:
payload = {
"type": "send",
@@ -90,7 +110,7 @@ class WhatsAppChannel(BaseChannel):
await self._ws.send(json.dumps(payload, ensure_ascii=False))
except Exception as e:
logger.error("Error sending WhatsApp message: {}", e)
-
+
async def _handle_bridge_message(self, raw: str) -> None:
"""Handle a message from the bridge."""
try:
@@ -98,51 +118,71 @@ class WhatsAppChannel(BaseChannel):
except json.JSONDecodeError:
logger.warning("Invalid JSON from bridge: {}", raw[:100])
return
-
+
msg_type = data.get("type")
-
+
if msg_type == "message":
# Incoming message from WhatsApp
# Deprecated by whatsapp: old phone number style typically: @s.whatspp.net
pn = data.get("pn", "")
- # New LID sytle typically:
+ # New LID sytle typically:
sender = data.get("sender", "")
content = data.get("content", "")
-
+ message_id = data.get("id", "")
+
+ if message_id:
+ if message_id in self._processed_message_ids:
+ return
+ self._processed_message_ids[message_id] = None
+ while len(self._processed_message_ids) > 1000:
+ self._processed_message_ids.popitem(last=False)
+
# Extract just the phone number or lid as chat_id
user_id = pn if pn else sender
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
logger.info("Sender {}", sender)
-
+
# Handle voice transcription if it's a voice message
if content == "[Voice Message]":
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
content = "[Voice Message: Transcription not available for WhatsApp yet]"
-
+
+ # Extract media paths (images/documents/videos downloaded by the bridge)
+ media_paths = data.get("media") or []
+
+ # Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
+ if media_paths:
+ for p in media_paths:
+ mime, _ = mimetypes.guess_type(p)
+ media_type = "image" if mime and mime.startswith("image/") else "file"
+ media_tag = f"[{media_type}: {p}]"
+ content = f"{content}\n{media_tag}" if content else media_tag
+
await self._handle_message(
sender_id=sender_id,
chat_id=sender, # Use full LID for replies
content=content,
+ media=media_paths,
metadata={
- "message_id": data.get("id"),
+ "message_id": message_id,
"timestamp": data.get("timestamp"),
"is_group": data.get("isGroup", False)
}
)
-
+
elif msg_type == "status":
# Connection status update
status = data.get("status")
logger.info("WhatsApp status: {}", status)
-
+
if status == "connected":
self._connected = True
elif status == "disconnected":
self._connected = False
-
+
elif msg_type == "qr":
# QR code for authentication
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
-
+
elif msg_type == "error":
logger.error("WhatsApp bridge error: {}", data.get('error'))
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py
index acea9e2..47f7316 100644
--- a/nanobot/cli/commands.py
+++ b/nanobot/cli/commands.py
@@ -1,25 +1,41 @@
"""CLI commands for nanobot."""
import asyncio
+from contextlib import contextmanager, nullcontext
import os
-import signal
-from pathlib import Path
import select
+import signal
import sys
+from pathlib import Path
+from typing import Any
+
+# Force UTF-8 encoding for Windows console
+if sys.platform == "win32":
+ if sys.stdout.encoding != "utf-8":
+ os.environ["PYTHONIOENCODING"] = "utf-8"
+ # Re-open stdout/stderr with UTF-8 encoding
+ try:
+ sys.stdout.reconfigure(encoding="utf-8", errors="replace")
+ sys.stderr.reconfigure(encoding="utf-8", errors="replace")
+ except Exception:
+ pass
import typer
+from prompt_toolkit import print_formatted_text
+from prompt_toolkit import PromptSession
+from prompt_toolkit.formatted_text import ANSI, HTML
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.patch_stdout import patch_stdout
+from prompt_toolkit.application import run_in_terminal
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
from rich.text import Text
-from prompt_toolkit import PromptSession
-from prompt_toolkit.formatted_text import HTML
-from prompt_toolkit.history import FileHistory
-from prompt_toolkit.patch_stdout import patch_stdout
-
-from nanobot import __version__, __logo__
+from nanobot import __logo__, __version__
+from nanobot.config.paths import get_workspace_path
from nanobot.config.schema import Config
+from nanobot.utils.helpers import sync_workspace_templates
app = typer.Typer(
name="nanobot",
@@ -87,7 +103,9 @@ def _init_prompt_session() -> None:
except Exception:
pass
- history_file = Path.home() / ".nanobot" / "history" / "cli_history"
+ from nanobot.config.paths import get_cli_history_path
+
+ history_file = get_cli_history_path()
history_file.parent.mkdir(parents=True, exist_ok=True)
_PROMPT_SESSION = PromptSession(
@@ -97,8 +115,25 @@ def _init_prompt_session() -> None:
)
+def _make_console() -> Console:
+ return Console(file=sys.stdout)
+
+
+def _render_interactive_ansi(render_fn) -> str:
+ """Render Rich output to ANSI so prompt_toolkit can print it safely."""
+ ansi_console = Console(
+ force_terminal=True,
+ color_system=console.color_system or "standard",
+ width=console.width,
+ )
+ with ansi_console.capture() as capture:
+ render_fn(ansi_console)
+ return capture.get()
+
+
def _print_agent_response(response: str, render_markdown: bool) -> None:
"""Render assistant response with consistent terminal styling."""
+ console = _make_console()
content = response or ""
body = Markdown(content) if render_markdown else Text(content)
console.print()
@@ -107,6 +142,79 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
console.print()
+async def _print_interactive_line(text: str) -> None:
+ """Print async interactive updates with prompt_toolkit-safe Rich styling."""
+ def _write() -> None:
+ ansi = _render_interactive_ansi(
+ lambda c: c.print(f" [dim]↳ {text}[/dim]")
+ )
+ print_formatted_text(ANSI(ansi), end="")
+
+ await run_in_terminal(_write)
+
+
+async def _print_interactive_response(response: str, render_markdown: bool) -> None:
+ """Print async interactive replies with prompt_toolkit-safe Rich styling."""
+ def _write() -> None:
+ content = response or ""
+ ansi = _render_interactive_ansi(
+ lambda c: (
+ c.print(),
+ c.print(f"[cyan]{__logo__} nanobot[/cyan]"),
+ c.print(Markdown(content) if render_markdown else Text(content)),
+ c.print(),
+ )
+ )
+ print_formatted_text(ANSI(ansi), end="")
+
+ await run_in_terminal(_write)
+
+
+class _ThinkingSpinner:
+ """Spinner wrapper with pause support for clean progress output."""
+
+ def __init__(self, enabled: bool):
+ self._spinner = console.status(
+ "[dim]nanobot is thinking...[/dim]", spinner="dots"
+ ) if enabled else None
+ self._active = False
+
+ def __enter__(self):
+ if self._spinner:
+ self._spinner.start()
+ self._active = True
+ return self
+
+ def __exit__(self, *exc):
+ self._active = False
+ if self._spinner:
+ self._spinner.stop()
+ return False
+
+ @contextmanager
+ def pause(self):
+ """Temporarily stop spinner while printing progress."""
+ if self._spinner and self._active:
+ self._spinner.stop()
+ try:
+ yield
+ finally:
+ if self._spinner and self._active:
+ self._spinner.start()
+
+
+def _print_cli_progress_line(text: str, thinking: _ThinkingSpinner | None) -> None:
+ """Print a CLI progress line, pausing the spinner if needed."""
+ with thinking.pause() if thinking else nullcontext():
+ console.print(f" [dim]↳ {text}[/dim]")
+
+
+async def _print_interactive_progress_line(text: str, thinking: _ThinkingSpinner | None) -> None:
+ """Print an interactive progress line, pausing the spinner if needed."""
+ with thinking.pause() if thinking else nullcontext():
+ await _print_interactive_line(text)
+
+
def _is_exit_command(command: str) -> bool:
"""Return True when input should end interactive chat."""
return command.lower() in EXIT_COMMANDS
@@ -158,10 +266,9 @@ def onboard():
"""Initialize nanobot configuration and workspace."""
from nanobot.config.loader import get_config_path, load_config, save_config
from nanobot.config.schema import Config
- from nanobot.utils.helpers import get_workspace_path
-
+
config_path = get_config_path()
-
+
if config_path.exists():
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
@@ -177,18 +284,18 @@ def onboard():
else:
save_config(Config())
console.print(f"[green]✓[/green] Created config at {config_path}")
-
- # Create workspace , use config workspace path if exists, otherwise use ~/.nanobot/workspace; try './workspace' will create a workspace
- # on the root dir of the project
+ console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
+
+ _onboard_plugins(config_path)
+
+ # Create workspace, preferring the configured workspace path.
workspace = get_workspace_path(config.workspace_path)
-
if not workspace.exists():
workspace.mkdir(parents=True, exist_ok=True)
console.print(f"[green]✓[/green] Created workspace at {workspace}")
-
- # Create default bootstrap files
- _create_workspace_templates(workspace)
-
+
+ sync_workspace_templates(workspace)
+
console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:")
console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
@@ -197,44 +304,49 @@ def onboard():
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
+def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
+ """Recursively fill in missing values from defaults without overwriting user config."""
+ if not isinstance(existing, dict) or not isinstance(defaults, dict):
+ return existing
+
+ merged = dict(existing)
+ for key, value in defaults.items():
+ if key not in merged:
+ merged[key] = value
+ else:
+ merged[key] = _merge_missing_defaults(merged[key], value)
+ return merged
-def _create_workspace_templates(workspace: Path):
- """Create default workspace template files from bundled templates."""
- from importlib.resources import files as pkg_files
+def _onboard_plugins(config_path: Path) -> None:
+ """Inject default config for all discovered channels (built-in + plugins)."""
+ import json
- templates_dir = pkg_files("nanobot") / "templates"
+ from nanobot.channels.registry import discover_all
- for item in templates_dir.iterdir():
- if not item.name.endswith(".md"):
- continue
- dest = workspace / item.name
- if not dest.exists():
- dest.write_text(item.read_text(encoding="utf-8"), encoding="utf-8")
- console.print(f" [dim]Created {item.name}[/dim]")
+ all_channels = discover_all()
+ if not all_channels:
+ return
- memory_dir = workspace / "memory"
- memory_dir.mkdir(exist_ok=True)
+ with open(config_path, encoding="utf-8") as f:
+ data = json.load(f)
- memory_template = templates_dir / "memory" / "MEMORY.md"
- memory_file = memory_dir / "MEMORY.md"
- if not memory_file.exists():
- memory_file.write_text(memory_template.read_text(encoding="utf-8"), encoding="utf-8")
- console.print(" [dim]Created memory/MEMORY.md[/dim]")
+ channels = data.setdefault("channels", {})
+ for name, cls in all_channels.items():
+ if name not in channels:
+ channels[name] = cls.default_config()
+ else:
+ channels[name] = _merge_missing_defaults(channels[name], cls.default_config())
- history_file = memory_dir / "HISTORY.md"
- if not history_file.exists():
- history_file.write_text("", encoding="utf-8")
- console.print(" [dim]Created memory/HISTORY.md[/dim]")
-
- (workspace / "skills").mkdir(exist_ok=True)
+ with open(config_path, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2, ensure_ascii=False)
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
- from nanobot.providers.litellm_provider import LiteLLMProvider
+ from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
- from nanobot.providers.custom_provider import CustomProvider
+ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
@@ -242,30 +354,80 @@ def _make_provider(config: Config):
# OpenAI Codex (OAuth)
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
- return OpenAICodexProvider(default_model=model)
-
+ provider = OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
- if provider_name == "custom":
- return CustomProvider(
+ elif provider_name == "custom":
+ from nanobot.providers.custom_provider import CustomProvider
+ provider = CustomProvider(
api_key=p.api_key if p else "no-key",
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
default_model=model,
+ extra_headers=p.extra_headers if p else None,
+ )
+ # Azure OpenAI: direct Azure OpenAI endpoint with deployment name
+ elif provider_name == "azure_openai":
+ if not p or not p.api_key or not p.api_base:
+ console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
+ console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
+ console.print("Use the model field to specify the deployment name.")
+ raise typer.Exit(1)
+ provider = AzureOpenAIProvider(
+ api_key=p.api_key,
+ api_base=p.api_base,
+ default_model=model,
+ )
+ else:
+ from nanobot.providers.litellm_provider import LiteLLMProvider
+ from nanobot.providers.registry import find_by_name
+ spec = find_by_name(provider_name)
+ if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
+ console.print("[red]Error: No API key configured.[/red]")
+ console.print("Set one in ~/.nanobot/config.json under providers section")
+ raise typer.Exit(1)
+ provider = LiteLLMProvider(
+ api_key=p.api_key if p else None,
+ api_base=config.get_api_base(model),
+ default_model=model,
+ extra_headers=p.extra_headers if p else None,
+ provider_name=provider_name,
)
- from nanobot.providers.registry import find_by_name
- spec = find_by_name(provider_name)
- if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
- console.print("[red]Error: No API key configured.[/red]")
- console.print("Set one in ~/.nanobot/config.json under providers section")
- raise typer.Exit(1)
-
- return LiteLLMProvider(
- api_key=p.api_key if p else None,
- api_base=config.get_api_base(model),
- default_model=model,
- extra_headers=p.extra_headers if p else None,
- provider_name=provider_name,
+ defaults = config.agents.defaults
+ provider.generation = GenerationSettings(
+ temperature=defaults.temperature,
+ max_tokens=defaults.max_tokens,
+ reasoning_effort=defaults.reasoning_effort,
)
+ return provider
+
+
+def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
+ """Load config and optionally override the active workspace."""
+ from nanobot.config.loader import load_config, set_config_path
+
+ config_path = None
+ if config:
+ config_path = Path(config).expanduser().resolve()
+ if not config_path.exists():
+ console.print(f"[red]Error: Config file not found: {config_path}[/red]")
+ raise typer.Exit(1)
+ set_config_path(config_path)
+ console.print(f"[dim]Using config: {config_path}[/dim]")
+
+ loaded = load_config(config_path)
+ if workspace:
+ loaded.agents.defaults.workspace = workspace
+ return loaded
+
+
+def _print_deprecated_memory_window_notice(config: Config) -> None:
+ """Warn when running with old memoryWindow-only config."""
+ if config.agents.defaults.should_warn_deprecated_memory_window:
+ console.print(
+ "[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
+ "`contextWindowTokens`. `memoryWindow` is ignored; run "
+ "[cyan]nanobot onboard[/cyan] to refresh your config template."
+ )
# ============================================================================
@@ -275,45 +437,49 @@ def _make_provider(config: Config):
@app.command()
def gateway(
- port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
+ port: int | None = typer.Option(None, "--port", "-p", help="Gateway port"),
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
+ config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Start the nanobot gateway."""
- from nanobot.config.loader import load_config, get_data_dir
- from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop
+ from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager
- from nanobot.session.manager import SessionManager
+ from nanobot.config.paths import get_cron_dir
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService
-
+ from nanobot.session.manager import SessionManager
+
if verbose:
import logging
logging.basicConfig(level=logging.DEBUG)
-
- console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
-
- config = load_config()
+
+ config = _load_runtime_config(config, workspace)
+ _print_deprecated_memory_window_notice(config)
+ port = port if port is not None else config.gateway.port
+
+ console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
+ sync_workspace_templates(config.workspace_path)
bus = MessageBus()
provider = _make_provider(config)
session_manager = SessionManager(config.workspace_path)
-
+
# Create cron service first (callback set after agent creation)
- cron_store_path = get_data_dir() / "cron" / "jobs.json"
+ cron_store_path = get_cron_dir() / "jobs.json"
cron = CronService(cron_store_path)
-
+
# Create agent with cron service
agent = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
- temperature=config.agents.defaults.temperature,
- max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
- memory_window=config.agents.defaults.memory_window,
- brave_api_key=config.tools.web.search.api_key or None,
+ context_window_tokens=config.agents.defaults.context_window_tokens,
+ web_search_config=config.tools.web.search,
+ web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace,
@@ -321,26 +487,53 @@ def gateway(
mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
)
-
+
# Set cron callback (needs agent)
async def on_cron_job(job: CronJob) -> str | None:
"""Execute a cron job through the agent."""
- response = await agent.process_direct(
- job.payload.message,
- session_key=f"cron:{job.id}",
- channel=job.payload.channel or "cli",
- chat_id=job.payload.to or "direct",
+ from nanobot.agent.tools.cron import CronTool
+ from nanobot.agent.tools.message import MessageTool
+ from nanobot.utils.evaluator import evaluate_response
+
+ reminder_note = (
+ "[Scheduled Task] Timer finished.\n\n"
+ f"Task '{job.name}' has been triggered.\n"
+ f"Scheduled instruction: {job.payload.message}"
)
- if job.payload.deliver and job.payload.to:
- from nanobot.bus.events import OutboundMessage
- await bus.publish_outbound(OutboundMessage(
+
+ cron_tool = agent.tools.get("cron")
+ cron_token = None
+ if isinstance(cron_tool, CronTool):
+ cron_token = cron_tool.set_cron_context(True)
+ try:
+ response = await agent.process_direct(
+ reminder_note,
+ session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli",
- chat_id=job.payload.to,
- content=response or ""
- ))
+ chat_id=job.payload.to or "direct",
+ )
+ finally:
+ if isinstance(cron_tool, CronTool) and cron_token is not None:
+ cron_tool.reset_cron_context(cron_token)
+
+ message_tool = agent.tools.get("message")
+ if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
+ return response
+
+ if job.payload.deliver and job.payload.to and response:
+ should_notify = await evaluate_response(
+ response, job.payload.message, provider, agent.model,
+ )
+ if should_notify:
+ from nanobot.bus.events import OutboundMessage
+ await bus.publish_outbound(OutboundMessage(
+ channel=job.payload.channel or "cli",
+ chat_id=job.payload.to,
+ content=response,
+ ))
return response
cron.on_job = on_cron_job
-
+
# Create channel manager
channels = ChannelManager(config, bus)
@@ -394,18 +587,18 @@ def gateway(
interval_s=hb_cfg.interval_s,
enabled=hb_cfg.enabled,
)
-
+
if channels.enabled_channels:
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
else:
console.print("[yellow]Warning: No channels enabled[/yellow]")
-
+
cron_status = cron.status()
if cron_status["jobs"] > 0:
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
-
+
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
-
+
async def run():
try:
await cron.start()
@@ -416,13 +609,17 @@ def gateway(
)
except KeyboardInterrupt:
console.print("\nShutting down...")
+ except Exception:
+ import traceback
+ console.print("\n[red]Error: Gateway crashed unexpectedly[/red]")
+ console.print(traceback.format_exc())
finally:
await agent.close_mcp()
heartbeat.stop()
cron.stop()
agent.stop()
await channels.stop_all()
-
+
asyncio.run(run())
@@ -437,54 +634,53 @@ def gateway(
def agent(
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
+ config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
):
"""Interact with the agent directly."""
- from nanobot.config.loader import load_config, get_data_dir
- from nanobot.bus.queue import MessageBus
- from nanobot.agent.loop import AgentLoop
- from nanobot.cron.service import CronService
from loguru import logger
-
- config = load_config()
-
+
+ from nanobot.agent.loop import AgentLoop
+ from nanobot.bus.queue import MessageBus
+ from nanobot.config.paths import get_cron_dir
+ from nanobot.cron.service import CronService
+
+ config = _load_runtime_config(config, workspace)
+ _print_deprecated_memory_window_notice(config)
+ sync_workspace_templates(config.workspace_path)
+
bus = MessageBus()
provider = _make_provider(config)
# Create cron service for tool usage (no callback needed for CLI unless running)
- cron_store_path = get_data_dir() / "cron" / "jobs.json"
+ cron_store_path = get_cron_dir() / "jobs.json"
cron = CronService(cron_store_path)
if logs:
logger.enable("nanobot")
else:
logger.disable("nanobot")
-
+
agent_loop = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
- temperature=config.agents.defaults.temperature,
- max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
- memory_window=config.agents.defaults.memory_window,
- brave_api_key=config.tools.web.search.api_key or None,
+ context_window_tokens=config.agents.defaults.context_window_tokens,
+ web_search_config=config.tools.web.search,
+ web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
)
-
- # Show spinner when logs are off (no output to miss); skip when logs are on
- def _thinking_ctx():
- if logs:
- from contextlib import nullcontext
- return nullcontext()
- # Animated spinner is safe to use with prompt_toolkit input handling
- return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
+
+ # Shared reference for progress callbacks
+ _thinking: _ThinkingSpinner | None = None
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
ch = agent_loop.channels_config
@@ -492,13 +688,16 @@ def agent(
return
if ch and not tool_hint and not ch.send_progress:
return
- console.print(f" [dim]↳ {content}[/dim]")
+ _print_cli_progress_line(content, _thinking)
if message:
# Single message mode — direct call, no bus needed
async def run_once():
- with _thinking_ctx():
+ nonlocal _thinking
+ _thinking = _ThinkingSpinner(enabled=not logs)
+ with _thinking:
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
+ _thinking = None
_print_agent_response(response, render_markdown=markdown)
await agent_loop.close_mcp()
@@ -514,12 +713,21 @@ def agent(
else:
cli_channel, cli_chat_id = "cli", session_id
- def _exit_on_sigint(signum, frame):
+ def _handle_signal(signum, frame):
+ sig_name = signal.Signals(signum).name
_restore_terminal()
- console.print("\nGoodbye!")
- os._exit(0)
+ console.print(f"\nReceived {sig_name}, goodbye!")
+ sys.exit(0)
- signal.signal(signal.SIGINT, _exit_on_sigint)
+ signal.signal(signal.SIGINT, _handle_signal)
+ signal.signal(signal.SIGTERM, _handle_signal)
+ # SIGHUP is not available on Windows
+ if hasattr(signal, 'SIGHUP'):
+ signal.signal(signal.SIGHUP, _handle_signal)
+ # Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
+ # SIGPIPE is not available on Windows
+ if hasattr(signal, 'SIGPIPE'):
+ signal.signal(signal.SIGPIPE, signal.SIG_IGN)
async def run_interactive():
bus_task = asyncio.create_task(agent_loop.run())
@@ -539,14 +747,15 @@ def agent(
elif ch and not is_tool_hint and not ch.send_progress:
pass
else:
- console.print(f" [dim]↳ {msg.content}[/dim]")
+ await _print_interactive_progress_line(msg.content, _thinking)
+
elif not turn_done.is_set():
if msg.content:
turn_response.append(msg.content)
turn_done.set()
elif msg.content:
- console.print()
- _print_agent_response(msg.content, render_markdown=markdown)
+ await _print_interactive_response(msg.content, render_markdown=markdown)
+
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
@@ -578,8 +787,11 @@ def agent(
content=user_input,
))
- with _thinking_ctx():
+ nonlocal _thinking
+ _thinking = _ThinkingSpinner(enabled=not logs)
+ with _thinking:
await turn_done.wait()
+ _thinking = None
if turn_response:
_print_agent_response(turn_response[0], render_markdown=markdown)
@@ -612,6 +824,7 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status")
def channels_status():
"""Show channel status."""
+ from nanobot.channels.registry import discover_all
from nanobot.config.loader import load_config
config = load_config()
@@ -619,85 +832,19 @@ def channels_status():
table = Table(title="Channel Status")
table.add_column("Channel", style="cyan")
table.add_column("Enabled", style="green")
- table.add_column("Configuration", style="yellow")
- # WhatsApp
- wa = config.channels.whatsapp
- table.add_row(
- "WhatsApp",
- "✓" if wa.enabled else "✗",
- wa.bridge_url
- )
-
- dc = config.channels.discord
- table.add_row(
- "Discord",
- "✓" if dc.enabled else "✗",
- dc.gateway_url
- )
-
- # Feishu
- fs = config.channels.feishu
- fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
- table.add_row(
- "Feishu",
- "✓" if fs.enabled else "✗",
- fs_config
- )
-
- # Mochat
- mc = config.channels.mochat
- mc_base = mc.base_url or "[dim]not configured[/dim]"
- table.add_row(
- "Mochat",
- "✓" if mc.enabled else "✗",
- mc_base
- )
-
- # Telegram
- tg = config.channels.telegram
- tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
- table.add_row(
- "Telegram",
- "✓" if tg.enabled else "✗",
- tg_config
- )
-
- # Slack
- slack = config.channels.slack
- slack_config = "socket" if slack.app_token and slack.bot_token else "[dim]not configured[/dim]"
- table.add_row(
- "Slack",
- "✓" if slack.enabled else "✗",
- slack_config
- )
-
- # DingTalk
- dt = config.channels.dingtalk
- dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
- table.add_row(
- "DingTalk",
- "✓" if dt.enabled else "✗",
- dt_config
- )
-
- # QQ
- qq = config.channels.qq
- qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
- table.add_row(
- "QQ",
- "✓" if qq.enabled else "✗",
- qq_config
- )
-
- # Email
- em = config.channels.email
- em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
- table.add_row(
- "Email",
- "✓" if em.enabled else "✗",
- em_config
- )
+ for name, cls in sorted(discover_all().items()):
+ section = getattr(config.channels, name, None)
+ if section is None:
+ enabled = False
+ elif isinstance(section, dict):
+ enabled = section.get("enabled", False)
+ else:
+ enabled = getattr(section, "enabled", False)
+ table.add_row(
+ cls.display_name,
+ "[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
+ )
console.print(table)
@@ -706,294 +853,136 @@ def _get_bridge_dir() -> Path:
"""Get the bridge directory, setting it up if needed."""
import shutil
import subprocess
-
+
# User's bridge location
- user_bridge = Path.home() / ".nanobot" / "bridge"
-
+ from nanobot.config.paths import get_bridge_install_dir
+
+ user_bridge = get_bridge_install_dir()
+
# Check if already built
if (user_bridge / "dist" / "index.js").exists():
return user_bridge
-
+
# Check for npm
- if not shutil.which("npm"):
+ npm_path = shutil.which("npm")
+ if not npm_path:
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
raise typer.Exit(1)
-
+
# Find source bridge: first check package data, then source dir
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
-
+
source = None
if (pkg_bridge / "package.json").exists():
source = pkg_bridge
elif (src_bridge / "package.json").exists():
source = src_bridge
-
+
if not source:
console.print("[red]Bridge source not found.[/red]")
console.print("Try reinstalling: pip install --force-reinstall nanobot")
raise typer.Exit(1)
-
+
console.print(f"{__logo__} Setting up bridge...")
-
+
# Copy to user directory
user_bridge.parent.mkdir(parents=True, exist_ok=True)
if user_bridge.exists():
shutil.rmtree(user_bridge)
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
-
+
# Install and build
try:
console.print(" Installing dependencies...")
- subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
-
+ subprocess.run([npm_path, "install"], cwd=user_bridge, check=True, capture_output=True)
+
console.print(" Building...")
- subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
-
+ subprocess.run([npm_path, "run", "build"], cwd=user_bridge, check=True, capture_output=True)
+
console.print("[green]✓[/green] Bridge ready\n")
except subprocess.CalledProcessError as e:
console.print(f"[red]Build failed: {e}[/red]")
if e.stderr:
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
raise typer.Exit(1)
-
+
return user_bridge
@channels_app.command("login")
def channels_login():
"""Link device via QR code."""
+ import shutil
import subprocess
+
from nanobot.config.loader import load_config
-
+ from nanobot.config.paths import get_runtime_subdir
+
config = load_config()
bridge_dir = _get_bridge_dir()
-
+
console.print(f"{__logo__} Starting bridge...")
console.print("Scan the QR code to connect.\n")
-
+
env = {**os.environ}
- if config.channels.whatsapp.bridge_token:
- env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
-
+ wa_cfg = getattr(config.channels, "whatsapp", None) or {}
+ bridge_token = wa_cfg.get("bridgeToken", "") if isinstance(wa_cfg, dict) else getattr(wa_cfg, "bridge_token", "")
+ if bridge_token:
+ env["BRIDGE_TOKEN"] = bridge_token
+ env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
+
+ npm_path = shutil.which("npm")
+ if not npm_path:
+ console.print("[red]npm not found. Please install Node.js.[/red]")
+ raise typer.Exit(1)
+
try:
- subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
+ subprocess.run([npm_path, "start"], cwd=bridge_dir, check=True, env=env)
except subprocess.CalledProcessError as e:
console.print(f"[red]Bridge failed: {e}[/red]")
- except FileNotFoundError:
- console.print("[red]npm not found. Please install Node.js.[/red]")
# ============================================================================
-# Cron Commands
+# Plugin Commands
# ============================================================================
-cron_app = typer.Typer(help="Manage scheduled tasks")
-app.add_typer(cron_app, name="cron")
+plugins_app = typer.Typer(help="Manage channel plugins")
+app.add_typer(plugins_app, name="plugins")
-@cron_app.command("list")
-def cron_list(
- all: bool = typer.Option(False, "--all", "-a", help="Include disabled jobs"),
-):
- """List scheduled jobs."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- jobs = service.list_jobs(include_disabled=all)
-
- if not jobs:
- console.print("No scheduled jobs.")
- return
-
- table = Table(title="Scheduled Jobs")
- table.add_column("ID", style="cyan")
- table.add_column("Name")
- table.add_column("Schedule")
- table.add_column("Status")
- table.add_column("Next Run")
-
- import time
- from datetime import datetime as _dt
- from zoneinfo import ZoneInfo
- for job in jobs:
- # Format schedule
- if job.schedule.kind == "every":
- sched = f"every {(job.schedule.every_ms or 0) // 1000}s"
- elif job.schedule.kind == "cron":
- sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
- else:
- sched = "one-time"
-
- # Format next run
- next_run = ""
- if job.state.next_run_at_ms:
- ts = job.state.next_run_at_ms / 1000
- try:
- tz = ZoneInfo(job.schedule.tz) if job.schedule.tz else None
- next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M")
- except Exception:
- next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))
-
- status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
-
- table.add_row(job.id, job.name, sched, status, next_run)
-
- console.print(table)
-
-
-@cron_app.command("add")
-def cron_add(
- name: str = typer.Option(..., "--name", "-n", help="Job name"),
- message: str = typer.Option(..., "--message", "-m", help="Message for agent"),
- every: int = typer.Option(None, "--every", "-e", help="Run every N seconds"),
- cron_expr: str = typer.Option(None, "--cron", "-c", help="Cron expression (e.g. '0 9 * * *')"),
- tz: str | None = typer.Option(None, "--tz", help="IANA timezone for cron (e.g. 'America/Vancouver')"),
- at: str = typer.Option(None, "--at", help="Run once at time (ISO format)"),
- deliver: bool = typer.Option(False, "--deliver", "-d", help="Deliver response to channel"),
- to: str = typer.Option(None, "--to", help="Recipient for delivery"),
- channel: str = typer.Option(None, "--channel", help="Channel for delivery (e.g. 'telegram', 'whatsapp')"),
-):
- """Add a scheduled job."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
- from nanobot.cron.types import CronSchedule
-
- if tz and not cron_expr:
- console.print("[red]Error: --tz can only be used with --cron[/red]")
- raise typer.Exit(1)
-
- # Determine schedule type
- if every:
- schedule = CronSchedule(kind="every", every_ms=every * 1000)
- elif cron_expr:
- schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
- elif at:
- import datetime
- dt = datetime.datetime.fromisoformat(at)
- schedule = CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000))
- else:
- console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
- raise typer.Exit(1)
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- try:
- job = service.add_job(
- name=name,
- schedule=schedule,
- message=message,
- deliver=deliver,
- to=to,
- channel=channel,
- )
- except ValueError as e:
- console.print(f"[red]Error: {e}[/red]")
- raise typer.Exit(1) from e
-
- console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
-
-
-@cron_app.command("remove")
-def cron_remove(
- job_id: str = typer.Argument(..., help="Job ID to remove"),
-):
- """Remove a scheduled job."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- if service.remove_job(job_id):
- console.print(f"[green]✓[/green] Removed job {job_id}")
- else:
- console.print(f"[red]Job {job_id} not found[/red]")
-
-
-@cron_app.command("enable")
-def cron_enable(
- job_id: str = typer.Argument(..., help="Job ID"),
- disable: bool = typer.Option(False, "--disable", help="Disable instead of enable"),
-):
- """Enable or disable a job."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- job = service.enable_job(job_id, enabled=not disable)
- if job:
- status = "disabled" if disable else "enabled"
- console.print(f"[green]✓[/green] Job '{job.name}' {status}")
- else:
- console.print(f"[red]Job {job_id} not found[/red]")
-
-
-@cron_app.command("run")
-def cron_run(
- job_id: str = typer.Argument(..., help="Job ID to run"),
- force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
-):
- """Manually run a job."""
- from loguru import logger
- from nanobot.config.loader import load_config, get_data_dir
- from nanobot.cron.service import CronService
- from nanobot.cron.types import CronJob
- from nanobot.bus.queue import MessageBus
- from nanobot.agent.loop import AgentLoop
- logger.disable("nanobot")
+@plugins_app.command("list")
+def plugins_list():
+ """List all discovered channels (built-in and plugins)."""
+ from nanobot.channels.registry import discover_all, discover_channel_names
+ from nanobot.config.loader import load_config
config = load_config()
- provider = _make_provider(config)
- bus = MessageBus()
- agent_loop = AgentLoop(
- bus=bus,
- provider=provider,
- workspace=config.workspace_path,
- model=config.agents.defaults.model,
- temperature=config.agents.defaults.temperature,
- max_tokens=config.agents.defaults.max_tokens,
- max_iterations=config.agents.defaults.max_tool_iterations,
- memory_window=config.agents.defaults.memory_window,
- brave_api_key=config.tools.web.search.api_key or None,
- exec_config=config.tools.exec,
- restrict_to_workspace=config.tools.restrict_to_workspace,
- mcp_servers=config.tools.mcp_servers,
- channels_config=config.channels,
- )
+ builtin_names = set(discover_channel_names())
+ all_channels = discover_all()
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
+ table = Table(title="Channel Plugins")
+ table.add_column("Name", style="cyan")
+ table.add_column("Source", style="magenta")
+ table.add_column("Enabled", style="green")
- result_holder = []
-
- async def on_job(job: CronJob) -> str | None:
- response = await agent_loop.process_direct(
- job.payload.message,
- session_key=f"cron:{job.id}",
- channel=job.payload.channel or "cli",
- chat_id=job.payload.to or "direct",
+ for name in sorted(all_channels):
+ cls = all_channels[name]
+ source = "builtin" if name in builtin_names else "plugin"
+ section = getattr(config.channels, name, None)
+ if section is None:
+ enabled = False
+ elif isinstance(section, dict):
+ enabled = section.get("enabled", False)
+ else:
+ enabled = getattr(section, "enabled", False)
+ table.add_row(
+ cls.display_name,
+ source,
+ "[green]yes[/green]" if enabled else "[dim]no[/dim]",
)
- result_holder.append(response)
- return response
- service.on_job = on_job
-
- async def run():
- return await service.run_job(job_id, force=force)
-
- if asyncio.run(run()):
- console.print("[green]✓[/green] Job executed")
- if result_holder:
- _print_agent_response(result_holder[0], render_markdown=True)
- else:
- console.print(f"[red]Failed to run job {job_id}[/red]")
+ console.print(table)
# ============================================================================
@@ -1004,7 +993,7 @@ def cron_run(
@app.command()
def status():
"""Show nanobot status."""
- from nanobot.config.loader import load_config, get_config_path
+ from nanobot.config.loader import get_config_path, load_config
config_path = get_config_path()
config = load_config()
@@ -1019,7 +1008,7 @@ def status():
from nanobot.providers.registry import PROVIDERS
console.print(f"Model: {config.agents.defaults.model}")
-
+
# Check API keys from registry
for spec in PROVIDERS:
p = getattr(config.providers, spec.name, None)
diff --git a/nanobot/config/__init__.py b/nanobot/config/__init__.py
index 88e8e9b..e2c24f8 100644
--- a/nanobot/config/__init__.py
+++ b/nanobot/config/__init__.py
@@ -1,6 +1,30 @@
"""Configuration module for nanobot."""
-from nanobot.config.loader import load_config, get_config_path
+from nanobot.config.loader import get_config_path, load_config
+from nanobot.config.paths import (
+ get_bridge_install_dir,
+ get_cli_history_path,
+ get_cron_dir,
+ get_data_dir,
+ get_legacy_sessions_dir,
+ get_logs_dir,
+ get_media_dir,
+ get_runtime_subdir,
+ get_workspace_path,
+)
from nanobot.config.schema import Config
-__all__ = ["Config", "load_config", "get_config_path"]
+__all__ = [
+ "Config",
+ "load_config",
+ "get_config_path",
+ "get_data_dir",
+ "get_runtime_subdir",
+ "get_media_dir",
+ "get_cron_dir",
+ "get_logs_dir",
+ "get_workspace_path",
+ "get_cli_history_path",
+ "get_bridge_install_dir",
+ "get_legacy_sessions_dir",
+]
diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py
index c789efd..7d309e5 100644
--- a/nanobot/config/loader.py
+++ b/nanobot/config/loader.py
@@ -6,17 +6,23 @@ from pathlib import Path
from nanobot.config.schema import Config
+# Global variable to store current config path (for multi-instance support)
+_current_config_path: Path | None = None
+
+
+def set_config_path(path: Path) -> None:
+ """Set the current config path (used to derive data directory)."""
+ global _current_config_path
+ _current_config_path = path
+
+
def get_config_path() -> Path:
- """Get the default configuration file path."""
+ """Get the configuration file path."""
+ if _current_config_path:
+ return _current_config_path
return Path.home() / ".nanobot" / "config.json"
-def get_data_dir() -> Path:
- """Get the nanobot data directory."""
- from nanobot.utils.helpers import get_data_path
- return get_data_path()
-
-
def load_config(config_path: Path | None = None) -> Config:
"""
Load configuration from file or create default.
diff --git a/nanobot/config/paths.py b/nanobot/config/paths.py
new file mode 100644
index 0000000..f4dfbd9
--- /dev/null
+++ b/nanobot/config/paths.py
@@ -0,0 +1,55 @@
+"""Runtime path helpers derived from the active config context."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from nanobot.config.loader import get_config_path
+from nanobot.utils.helpers import ensure_dir
+
+
+def get_data_dir() -> Path:
+ """Return the instance-level runtime data directory."""
+ return ensure_dir(get_config_path().parent)
+
+
+def get_runtime_subdir(name: str) -> Path:
+ """Return a named runtime subdirectory under the instance data dir."""
+ return ensure_dir(get_data_dir() / name)
+
+
+def get_media_dir(channel: str | None = None) -> Path:
+ """Return the media directory, optionally namespaced per channel."""
+ base = get_runtime_subdir("media")
+ return ensure_dir(base / channel) if channel else base
+
+
+def get_cron_dir() -> Path:
+ """Return the cron storage directory."""
+ return get_runtime_subdir("cron")
+
+
+def get_logs_dir() -> Path:
+ """Return the logs directory."""
+ return get_runtime_subdir("logs")
+
+
+def get_workspace_path(workspace: str | None = None) -> Path:
+ """Resolve and ensure the agent workspace path."""
+ path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
+ return ensure_dir(path)
+
+
+def get_cli_history_path() -> Path:
+ """Return the shared CLI history file path."""
+ return Path.home() / ".nanobot" / "history" / "cli_history"
+
+
+def get_bridge_install_dir() -> Path:
+ """Return the shared WhatsApp bridge installation directory."""
+ return Path.home() / ".nanobot" / "bridge"
+
+
+def get_legacy_sessions_dir() -> Path:
+ """Return the legacy global session directory used for migration fallback."""
+ return Path.home() / ".nanobot" / "sessions"
diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py
index 215f38d..033fb63 100644
--- a/nanobot/config/schema.py
+++ b/nanobot/config/schema.py
@@ -1,7 +1,9 @@
"""Configuration schema using Pydantic."""
from pathlib import Path
-from pydantic import BaseModel, Field, ConfigDict
+from typing import Literal
+
+from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_camel
from pydantic_settings import BaseSettings
@@ -12,173 +14,17 @@ class Base(BaseModel):
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
-class WhatsAppConfig(Base):
- """WhatsApp channel configuration."""
-
- enabled: bool = False
- bridge_url: str = "ws://localhost:3001"
- bridge_token: str = "" # Shared token for bridge auth (optional, recommended)
- allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
-
-
-class TelegramConfig(Base):
- """Telegram channel configuration."""
-
- enabled: bool = False
- token: str = "" # Bot token from @BotFather
- allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
- proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
- reply_to_message: bool = False # If true, bot replies quote the original message
-
-
-class FeishuConfig(Base):
- """Feishu/Lark channel configuration using WebSocket long connection."""
-
- enabled: bool = False
- app_id: str = "" # App ID from Feishu Open Platform
- app_secret: str = "" # App Secret from Feishu Open Platform
- encrypt_key: str = "" # Encrypt Key for event subscription (optional)
- verification_token: str = "" # Verification Token for event subscription (optional)
- allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
-
-
-class DingTalkConfig(Base):
- """DingTalk channel configuration using Stream mode."""
-
- enabled: bool = False
- client_id: str = "" # AppKey
- client_secret: str = "" # AppSecret
- allow_from: list[str] = Field(default_factory=list) # Allowed staff_ids
-
-
-class DiscordConfig(Base):
- """Discord channel configuration."""
-
- enabled: bool = False
- token: str = "" # Bot token from Discord Developer Portal
- allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
- gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
- intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
-
-
-class EmailConfig(Base):
- """Email channel configuration (IMAP inbound + SMTP outbound)."""
-
- enabled: bool = False
- consent_granted: bool = False # Explicit owner permission to access mailbox data
-
- # IMAP (receive)
- imap_host: str = ""
- imap_port: int = 993
- imap_username: str = ""
- imap_password: str = ""
- imap_mailbox: str = "INBOX"
- imap_use_ssl: bool = True
-
- # SMTP (send)
- smtp_host: str = ""
- smtp_port: int = 587
- smtp_username: str = ""
- smtp_password: str = ""
- smtp_use_tls: bool = True
- smtp_use_ssl: bool = False
- from_address: str = ""
-
- # Behavior
- auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent
- poll_interval_seconds: int = 30
- mark_seen: bool = True
- max_body_chars: int = 12000
- subject_prefix: str = "Re: "
- allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses
-
-
-class MochatMentionConfig(Base):
- """Mochat mention behavior configuration."""
-
- require_in_groups: bool = False
-
-
-class MochatGroupRule(Base):
- """Mochat per-group mention requirement."""
-
- require_mention: bool = False
-
-
-class MochatConfig(Base):
- """Mochat channel configuration."""
-
- enabled: bool = False
- base_url: str = "https://mochat.io"
- socket_url: str = ""
- socket_path: str = "/socket.io"
- socket_disable_msgpack: bool = False
- socket_reconnect_delay_ms: int = 1000
- socket_max_reconnect_delay_ms: int = 10000
- socket_connect_timeout_ms: int = 10000
- refresh_interval_ms: int = 30000
- watch_timeout_ms: int = 25000
- watch_limit: int = 100
- retry_delay_ms: int = 500
- max_retry_attempts: int = 0 # 0 means unlimited retries
- claw_token: str = ""
- agent_user_id: str = ""
- sessions: list[str] = Field(default_factory=list)
- panels: list[str] = Field(default_factory=list)
- allow_from: list[str] = Field(default_factory=list)
- mention: MochatMentionConfig = Field(default_factory=MochatMentionConfig)
- groups: dict[str, MochatGroupRule] = Field(default_factory=dict)
- reply_delay_mode: str = "non-mention" # off | non-mention
- reply_delay_ms: int = 120000
-
-
-class SlackDMConfig(Base):
- """Slack DM policy configuration."""
-
- enabled: bool = True
- policy: str = "open" # "open" or "allowlist"
- allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs
-
-
-class SlackConfig(Base):
- """Slack channel configuration."""
-
- enabled: bool = False
- mode: str = "socket" # "socket" supported
- webhook_path: str = "/slack/events"
- bot_token: str = "" # xoxb-...
- app_token: str = "" # xapp-...
- user_token_read_only: bool = True
- reply_in_thread: bool = True
- react_emoji: str = "eyes"
- group_policy: str = "mention" # "mention", "open", "allowlist"
- group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
- dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
-
-
-class QQConfig(Base):
- """QQ channel configuration using botpy SDK."""
-
- enabled: bool = False
- app_id: str = "" # 机器人 ID (AppID) from q.qq.com
- secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
- allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
-
-
class ChannelsConfig(Base):
- """Configuration for chat channels."""
+ """Configuration for chat channels.
- send_progress: bool = True # stream agent's text progress to the channel
+ Built-in and plugin channel configs are stored as extra fields (dicts).
+ Each channel parses its own config in __init__.
+ """
+
+ model_config = ConfigDict(extra="allow")
+
+ send_progress: bool = True # stream agent's text progress to the channel
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
- whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
- telegram: TelegramConfig = Field(default_factory=TelegramConfig)
- discord: DiscordConfig = Field(default_factory=DiscordConfig)
- feishu: FeishuConfig = Field(default_factory=FeishuConfig)
- mochat: MochatConfig = Field(default_factory=MochatConfig)
- dingtalk: DingTalkConfig = Field(default_factory=DingTalkConfig)
- email: EmailConfig = Field(default_factory=EmailConfig)
- slack: SlackConfig = Field(default_factory=SlackConfig)
- qq: QQConfig = Field(default_factory=QQConfig)
class AgentDefaults(Base):
@@ -186,10 +32,21 @@ class AgentDefaults(Base):
workspace: str = "~/.nanobot/workspace"
model: str = "anthropic/claude-opus-4-5"
+ provider: str = (
+ "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
+ )
max_tokens: int = 8192
+ context_window_tokens: int = 65_536
temperature: float = 0.1
max_tool_iterations: int = 40
- memory_window: int = 100
+ # Deprecated compatibility field: accepted from old configs but ignored at runtime.
+ memory_window: int | None = Field(default=None, exclude=True)
+ reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
+
+ @property
+ def should_warn_deprecated_memory_window(self) -> bool:
+ """Return True when old memoryWindow is present without contextWindowTokens."""
+ return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
class AgentsConfig(Base):
@@ -210,20 +67,25 @@ class ProvidersConfig(Base):
"""Configuration for LLM providers."""
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
+ azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
groq: ProviderConfig = Field(default_factory=ProviderConfig)
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
- dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
+ dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
+ ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
- siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) API gateway
- volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) API gateway
+ siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
+ volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
+ volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
+ byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
+ byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
@@ -246,13 +108,18 @@ class GatewayConfig(Base):
class WebSearchConfig(Base):
"""Web search tool configuration."""
- api_key: str = "" # Brave Search API key
+ provider: str = "brave" # brave, tavily, duckduckgo, searxng, jina
+ api_key: str = ""
+ base_url: str = "" # SearXNG base URL
max_results: int = 5
class WebToolsConfig(Base):
"""Web tools configuration."""
+ proxy: str | None = (
+ None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
+ )
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
@@ -260,18 +127,20 @@ class ExecToolConfig(Base):
"""Shell exec tool configuration."""
timeout: int = 60
+ path_append: str = ""
class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP)."""
+ type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
command: str = "" # Stdio: command to run (e.g. "npx")
args: list[str] = Field(default_factory=list) # Stdio: command arguments
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
- url: str = "" # HTTP: streamable HTTP endpoint URL
- headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
- tool_timeout: int = 30 # Seconds before a tool call is cancelled
-
+ url: str = "" # HTTP/SSE: endpoint URL
+ headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
+ tool_timeout: int = 30 # seconds before a tool call is cancelled
+ enabled_tools: list[str] = Field(default_factory=lambda: ["*"]) # Only register these tools; accepts raw MCP names or wrapped mcp__ names; ["*"] = all tools; [] = no tools
class ToolsConfig(Base):
"""Tools configuration."""
@@ -296,10 +165,17 @@ class Config(BaseSettings):
"""Get expanded workspace path."""
return Path(self.agents.defaults.workspace).expanduser()
- def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]:
+ def _match_provider(
+ self, model: str | None = None
+ ) -> tuple["ProviderConfig | None", str | None]:
"""Match provider config and its registry name. Returns (config, spec_name)."""
from nanobot.providers.registry import PROVIDERS
+ forced = self.agents.defaults.provider
+ if forced != "auto":
+ p = getattr(self.providers, forced, None)
+ return (p, forced) if p else (None, None)
+
model_lower = (model or self.agents.defaults.model).lower()
model_normalized = model_lower.replace("-", "_")
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
@@ -313,16 +189,34 @@ class Config(BaseSettings):
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
if p and model_prefix and normalized_prefix == spec.name:
- if spec.is_oauth or p.api_key:
+ if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name
# Match by keyword (order follows PROVIDERS registry)
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
if p and any(_kw_matches(kw) for kw in spec.keywords):
- if spec.is_oauth or p.api_key:
+ if spec.is_oauth or spec.is_local or p.api_key:
return p, spec.name
+ # Fallback: configured local providers can route models without
+ # provider-specific keywords (for example plain "llama3.2" on Ollama).
+ # Prefer providers whose detect_by_base_keyword matches the configured api_base
+ # (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
+ local_fallback: tuple[ProviderConfig, str] | None = None
+ for spec in PROVIDERS:
+ if not spec.is_local:
+ continue
+ p = getattr(self.providers, spec.name, None)
+ if not (p and p.api_base):
+ continue
+ if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
+ return p, spec.name
+ if local_fallback is None:
+ local_fallback = (p, spec.name)
+ if local_fallback:
+ return local_fallback
+
# Fallback: gateways first, then others (follows registry order)
# OAuth providers are NOT valid fallbacks — they require explicit model selection
for spec in PROVIDERS:
@@ -349,7 +243,7 @@ class Config(BaseSettings):
return p.api_key if p else None
def get_api_base(self, model: str | None = None) -> str | None:
- """Get API base URL for the given model. Applies default URLs for known gateways."""
+ """Get API base URL for the given model. Applies default URLs for gateway/local providers."""
from nanobot.providers.registry import find_by_name
p, name = self._match_provider(model)
@@ -360,7 +254,7 @@ class Config(BaseSettings):
# to avoid polluting the global litellm.api_base.
if name:
spec = find_by_name(name)
- if spec and spec.is_gateway and spec.default_api_base:
+ if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
return spec.default_api_base
return None
diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py
index 6889a10..1ed71f0 100644
--- a/nanobot/cron/service.py
+++ b/nanobot/cron/service.py
@@ -21,17 +21,18 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
"""Compute next run time in ms."""
if schedule.kind == "at":
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
-
+
if schedule.kind == "every":
if not schedule.every_ms or schedule.every_ms <= 0:
return None
# Next interval from now
return now_ms + schedule.every_ms
-
+
if schedule.kind == "cron" and schedule.expr:
try:
- from croniter import croniter
from zoneinfo import ZoneInfo
+
+ from croniter import croniter
# Use caller-provided reference time for deterministic scheduling
base_time = now_ms / 1000
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
@@ -41,7 +42,7 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
return int(next_dt.timestamp() * 1000)
except Exception:
return None
-
+
return None
@@ -61,23 +62,29 @@ def _validate_schedule_for_add(schedule: CronSchedule) -> None:
class CronService:
"""Service for managing and executing scheduled jobs."""
-
+
def __init__(
self,
store_path: Path,
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
):
self.store_path = store_path
- self.on_job = on_job # Callback to execute job, returns response text
+ self.on_job = on_job
self._store: CronStore | None = None
+ self._last_mtime: float = 0.0
self._timer_task: asyncio.Task | None = None
self._running = False
-
+
def _load_store(self) -> CronStore:
- """Load jobs from disk."""
+ """Load jobs from disk. Reloads automatically if file was modified externally."""
+ if self._store and self.store_path.exists():
+ mtime = self.store_path.stat().st_mtime
+ if mtime != self._last_mtime:
+ logger.info("Cron: jobs.json modified externally, reloading")
+ self._store = None
if self._store:
return self._store
-
+
if self.store_path.exists():
try:
data = json.loads(self.store_path.read_text(encoding="utf-8"))
@@ -117,16 +124,16 @@ class CronService:
self._store = CronStore()
else:
self._store = CronStore()
-
+
return self._store
-
+
def _save_store(self) -> None:
"""Save jobs to disk."""
if not self._store:
return
-
+
self.store_path.parent.mkdir(parents=True, exist_ok=True)
-
+
data = {
"version": self._store.version,
"jobs": [
@@ -161,8 +168,9 @@ class CronService:
for j in self._store.jobs
]
}
-
+
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
+ self._last_mtime = self.store_path.stat().st_mtime
async def start(self) -> None:
"""Start the cron service."""
@@ -172,14 +180,14 @@ class CronService:
self._save_store()
self._arm_timer()
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
-
+
def stop(self) -> None:
"""Stop the cron service."""
self._running = False
if self._timer_task:
self._timer_task.cancel()
self._timer_task = None
-
+
def _recompute_next_runs(self) -> None:
"""Recompute next run times for all enabled jobs."""
if not self._store:
@@ -188,73 +196,74 @@ class CronService:
for job in self._store.jobs:
if job.enabled:
job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
-
+
def _get_next_wake_ms(self) -> int | None:
"""Get the earliest next run time across all jobs."""
if not self._store:
return None
- times = [j.state.next_run_at_ms for j in self._store.jobs
+ times = [j.state.next_run_at_ms for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms]
return min(times) if times else None
-
+
def _arm_timer(self) -> None:
"""Schedule the next timer tick."""
if self._timer_task:
self._timer_task.cancel()
-
+
next_wake = self._get_next_wake_ms()
if not next_wake or not self._running:
return
-
+
delay_ms = max(0, next_wake - _now_ms())
delay_s = delay_ms / 1000
-
+
async def tick():
await asyncio.sleep(delay_s)
if self._running:
await self._on_timer()
-
+
self._timer_task = asyncio.create_task(tick())
-
+
async def _on_timer(self) -> None:
"""Handle timer tick - run due jobs."""
+ self._load_store()
if not self._store:
return
-
+
now = _now_ms()
due_jobs = [
j for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
]
-
+
for job in due_jobs:
await self._execute_job(job)
-
+
self._save_store()
self._arm_timer()
-
+
async def _execute_job(self, job: CronJob) -> None:
"""Execute a single job."""
start_ms = _now_ms()
logger.info("Cron: executing job '{}' ({})", job.name, job.id)
-
+
try:
response = None
if self.on_job:
response = await self.on_job(job)
-
+
job.state.last_status = "ok"
job.state.last_error = None
logger.info("Cron: job '{}' completed", job.name)
-
+
except Exception as e:
job.state.last_status = "error"
job.state.last_error = str(e)
logger.error("Cron: job '{}' failed: {}", job.name, e)
-
+
job.state.last_run_at_ms = start_ms
job.updated_at_ms = _now_ms()
-
+
# Handle one-shot jobs
if job.schedule.kind == "at":
if job.delete_after_run:
@@ -265,15 +274,15 @@ class CronService:
else:
# Compute next run
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
-
+
# ========== Public API ==========
-
+
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
"""List all jobs."""
store = self._load_store()
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf'))
-
+
def add_job(
self,
name: str,
@@ -288,7 +297,7 @@ class CronService:
store = self._load_store()
_validate_schedule_for_add(schedule)
now = _now_ms()
-
+
job = CronJob(
id=str(uuid.uuid4())[:8],
name=name,
@@ -306,28 +315,28 @@ class CronService:
updated_at_ms=now,
delete_after_run=delete_after_run,
)
-
+
store.jobs.append(job)
self._save_store()
self._arm_timer()
-
+
logger.info("Cron: added job '{}' ({})", name, job.id)
return job
-
+
def remove_job(self, job_id: str) -> bool:
"""Remove a job by ID."""
store = self._load_store()
before = len(store.jobs)
store.jobs = [j for j in store.jobs if j.id != job_id]
removed = len(store.jobs) < before
-
+
if removed:
self._save_store()
self._arm_timer()
logger.info("Cron: removed job {}", job_id)
-
+
return removed
-
+
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
"""Enable or disable a job."""
store = self._load_store()
@@ -343,7 +352,7 @@ class CronService:
self._arm_timer()
return job
return None
-
+
async def run_job(self, job_id: str, force: bool = False) -> bool:
"""Manually run a job."""
store = self._load_store()
@@ -356,7 +365,7 @@ class CronService:
self._arm_timer()
return True
return False
-
+
def status(self) -> dict:
"""Get service status."""
store = self._load_store()
diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py
index e534017..7be81ff 100644
--- a/nanobot/heartbeat/service.py
+++ b/nanobot/heartbeat/service.py
@@ -87,10 +87,13 @@ class HeartbeatService:
Returns (action, tasks) where action is 'skip' or 'run'.
"""
- response = await self.provider.chat(
+ from nanobot.utils.helpers import current_time_str
+
+ response = await self.provider.chat_with_retry(
messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": (
+ f"Current Time: {current_time_str()}\n\n"
"Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n"
f"{content}"
)},
@@ -139,6 +142,8 @@ class HeartbeatService:
async def _tick(self) -> None:
"""Execute a single heartbeat tick."""
+ from nanobot.utils.evaluator import evaluate_response
+
content = self._read_heartbeat_file()
if not content:
logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
@@ -156,9 +161,16 @@ class HeartbeatService:
logger.info("Heartbeat: tasks found, executing...")
if self.on_execute:
response = await self.on_execute(tasks)
- if response and self.on_notify:
- logger.info("Heartbeat: completed, delivering response")
- await self.on_notify(response)
+
+ if response:
+ should_notify = await evaluate_response(
+ response, tasks, self.provider, self.model,
+ )
+ if should_notify and self.on_notify:
+ logger.info("Heartbeat: completed, delivering response")
+ await self.on_notify(response)
+ else:
+ logger.info("Heartbeat: silenced by post-run evaluation")
except Exception:
logger.exception("Heartbeat execution failed")
diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py
index b2bb2b9..5bd06f9 100644
--- a/nanobot/providers/__init__.py
+++ b/nanobot/providers/__init__.py
@@ -3,5 +3,6 @@
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
+from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
-__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
+__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py
new file mode 100644
index 0000000..05fbac4
--- /dev/null
+++ b/nanobot/providers/azure_openai_provider.py
@@ -0,0 +1,213 @@
+"""Azure OpenAI provider implementation with API version 2024-10-21."""
+
+from __future__ import annotations
+
+import uuid
+from typing import Any
+from urllib.parse import urljoin
+
+import httpx
+import json_repair
+
+from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
+
+_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
+
+
+class AzureOpenAIProvider(LLMProvider):
+ """
+ Azure OpenAI provider with API version 2024-10-21 compliance.
+
+ Features:
+ - Hardcoded API version 2024-10-21
+ - Uses model field as Azure deployment name in URL path
+ - Uses api-key header instead of Authorization Bearer
+ - Uses max_completion_tokens instead of max_tokens
+ - Direct HTTP calls, bypasses LiteLLM
+ """
+
+ def __init__(
+ self,
+ api_key: str = "",
+ api_base: str = "",
+ default_model: str = "gpt-5.2-chat",
+ ):
+ super().__init__(api_key, api_base)
+ self.default_model = default_model
+ self.api_version = "2024-10-21"
+
+ # Validate required parameters
+ if not api_key:
+ raise ValueError("Azure OpenAI api_key is required")
+ if not api_base:
+ raise ValueError("Azure OpenAI api_base is required")
+
+ # Ensure api_base ends with /
+ if not api_base.endswith('/'):
+ api_base += '/'
+ self.api_base = api_base
+
+ def _build_chat_url(self, deployment_name: str) -> str:
+ """Build the Azure OpenAI chat completions URL."""
+ # Azure OpenAI URL format:
+ # https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
+ base_url = self.api_base
+ if not base_url.endswith('/'):
+ base_url += '/'
+
+ url = urljoin(
+ base_url,
+ f"openai/deployments/{deployment_name}/chat/completions"
+ )
+ return f"{url}?api-version={self.api_version}"
+
+ def _build_headers(self) -> dict[str, str]:
+ """Build headers for Azure OpenAI API with api-key header."""
+ return {
+ "Content-Type": "application/json",
+ "api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
+ "x-session-affinity": uuid.uuid4().hex, # For cache locality
+ }
+
+ @staticmethod
+ def _supports_temperature(
+ deployment_name: str,
+ reasoning_effort: str | None = None,
+ ) -> bool:
+ """Return True when temperature is likely supported for this deployment."""
+ if reasoning_effort:
+ return False
+ name = deployment_name.lower()
+ return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
+
+ def _prepare_request_payload(
+ self,
+ deployment_name: str,
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None = None,
+ max_tokens: int = 4096,
+ temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ """Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
+ payload: dict[str, Any] = {
+ "messages": self._sanitize_request_messages(
+ self._sanitize_empty_content(messages),
+ _AZURE_MSG_KEYS,
+ ),
+ "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
+ }
+
+ if self._supports_temperature(deployment_name, reasoning_effort):
+ payload["temperature"] = temperature
+
+ if reasoning_effort:
+ payload["reasoning_effort"] = reasoning_effort
+
+ if tools:
+ payload["tools"] = tools
+ payload["tool_choice"] = tool_choice or "auto"
+
+ return payload
+
+ async def chat(
+ self,
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None = None,
+ model: str | None = None,
+ max_tokens: int = 4096,
+ temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
+ ) -> LLMResponse:
+ """
+ Send a chat completion request to Azure OpenAI.
+
+ Args:
+ messages: List of message dicts with 'role' and 'content'.
+ tools: Optional list of tool definitions in OpenAI format.
+ model: Model identifier (used as deployment name).
+ max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
+ temperature: Sampling temperature.
+ reasoning_effort: Optional reasoning effort parameter.
+
+ Returns:
+ LLMResponse with content and/or tool calls.
+ """
+ deployment_name = model or self.default_model
+ url = self._build_chat_url(deployment_name)
+ headers = self._build_headers()
+ payload = self._prepare_request_payload(
+ deployment_name, messages, tools, max_tokens, temperature, reasoning_effort,
+ tool_choice=tool_choice,
+ )
+
+ try:
+ async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
+ response = await client.post(url, headers=headers, json=payload)
+ if response.status_code != 200:
+ return LLMResponse(
+ content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
+ finish_reason="error",
+ )
+
+ response_data = response.json()
+ return self._parse_response(response_data)
+
+ except Exception as e:
+ return LLMResponse(
+ content=f"Error calling Azure OpenAI: {repr(e)}",
+ finish_reason="error",
+ )
+
+ def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
+ """Parse Azure OpenAI response into our standard format."""
+ try:
+ choice = response["choices"][0]
+ message = choice["message"]
+
+ tool_calls = []
+ if message.get("tool_calls"):
+ for tc in message["tool_calls"]:
+ # Parse arguments from JSON string if needed
+ args = tc["function"]["arguments"]
+ if isinstance(args, str):
+ args = json_repair.loads(args)
+
+ tool_calls.append(
+ ToolCallRequest(
+ id=tc["id"],
+ name=tc["function"]["name"],
+ arguments=args,
+ )
+ )
+
+ usage = {}
+ if response.get("usage"):
+ usage_data = response["usage"]
+ usage = {
+ "prompt_tokens": usage_data.get("prompt_tokens", 0),
+ "completion_tokens": usage_data.get("completion_tokens", 0),
+ "total_tokens": usage_data.get("total_tokens", 0),
+ }
+
+ reasoning_content = message.get("reasoning_content") or None
+
+ return LLMResponse(
+ content=message.get("content"),
+ tool_calls=tool_calls,
+ finish_reason=choice.get("finish_reason", "stop"),
+ usage=usage,
+ reasoning_content=reasoning_content,
+ )
+
+ except (KeyError, IndexError) as e:
+ return LLMResponse(
+ content=f"Error parsing Azure OpenAI response: {str(e)}",
+ finish_reason="error",
+ )
+
+ def get_default_model(self) -> str:
+ """Get the default model (also used as default deployment name)."""
+ return self.default_model
\ No newline at end of file
diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py
index eb1599a..8b6956c 100644
--- a/nanobot/providers/base.py
+++ b/nanobot/providers/base.py
@@ -1,9 +1,13 @@
"""Base LLM provider interface."""
+import asyncio
+import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
+from loguru import logger
+
@dataclass
class ToolCallRequest:
@@ -11,6 +15,24 @@ class ToolCallRequest:
id: str
name: str
arguments: dict[str, Any]
+ provider_specific_fields: dict[str, Any] | None = None
+ function_provider_specific_fields: dict[str, Any] | None = None
+
+ def to_openai_tool_call(self) -> dict[str, Any]:
+ """Serialize to an OpenAI-style tool_call payload."""
+ tool_call = {
+ "id": self.id,
+ "type": "function",
+ "function": {
+ "name": self.name,
+ "arguments": json.dumps(self.arguments, ensure_ascii=False),
+ },
+ }
+ if self.provider_specific_fields:
+ tool_call["provider_specific_fields"] = self.provider_specific_fields
+ if self.function_provider_specific_fields:
+ tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
+ return tool_call
@dataclass
@@ -21,6 +43,7 @@ class LLMResponse:
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
+ thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property
def has_tool_calls(self) -> bool:
@@ -28,6 +51,21 @@ class LLMResponse:
return len(self.tool_calls) > 0
+@dataclass(frozen=True)
+class GenerationSettings:
+ """Default generation parameters for LLM calls.
+
+ Stored on the provider so every call site inherits the same defaults
+ without having to pass temperature / max_tokens / reasoning_effort
+ through every layer. Individual call sites can still override by
+ passing explicit keyword arguments to chat() / chat_with_retry().
+ """
+
+ temperature: float = 0.7
+ max_tokens: int = 4096
+ reasoning_effort: str | None = None
+
+
class LLMProvider(ABC):
"""
Abstract base class for LLM providers.
@@ -35,10 +73,37 @@ class LLMProvider(ABC):
Implementations should handle the specifics of each provider's API
while maintaining a consistent interface.
"""
-
+
+ _CHAT_RETRY_DELAYS = (1, 2, 4)
+ _TRANSIENT_ERROR_MARKERS = (
+ "429",
+ "rate limit",
+ "500",
+ "502",
+ "503",
+ "504",
+ "overloaded",
+ "timeout",
+ "timed out",
+ "connection",
+ "server error",
+ "temporarily unavailable",
+ )
+ _IMAGE_UNSUPPORTED_MARKERS = (
+ "image_url is only supported",
+ "does not support image",
+ "images are not supported",
+ "image input is not supported",
+ "image_url is not supported",
+ "unsupported image input",
+ )
+
+ _SENTINEL = object()
+
def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key
self.api_base = api_base
+ self.generation: GenerationSettings = GenerationSettings()
@staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -77,9 +142,29 @@ class LLMProvider(ABC):
result.append(clean)
continue
+ if isinstance(content, dict):
+ clean = dict(msg)
+ clean["content"] = [content]
+ result.append(clean)
+ continue
+
result.append(msg)
return result
-
+
+ @staticmethod
+ def _sanitize_request_messages(
+ messages: list[dict[str, Any]],
+ allowed_keys: frozenset[str],
+ ) -> list[dict[str, Any]]:
+ """Keep only provider-safe message keys and normalize assistant content."""
+ sanitized = []
+ for msg in messages:
+ clean = {k: v for k, v in msg.items() if k in allowed_keys}
+ if clean.get("role") == "assistant" and "content" not in clean:
+ clean["content"] = None
+ sanitized.append(clean)
+ return sanitized
+
@abstractmethod
async def chat(
self,
@@ -88,6 +173,8 @@ class LLMProvider(ABC):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request.
@@ -98,12 +185,104 @@ class LLMProvider(ABC):
model: Model identifier (provider-specific).
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
+ tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns:
LLMResponse with content and/or tool calls.
"""
pass
-
+
+ @classmethod
+ def _is_transient_error(cls, content: str | None) -> bool:
+ err = (content or "").lower()
+ return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
+
+ @classmethod
+ def _is_image_unsupported_error(cls, content: str | None) -> bool:
+ err = (content or "").lower()
+ return any(marker in err for marker in cls._IMAGE_UNSUPPORTED_MARKERS)
+
+ @staticmethod
+ def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
+ """Replace image_url blocks with text placeholder. Returns None if no images found."""
+ found = False
+ result = []
+ for msg in messages:
+ content = msg.get("content")
+ if isinstance(content, list):
+ new_content = []
+ for b in content:
+ if isinstance(b, dict) and b.get("type") == "image_url":
+ new_content.append({"type": "text", "text": "[image omitted]"})
+ found = True
+ else:
+ new_content.append(b)
+ result.append({**msg, "content": new_content})
+ else:
+ result.append(msg)
+ return result if found else None
+
+ async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
+ """Call chat() and convert unexpected exceptions to error responses."""
+ try:
+ return await self.chat(**kwargs)
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc:
+ return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
+
+ async def chat_with_retry(
+ self,
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None = None,
+ model: str | None = None,
+ max_tokens: object = _SENTINEL,
+ temperature: object = _SENTINEL,
+ reasoning_effort: object = _SENTINEL,
+ tool_choice: str | dict[str, Any] | None = None,
+ ) -> LLMResponse:
+ """Call chat() with retry on transient provider failures.
+
+ Parameters default to ``self.generation`` when not explicitly passed,
+ so callers no longer need to thread temperature / max_tokens /
+ reasoning_effort through every layer.
+ """
+ if max_tokens is self._SENTINEL:
+ max_tokens = self.generation.max_tokens
+ if temperature is self._SENTINEL:
+ temperature = self.generation.temperature
+ if reasoning_effort is self._SENTINEL:
+ reasoning_effort = self.generation.reasoning_effort
+
+ kw: dict[str, Any] = dict(
+ messages=messages, tools=tools, model=model,
+ max_tokens=max_tokens, temperature=temperature,
+ reasoning_effort=reasoning_effort, tool_choice=tool_choice,
+ )
+
+ for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
+ response = await self._safe_chat(**kw)
+
+ if response.finish_reason != "error":
+ return response
+
+ if not self._is_transient_error(response.content):
+ if self._is_image_unsupported_error(response.content):
+ stripped = self._strip_image_content(messages)
+ if stripped is not None:
+ logger.warning("Model does not support image input, retrying without images")
+ return await self._safe_chat(**{**kw, "messages": stripped})
+ return response
+
+ logger.warning(
+ "LLM transient error (attempt {}/{}), retrying in {}s: {}",
+ attempt, len(self._CHAT_RETRY_DELAYS), delay,
+ (response.content or "")[:120].lower(),
+ )
+ await asyncio.sleep(delay)
+
+ return await self._safe_chat(**kw)
+
@abstractmethod
def get_default_model(self) -> str:
"""Get the default model for this provider."""
diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py
index a578d14..e177e55 100644
--- a/nanobot/providers/custom_provider.py
+++ b/nanobot/providers/custom_provider.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import uuid
from typing import Any
import json_repair
@@ -12,21 +13,41 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class CustomProvider(LLMProvider):
- def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
+ def __init__(
+ self,
+ api_key: str = "no-key",
+ api_base: str = "http://localhost:8000/v1",
+ default_model: str = "default",
+ extra_headers: dict[str, str] | None = None,
+ ):
super().__init__(api_key, api_base)
self.default_model = default_model
- self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
+ # Keep affinity stable for this provider instance to improve backend cache locality,
+ # while still letting users attach provider-specific headers for custom gateways.
+ default_headers = {
+ "x-session-affinity": uuid.uuid4().hex,
+ **(extra_headers or {}),
+ }
+ self._client = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ default_headers=default_headers,
+ )
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
- model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
+ model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"messages": self._sanitize_empty_content(messages),
"max_tokens": max(1, max_tokens),
"temperature": temperature,
}
+ if reasoning_effort:
+ kwargs["reasoning_effort"] = reasoning_effort
if tools:
- kwargs.update(tools=tools, tool_choice="auto")
+ kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
try:
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py
index 7402a2b..d14e4c0 100644
--- a/nanobot/providers/litellm_provider.py
+++ b/nanobot/providers/litellm_provider.py
@@ -1,19 +1,27 @@
"""LiteLLM provider implementation for multi-provider support."""
-import json
-import json_repair
+import hashlib
import os
+import secrets
+import string
from typing import Any
+import json_repair
import litellm
from litellm import acompletion
+from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.registry import find_by_model, find_gateway
+# Standard chat-completion message keys.
+_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
+_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
+_ALNUM = string.ascii_letters + string.digits
-# Standard OpenAI chat-completion message keys; extras (e.g. reasoning_content) are stripped for strict providers.
-_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
+def _short_tool_id() -> str:
+ """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
+ return "".join(secrets.choice(_ALNUM) for _ in range(9))
class LiteLLMProvider(LLMProvider):
@@ -24,10 +32,10 @@ class LiteLLMProvider(LLMProvider):
a unified interface. Provider-specific logic is driven by the registry
(see providers/registry.py) — no if-elif chains needed here.
"""
-
+
def __init__(
- self,
- api_key: str | None = None,
+ self,
+ api_key: str | None = None,
api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None,
@@ -36,24 +44,26 @@ class LiteLLMProvider(LLMProvider):
super().__init__(api_key, api_base)
self.default_model = default_model
self.extra_headers = extra_headers or {}
-
+
# Detect gateway / local deployment.
# provider_name (from config key) is the primary signal;
# api_key / api_base are fallback for auto-detection.
self._gateway = find_gateway(provider_name, api_key, api_base)
-
+
# Configure environment variables
if api_key:
self._setup_env(api_key, api_base, default_model)
-
+
if api_base:
litellm.api_base = api_base
-
+
# Disable LiteLLM logging noise
litellm.suppress_debug_info = True
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
litellm.drop_params = True
-
+
+ self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY"))
+
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
"""Set environment variables based on detected provider."""
spec = self._gateway or find_by_model(model)
@@ -77,18 +87,17 @@ class LiteLLMProvider(LLMProvider):
resolved = env_val.replace("{api_key}", api_key)
resolved = resolved.replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
-
+
def _resolve_model(self, model: str) -> str:
"""Resolve model name by applying provider/gateway prefixes."""
if self._gateway:
- # Gateway mode: apply gateway prefix, skip provider-specific prefixes
prefix = self._gateway.litellm_prefix
if self._gateway.strip_model_prefix:
model = model.split("/")[-1]
- if prefix and not model.startswith(f"{prefix}/"):
+ if prefix:
model = f"{prefix}/{model}"
return model
-
+
# Standard mode: auto-prefix for known providers
spec = find_by_model(model)
if spec and spec.litellm_prefix:
@@ -107,7 +116,7 @@ class LiteLLMProvider(LLMProvider):
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"
-
+
def _supports_cache_control(self, model: str) -> bool:
"""Return True when the provider supports cache_control on content blocks."""
if self._gateway is not None:
@@ -150,17 +159,52 @@ class LiteLLMProvider(LLMProvider):
if pattern in model_lower:
kwargs.update(overrides)
return
-
+
@staticmethod
- def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
+ """Return provider-specific extra keys to preserve in request messages."""
+ spec = find_by_model(original_model) or find_by_model(resolved_model)
+ if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
+ return _ANTHROPIC_EXTRA_KEYS
+ return frozenset()
+
+ @staticmethod
+ def _normalize_tool_call_id(tool_call_id: Any) -> Any:
+ """Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
+ if not isinstance(tool_call_id, str):
+ return tool_call_id
+ if len(tool_call_id) == 9 and tool_call_id.isalnum():
+ return tool_call_id
+ return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
+
+ @staticmethod
+ def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
- sanitized = []
- for msg in messages:
- clean = {k: v for k, v in msg.items() if k in _ALLOWED_MSG_KEYS}
- # Strict providers require "content" even when assistant only has tool_calls
- if clean.get("role") == "assistant" and "content" not in clean:
- clean["content"] = None
- sanitized.append(clean)
+ allowed = _ALLOWED_MSG_KEYS | extra_keys
+ sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
+ id_map: dict[str, str] = {}
+
+ def map_id(value: Any) -> Any:
+ if not isinstance(value, str):
+ return value
+ return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
+
+ for clean in sanitized:
+ # Keep assistant tool_calls[].id and tool tool_call_id in sync after
+ # shortening, otherwise strict providers reject the broken linkage.
+ if isinstance(clean.get("tool_calls"), list):
+ normalized_tool_calls = []
+ for tc in clean["tool_calls"]:
+ if not isinstance(tc, dict):
+ normalized_tool_calls.append(tc)
+ continue
+ tc_clean = dict(tc)
+ tc_clean["id"] = map_id(tc_clean.get("id"))
+ normalized_tool_calls.append(tc_clean)
+ clean["tool_calls"] = normalized_tool_calls
+
+ if "tool_call_id" in clean and clean["tool_call_id"]:
+ clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
async def chat(
@@ -170,22 +214,25 @@ class LiteLLMProvider(LLMProvider):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request via LiteLLM.
-
+
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions in OpenAI format.
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
-
+
Returns:
LLMResponse with content and/or tool calls.
"""
original_model = model or self.default_model
model = self._resolve_model(original_model)
+ extra_msg_keys = self._extra_msg_keys(original_model, model)
if self._supports_cache_control(original_model):
messages, tools = self._apply_cache_control(messages, tools)
@@ -193,33 +240,43 @@ class LiteLLMProvider(LLMProvider):
# Clamp max_tokens to at least 1 — negative or zero values cause
# LiteLLM to reject the request with "max_tokens must be at least 1".
max_tokens = max(1, max_tokens)
-
+
kwargs: dict[str, Any] = {
"model": model,
- "messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
+ "messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
"max_tokens": max_tokens,
"temperature": temperature,
}
-
+
+ if self._gateway:
+ kwargs.update(self._gateway.litellm_kwargs)
+
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
self._apply_model_overrides(model, kwargs)
-
+
+ if self._langsmith_enabled:
+ kwargs.setdefault("callbacks", []).append("langsmith")
+
# Pass api_key directly — more reliable than env vars alone
if self.api_key:
kwargs["api_key"] = self.api_key
-
+
# Pass api_base for custom endpoints
if self.api_base:
kwargs["api_base"] = self.api_base
-
+
# Pass extra headers (e.g. APP-Code for AiHubMix)
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
+ if reasoning_effort:
+ kwargs["reasoning_effort"] = reasoning_effort
+ kwargs["drop_params"] = True
+
if tools:
kwargs["tools"] = tools
- kwargs["tool_choice"] = "auto"
-
+ kwargs["tool_choice"] = tool_choice or "auto"
+
try:
response = await acompletion(**kwargs)
return self._parse_response(response)
@@ -229,26 +286,50 @@ class LiteLLMProvider(LLMProvider):
content=f"Error calling LLM: {str(e)}",
finish_reason="error",
)
-
+
def _parse_response(self, response: Any) -> LLMResponse:
"""Parse LiteLLM response into our standard format."""
choice = response.choices[0]
message = choice.message
-
+ content = message.content
+ finish_reason = choice.finish_reason
+
+ # Some providers (e.g. GitHub Copilot) split content and tool_calls
+ # across multiple choices. Merge them so tool_calls are not lost.
+ raw_tool_calls = []
+ for ch in response.choices:
+ msg = ch.message
+ if hasattr(msg, "tool_calls") and msg.tool_calls:
+ raw_tool_calls.extend(msg.tool_calls)
+ if ch.finish_reason in ("tool_calls", "stop"):
+ finish_reason = ch.finish_reason
+ if not content and msg.content:
+ content = msg.content
+
+ if len(response.choices) > 1:
+ logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
+ len(response.choices), len(raw_tool_calls))
+
tool_calls = []
- if hasattr(message, "tool_calls") and message.tool_calls:
- for tc in message.tool_calls:
- # Parse arguments from JSON string if needed
- args = tc.function.arguments
- if isinstance(args, str):
- args = json_repair.loads(args)
-
- tool_calls.append(ToolCallRequest(
- id=tc.id,
- name=tc.function.name,
- arguments=args,
- ))
-
+ for tc in raw_tool_calls:
+ # Parse arguments from JSON string if needed
+ args = tc.function.arguments
+ if isinstance(args, str):
+ args = json_repair.loads(args)
+
+ provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
+ function_provider_specific_fields = (
+ getattr(tc.function, "provider_specific_fields", None) or None
+ )
+
+ tool_calls.append(ToolCallRequest(
+ id=_short_tool_id(),
+ name=tc.function.name,
+ arguments=args,
+ provider_specific_fields=provider_specific_fields,
+ function_provider_specific_fields=function_provider_specific_fields,
+ ))
+
usage = {}
if hasattr(response, "usage") and response.usage:
usage = {
@@ -256,17 +337,19 @@ class LiteLLMProvider(LLMProvider):
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
-
+
reasoning_content = getattr(message, "reasoning_content", None) or None
-
+ thinking_blocks = getattr(message, "thinking_blocks", None) or None
+
return LLMResponse(
- content=message.content,
+ content=content,
tool_calls=tool_calls,
- finish_reason=choice.finish_reason or "stop",
+ finish_reason=finish_reason or "stop",
usage=usage,
reasoning_content=reasoning_content,
+ thinking_blocks=thinking_blocks,
)
-
+
def get_default_model(self) -> str:
"""Get the default model."""
return self.default_model
diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py
index fa28593..c8f2155 100644
--- a/nanobot/providers/openai_codex_provider.py
+++ b/nanobot/providers/openai_codex_provider.py
@@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
import httpx
from loguru import logger
-
from oauth_cli_kit import get_token as get_codex_token
+
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
@@ -31,6 +31,8 @@ class OpenAICodexProvider(LLMProvider):
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
+ reasoning_effort: str | None = None,
+ tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
model = model or self.default_model
system_prompt, input_items = _convert_messages(messages)
@@ -47,10 +49,13 @@ class OpenAICodexProvider(LLMProvider):
"text": {"verbosity": "medium"},
"include": ["reasoning.encrypted_content"],
"prompt_cache_key": _prompt_cache_key(messages),
- "tool_choice": "auto",
+ "tool_choice": tool_choice or "auto",
"parallel_tool_calls": True,
}
+ if reasoning_effort:
+ body["reasoning"] = {"effort": reasoning_effort}
+
if tools:
body["tools"] = _convert_tools(tools)
diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py
index 2766929..42c1d24 100644
--- a/nanobot/providers/registry.py
+++ b/nanobot/providers/registry.py
@@ -12,7 +12,7 @@ Every entry writes out all fields so you can copy-paste as a template.
from __future__ import annotations
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import Any
@@ -26,33 +26,34 @@ class ProviderSpec:
"""
# identity
- name: str # config field name, e.g. "dashscope"
- keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
- env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
- display_name: str = "" # shown in `nanobot status`
+ name: str # config field name, e.g. "dashscope"
+ keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
+ env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
+ display_name: str = "" # shown in `nanobot status`
# model prefixing
- litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
- skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
+ litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
+ skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
env_extras: tuple[tuple[str, str], ...] = ()
# gateway / local detection
- is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
- is_local: bool = False # local deployment (vLLM, Ollama)
- detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
- detect_by_base_keyword: str = "" # match substring in api_base URL
- default_api_base: str = "" # fallback base URL
+ is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
+ is_local: bool = False # local deployment (vLLM, Ollama)
+ detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
+ detect_by_base_keyword: str = "" # match substring in api_base URL
+ default_api_base: str = "" # fallback base URL
# gateway behavior
- strip_model_prefix: bool = False # strip "provider/" before re-prefixing
+ strip_model_prefix: bool = False # strip "provider/" before re-prefixing
+ litellm_kwargs: dict[str, Any] = field(default_factory=dict) # extra kwargs passed to LiteLLM
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
- is_oauth: bool = False # if True, uses OAuth flow instead of API key
+ is_oauth: bool = False # if True, uses OAuth flow instead of API key
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
is_direct: bool = False
@@ -70,7 +71,6 @@ class ProviderSpec:
# ---------------------------------------------------------------------------
PROVIDERS: tuple[ProviderSpec, ...] = (
-
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
ProviderSpec(
name="custom",
@@ -81,16 +81,24 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
is_direct=True,
),
+ # === Azure OpenAI (direct API calls with API version 2024-10-21) =====
+ ProviderSpec(
+ name="azure_openai",
+ keywords=("azure", "azure-openai"),
+ env_key="",
+ display_name="Azure OpenAI",
+ litellm_prefix="",
+ is_direct=True,
+ ),
# === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback.
-
# OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
- litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
+ litellm_prefix="openrouter", # anthropic/claude-3 → openrouter/anthropic/claude-3
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -102,16 +110,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
-
# AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
- env_key="OPENAI_API_KEY", # OpenAI-compatible
+ env_key="OPENAI_API_KEY", # OpenAI-compatible
display_name="AiHubMix",
- litellm_prefix="openai", # → openai/{model}
+ litellm_prefix="openai", # → openai/{model}
skip_prefixes=(),
env_extras=(),
is_gateway=True,
@@ -119,10 +126,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
detect_by_key_prefix="",
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
- strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
+ strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
model_overrides=(),
),
-
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
ProviderSpec(
name="siliconflow",
@@ -141,7 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
),
- # VolcEngine (火山引擎): OpenAI-compatible gateway
+ # VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
@@ -159,8 +165,62 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
),
- # === Standard providers (matched by model-name keywords) ===============
+ # VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
+ ProviderSpec(
+ name="volcengine_coding_plan",
+ keywords=("volcengine-plan",),
+ env_key="OPENAI_API_KEY",
+ display_name="VolcEngine Coding Plan",
+ litellm_prefix="volcengine",
+ skip_prefixes=(),
+ env_extras=(),
+ is_gateway=True,
+ is_local=False,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="",
+ default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
+ strip_model_prefix=True,
+ model_overrides=(),
+ ),
+ # BytePlus: VolcEngine international, pay-per-use models
+ ProviderSpec(
+ name="byteplus",
+ keywords=("byteplus",),
+ env_key="OPENAI_API_KEY",
+ display_name="BytePlus",
+ litellm_prefix="volcengine",
+ skip_prefixes=(),
+ env_extras=(),
+ is_gateway=True,
+ is_local=False,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="bytepluses",
+ default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
+ strip_model_prefix=True,
+ model_overrides=(),
+ ),
+
+ # BytePlus Coding Plan: same key as byteplus
+ ProviderSpec(
+ name="byteplus_coding_plan",
+ keywords=("byteplus-plan",),
+ env_key="OPENAI_API_KEY",
+ display_name="BytePlus Coding Plan",
+ litellm_prefix="volcengine",
+ skip_prefixes=(),
+ env_extras=(),
+ is_gateway=True,
+ is_local=False,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="",
+ default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
+ strip_model_prefix=True,
+ model_overrides=(),
+ ),
+
+
+ # === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec(
name="anthropic",
@@ -179,7 +239,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
supports_prompt_caching=True,
),
-
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec(
name="openai",
@@ -197,14 +256,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# OpenAI Codex: uses OAuth, not API key.
ProviderSpec(
name="openai_codex",
- keywords=("openai-codex", "codex"),
- env_key="", # OAuth-based, no API key
+ keywords=("openai-codex",),
+ env_key="", # OAuth-based, no API key
display_name="OpenAI Codex",
- litellm_prefix="", # Not routed through LiteLLM
+ litellm_prefix="", # Not routed through LiteLLM
skip_prefixes=(),
env_extras=(),
is_gateway=False,
@@ -214,16 +272,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://chatgpt.com/backend-api",
strip_model_prefix=False,
model_overrides=(),
- is_oauth=True, # OAuth-based authentication
+ is_oauth=True, # OAuth-based authentication
),
-
# Github Copilot: uses OAuth, not API key.
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
- env_key="", # OAuth-based, no API key
+ env_key="", # OAuth-based, no API key
display_name="Github Copilot",
- litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
+ litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
skip_prefixes=("github_copilot/",),
env_extras=(),
is_gateway=False,
@@ -233,17 +290,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
- is_oauth=True, # OAuth-based authentication
+ is_oauth=True, # OAuth-based authentication
),
-
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek",
- litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
- skip_prefixes=("deepseek/",), # avoid double-prefix
+ litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
+ skip_prefixes=("deepseek/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -253,15 +309,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
- litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
- skip_prefixes=("gemini/",), # avoid double-prefix
+ litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
+ skip_prefixes=("gemini/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -271,7 +326,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway.
@@ -280,11 +334,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
- litellm_prefix="zai", # glm-4 → zai/glm-4
+ litellm_prefix="zai", # glm-4 → zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
- env_extras=(
- ("ZHIPUAI_API_KEY", "{api_key}"),
- ),
+ env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
@@ -293,14 +345,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
- litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
+ litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
skip_prefixes=("dashscope/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -311,7 +362,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0.
@@ -320,22 +370,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
- litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
+ litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"),
- env_extras=(
- ("MOONSHOT_API_BASE", "{api_base}"),
- ),
+ env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
- default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
+ default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False,
- model_overrides=(
- ("kimi-k2.5", {"temperature": 1.0}),
- ),
+ model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
),
-
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec(
@@ -343,7 +388,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
- litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
+ litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
skip_prefixes=("minimax/", "openrouter/"),
env_extras=(),
is_gateway=False,
@@ -354,9 +399,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
strip_model_prefix=False,
model_overrides=(),
),
-
# === Local deployment (matched by config key, NOT by api_base) =========
-
# vLLM / any OpenAI-compatible local server.
# Detected when config key is "vllm" (provider_name="vllm").
ProviderSpec(
@@ -364,20 +407,35 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local",
- litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
+ litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="",
- default_api_base="", # user must provide in config
+ default_api_base="", # user must provide in config
+ strip_model_prefix=False,
+ model_overrides=(),
+ ),
+ # === Ollama (local, OpenAI-compatible) ===================================
+ ProviderSpec(
+ name="ollama",
+ keywords=("ollama", "nemotron"),
+ env_key="OLLAMA_API_KEY",
+ display_name="Ollama",
+ litellm_prefix="ollama_chat", # model → ollama_chat/model
+ skip_prefixes=("ollama/", "ollama_chat/"),
+ env_extras=(),
+ is_gateway=False,
+ is_local=True,
+ detect_by_key_prefix="",
+ detect_by_base_keyword="11434",
+ default_api_base="http://localhost:11434",
strip_model_prefix=False,
model_overrides=(),
),
-
# === Auxiliary (not a primary LLM provider) ============================
-
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
ProviderSpec(
@@ -385,8 +443,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
- litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
- skip_prefixes=("groq/",), # avoid double-prefix
+ litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
+ skip_prefixes=("groq/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
@@ -403,6 +461,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
# Lookup helpers
# ---------------------------------------------------------------------------
+
def find_by_model(model: str) -> ProviderSpec | None:
"""Match a standard provider by model-name keyword (case-insensitive).
Skips gateways/local — those are matched by api_key/api_base instead."""
@@ -418,7 +477,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
return spec
for spec in std_specs:
- if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
+ if any(
+ kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
+ ):
return spec
return None
diff --git a/nanobot/providers/transcription.py b/nanobot/providers/transcription.py
index 7a3c628..1c8cb6a 100644
--- a/nanobot/providers/transcription.py
+++ b/nanobot/providers/transcription.py
@@ -2,7 +2,6 @@
import os
from pathlib import Path
-from typing import Any
import httpx
from loguru import logger
@@ -11,33 +10,33 @@ from loguru import logger
class GroqTranscriptionProvider:
"""
Voice transcription provider using Groq's Whisper API.
-
+
Groq offers extremely fast transcription with a generous free tier.
"""
-
+
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
-
+
async def transcribe(self, file_path: str | Path) -> str:
"""
Transcribe an audio file using Groq.
-
+
Args:
file_path: Path to the audio file.
-
+
Returns:
Transcribed text.
"""
if not self.api_key:
logger.warning("Groq API key not configured for transcription")
return ""
-
+
path = Path(file_path)
if not path.exists():
logger.error("Audio file not found: {}", file_path)
return ""
-
+
try:
async with httpx.AsyncClient() as client:
with open(path, "rb") as f:
@@ -48,18 +47,18 @@ class GroqTranscriptionProvider:
headers = {
"Authorization": f"Bearer {self.api_key}",
}
-
+
response = await client.post(
self.api_url,
headers=headers,
files=files,
timeout=60.0
)
-
+
response.raise_for_status()
data = response.json()
return data.get("text", "")
-
+
except Exception as e:
logger.error("Groq transcription error: {}", e)
return ""
diff --git a/nanobot/security/__init__.py b/nanobot/security/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/nanobot/security/__init__.py
@@ -0,0 +1 @@
+
diff --git a/nanobot/security/network.py b/nanobot/security/network.py
new file mode 100644
index 0000000..9005828
--- /dev/null
+++ b/nanobot/security/network.py
@@ -0,0 +1,104 @@
+"""Network security utilities — SSRF protection and internal URL detection."""
+
+from __future__ import annotations
+
+import ipaddress
+import re
+import socket
+from urllib.parse import urlparse
+
+_BLOCKED_NETWORKS = [
+ ipaddress.ip_network("0.0.0.0/8"),
+ ipaddress.ip_network("10.0.0.0/8"),
+ ipaddress.ip_network("100.64.0.0/10"), # carrier-grade NAT
+ ipaddress.ip_network("127.0.0.0/8"),
+ ipaddress.ip_network("169.254.0.0/16"), # link-local / cloud metadata
+ ipaddress.ip_network("172.16.0.0/12"),
+ ipaddress.ip_network("192.168.0.0/16"),
+ ipaddress.ip_network("::1/128"),
+ ipaddress.ip_network("fc00::/7"), # unique local
+ ipaddress.ip_network("fe80::/10"), # link-local v6
+]
+
+_URL_RE = re.compile(r"https?://[^\s\"'`;|<>]+", re.IGNORECASE)
+
+
+def _is_private(addr: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
+ return any(addr in net for net in _BLOCKED_NETWORKS)
+
+
+def validate_url_target(url: str) -> tuple[bool, str]:
+ """Validate a URL is safe to fetch: scheme, hostname, and resolved IPs.
+
+ Returns (ok, error_message). When ok is True, error_message is empty.
+ """
+ try:
+ p = urlparse(url)
+ except Exception as e:
+ return False, str(e)
+
+ if p.scheme not in ("http", "https"):
+ return False, f"Only http/https allowed, got '{p.scheme or 'none'}'"
+ if not p.netloc:
+ return False, "Missing domain"
+
+ hostname = p.hostname
+ if not hostname:
+ return False, "Missing hostname"
+
+ try:
+ infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
+ except socket.gaierror:
+ return False, f"Cannot resolve hostname: {hostname}"
+
+ for info in infos:
+ try:
+ addr = ipaddress.ip_address(info[4][0])
+ except ValueError:
+ continue
+ if _is_private(addr):
+ return False, f"Blocked: {hostname} resolves to private/internal address {addr}"
+
+ return True, ""
+
+
+def validate_resolved_url(url: str) -> tuple[bool, str]:
+ """Validate an already-fetched URL (e.g. after redirect). Only checks the IP, skips DNS."""
+ try:
+ p = urlparse(url)
+ except Exception:
+ return True, ""
+
+ hostname = p.hostname
+ if not hostname:
+ return True, ""
+
+ try:
+ addr = ipaddress.ip_address(hostname)
+ if _is_private(addr):
+ return False, f"Redirect target is a private address: {addr}"
+ except ValueError:
+ # hostname is a domain name, resolve it
+ try:
+ infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM)
+ except socket.gaierror:
+ return True, ""
+ for info in infos:
+ try:
+ addr = ipaddress.ip_address(info[4][0])
+ except ValueError:
+ continue
+ if _is_private(addr):
+ return False, f"Redirect target {hostname} resolves to private address {addr}"
+
+ return True, ""
+
+
+def contains_internal_url(command: str) -> bool:
+ """Return True if the command string contains a URL targeting an internal/private address."""
+ for m in _URL_RE.finditer(command):
+ url = m.group(0)
+ ok, _ = validate_url_target(url)
+ if not ok:
+ return True
+ return False
diff --git a/nanobot/session/__init__.py b/nanobot/session/__init__.py
index 3faf424..931f7c6 100644
--- a/nanobot/session/__init__.py
+++ b/nanobot/session/__init__.py
@@ -1,5 +1,5 @@
"""Session management module."""
-from nanobot.session.manager import SessionManager, Session
+from nanobot.session.manager import Session, SessionManager
__all__ = ["SessionManager", "Session"]
diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py
index d59b7c9..f8244e5 100644
--- a/nanobot/session/manager.py
+++ b/nanobot/session/manager.py
@@ -2,13 +2,14 @@
import json
import shutil
-from pathlib import Path
from dataclasses import dataclass, field
from datetime import datetime
+from pathlib import Path
from typing import Any
from loguru import logger
+from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ensure_dir, safe_filename
@@ -30,7 +31,7 @@ class Session:
updated_at: datetime = field(default_factory=datetime.now)
metadata: dict[str, Any] = field(default_factory=dict)
last_consolidated: int = 0 # Number of messages already consolidated to files
-
+
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
"""Add a message to the session."""
msg = {
@@ -41,27 +42,56 @@ class Session:
}
self.messages.append(msg)
self.updated_at = datetime.now()
-
+
+ @staticmethod
+ def _find_legal_start(messages: list[dict[str, Any]]) -> int:
+ """Find first index where every tool result has a matching assistant tool_call."""
+ declared: set[str] = set()
+ start = 0
+ for i, msg in enumerate(messages):
+ role = msg.get("role")
+ if role == "assistant":
+ for tc in msg.get("tool_calls") or []:
+ if isinstance(tc, dict) and tc.get("id"):
+ declared.add(str(tc["id"]))
+ elif role == "tool":
+ tid = msg.get("tool_call_id")
+ if tid and str(tid) not in declared:
+ start = i + 1
+ declared.clear()
+ for prev in messages[start:i + 1]:
+ if prev.get("role") == "assistant":
+ for tc in prev.get("tool_calls") or []:
+ if isinstance(tc, dict) and tc.get("id"):
+ declared.add(str(tc["id"]))
+ return start
+
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
- """Return unconsolidated messages for LLM input, aligned to a user turn."""
+ """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
- # Drop leading non-user messages to avoid orphaned tool_result blocks
- for i, m in enumerate(sliced):
- if m.get("role") == "user":
+ # Drop leading non-user messages to avoid starting mid-turn when possible.
+ for i, message in enumerate(sliced):
+ if message.get("role") == "user":
sliced = sliced[i:]
break
+ # Some providers reject orphan tool results if the matching assistant
+ # tool_calls message fell outside the fixed-size history window.
+ start = self._find_legal_start(sliced)
+ if start:
+ sliced = sliced[start:]
+
out: list[dict[str, Any]] = []
- for m in sliced:
- entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
- for k in ("tool_calls", "tool_call_id", "name"):
- if k in m:
- entry[k] = m[k]
+ for message in sliced:
+ entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")}
+ for key in ("tool_calls", "tool_call_id", "name"):
+ if key in message:
+ entry[key] = message[key]
out.append(entry)
return out
-
+
def clear(self) -> None:
"""Clear all messages and reset session to initial state."""
self.messages = []
@@ -79,9 +109,9 @@ class SessionManager:
def __init__(self, workspace: Path):
self.workspace = workspace
self.sessions_dir = ensure_dir(self.workspace / "sessions")
- self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
+ self.legacy_sessions_dir = get_legacy_sessions_dir()
self._cache: dict[str, Session] = {}
-
+
def _get_session_path(self, key: str) -> Path:
"""Get the file path for a session."""
safe_key = safe_filename(key.replace(":", "_"))
@@ -91,27 +121,27 @@ class SessionManager:
"""Legacy global session path (~/.nanobot/sessions/)."""
safe_key = safe_filename(key.replace(":", "_"))
return self.legacy_sessions_dir / f"{safe_key}.jsonl"
-
+
def get_or_create(self, key: str) -> Session:
"""
Get an existing session or create a new one.
-
+
Args:
key: Session key (usually channel:chat_id).
-
+
Returns:
The session.
"""
if key in self._cache:
return self._cache[key]
-
+
session = self._load(key)
if session is None:
session = Session(key=key)
-
+
self._cache[key] = session
return session
-
+
def _load(self, key: str) -> Session | None:
"""Load a session from disk."""
path = self._get_session_path(key)
@@ -158,7 +188,7 @@ class SessionManager:
except Exception as e:
logger.warning("Failed to load session {}: {}", key, e)
return None
-
+
def save(self, session: Session) -> None:
"""Save a session to disk."""
path = self._get_session_path(session.key)
@@ -177,20 +207,20 @@ class SessionManager:
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
self._cache[session.key] = session
-
+
def invalidate(self, key: str) -> None:
"""Remove a session from the in-memory cache."""
self._cache.pop(key, None)
-
+
def list_sessions(self) -> list[dict[str, Any]]:
"""
List all sessions.
-
+
Returns:
List of session info dicts.
"""
sessions = []
-
+
for path in self.sessions_dir.glob("*.jsonl"):
try:
# Read just the metadata line
@@ -208,5 +238,5 @@ class SessionManager:
})
except Exception:
continue
-
+
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
diff --git a/nanobot/skills/memory/SKILL.md b/nanobot/skills/memory/SKILL.md
index 39adbde..3f0a8fc 100644
--- a/nanobot/skills/memory/SKILL.md
+++ b/nanobot/skills/memory/SKILL.md
@@ -9,15 +9,21 @@ always: true
## Structure
- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context.
-- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep.
+- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep-style tools or in-memory filters. Each entry starts with [YYYY-MM-DD HH:MM].
## Search Past Events
-```bash
-grep -i "keyword" memory/HISTORY.md
-```
+Choose the search method based on file size:
-Use the `exec` tool to run grep. Combine patterns: `grep -iE "meeting|deadline" memory/HISTORY.md`
+- Small `memory/HISTORY.md`: use `read_file`, then search in-memory
+- Large or long-lived `memory/HISTORY.md`: use the `exec` tool for targeted search
+
+Examples:
+- **Linux/macOS:** `grep -i "keyword" memory/HISTORY.md`
+- **Windows:** `findstr /i "keyword" memory\HISTORY.md`
+- **Cross-platform Python:** `python -c "from pathlib import Path; text = Path('memory/HISTORY.md').read_text(encoding='utf-8'); print('\n'.join([l for l in text.splitlines() if 'keyword' in l.lower()][-20:]))"`
+
+Prefer targeted command-line search for large history files.
## When to Update MEMORY.md
diff --git a/nanobot/skills/skill-creator/SKILL.md b/nanobot/skills/skill-creator/SKILL.md
index 9b5eb6f..ea53abe 100644
--- a/nanobot/skills/skill-creator/SKILL.md
+++ b/nanobot/skills/skill-creator/SKILL.md
@@ -268,6 +268,8 @@ Skip this step only if the skill being developed already exists, and iteration o
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
+For `nanobot`, custom skills should live under the active workspace `skills/` directory so they can be discovered automatically at runtime (for example, `/skills/my-skill/SKILL.md`).
+
Usage:
```bash
@@ -277,9 +279,9 @@ scripts/init_skill.py --path [--resources script
Examples:
```bash
-scripts/init_skill.py my-skill --path skills/public
-scripts/init_skill.py my-skill --path skills/public --resources scripts,references
-scripts/init_skill.py my-skill --path skills/public --resources scripts --examples
+scripts/init_skill.py my-skill --path ./workspace/skills
+scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts,references
+scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts --examples
```
The script:
@@ -326,7 +328,7 @@ Write the YAML frontmatter with `name` and `description`:
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent.
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
-Do not include any other fields in YAML frontmatter.
+Keep frontmatter minimal. In `nanobot`, `metadata` and `always` are also supported when needed, but avoid adding extra fields unless they are actually required.
##### Body
@@ -349,7 +351,6 @@ scripts/package_skill.py ./dist
The packaging script will:
1. **Validate** the skill automatically, checking:
-
- YAML frontmatter format and required fields
- Skill naming conventions and directory structure
- Description completeness and quality
@@ -357,6 +358,8 @@ The packaging script will:
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
+ Security restriction: symlinks are rejected and packaging fails when any symlink is present.
+
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
### Step 6: Iterate
diff --git a/nanobot/skills/skill-creator/scripts/init_skill.py b/nanobot/skills/skill-creator/scripts/init_skill.py
new file mode 100755
index 0000000..8633fe9
--- /dev/null
+++ b/nanobot/skills/skill-creator/scripts/init_skill.py
@@ -0,0 +1,378 @@
+#!/usr/bin/env python3
+"""
+Skill Initializer - Creates a new skill from template
+
+Usage:
+ init_skill.py --path [--resources scripts,references,assets] [--examples]
+
+Examples:
+ init_skill.py my-new-skill --path skills/public
+ init_skill.py my-new-skill --path skills/public --resources scripts,references
+ init_skill.py my-api-helper --path skills/private --resources scripts --examples
+ init_skill.py custom-skill --path /custom/location
+"""
+
+import argparse
+import re
+import sys
+from pathlib import Path
+
+MAX_SKILL_NAME_LENGTH = 64
+ALLOWED_RESOURCES = {"scripts", "references", "assets"}
+
+SKILL_TEMPLATE = """---
+name: {skill_name}
+description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
+---
+
+# {skill_title}
+
+## Overview
+
+[TODO: 1-2 sentences explaining what this skill enables]
+
+## Structuring This Skill
+
+[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
+
+**1. Workflow-Based** (best for sequential processes)
+- Works well when there are clear step-by-step procedures
+- Example: DOCX skill with "Workflow Decision Tree" -> "Reading" -> "Creating" -> "Editing"
+- Structure: ## Overview -> ## Workflow Decision Tree -> ## Step 1 -> ## Step 2...
+
+**2. Task-Based** (best for tool collections)
+- Works well when the skill offers different operations/capabilities
+- Example: PDF skill with "Quick Start" -> "Merge PDFs" -> "Split PDFs" -> "Extract Text"
+- Structure: ## Overview -> ## Quick Start -> ## Task Category 1 -> ## Task Category 2...
+
+**3. Reference/Guidelines** (best for standards or specifications)
+- Works well for brand guidelines, coding standards, or requirements
+- Example: Brand styling with "Brand Guidelines" -> "Colors" -> "Typography" -> "Features"
+- Structure: ## Overview -> ## Guidelines -> ## Specifications -> ## Usage...
+
+**4. Capabilities-Based** (best for integrated systems)
+- Works well when the skill provides multiple interrelated features
+- Example: Product Management with "Core Capabilities" -> numbered capability list
+- Structure: ## Overview -> ## Core Capabilities -> ### 1. Feature -> ### 2. Feature...
+
+Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
+
+Delete this entire "Structuring This Skill" section when done - it's just guidance.]
+
+## [TODO: Replace with the first main section based on chosen structure]
+
+[TODO: Add content here. See examples in existing skills:
+- Code samples for technical skills
+- Decision trees for complex workflows
+- Concrete examples with realistic user requests
+- References to scripts/templates/references as needed]
+
+## Resources (optional)
+
+Create only the resource directories this skill actually needs. Delete this section if no resources are required.
+
+### scripts/
+Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
+
+**Examples from other skills:**
+- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
+- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
+
+**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
+
+**Note:** Scripts may be executed without loading into context, but can still be read by Codex for patching or environment adjustments.
+
+### references/
+Documentation and reference material intended to be loaded into context to inform Codex's process and thinking.
+
+**Examples from other skills:**
+- Product management: `communication.md`, `context_building.md` - detailed workflow guides
+- BigQuery: API reference documentation and query examples
+- Finance: Schema documentation, company policies
+
+**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Codex should reference while working.
+
+### assets/
+Files not intended to be loaded into context, but rather used within the output Codex produces.
+
+**Examples from other skills:**
+- Brand styling: PowerPoint template files (.pptx), logo files
+- Frontend builder: HTML/React boilerplate project directories
+- Typography: Font files (.ttf, .woff2)
+
+**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
+
+---
+
+**Not every skill requires all three types of resources.**
+"""
+
+EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
+"""
+Example helper script for {skill_name}
+
+This is a placeholder script that can be executed directly.
+Replace with actual implementation or delete if not needed.
+
+Example real scripts from other skills:
+- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
+- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
+"""
+
+def main():
+ print("This is an example script for {skill_name}")
+ # TODO: Add actual script logic here
+ # This could be data processing, file conversion, API calls, etc.
+
+if __name__ == "__main__":
+ main()
+'''
+
+EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
+
+This is a placeholder for detailed reference documentation.
+Replace with actual reference content or delete if not needed.
+
+Example real reference docs from other skills:
+- product-management/references/communication.md - Comprehensive guide for status updates
+- product-management/references/context_building.md - Deep-dive on gathering context
+- bigquery/references/ - API references and query examples
+
+## When Reference Docs Are Useful
+
+Reference docs are ideal for:
+- Comprehensive API documentation
+- Detailed workflow guides
+- Complex multi-step processes
+- Information too lengthy for main SKILL.md
+- Content that's only needed for specific use cases
+
+## Structure Suggestions
+
+### API Reference Example
+- Overview
+- Authentication
+- Endpoints with examples
+- Error codes
+- Rate limits
+
+### Workflow Guide Example
+- Prerequisites
+- Step-by-step instructions
+- Common patterns
+- Troubleshooting
+- Best practices
+"""
+
+EXAMPLE_ASSET = """# Example Asset File
+
+This placeholder represents where asset files would be stored.
+Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
+
+Asset files are NOT intended to be loaded into context, but rather used within
+the output Codex produces.
+
+Example asset files from other skills:
+- Brand guidelines: logo.png, slides_template.pptx
+- Frontend builder: hello-world/ directory with HTML/React boilerplate
+- Typography: custom-font.ttf, font-family.woff2
+- Data: sample_data.csv, test_dataset.json
+
+## Common Asset Types
+
+- Templates: .pptx, .docx, boilerplate directories
+- Images: .png, .jpg, .svg, .gif
+- Fonts: .ttf, .otf, .woff, .woff2
+- Boilerplate code: Project directories, starter files
+- Icons: .ico, .svg
+- Data files: .csv, .json, .xml, .yaml
+
+Note: This is a text placeholder. Actual assets can be any file type.
+"""
+
+
+def normalize_skill_name(skill_name):
+ """Normalize a skill name to lowercase hyphen-case."""
+ normalized = skill_name.strip().lower()
+ normalized = re.sub(r"[^a-z0-9]+", "-", normalized)
+ normalized = normalized.strip("-")
+ normalized = re.sub(r"-{2,}", "-", normalized)
+ return normalized
+
+
+def title_case_skill_name(skill_name):
+ """Convert hyphenated skill name to Title Case for display."""
+ return " ".join(word.capitalize() for word in skill_name.split("-"))
+
+
+def parse_resources(raw_resources):
+ if not raw_resources:
+ return []
+ resources = [item.strip() for item in raw_resources.split(",") if item.strip()]
+ invalid = sorted({item for item in resources if item not in ALLOWED_RESOURCES})
+ if invalid:
+ allowed = ", ".join(sorted(ALLOWED_RESOURCES))
+ print(f"[ERROR] Unknown resource type(s): {', '.join(invalid)}")
+ print(f" Allowed: {allowed}")
+ sys.exit(1)
+ deduped = []
+ seen = set()
+ for resource in resources:
+ if resource not in seen:
+ deduped.append(resource)
+ seen.add(resource)
+ return deduped
+
+
+def create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples):
+ for resource in resources:
+ resource_dir = skill_dir / resource
+ resource_dir.mkdir(exist_ok=True)
+ if resource == "scripts":
+ if include_examples:
+ example_script = resource_dir / "example.py"
+ example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
+ example_script.chmod(0o755)
+ print("[OK] Created scripts/example.py")
+ else:
+ print("[OK] Created scripts/")
+ elif resource == "references":
+ if include_examples:
+ example_reference = resource_dir / "api_reference.md"
+ example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
+ print("[OK] Created references/api_reference.md")
+ else:
+ print("[OK] Created references/")
+ elif resource == "assets":
+ if include_examples:
+ example_asset = resource_dir / "example_asset.txt"
+ example_asset.write_text(EXAMPLE_ASSET)
+ print("[OK] Created assets/example_asset.txt")
+ else:
+ print("[OK] Created assets/")
+
+
+def init_skill(skill_name, path, resources, include_examples):
+ """
+ Initialize a new skill directory with template SKILL.md.
+
+ Args:
+ skill_name: Name of the skill
+ path: Path where the skill directory should be created
+ resources: Resource directories to create
+ include_examples: Whether to create example files in resource directories
+
+ Returns:
+ Path to created skill directory, or None if error
+ """
+ # Determine skill directory path
+ skill_dir = Path(path).resolve() / skill_name
+
+ # Check if directory already exists
+ if skill_dir.exists():
+ print(f"[ERROR] Skill directory already exists: {skill_dir}")
+ return None
+
+ # Create skill directory
+ try:
+ skill_dir.mkdir(parents=True, exist_ok=False)
+ print(f"[OK] Created skill directory: {skill_dir}")
+ except Exception as e:
+ print(f"[ERROR] Error creating directory: {e}")
+ return None
+
+ # Create SKILL.md from template
+ skill_title = title_case_skill_name(skill_name)
+ skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title)
+
+ skill_md_path = skill_dir / "SKILL.md"
+ try:
+ skill_md_path.write_text(skill_content)
+ print("[OK] Created SKILL.md")
+ except Exception as e:
+ print(f"[ERROR] Error creating SKILL.md: {e}")
+ return None
+
+ # Create resource directories if requested
+ if resources:
+ try:
+ create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples)
+ except Exception as e:
+ print(f"[ERROR] Error creating resource directories: {e}")
+ return None
+
+ # Print next steps
+ print(f"\n[OK] Skill '{skill_name}' initialized successfully at {skill_dir}")
+ print("\nNext steps:")
+ print("1. Edit SKILL.md to complete the TODO items and update the description")
+ if resources:
+ if include_examples:
+ print("2. Customize or delete the example files in scripts/, references/, and assets/")
+ else:
+ print("2. Add resources to scripts/, references/, and assets/ as needed")
+ else:
+ print("2. Create resource directories only if needed (scripts/, references/, assets/)")
+ print("3. Run the validator when ready to check the skill structure")
+
+ return skill_dir
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Create a new skill directory with a SKILL.md template.",
+ )
+ parser.add_argument("skill_name", help="Skill name (normalized to hyphen-case)")
+ parser.add_argument("--path", required=True, help="Output directory for the skill")
+ parser.add_argument(
+ "--resources",
+ default="",
+ help="Comma-separated list: scripts,references,assets",
+ )
+ parser.add_argument(
+ "--examples",
+ action="store_true",
+ help="Create example files inside the selected resource directories",
+ )
+ args = parser.parse_args()
+
+ raw_skill_name = args.skill_name
+ skill_name = normalize_skill_name(raw_skill_name)
+ if not skill_name:
+ print("[ERROR] Skill name must include at least one letter or digit.")
+ sys.exit(1)
+ if len(skill_name) > MAX_SKILL_NAME_LENGTH:
+ print(
+ f"[ERROR] Skill name '{skill_name}' is too long ({len(skill_name)} characters). "
+ f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
+ )
+ sys.exit(1)
+ if skill_name != raw_skill_name:
+ print(f"Note: Normalized skill name from '{raw_skill_name}' to '{skill_name}'.")
+
+ resources = parse_resources(args.resources)
+ if args.examples and not resources:
+ print("[ERROR] --examples requires --resources to be set.")
+ sys.exit(1)
+
+ path = args.path
+
+ print(f"Initializing skill: {skill_name}")
+ print(f" Location: {path}")
+ if resources:
+ print(f" Resources: {', '.join(resources)}")
+ if args.examples:
+ print(" Examples: enabled")
+ else:
+ print(" Resources: none (create as needed)")
+ print()
+
+ result = init_skill(skill_name, path, resources, args.examples)
+
+ if result:
+ sys.exit(0)
+ else:
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nanobot/skills/skill-creator/scripts/package_skill.py b/nanobot/skills/skill-creator/scripts/package_skill.py
new file mode 100755
index 0000000..48fcbbe
--- /dev/null
+++ b/nanobot/skills/skill-creator/scripts/package_skill.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+"""
+Skill Packager - Creates a distributable .skill file of a skill folder
+
+Usage:
+ python package_skill.py [output-directory]
+
+Example:
+ python package_skill.py skills/public/my-skill
+ python package_skill.py skills/public/my-skill ./dist
+"""
+
+import sys
+import zipfile
+from pathlib import Path
+
+from quick_validate import validate_skill
+
+
+def _is_within(path: Path, root: Path) -> bool:
+ try:
+ path.relative_to(root)
+ return True
+ except ValueError:
+ return False
+
+
+def _cleanup_partial_archive(skill_filename: Path) -> None:
+ try:
+ if skill_filename.exists():
+ skill_filename.unlink()
+ except OSError:
+ pass
+
+
+def package_skill(skill_path, output_dir=None):
+ """
+ Package a skill folder into a .skill file.
+
+ Args:
+ skill_path: Path to the skill folder
+ output_dir: Optional output directory for the .skill file (defaults to current directory)
+
+ Returns:
+ Path to the created .skill file, or None if error
+ """
+ skill_path = Path(skill_path).resolve()
+
+ # Validate skill folder exists
+ if not skill_path.exists():
+ print(f"[ERROR] Skill folder not found: {skill_path}")
+ return None
+
+ if not skill_path.is_dir():
+ print(f"[ERROR] Path is not a directory: {skill_path}")
+ return None
+
+ # Validate SKILL.md exists
+ skill_md = skill_path / "SKILL.md"
+ if not skill_md.exists():
+ print(f"[ERROR] SKILL.md not found in {skill_path}")
+ return None
+
+ # Run validation before packaging
+ print("Validating skill...")
+ valid, message = validate_skill(skill_path)
+ if not valid:
+ print(f"[ERROR] Validation failed: {message}")
+ print(" Please fix the validation errors before packaging.")
+ return None
+ print(f"[OK] {message}\n")
+
+ # Determine output location
+ skill_name = skill_path.name
+ if output_dir:
+ output_path = Path(output_dir).resolve()
+ output_path.mkdir(parents=True, exist_ok=True)
+ else:
+ output_path = Path.cwd()
+
+ skill_filename = output_path / f"{skill_name}.skill"
+
+ EXCLUDED_DIRS = {".git", ".svn", ".hg", "__pycache__", "node_modules"}
+
+ files_to_package = []
+ resolved_archive = skill_filename.resolve()
+
+ for file_path in skill_path.rglob("*"):
+ # Fail closed on symlinks so the packaged contents are explicit and predictable.
+ if file_path.is_symlink():
+ print(f"[ERROR] Symlink not allowed in packaged skill: {file_path}")
+ _cleanup_partial_archive(skill_filename)
+ return None
+
+ rel_parts = file_path.relative_to(skill_path).parts
+ if any(part in EXCLUDED_DIRS for part in rel_parts):
+ continue
+
+ if file_path.is_file():
+ resolved_file = file_path.resolve()
+ if not _is_within(resolved_file, skill_path):
+ print(f"[ERROR] File escapes skill root: {file_path}")
+ _cleanup_partial_archive(skill_filename)
+ return None
+ # If output lives under skill_path, avoid writing archive into itself.
+ if resolved_file == resolved_archive:
+ print(f"[WARN] Skipping output archive: {file_path}")
+ continue
+ files_to_package.append(file_path)
+
+ # Create the .skill file (zip format)
+ try:
+ with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
+ for file_path in files_to_package:
+ # Calculate the relative path within the zip.
+ arcname = Path(skill_name) / file_path.relative_to(skill_path)
+ zipf.write(file_path, arcname)
+ print(f" Added: {arcname}")
+
+ print(f"\n[OK] Successfully packaged skill to: {skill_filename}")
+ return skill_filename
+
+ except Exception as e:
+ _cleanup_partial_archive(skill_filename)
+ print(f"[ERROR] Error creating .skill file: {e}")
+ return None
+
+
+def main():
+ if len(sys.argv) < 2:
+ print("Usage: python package_skill.py [output-directory]")
+ print("\nExample:")
+ print(" python package_skill.py skills/public/my-skill")
+ print(" python package_skill.py skills/public/my-skill ./dist")
+ sys.exit(1)
+
+ skill_path = sys.argv[1]
+ output_dir = sys.argv[2] if len(sys.argv) > 2 else None
+
+ print(f"Packaging skill: {skill_path}")
+ if output_dir:
+ print(f" Output directory: {output_dir}")
+ print()
+
+ result = package_skill(skill_path, output_dir)
+
+ if result:
+ sys.exit(0)
+ else:
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nanobot/skills/skill-creator/scripts/quick_validate.py b/nanobot/skills/skill-creator/scripts/quick_validate.py
new file mode 100644
index 0000000..03d246d
--- /dev/null
+++ b/nanobot/skills/skill-creator/scripts/quick_validate.py
@@ -0,0 +1,213 @@
+#!/usr/bin/env python3
+"""
+Minimal validator for nanobot skill folders.
+"""
+
+import re
+import sys
+from pathlib import Path
+from typing import Optional
+
+try:
+ import yaml
+except ModuleNotFoundError:
+ yaml = None
+
+MAX_SKILL_NAME_LENGTH = 64
+ALLOWED_FRONTMATTER_KEYS = {
+ "name",
+ "description",
+ "metadata",
+ "always",
+ "license",
+ "allowed-tools",
+}
+ALLOWED_RESOURCE_DIRS = {"scripts", "references", "assets"}
+PLACEHOLDER_MARKERS = ("[todo", "todo:")
+
+
+def _extract_frontmatter(content: str) -> Optional[str]:
+ lines = content.splitlines()
+ if not lines or lines[0].strip() != "---":
+ return None
+ for i in range(1, len(lines)):
+ if lines[i].strip() == "---":
+ return "\n".join(lines[1:i])
+ return None
+
+
+def _parse_simple_frontmatter(frontmatter_text: str) -> Optional[dict[str, str]]:
+ """Fallback parser for simple frontmatter when PyYAML is unavailable."""
+ parsed: dict[str, str] = {}
+ current_key: Optional[str] = None
+ multiline_key: Optional[str] = None
+
+ for raw_line in frontmatter_text.splitlines():
+ stripped = raw_line.strip()
+ if not stripped or stripped.startswith("#"):
+ continue
+
+ is_indented = raw_line[:1].isspace()
+ if is_indented:
+ if current_key is None:
+ return None
+ current_value = parsed[current_key]
+ parsed[current_key] = f"{current_value}\n{stripped}" if current_value else stripped
+ continue
+
+ if ":" not in stripped:
+ return None
+
+ key, value = stripped.split(":", 1)
+ key = key.strip()
+ value = value.strip()
+ if not key:
+ return None
+
+ if value in {"|", ">"}:
+ parsed[key] = ""
+ current_key = key
+ multiline_key = key
+ continue
+
+ if (value.startswith('"') and value.endswith('"')) or (
+ value.startswith("'") and value.endswith("'")
+ ):
+ value = value[1:-1]
+ parsed[key] = value
+ current_key = key
+ multiline_key = None
+
+ if multiline_key is not None and multiline_key not in parsed:
+ return None
+ return parsed
+
+
+def _load_frontmatter(frontmatter_text: str) -> tuple[Optional[dict], Optional[str]]:
+ if yaml is not None:
+ try:
+ frontmatter = yaml.safe_load(frontmatter_text)
+ except yaml.YAMLError as exc:
+ return None, f"Invalid YAML in frontmatter: {exc}"
+ if not isinstance(frontmatter, dict):
+ return None, "Frontmatter must be a YAML dictionary"
+ return frontmatter, None
+
+ frontmatter = _parse_simple_frontmatter(frontmatter_text)
+ if frontmatter is None:
+ return None, "Invalid YAML in frontmatter: unsupported syntax without PyYAML installed"
+ return frontmatter, None
+
+
+def _validate_skill_name(name: str, folder_name: str) -> Optional[str]:
+ if not re.fullmatch(r"[a-z0-9]+(?:-[a-z0-9]+)*", name):
+ return (
+ f"Name '{name}' should be hyphen-case "
+ "(lowercase letters, digits, and single hyphens only)"
+ )
+ if len(name) > MAX_SKILL_NAME_LENGTH:
+ return (
+ f"Name is too long ({len(name)} characters). "
+ f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
+ )
+ if name != folder_name:
+ return f"Skill name '{name}' must match directory name '{folder_name}'"
+ return None
+
+
+def _validate_description(description: str) -> Optional[str]:
+ trimmed = description.strip()
+ if not trimmed:
+ return "Description cannot be empty"
+ lowered = trimmed.lower()
+ if any(marker in lowered for marker in PLACEHOLDER_MARKERS):
+ return "Description still contains TODO placeholder text"
+ if "<" in trimmed or ">" in trimmed:
+ return "Description cannot contain angle brackets (< or >)"
+ if len(trimmed) > 1024:
+ return f"Description is too long ({len(trimmed)} characters). Maximum is 1024 characters."
+ return None
+
+
+def validate_skill(skill_path):
+ """Validate a skill folder structure and required frontmatter."""
+ skill_path = Path(skill_path).resolve()
+
+ if not skill_path.exists():
+ return False, f"Skill folder not found: {skill_path}"
+ if not skill_path.is_dir():
+ return False, f"Path is not a directory: {skill_path}"
+
+ skill_md = skill_path / "SKILL.md"
+ if not skill_md.exists():
+ return False, "SKILL.md not found"
+
+ try:
+ content = skill_md.read_text(encoding="utf-8")
+ except OSError as exc:
+ return False, f"Could not read SKILL.md: {exc}"
+
+ frontmatter_text = _extract_frontmatter(content)
+ if frontmatter_text is None:
+ return False, "Invalid frontmatter format"
+
+ frontmatter, error = _load_frontmatter(frontmatter_text)
+ if error:
+ return False, error
+
+ unexpected_keys = sorted(set(frontmatter.keys()) - ALLOWED_FRONTMATTER_KEYS)
+ if unexpected_keys:
+ allowed = ", ".join(sorted(ALLOWED_FRONTMATTER_KEYS))
+ unexpected = ", ".join(unexpected_keys)
+ return (
+ False,
+ f"Unexpected key(s) in SKILL.md frontmatter: {unexpected}. Allowed properties are: {allowed}",
+ )
+
+ if "name" not in frontmatter:
+ return False, "Missing 'name' in frontmatter"
+ if "description" not in frontmatter:
+ return False, "Missing 'description' in frontmatter"
+
+ name = frontmatter["name"]
+ if not isinstance(name, str):
+ return False, f"Name must be a string, got {type(name).__name__}"
+ name_error = _validate_skill_name(name.strip(), skill_path.name)
+ if name_error:
+ return False, name_error
+
+ description = frontmatter["description"]
+ if not isinstance(description, str):
+ return False, f"Description must be a string, got {type(description).__name__}"
+ description_error = _validate_description(description)
+ if description_error:
+ return False, description_error
+
+ always = frontmatter.get("always")
+ if always is not None and not isinstance(always, bool):
+ return False, f"'always' must be a boolean, got {type(always).__name__}"
+
+ for child in skill_path.iterdir():
+ if child.name == "SKILL.md":
+ continue
+ if child.is_dir() and child.name in ALLOWED_RESOURCE_DIRS:
+ continue
+ if child.is_symlink():
+ continue
+ return (
+ False,
+ f"Unexpected file or directory in skill root: {child.name}. "
+ "Only SKILL.md, scripts/, references/, and assets/ are allowed.",
+ )
+
+ return True, "Skill is valid!"
+
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print("Usage: python quick_validate.py ")
+ sys.exit(1)
+
+ valid, message = validate_skill(sys.argv[1])
+ print(message)
+ sys.exit(0 if valid else 1)
diff --git a/nanobot/templates/AGENTS.md b/nanobot/templates/AGENTS.md
index 84ba657..a24604b 100644
--- a/nanobot/templates/AGENTS.md
+++ b/nanobot/templates/AGENTS.md
@@ -2,27 +2,17 @@
You are a helpful AI assistant. Be concise, accurate, and friendly.
-## Guidelines
-
-- Before calling tools, briefly state your intent — but NEVER predict results before receiving them
-- Use precise tense: "I will run X" before the call, "X returned Y" after
-- NEVER claim success before a tool result confirms it
-- Ask for clarification when the request is ambiguous
-- Remember important information in `memory/MEMORY.md`; past events are logged in `memory/HISTORY.md`
-
## Scheduled Reminders
-When user asks for a reminder at a specific time, use `exec` to run:
-```
-nanobot cron add --name "reminder" --message "Your message" --at "YYYY-MM-DDTHH:MM:SS" --deliver --to "USER_ID" --channel "CHANNEL"
-```
+Before scheduling reminders, check available skills and follow skill guidance first.
+Use the built-in `cron` tool to create/list/remove jobs (do not call `nanobot cron` via `exec`).
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
## Heartbeat Tasks
-`HEARTBEAT.md` is checked every 30 minutes. Use file tools to manage periodic tasks:
+`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
- **Add**: `edit_file` to append new tasks
- **Remove**: `edit_file` to delete completed tasks
diff --git a/nanobot/utils/__init__.py b/nanobot/utils/__init__.py
index 7444987..46f02ac 100644
--- a/nanobot/utils/__init__.py
+++ b/nanobot/utils/__init__.py
@@ -1,5 +1,5 @@
"""Utility functions for nanobot."""
-from nanobot.utils.helpers import ensure_dir, get_workspace_path, get_data_path
+from nanobot.utils.helpers import ensure_dir
-__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]
+__all__ = ["ensure_dir"]
diff --git a/nanobot/utils/evaluator.py b/nanobot/utils/evaluator.py
new file mode 100644
index 0000000..6110471
--- /dev/null
+++ b/nanobot/utils/evaluator.py
@@ -0,0 +1,92 @@
+"""Post-run evaluation for background tasks (heartbeat & cron).
+
+After the agent executes a background task, this module makes a lightweight
+LLM call to decide whether the result warrants notifying the user.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from loguru import logger
+
+if TYPE_CHECKING:
+ from nanobot.providers.base import LLMProvider
+
+_EVALUATE_TOOL = [
+ {
+ "type": "function",
+ "function": {
+ "name": "evaluate_notification",
+ "description": "Decide whether the user should be notified about this background task result.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "should_notify": {
+ "type": "boolean",
+ "description": "true = result contains actionable/important info the user should see; false = routine or empty, safe to suppress",
+ },
+ "reason": {
+ "type": "string",
+ "description": "One-sentence reason for the decision",
+ },
+ },
+ "required": ["should_notify"],
+ },
+ },
+ }
+]
+
+_SYSTEM_PROMPT = (
+ "You are a notification gate for a background agent. "
+ "You will be given the original task and the agent's response. "
+ "Call the evaluate_notification tool to decide whether the user "
+ "should be notified.\n\n"
+ "Notify when the response contains actionable information, errors, "
+ "completed deliverables, or anything the user explicitly asked to "
+ "be reminded about.\n\n"
+ "Suppress when the response is a routine status check with nothing "
+ "new, a confirmation that everything is normal, or essentially empty."
+)
+
+
+async def evaluate_response(
+ response: str,
+ task_context: str,
+ provider: LLMProvider,
+ model: str,
+) -> bool:
+ """Decide whether a background-task result should be delivered to the user.
+
+ Uses a lightweight tool-call LLM request (same pattern as heartbeat
+ ``_decide()``). Falls back to ``True`` (notify) on any failure so
+ that important messages are never silently dropped.
+ """
+ try:
+ llm_response = await provider.chat_with_retry(
+ messages=[
+ {"role": "system", "content": _SYSTEM_PROMPT},
+ {"role": "user", "content": (
+ f"## Original task\n{task_context}\n\n"
+ f"## Agent response\n{response}"
+ )},
+ ],
+ tools=_EVALUATE_TOOL,
+ model=model,
+ max_tokens=256,
+ temperature=0.0,
+ )
+
+ if not llm_response.has_tool_calls:
+ logger.warning("evaluate_response: no tool call returned, defaulting to notify")
+ return True
+
+ args = llm_response.tool_calls[0].arguments
+ should_notify = args.get("should_notify", True)
+ reason = args.get("reason", "")
+ logger.info("evaluate_response: should_notify={}, reason={}", should_notify, reason)
+ return bool(should_notify)
+
+ except Exception:
+ logger.exception("evaluate_response failed, defaulting to notify")
+ return True
diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py
index 62f80ac..d937b6e 100644
--- a/nanobot/utils/helpers.py
+++ b/nanobot/utils/helpers.py
@@ -1,80 +1,211 @@
"""Utility functions for nanobot."""
-from pathlib import Path
+import json
+import re
+import time
from datetime import datetime
+from pathlib import Path
+from typing import Any
+
+import tiktoken
+
+
+def detect_image_mime(data: bytes) -> str | None:
+ """Detect image MIME type from magic bytes, ignoring file extension."""
+ if data[:8] == b"\x89PNG\r\n\x1a\n":
+ return "image/png"
+ if data[:3] == b"\xff\xd8\xff":
+ return "image/jpeg"
+ if data[:6] in (b"GIF87a", b"GIF89a"):
+ return "image/gif"
+ if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
+ return "image/webp"
+ return None
def ensure_dir(path: Path) -> Path:
- """Ensure a directory exists, creating it if necessary."""
+ """Ensure directory exists, return it."""
path.mkdir(parents=True, exist_ok=True)
return path
-def get_data_path() -> Path:
- """Get the nanobot data directory (~/.nanobot)."""
- return ensure_dir(Path.home() / ".nanobot")
-
-
-def get_workspace_path(workspace: str | None = None) -> Path:
- """
- Get the workspace path.
-
- Args:
- workspace: Optional workspace path. Defaults to ~/.nanobot/workspace.
-
- Returns:
- Expanded and ensured workspace path.
- """
- if workspace:
- path = Path(workspace).expanduser()
- else:
- path = Path.home() / ".nanobot" / "workspace"
- return ensure_dir(path)
-
-
-def get_sessions_path() -> Path:
- """Get the sessions storage directory."""
- return ensure_dir(get_data_path() / "sessions")
-
-
-def get_skills_path(workspace: Path | None = None) -> Path:
- """Get the skills directory within the workspace."""
- ws = workspace or get_workspace_path()
- return ensure_dir(ws / "skills")
-
-
def timestamp() -> str:
- """Get current timestamp in ISO format."""
+ """Current ISO timestamp."""
return datetime.now().isoformat()
-def truncate_string(s: str, max_len: int = 100, suffix: str = "...") -> str:
- """Truncate a string to max length, adding suffix if truncated."""
- if len(s) <= max_len:
- return s
- return s[: max_len - len(suffix)] + suffix
+def current_time_str() -> str:
+ """Human-readable current time with weekday and timezone, e.g. '2026-03-15 22:30 (Saturday) (CST)'."""
+ now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
+ tz = time.strftime("%Z") or "UTC"
+ return f"{now} ({tz})"
+_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
+
def safe_filename(name: str) -> str:
- """Convert a string to a safe filename."""
- # Replace unsafe characters
- unsafe = '<>:"/\\|?*'
- for char in unsafe:
- name = name.replace(char, "_")
- return name.strip()
+ """Replace unsafe path characters with underscores."""
+ return _UNSAFE_CHARS.sub("_", name).strip()
-def parse_session_key(key: str) -> tuple[str, str]:
+def split_message(content: str, max_len: int = 2000) -> list[str]:
"""
- Parse a session key into channel and chat_id.
-
+ Split content into chunks within max_len, preferring line breaks.
+
Args:
- key: Session key in format "channel:chat_id"
-
+ content: The text content to split.
+ max_len: Maximum length per chunk (default 2000 for Discord compatibility).
+
Returns:
- Tuple of (channel, chat_id)
+ List of message chunks, each within max_len.
"""
- parts = key.split(":", 1)
- if len(parts) != 2:
- raise ValueError(f"Invalid session key: {key}")
- return parts[0], parts[1]
+ if not content:
+ return []
+ if len(content) <= max_len:
+ return [content]
+ chunks: list[str] = []
+ while content:
+ if len(content) <= max_len:
+ chunks.append(content)
+ break
+ cut = content[:max_len]
+ # Try to break at newline first, then space, then hard break
+ pos = cut.rfind('\n')
+ if pos <= 0:
+ pos = cut.rfind(' ')
+ if pos <= 0:
+ pos = max_len
+ chunks.append(content[:pos])
+ content = content[pos:].lstrip()
+ return chunks
+
+
+def build_assistant_message(
+ content: str | None,
+ tool_calls: list[dict[str, Any]] | None = None,
+ reasoning_content: str | None = None,
+ thinking_blocks: list[dict] | None = None,
+) -> dict[str, Any]:
+ """Build a provider-safe assistant message with optional reasoning fields."""
+ msg: dict[str, Any] = {"role": "assistant", "content": content}
+ if tool_calls:
+ msg["tool_calls"] = tool_calls
+ if reasoning_content is not None:
+ msg["reasoning_content"] = reasoning_content
+ if thinking_blocks:
+ msg["thinking_blocks"] = thinking_blocks
+ return msg
+
+
+def estimate_prompt_tokens(
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None = None,
+) -> int:
+ """Estimate prompt tokens with tiktoken."""
+ try:
+ enc = tiktoken.get_encoding("cl100k_base")
+ parts: list[str] = []
+ for msg in messages:
+ content = msg.get("content")
+ if isinstance(content, str):
+ parts.append(content)
+ elif isinstance(content, list):
+ for part in content:
+ if isinstance(part, dict) and part.get("type") == "text":
+ txt = part.get("text", "")
+ if txt:
+ parts.append(txt)
+ if tools:
+ parts.append(json.dumps(tools, ensure_ascii=False))
+ return len(enc.encode("\n".join(parts)))
+ except Exception:
+ return 0
+
+
+def estimate_message_tokens(message: dict[str, Any]) -> int:
+ """Estimate prompt tokens contributed by one persisted message."""
+ content = message.get("content")
+ parts: list[str] = []
+ if isinstance(content, str):
+ parts.append(content)
+ elif isinstance(content, list):
+ for part in content:
+ if isinstance(part, dict) and part.get("type") == "text":
+ text = part.get("text", "")
+ if text:
+ parts.append(text)
+ else:
+ parts.append(json.dumps(part, ensure_ascii=False))
+ elif content is not None:
+ parts.append(json.dumps(content, ensure_ascii=False))
+
+ for key in ("name", "tool_call_id"):
+ value = message.get(key)
+ if isinstance(value, str) and value:
+ parts.append(value)
+ if message.get("tool_calls"):
+ parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
+
+ payload = "\n".join(parts)
+ if not payload:
+ return 1
+ try:
+ enc = tiktoken.get_encoding("cl100k_base")
+ return max(1, len(enc.encode(payload)))
+ except Exception:
+ return max(1, len(payload) // 4)
+
+
+def estimate_prompt_tokens_chain(
+ provider: Any,
+ model: str | None,
+ messages: list[dict[str, Any]],
+ tools: list[dict[str, Any]] | None = None,
+) -> tuple[int, str]:
+ """Estimate prompt tokens via provider counter first, then tiktoken fallback."""
+ provider_counter = getattr(provider, "estimate_prompt_tokens", None)
+ if callable(provider_counter):
+ try:
+ tokens, source = provider_counter(messages, tools, model)
+ if isinstance(tokens, (int, float)) and tokens > 0:
+ return int(tokens), str(source or "provider_counter")
+ except Exception:
+ pass
+
+ estimated = estimate_prompt_tokens(messages, tools)
+ if estimated > 0:
+ return int(estimated), "tiktoken"
+ return 0, "none"
+
+
+def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
+ """Sync bundled templates to workspace. Only creates missing files."""
+ from importlib.resources import files as pkg_files
+ try:
+ tpl = pkg_files("nanobot") / "templates"
+ except Exception:
+ return []
+ if not tpl.is_dir():
+ return []
+
+ added: list[str] = []
+
+ def _write(src, dest: Path):
+ if dest.exists():
+ return
+ dest.parent.mkdir(parents=True, exist_ok=True)
+ dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
+ added.append(str(dest.relative_to(workspace)))
+
+ for item in tpl.iterdir():
+ if item.name.endswith(".md") and not item.name.startswith("."):
+ _write(item, workspace / item.name)
+ _write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
+ _write(None, workspace / "memory" / "HISTORY.md")
+ (workspace / "skills").mkdir(exist_ok=True)
+
+ if added and not silent:
+ from rich.console import Console
+ for name in added:
+ Console().print(f" [dim]Created {name}[/dim]")
+ return added
diff --git a/pyproject.toml b/pyproject.toml
index d15d18a..25ef590 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,8 @@
[project]
name = "nanobot-ai"
-version = "0.1.4.post2"
+version = "0.1.4.post5"
description = "A lightweight personal AI assistant framework"
+readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
license = {text = "MIT"}
authors = [
@@ -18,19 +19,20 @@ classifiers = [
dependencies = [
"typer>=0.20.0,<1.0.0",
- "litellm>=1.81.5,<2.0.0",
+ "litellm>=1.82.1,<2.0.0",
"pydantic>=2.12.0,<3.0.0",
"pydantic-settings>=2.12.0,<3.0.0",
"websockets>=16.0,<17.0",
"websocket-client>=1.9.0,<2.0.0",
"httpx>=0.28.0,<1.0.0",
+ "ddgs>=9.5.5,<10.0.0",
"oauth-cli-kit>=0.1.3,<1.0.0",
"loguru>=0.7.3,<1.0.0",
"readability-lxml>=0.8.4,<1.0.0",
"rich>=14.0.0,<15.0.0",
"croniter>=6.0.0,<7.0.0",
"dingtalk-stream>=0.24.0,<1.0.0",
- "python-telegram-bot[socks]>=22.0,<23.0",
+ "python-telegram-bot[socks]>=22.6,<23.0",
"lark-oapi>=1.5.0,<2.0.0",
"socksio>=1.0.0,<2.0.0",
"python-socketio>=5.16.0,<6.0.0",
@@ -42,13 +44,30 @@ dependencies = [
"prompt-toolkit>=3.0.50,<4.0.0",
"mcp>=1.26.0,<2.0.0",
"json-repair>=0.57.0,<1.0.0",
+ "chardet>=3.0.2,<6.0.0",
+ "openai>=2.8.0",
+ "tiktoken>=0.12.0,<1.0.0",
]
[project.optional-dependencies]
+wecom = [
+ "wecom-aibot-sdk-python>=0.1.5",
+]
+matrix = [
+ "matrix-nio[e2e]>=0.25.2",
+ "mistune>=3.0.0,<4.0.0",
+ "nh3>=0.2.17,<1.0.0",
+]
+langsmith = [
+ "langsmith>=0.1.0",
+]
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"ruff>=0.1.0",
+ "matrix-nio[e2e]>=0.25.2",
+ "mistune>=3.0.0,<4.0.0",
+ "nh3>=0.2.17,<1.0.0",
]
[project.scripts]
@@ -58,13 +77,9 @@ nanobot = "nanobot.cli.commands:app"
requires = ["hatchling"]
build-backend = "hatchling.build"
-[tool.hatch.build.targets.wheel]
-packages = ["nanobot"]
+[tool.hatch.metadata]
+allow-direct-references = true
-[tool.hatch.build.targets.wheel.sources]
-"nanobot" = "nanobot"
-
-# Include non-Python files in skills and templates
[tool.hatch.build]
include = [
"nanobot/**/*.py",
@@ -73,6 +88,15 @@ include = [
"nanobot/skills/**/*.sh",
]
+[tool.hatch.build.targets.wheel]
+packages = ["nanobot"]
+
+[tool.hatch.build.targets.wheel.sources]
+"nanobot" = "nanobot"
+
+[tool.hatch.build.targets.wheel.force-include]
+"bridge" = "nanobot/bridge"
+
[tool.hatch.build.targets.sdist]
include = [
"nanobot/",
@@ -81,9 +105,6 @@ include = [
"LICENSE",
]
-[tool.hatch.build.targets.wheel.force-include]
-"bridge" = "nanobot/bridge"
-
[tool.ruff]
line-length = 100
target-version = "py311"
diff --git a/tests/test_azure_openai_provider.py b/tests/test_azure_openai_provider.py
new file mode 100644
index 0000000..77f36d4
--- /dev/null
+++ b/tests/test_azure_openai_provider.py
@@ -0,0 +1,399 @@
+"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
+
+from unittest.mock import AsyncMock, Mock, patch
+
+import pytest
+
+from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
+from nanobot.providers.base import LLMResponse
+
+
+def test_azure_openai_provider_init():
+ """Test AzureOpenAIProvider initialization without deployment_name."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o-deployment",
+ )
+
+ assert provider.api_key == "test-key"
+ assert provider.api_base == "https://test-resource.openai.azure.com/"
+ assert provider.default_model == "gpt-4o-deployment"
+ assert provider.api_version == "2024-10-21"
+
+
+def test_azure_openai_provider_init_validation():
+ """Test AzureOpenAIProvider initialization validation."""
+ # Missing api_key
+ with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
+ AzureOpenAIProvider(api_key="", api_base="https://test.com")
+
+ # Missing api_base
+ with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
+ AzureOpenAIProvider(api_key="test", api_base="")
+
+
+def test_build_chat_url():
+ """Test Azure OpenAI URL building with different deployment names."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ # Test various deployment names
+ test_cases = [
+ ("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
+ ("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
+ ("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
+ ]
+
+ for deployment_name, expected_url in test_cases:
+ url = provider._build_chat_url(deployment_name)
+ assert url == expected_url
+
+
+def test_build_chat_url_api_base_without_slash():
+ """Test URL building when api_base doesn't end with slash."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com", # No trailing slash
+ default_model="gpt-4o",
+ )
+
+ url = provider._build_chat_url("test-deployment")
+ expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
+ assert url == expected
+
+
+def test_build_headers():
+ """Test Azure OpenAI header building with api-key authentication."""
+ provider = AzureOpenAIProvider(
+ api_key="test-api-key-123",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ headers = provider._build_headers()
+ assert headers["Content-Type"] == "application/json"
+ assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
+ assert "x-session-affinity" in headers
+
+
+def test_prepare_request_payload():
+ """Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ messages = [{"role": "user", "content": "Hello"}]
+ payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
+
+ assert payload["messages"] == messages
+ assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
+ assert payload["temperature"] == 0.8
+ assert "tools" not in payload
+
+ # Test with tools
+ tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
+ payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
+ assert payload_with_tools["tools"] == tools
+ assert payload_with_tools["tool_choice"] == "auto"
+
+ # Test with reasoning_effort
+ payload_with_reasoning = provider._prepare_request_payload(
+ "gpt-5-chat", messages, reasoning_effort="medium"
+ )
+ assert payload_with_reasoning["reasoning_effort"] == "medium"
+ assert "temperature" not in payload_with_reasoning
+
+
+def test_prepare_request_payload_sanitizes_messages():
+ """Test Azure payload strips non-standard message keys before sending."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ messages = [
+ {
+ "role": "assistant",
+ "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
+ "reasoning_content": "hidden chain-of-thought",
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_123",
+ "name": "x",
+ "content": "ok",
+ "extra_field": "should be removed",
+ },
+ ]
+
+ payload = provider._prepare_request_payload("gpt-4o", messages)
+
+ assert payload["messages"] == [
+ {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_123",
+ "name": "x",
+ "content": "ok",
+ },
+ ]
+
+
+@pytest.mark.asyncio
+async def test_chat_success():
+ """Test successful chat request using model as deployment name."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o-deployment",
+ )
+
+ # Mock response data
+ mock_response_data = {
+ "choices": [{
+ "message": {
+ "content": "Hello! How can I help you today?",
+ "role": "assistant"
+ },
+ "finish_reason": "stop"
+ }],
+ "usage": {
+ "prompt_tokens": 12,
+ "completion_tokens": 18,
+ "total_tokens": 30
+ }
+ }
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.json = Mock(return_value=mock_response_data)
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ # Test with specific model (deployment name)
+ messages = [{"role": "user", "content": "Hello"}]
+ result = await provider.chat(messages, model="custom-deployment")
+
+ assert isinstance(result, LLMResponse)
+ assert result.content == "Hello! How can I help you today?"
+ assert result.finish_reason == "stop"
+ assert result.usage["prompt_tokens"] == 12
+ assert result.usage["completion_tokens"] == 18
+ assert result.usage["total_tokens"] == 30
+
+ # Verify URL was built with the provided model as deployment name
+ call_args = mock_context.post.call_args
+ expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
+ assert call_args[0][0] == expected_url
+
+
+@pytest.mark.asyncio
+async def test_chat_uses_default_model_when_no_model_provided():
+ """Test that chat uses default_model when no model is specified."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="default-deployment",
+ )
+
+ mock_response_data = {
+ "choices": [{
+ "message": {"content": "Response", "role": "assistant"},
+ "finish_reason": "stop"
+ }],
+ "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
+ }
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.json = Mock(return_value=mock_response_data)
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "Test"}]
+ await provider.chat(messages) # No model specified
+
+ # Verify URL was built with default model as deployment name
+ call_args = mock_context.post.call_args
+ expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
+ assert call_args[0][0] == expected_url
+
+
+@pytest.mark.asyncio
+async def test_chat_with_tool_calls():
+ """Test chat request with tool calls in response."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ # Mock response with tool calls
+ mock_response_data = {
+ "choices": [{
+ "message": {
+ "content": None,
+ "role": "assistant",
+ "tool_calls": [{
+ "id": "call_12345",
+ "function": {
+ "name": "get_weather",
+ "arguments": '{"location": "San Francisco"}'
+ }
+ }]
+ },
+ "finish_reason": "tool_calls"
+ }],
+ "usage": {
+ "prompt_tokens": 20,
+ "completion_tokens": 15,
+ "total_tokens": 35
+ }
+ }
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 200
+ mock_response.json = Mock(return_value=mock_response_data)
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "What's the weather?"}]
+ tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
+ result = await provider.chat(messages, tools=tools, model="weather-model")
+
+ assert isinstance(result, LLMResponse)
+ assert result.content is None
+ assert result.finish_reason == "tool_calls"
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].name == "get_weather"
+ assert result.tool_calls[0].arguments == {"location": "San Francisco"}
+
+
+@pytest.mark.asyncio
+async def test_chat_api_error():
+ """Test chat request API error handling."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_response = AsyncMock()
+ mock_response.status_code = 401
+ mock_response.text = "Invalid authentication credentials"
+
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(return_value=mock_response)
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "Hello"}]
+ result = await provider.chat(messages)
+
+ assert isinstance(result, LLMResponse)
+ assert "Azure OpenAI API Error 401" in result.content
+ assert "Invalid authentication credentials" in result.content
+ assert result.finish_reason == "error"
+
+
+@pytest.mark.asyncio
+async def test_chat_connection_error():
+ """Test chat request connection error handling."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ with patch("httpx.AsyncClient") as mock_client:
+ mock_context = AsyncMock()
+ mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
+ mock_client.return_value.__aenter__.return_value = mock_context
+
+ messages = [{"role": "user", "content": "Hello"}]
+ result = await provider.chat(messages)
+
+ assert isinstance(result, LLMResponse)
+ assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
+ assert result.finish_reason == "error"
+
+
+def test_parse_response_malformed():
+ """Test response parsing with malformed data."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o",
+ )
+
+ # Test with missing choices
+ malformed_response = {"usage": {"prompt_tokens": 10}}
+ result = provider._parse_response(malformed_response)
+
+ assert isinstance(result, LLMResponse)
+ assert "Error parsing Azure OpenAI response" in result.content
+ assert result.finish_reason == "error"
+
+
+def test_get_default_model():
+ """Test get_default_model method."""
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="my-custom-deployment",
+ )
+
+ assert provider.get_default_model() == "my-custom-deployment"
+
+
+if __name__ == "__main__":
+ # Run basic tests
+ print("Running basic Azure OpenAI provider tests...")
+
+ # Test initialization
+ provider = AzureOpenAIProvider(
+ api_key="test-key",
+ api_base="https://test-resource.openai.azure.com",
+ default_model="gpt-4o-deployment",
+ )
+ print("✅ Provider initialization successful")
+
+ # Test URL building
+ url = provider._build_chat_url("my-deployment")
+ expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
+ assert url == expected
+ print("✅ URL building works correctly")
+
+ # Test headers
+ headers = provider._build_headers()
+ assert headers["api-key"] == "test-key"
+ assert headers["Content-Type"] == "application/json"
+ print("✅ Header building works correctly")
+
+ # Test payload preparation
+ messages = [{"role": "user", "content": "Test"}]
+ payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
+ assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
+ print("✅ Payload preparation works correctly")
+
+ print("✅ All basic tests passed! Updated test file is working correctly.")
\ No newline at end of file
diff --git a/tests/test_base_channel.py b/tests/test_base_channel.py
new file mode 100644
index 0000000..5d10d4e
--- /dev/null
+++ b/tests/test_base_channel.py
@@ -0,0 +1,25 @@
+from types import SimpleNamespace
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+
+
+class _DummyChannel(BaseChannel):
+ name = "dummy"
+
+ async def start(self) -> None:
+ return None
+
+ async def stop(self) -> None:
+ return None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ return None
+
+
+def test_is_allowed_requires_exact_match() -> None:
+ channel = _DummyChannel(SimpleNamespace(allow_from=["allow@email.com"]), MessageBus())
+
+ assert channel.is_allowed("allow@email.com") is True
+ assert channel.is_allowed("attacker|allow@email.com") is False
diff --git a/tests/test_channel_plugins.py b/tests/test_channel_plugins.py
new file mode 100644
index 0000000..e8a6d49
--- /dev/null
+++ b/tests/test_channel_plugins.py
@@ -0,0 +1,228 @@
+"""Tests for channel plugin discovery, merging, and config compatibility."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from unittest.mock import patch
+
+import pytest
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.channels.manager import ChannelManager
+from nanobot.config.schema import ChannelsConfig
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+class _FakePlugin(BaseChannel):
+ name = "fakeplugin"
+ display_name = "Fake Plugin"
+
+ async def start(self) -> None:
+ pass
+
+ async def stop(self) -> None:
+ pass
+
+ async def send(self, msg: OutboundMessage) -> None:
+ pass
+
+
+class _FakeTelegram(BaseChannel):
+ """Plugin that tries to shadow built-in telegram."""
+ name = "telegram"
+ display_name = "Fake Telegram"
+
+ async def start(self) -> None:
+ pass
+
+ async def stop(self) -> None:
+ pass
+
+ async def send(self, msg: OutboundMessage) -> None:
+ pass
+
+
+def _make_entry_point(name: str, cls: type):
+ """Create a mock entry point that returns *cls* on load()."""
+ ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls)
+ return ep
+
+
+# ---------------------------------------------------------------------------
+# ChannelsConfig extra="allow"
+# ---------------------------------------------------------------------------
+
+def test_channels_config_accepts_unknown_keys():
+ cfg = ChannelsConfig.model_validate({
+ "myplugin": {"enabled": True, "token": "abc"},
+ })
+ extra = cfg.model_extra
+ assert extra is not None
+ assert extra["myplugin"]["enabled"] is True
+ assert extra["myplugin"]["token"] == "abc"
+
+
+def test_channels_config_getattr_returns_extra():
+ cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}})
+ section = getattr(cfg, "myplugin", None)
+ assert isinstance(section, dict)
+ assert section["enabled"] is True
+
+
+def test_channels_config_builtin_fields_removed():
+ """After decoupling, ChannelsConfig has no explicit channel fields."""
+ cfg = ChannelsConfig()
+ assert not hasattr(cfg, "telegram")
+ assert cfg.send_progress is True
+ assert cfg.send_tool_hints is False
+
+
+# ---------------------------------------------------------------------------
+# discover_plugins
+# ---------------------------------------------------------------------------
+
+_EP_TARGET = "importlib.metadata.entry_points"
+
+
+def test_discover_plugins_loads_entry_points():
+ from nanobot.channels.registry import discover_plugins
+
+ ep = _make_entry_point("line", _FakePlugin)
+ with patch(_EP_TARGET, return_value=[ep]):
+ result = discover_plugins()
+
+ assert "line" in result
+ assert result["line"] is _FakePlugin
+
+
+def test_discover_plugins_handles_load_error():
+ from nanobot.channels.registry import discover_plugins
+
+ def _boom():
+ raise RuntimeError("broken")
+
+ ep = SimpleNamespace(name="broken", load=_boom)
+ with patch(_EP_TARGET, return_value=[ep]):
+ result = discover_plugins()
+
+ assert "broken" not in result
+
+
+# ---------------------------------------------------------------------------
+# discover_all — merge & priority
+# ---------------------------------------------------------------------------
+
+def test_discover_all_includes_builtins():
+ from nanobot.channels.registry import discover_all, discover_channel_names
+
+ with patch(_EP_TARGET, return_value=[]):
+ result = discover_all()
+
+ # discover_all() only returns channels that are actually available (dependencies installed)
+ # discover_channel_names() returns all built-in channel names
+ # So we check that all actually loaded channels are in the result
+ for name in result:
+ assert name in discover_channel_names()
+
+
+def test_discover_all_includes_external_plugin():
+ from nanobot.channels.registry import discover_all
+
+ ep = _make_entry_point("line", _FakePlugin)
+ with patch(_EP_TARGET, return_value=[ep]):
+ result = discover_all()
+
+ assert "line" in result
+ assert result["line"] is _FakePlugin
+
+
+def test_discover_all_builtin_shadows_plugin():
+ from nanobot.channels.registry import discover_all
+
+ ep = _make_entry_point("telegram", _FakeTelegram)
+ with patch(_EP_TARGET, return_value=[ep]):
+ result = discover_all()
+
+ assert "telegram" in result
+ assert result["telegram"] is not _FakeTelegram
+
+
+# ---------------------------------------------------------------------------
+# Manager _init_channels with dict config (plugin scenario)
+# ---------------------------------------------------------------------------
+
+@pytest.mark.asyncio
+async def test_manager_loads_plugin_from_dict_config():
+ """ChannelManager should instantiate a plugin channel from a raw dict config."""
+ from nanobot.channels.manager import ChannelManager
+
+ fake_config = SimpleNamespace(
+ channels=ChannelsConfig.model_validate({
+ "fakeplugin": {"enabled": True, "allowFrom": ["*"]},
+ }),
+ providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
+ )
+
+ with patch(
+ "nanobot.channels.registry.discover_all",
+ return_value={"fakeplugin": _FakePlugin},
+ ):
+ mgr = ChannelManager.__new__(ChannelManager)
+ mgr.config = fake_config
+ mgr.bus = MessageBus()
+ mgr.channels = {}
+ mgr._dispatch_task = None
+ mgr._init_channels()
+
+ assert "fakeplugin" in mgr.channels
+ assert isinstance(mgr.channels["fakeplugin"], _FakePlugin)
+
+
+@pytest.mark.asyncio
+async def test_manager_skips_disabled_plugin():
+ fake_config = SimpleNamespace(
+ channels=ChannelsConfig.model_validate({
+ "fakeplugin": {"enabled": False},
+ }),
+ providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
+ )
+
+ with patch(
+ "nanobot.channels.registry.discover_all",
+ return_value={"fakeplugin": _FakePlugin},
+ ):
+ mgr = ChannelManager.__new__(ChannelManager)
+ mgr.config = fake_config
+ mgr.bus = MessageBus()
+ mgr.channels = {}
+ mgr._dispatch_task = None
+ mgr._init_channels()
+
+ assert "fakeplugin" not in mgr.channels
+
+
+# ---------------------------------------------------------------------------
+# Built-in channel default_config() and dict->Pydantic conversion
+# ---------------------------------------------------------------------------
+
+def test_builtin_channel_default_config():
+ """Built-in channels expose default_config() returning a dict with 'enabled': False."""
+ from nanobot.channels.telegram import TelegramChannel
+ cfg = TelegramChannel.default_config()
+ assert isinstance(cfg, dict)
+ assert cfg["enabled"] is False
+ assert "token" in cfg
+
+
+def test_builtin_channel_init_from_dict():
+ """Built-in channels accept a raw dict and convert to Pydantic internally."""
+ from nanobot.channels.telegram import TelegramChannel
+ bus = MessageBus()
+ ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus)
+ assert ch.config.token == "test-tok"
+ assert ch.config.allow_from == ["*"]
diff --git a/tests/test_cli_input.py b/tests/test_cli_input.py
index 9626120..e77bc13 100644
--- a/tests/test_cli_input.py
+++ b/tests/test_cli_input.py
@@ -1,5 +1,5 @@
import asyncio
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
from prompt_toolkit.formatted_text import HTML
@@ -57,3 +57,57 @@ def test_init_prompt_session_creates_session():
_, kwargs = MockSession.call_args
assert kwargs["multiline"] is False
assert kwargs["enable_open_in_editor"] is False
+
+
+def test_thinking_spinner_pause_stops_and_restarts():
+ """Pause should stop the active spinner and restart it afterward."""
+ spinner = MagicMock()
+
+ with patch.object(commands.console, "status", return_value=spinner):
+ thinking = commands._ThinkingSpinner(enabled=True)
+ with thinking:
+ with thinking.pause():
+ pass
+
+ assert spinner.method_calls == [
+ call.start(),
+ call.stop(),
+ call.start(),
+ call.stop(),
+ ]
+
+
+def test_print_cli_progress_line_pauses_spinner_before_printing():
+ """CLI progress output should pause spinner to avoid garbled lines."""
+ order: list[str] = []
+ spinner = MagicMock()
+ spinner.start.side_effect = lambda: order.append("start")
+ spinner.stop.side_effect = lambda: order.append("stop")
+
+ with patch.object(commands.console, "status", return_value=spinner), \
+ patch.object(commands.console, "print", side_effect=lambda *_args, **_kwargs: order.append("print")):
+ thinking = commands._ThinkingSpinner(enabled=True)
+ with thinking:
+ commands._print_cli_progress_line("tool running", thinking)
+
+ assert order == ["start", "stop", "print", "start", "stop"]
+
+
+@pytest.mark.asyncio
+async def test_print_interactive_progress_line_pauses_spinner_before_printing():
+ """Interactive progress output should also pause spinner cleanly."""
+ order: list[str] = []
+ spinner = MagicMock()
+ spinner.start.side_effect = lambda: order.append("start")
+ spinner.stop.side_effect = lambda: order.append("stop")
+
+ async def fake_print(_text: str) -> None:
+ order.append("print")
+
+ with patch.object(commands.console, "status", return_value=spinner), \
+ patch("nanobot.cli.commands._print_interactive_line", side_effect=fake_print):
+ thinking = commands._ThinkingSpinner(enabled=True)
+ with thinking:
+ await commands._print_interactive_progress_line("tool running", thinking)
+
+ assert order == ["start", "stop", "print", "start", "stop"]
diff --git a/tests/test_commands.py b/tests/test_commands.py
index 044d113..b09c955 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -1,26 +1,37 @@
+import re
import shutil
from pathlib import Path
-from unittest.mock import patch
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from typer.testing import CliRunner
-from nanobot.cli.commands import app
+from nanobot.cli.commands import _make_provider, app
from nanobot.config.schema import Config
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model
+
+def _strip_ansi(text):
+ """Remove ANSI escape codes from text."""
+ ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
+ return ansi_escape.sub('', text)
+
runner = CliRunner()
+class _StopGateway(RuntimeError):
+ pass
+
+
@pytest.fixture
def mock_paths():
"""Mock config/workspace paths for test isolation."""
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
patch("nanobot.config.loader.save_config") as mock_sc, \
patch("nanobot.config.loader.load_config") as mock_lc, \
- patch("nanobot.utils.helpers.get_workspace_path") as mock_ws:
+ patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
base_dir = Path("./test_onboard_data")
if base_dir.exists():
@@ -110,6 +121,64 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex"
+def test_config_matches_explicit_ollama_prefix_without_api_key():
+ config = Config()
+ config.agents.defaults.model = "ollama/llama3.2"
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
+ config = Config()
+ config.agents.defaults.provider = "ollama"
+ config.agents.defaults.model = "llama3.2"
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_auto_detects_ollama_from_local_api_base():
+ config = Config.model_validate(
+ {
+ "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
+ "providers": {"ollama": {"apiBase": "http://localhost:11434"}},
+ }
+ )
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
+ config = Config.model_validate(
+ {
+ "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
+ "providers": {
+ "vllm": {"apiBase": "http://localhost:8000"},
+ "ollama": {"apiBase": "http://localhost:11434"},
+ },
+ }
+ )
+
+ assert config.get_provider_name() == "ollama"
+ assert config.get_api_base() == "http://localhost:11434"
+
+
+def test_config_falls_back_to_vllm_when_ollama_not_configured():
+ config = Config.model_validate(
+ {
+ "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
+ "providers": {
+ "vllm": {"apiBase": "http://localhost:8000"},
+ },
+ }
+ )
+
+ assert config.get_provider_name() == "vllm"
+ assert config.get_api_base() == "http://localhost:8000"
+
+
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex")
@@ -128,3 +197,331 @@ def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
+
+
+def test_make_provider_passes_extra_headers_to_custom_provider():
+ config = Config.model_validate(
+ {
+ "agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}},
+ "providers": {
+ "custom": {
+ "apiKey": "test-key",
+ "apiBase": "https://example.com/v1",
+ "extraHeaders": {
+ "APP-Code": "demo-app",
+ "x-session-affinity": "sticky-session",
+ },
+ }
+ },
+ }
+ )
+
+ with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai:
+ _make_provider(config)
+
+ kwargs = mock_async_openai.call_args.kwargs
+ assert kwargs["api_key"] == "test-key"
+ assert kwargs["base_url"] == "https://example.com/v1"
+ assert kwargs["default_headers"]["APP-Code"] == "demo-app"
+ assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session"
+
+
+@pytest.fixture
+def mock_agent_runtime(tmp_path):
+ """Mock agent command dependencies for focused CLI tests."""
+ config = Config()
+ config.agents.defaults.workspace = str(tmp_path / "default-workspace")
+ cron_dir = tmp_path / "data" / "cron"
+
+ with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
+ patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \
+ patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
+ patch("nanobot.cli.commands._make_provider", return_value=object()), \
+ patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
+ patch("nanobot.bus.queue.MessageBus"), \
+ patch("nanobot.cron.service.CronService"), \
+ patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
+
+ agent_loop = MagicMock()
+ agent_loop.channels_config = None
+ agent_loop.process_direct = AsyncMock(return_value="mock-response")
+ agent_loop.close_mcp = AsyncMock(return_value=None)
+ mock_agent_loop_cls.return_value = agent_loop
+
+ yield {
+ "config": config,
+ "load_config": mock_load_config,
+ "sync_templates": mock_sync_templates,
+ "agent_loop_cls": mock_agent_loop_cls,
+ "agent_loop": agent_loop,
+ "print_response": mock_print_response,
+ }
+
+
+def test_agent_help_shows_workspace_and_config_options():
+ result = runner.invoke(app, ["agent", "--help"])
+
+ assert result.exit_code == 0
+ stripped_output = _strip_ansi(result.stdout)
+ assert "--workspace" in stripped_output
+ assert "-w" in stripped_output
+ assert "--config" in stripped_output
+ assert "-c" in stripped_output
+
+
+def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
+ result = runner.invoke(app, ["agent", "-m", "hello"])
+
+ assert result.exit_code == 0
+ assert mock_agent_runtime["load_config"].call_args.args == (None,)
+ assert mock_agent_runtime["sync_templates"].call_args.args == (
+ mock_agent_runtime["config"].workspace_path,
+ )
+ assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == (
+ mock_agent_runtime["config"].workspace_path
+ )
+ mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
+ mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True)
+
+
+def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
+ config_path = tmp_path / "agent-config.json"
+ config_path.write_text("{}")
+
+ result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_path)])
+
+ assert result.exit_code == 0
+ assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
+
+
+def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance" / "config.json"
+ config_file.parent.mkdir(parents=True)
+ config_file.write_text("{}")
+
+ config = Config()
+ seen: dict[str, Path] = {}
+
+ monkeypatch.setattr(
+ "nanobot.config.loader.set_config_path",
+ lambda path: seen.__setitem__("config_path", path),
+ )
+ monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
+ monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
+ monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
+ monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
+ monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
+ monkeypatch.setattr("nanobot.cron.service.CronService", lambda _store: object())
+
+ class _FakeAgentLoop:
+ def __init__(self, *args, **kwargs) -> None:
+ pass
+
+ async def process_direct(self, *_args, **_kwargs) -> str:
+ return "ok"
+
+ async def close_mcp(self) -> None:
+ return None
+
+ monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
+ monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
+
+ result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
+
+ assert result.exit_code == 0
+ assert seen["config_path"] == config_file.resolve()
+
+
+def test_agent_overrides_workspace_path(mock_agent_runtime):
+ workspace_path = Path("/tmp/agent-workspace")
+
+ result = runner.invoke(app, ["agent", "-m", "hello", "-w", str(workspace_path)])
+
+ assert result.exit_code == 0
+ assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
+ assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
+ assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
+
+
+def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, tmp_path: Path):
+ config_path = tmp_path / "agent-config.json"
+ config_path.write_text("{}")
+ workspace_path = Path("/tmp/agent-workspace")
+
+ result = runner.invoke(
+ app,
+ ["agent", "-m", "hello", "-c", str(config_path), "-w", str(workspace_path)],
+ )
+
+ assert result.exit_code == 0
+ assert mock_agent_runtime["load_config"].call_args.args == (config_path.resolve(),)
+ assert mock_agent_runtime["config"].agents.defaults.workspace == str(workspace_path)
+ assert mock_agent_runtime["sync_templates"].call_args.args == (workspace_path,)
+ assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
+
+
+def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
+ mock_agent_runtime["config"].agents.defaults.memory_window = 100
+
+ result = runner.invoke(app, ["agent", "-m", "hello"])
+
+ assert result.exit_code == 0
+ assert "memoryWindow" in result.stdout
+ assert "contextWindowTokens" in result.stdout
+
+
+def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance" / "config.json"
+ config_file.parent.mkdir(parents=True)
+ config_file.write_text("{}")
+
+ config = Config()
+ config.agents.defaults.workspace = str(tmp_path / "config-workspace")
+ seen: dict[str, Path] = {}
+
+ monkeypatch.setattr(
+ "nanobot.config.loader.set_config_path",
+ lambda path: seen.__setitem__("config_path", path),
+ )
+ monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
+ monkeypatch.setattr(
+ "nanobot.cli.commands.sync_workspace_templates",
+ lambda path: seen.__setitem__("workspace", path),
+ )
+ monkeypatch.setattr(
+ "nanobot.cli.commands._make_provider",
+ lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ )
+
+ result = runner.invoke(app, ["gateway", "--config", str(config_file)])
+
+ assert isinstance(result.exception, _StopGateway)
+ assert seen["config_path"] == config_file.resolve()
+ assert seen["workspace"] == Path(config.agents.defaults.workspace)
+
+
+def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance" / "config.json"
+ config_file.parent.mkdir(parents=True)
+ config_file.write_text("{}")
+
+ config = Config()
+ config.agents.defaults.workspace = str(tmp_path / "config-workspace")
+ override = tmp_path / "override-workspace"
+ seen: dict[str, Path] = {}
+
+ monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
+ monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
+ monkeypatch.setattr(
+ "nanobot.cli.commands.sync_workspace_templates",
+ lambda path: seen.__setitem__("workspace", path),
+ )
+ monkeypatch.setattr(
+ "nanobot.cli.commands._make_provider",
+ lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ )
+
+ result = runner.invoke(
+ app,
+ ["gateway", "--config", str(config_file), "--workspace", str(override)],
+ )
+
+ assert isinstance(result.exception, _StopGateway)
+ assert seen["workspace"] == override
+ assert config.workspace_path == override
+
+
+def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance" / "config.json"
+ config_file.parent.mkdir(parents=True)
+ config_file.write_text("{}")
+
+ config = Config()
+ config.agents.defaults.memory_window = 100
+
+ monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
+ monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
+ monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
+ monkeypatch.setattr(
+ "nanobot.cli.commands._make_provider",
+ lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ )
+
+ result = runner.invoke(app, ["gateway", "--config", str(config_file)])
+
+ assert isinstance(result.exception, _StopGateway)
+ assert "memoryWindow" in result.stdout
+ assert "contextWindowTokens" in result.stdout
+
+def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance" / "config.json"
+ config_file.parent.mkdir(parents=True)
+ config_file.write_text("{}")
+
+ config = Config()
+ config.agents.defaults.workspace = str(tmp_path / "config-workspace")
+ seen: dict[str, Path] = {}
+
+ monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
+ monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
+ monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: config_file.parent / "cron")
+ monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
+ monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
+ monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
+ monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
+
+ class _StopCron:
+ def __init__(self, store_path: Path) -> None:
+ seen["cron_store"] = store_path
+ raise _StopGateway("stop")
+
+ monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
+
+ result = runner.invoke(app, ["gateway", "--config", str(config_file)])
+
+ assert isinstance(result.exception, _StopGateway)
+ assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
+
+
+def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance" / "config.json"
+ config_file.parent.mkdir(parents=True)
+ config_file.write_text("{}")
+
+ config = Config()
+ config.gateway.port = 18791
+
+ monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
+ monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
+ monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
+ monkeypatch.setattr(
+ "nanobot.cli.commands._make_provider",
+ lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ )
+
+ result = runner.invoke(app, ["gateway", "--config", str(config_file)])
+
+ assert isinstance(result.exception, _StopGateway)
+ assert "port 18791" in result.stdout
+
+
+def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance" / "config.json"
+ config_file.parent.mkdir(parents=True)
+ config_file.write_text("{}")
+
+ config = Config()
+ config.gateway.port = 18791
+
+ monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
+ monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
+ monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
+ monkeypatch.setattr(
+ "nanobot.cli.commands._make_provider",
+ lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ )
+
+ result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
+
+ assert isinstance(result.exception, _StopGateway)
+ assert "port 18792" in result.stdout
diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py
new file mode 100644
index 0000000..f800fb5
--- /dev/null
+++ b/tests/test_config_migration.py
@@ -0,0 +1,132 @@
+import json
+from types import SimpleNamespace
+
+from typer.testing import CliRunner
+
+from nanobot.cli.commands import app
+from nanobot.config.loader import load_config, save_config
+
+runner = CliRunner()
+
+
+def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
+ config_path = tmp_path / "config.json"
+ config_path.write_text(
+ json.dumps(
+ {
+ "agents": {
+ "defaults": {
+ "maxTokens": 1234,
+ "memoryWindow": 42,
+ }
+ }
+ }
+ ),
+ encoding="utf-8",
+ )
+
+ config = load_config(config_path)
+
+ assert config.agents.defaults.max_tokens == 1234
+ assert config.agents.defaults.context_window_tokens == 65_536
+ assert config.agents.defaults.should_warn_deprecated_memory_window is True
+
+
+def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
+ config_path = tmp_path / "config.json"
+ config_path.write_text(
+ json.dumps(
+ {
+ "agents": {
+ "defaults": {
+ "maxTokens": 2222,
+ "memoryWindow": 30,
+ }
+ }
+ }
+ ),
+ encoding="utf-8",
+ )
+
+ config = load_config(config_path)
+ save_config(config, config_path)
+ saved = json.loads(config_path.read_text(encoding="utf-8"))
+ defaults = saved["agents"]["defaults"]
+
+ assert defaults["maxTokens"] == 2222
+ assert defaults["contextWindowTokens"] == 65_536
+ assert "memoryWindow" not in defaults
+
+
+def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
+ config_path = tmp_path / "config.json"
+ workspace = tmp_path / "workspace"
+ config_path.write_text(
+ json.dumps(
+ {
+ "agents": {
+ "defaults": {
+ "maxTokens": 3333,
+ "memoryWindow": 50,
+ }
+ }
+ }
+ ),
+ encoding="utf-8",
+ )
+
+ monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
+ monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
+
+ result = runner.invoke(app, ["onboard"], input="n\n")
+
+ assert result.exit_code == 0
+ assert "contextWindowTokens" in result.stdout
+ saved = json.loads(config_path.read_text(encoding="utf-8"))
+ defaults = saved["agents"]["defaults"]
+ assert defaults["maxTokens"] == 3333
+ assert defaults["contextWindowTokens"] == 65_536
+ assert "memoryWindow" not in defaults
+
+
+def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
+ config_path = tmp_path / "config.json"
+ workspace = tmp_path / "workspace"
+ config_path.write_text(
+ json.dumps(
+ {
+ "channels": {
+ "qq": {
+ "enabled": False,
+ "appId": "",
+ "secret": "",
+ "allowFrom": [],
+ }
+ }
+ }
+ ),
+ encoding="utf-8",
+ )
+
+ monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
+ monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
+ monkeypatch.setattr(
+ "nanobot.channels.registry.discover_all",
+ lambda: {
+ "qq": SimpleNamespace(
+ default_config=lambda: {
+ "enabled": False,
+ "appId": "",
+ "secret": "",
+ "allowFrom": [],
+ "msgFormat": "plain",
+ }
+ )
+ },
+ )
+
+ result = runner.invoke(app, ["onboard"], input="n\n")
+
+ assert result.exit_code == 0
+ saved = json.loads(config_path.read_text(encoding="utf-8"))
+ assert saved["channels"]["qq"]["msgFormat"] == "plain"
diff --git a/tests/test_config_paths.py b/tests/test_config_paths.py
new file mode 100644
index 0000000..473a6c8
--- /dev/null
+++ b/tests/test_config_paths.py
@@ -0,0 +1,42 @@
+from pathlib import Path
+
+from nanobot.config.paths import (
+ get_bridge_install_dir,
+ get_cli_history_path,
+ get_cron_dir,
+ get_data_dir,
+ get_legacy_sessions_dir,
+ get_logs_dir,
+ get_media_dir,
+ get_runtime_subdir,
+ get_workspace_path,
+)
+
+
+def test_runtime_dirs_follow_config_path(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance-a" / "config.json"
+ monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
+
+ assert get_data_dir() == config_file.parent
+ assert get_runtime_subdir("cron") == config_file.parent / "cron"
+ assert get_cron_dir() == config_file.parent / "cron"
+ assert get_logs_dir() == config_file.parent / "logs"
+
+
+def test_media_dir_supports_channel_namespace(monkeypatch, tmp_path: Path) -> None:
+ config_file = tmp_path / "instance-b" / "config.json"
+ monkeypatch.setattr("nanobot.config.paths.get_config_path", lambda: config_file)
+
+ assert get_media_dir() == config_file.parent / "media"
+ assert get_media_dir("telegram") == config_file.parent / "media" / "telegram"
+
+
+def test_shared_and_legacy_paths_remain_global() -> None:
+ assert get_cli_history_path() == Path.home() / ".nanobot" / "history" / "cli_history"
+ assert get_bridge_install_dir() == Path.home() / ".nanobot" / "bridge"
+ assert get_legacy_sessions_dir() == Path.home() / ".nanobot" / "sessions"
+
+
+def test_workspace_path_is_explicitly_resolved() -> None:
+ assert get_workspace_path() == Path.home() / ".nanobot" / "workspace"
+ assert get_workspace_path("~/custom-workspace") == Path.home() / "custom-workspace"
diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py
index 323519e..21e1e78 100644
--- a/tests/test_consolidate_offset.py
+++ b/tests/test_consolidate_offset.py
@@ -480,349 +480,140 @@ class TestEmptyAndBoundarySessions:
assert_messages_content(old_messages, 10, 34)
-class TestConsolidationDeduplicationGuard:
- """Test that consolidation tasks are deduplicated and serialized."""
+class TestNewCommandArchival:
+ """Test /new archival behavior with the simplified consolidation flow."""
- @pytest.mark.asyncio
- async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
- """Concurrent messages above memory_window spawn only one consolidation task."""
+ @staticmethod
+ def _make_loop(tmp_path: Path):
from nanobot.agent.loop import AgentLoop
- from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
+ provider.estimate_prompt_tokens.return_value = (10_000, "test")
loop = AgentLoop(
- bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
+ bus=bus,
+ provider=provider,
+ workspace=tmp_path,
+ model="test-model",
+ context_window_tokens=1,
)
-
- loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
+ loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
-
- session = loop.sessions.get_or_create("cli:test")
- for i in range(15):
- session.add_message("user", f"msg{i}")
- session.add_message("assistant", f"resp{i}")
- loop.sessions.save(session)
-
- consolidation_calls = 0
-
- async def _fake_consolidate(_session, archive_all: bool = False) -> None:
- nonlocal consolidation_calls
- consolidation_calls += 1
- await asyncio.sleep(0.05)
-
- loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
-
- msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
- await loop._process_message(msg)
- await loop._process_message(msg)
- await asyncio.sleep(0.1)
-
- assert consolidation_calls == 1, (
- f"Expected exactly 1 consolidation, got {consolidation_calls}"
- )
+ return loop
@pytest.mark.asyncio
- async def test_new_command_guard_prevents_concurrent_consolidation(
- self, tmp_path: Path
- ) -> None:
- """/new command does not run consolidation concurrently with in-flight consolidation."""
- from nanobot.agent.loop import AgentLoop
+ async def test_new_clears_session_immediately_even_if_archive_fails(self, tmp_path: Path) -> None:
+ """/new clears session immediately; archive_messages retries until raw dump."""
from nanobot.bus.events import InboundMessage
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.base import LLMResponse
-
- bus = MessageBus()
- provider = MagicMock()
- provider.get_default_model.return_value = "test-model"
- loop = AgentLoop(
- bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
- )
-
- loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
- loop.tools.get_definitions = MagicMock(return_value=[])
-
- session = loop.sessions.get_or_create("cli:test")
- for i in range(15):
- session.add_message("user", f"msg{i}")
- session.add_message("assistant", f"resp{i}")
- loop.sessions.save(session)
-
- consolidation_calls = 0
- active = 0
- max_active = 0
-
- async def _fake_consolidate(_session, archive_all: bool = False) -> None:
- nonlocal consolidation_calls, active, max_active
- consolidation_calls += 1
- active += 1
- max_active = max(max_active, active)
- await asyncio.sleep(0.05)
- active -= 1
-
- loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
-
- msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
- await loop._process_message(msg)
-
- new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
- await loop._process_message(new_msg)
- await asyncio.sleep(0.1)
-
- assert consolidation_calls == 2, (
- f"Expected normal + /new consolidations, got {consolidation_calls}"
- )
- assert max_active == 1, (
- f"Expected serialized consolidation, observed concurrency={max_active}"
- )
-
- @pytest.mark.asyncio
- async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
- """create_task results are tracked in _consolidation_tasks while in flight."""
- from nanobot.agent.loop import AgentLoop
- from nanobot.bus.events import InboundMessage
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.base import LLMResponse
-
- bus = MessageBus()
- provider = MagicMock()
- provider.get_default_model.return_value = "test-model"
- loop = AgentLoop(
- bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
- )
-
- loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
- loop.tools.get_definitions = MagicMock(return_value=[])
-
- session = loop.sessions.get_or_create("cli:test")
- for i in range(15):
- session.add_message("user", f"msg{i}")
- session.add_message("assistant", f"resp{i}")
- loop.sessions.save(session)
-
- started = asyncio.Event()
-
- async def _slow_consolidate(_session, archive_all: bool = False) -> None:
- started.set()
- await asyncio.sleep(0.1)
-
- loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
-
- msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
- await loop._process_message(msg)
-
- await started.wait()
- assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
-
- await asyncio.sleep(0.15)
- assert len(loop._consolidation_tasks) == 0, (
- "Task reference must be removed after completion"
- )
-
- @pytest.mark.asyncio
- async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
- self, tmp_path: Path
- ) -> None:
- """/new waits for in-flight consolidation and archives before clear."""
- from nanobot.agent.loop import AgentLoop
- from nanobot.bus.events import InboundMessage
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.base import LLMResponse
-
- bus = MessageBus()
- provider = MagicMock()
- provider.get_default_model.return_value = "test-model"
- loop = AgentLoop(
- bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
- )
-
- loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
- loop.tools.get_definitions = MagicMock(return_value=[])
-
- session = loop.sessions.get_or_create("cli:test")
- for i in range(15):
- session.add_message("user", f"msg{i}")
- session.add_message("assistant", f"resp{i}")
- loop.sessions.save(session)
-
- started = asyncio.Event()
- release = asyncio.Event()
- archived_count = 0
-
- async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
- nonlocal archived_count
- if archive_all:
- archived_count = len(sess.messages)
- return True
- started.set()
- await release.wait()
- return True
-
- loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
-
- msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
- await loop._process_message(msg)
- await started.wait()
-
- new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
- pending_new = asyncio.create_task(loop._process_message(new_msg))
-
- await asyncio.sleep(0.02)
- assert not pending_new.done(), "/new should wait while consolidation is in-flight"
-
- release.set()
- response = await pending_new
- assert response is not None
- assert "new session started" in response.content.lower()
- assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
-
- session_after = loop.sessions.get_or_create("cli:test")
- assert session_after.messages == [], "Session should be cleared after successful archival"
-
- @pytest.mark.asyncio
- async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
- """/new must keep session data if archive step reports failure."""
- from nanobot.agent.loop import AgentLoop
- from nanobot.bus.events import InboundMessage
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.base import LLMResponse
-
- bus = MessageBus()
- provider = MagicMock()
- provider.get_default_model.return_value = "test-model"
- loop = AgentLoop(
- bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
- )
-
- loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
- loop.tools.get_definitions = MagicMock(return_value=[])
+ loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
- before_count = len(session.messages)
- async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
- if archive_all:
- return False
- return True
+ call_count = 0
- loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
+ async def _failing_consolidate(_messages) -> bool:
+ nonlocal call_count
+ call_count += 1
+ return False
+
+ loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
- assert "failed" in response.content.lower()
+ assert "new session started" in response.content.lower()
+
session_after = loop.sessions.get_or_create("cli:test")
- assert len(session_after.messages) == before_count, (
- "Session must remain intact when /new archival fails"
- )
+ assert len(session_after.messages) == 0
+
+ await loop.close_mcp()
+ assert call_count == 3 # retried up to raw-archive threshold
@pytest.mark.asyncio
- async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
- self, tmp_path: Path
- ) -> None:
- """/new should archive only messages not yet consolidated by prior task."""
- from nanobot.agent.loop import AgentLoop
+ async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
from nanobot.bus.events import InboundMessage
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.base import LLMResponse
-
- bus = MessageBus()
- provider = MagicMock()
- provider.get_default_model.return_value = "test-model"
- loop = AgentLoop(
- bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
- )
-
- loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
- loop.tools.get_definitions = MagicMock(return_value=[])
+ loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
+ session.last_consolidated = len(session.messages) - 3
loop.sessions.save(session)
- started = asyncio.Event()
- release = asyncio.Event()
archived_count = -1
- async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
+ async def _fake_consolidate(messages) -> bool:
nonlocal archived_count
- if archive_all:
- archived_count = len(sess.messages)
- return True
-
- started.set()
- await release.wait()
- sess.last_consolidated = len(sess.messages) - 3
+ archived_count = len(messages)
return True
- loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
-
- msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
- await loop._process_message(msg)
- await started.wait()
+ loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
- pending_new = asyncio.create_task(loop._process_message(new_msg))
- await asyncio.sleep(0.02)
- assert not pending_new.done()
-
- release.set()
- response = await pending_new
+ response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
- assert archived_count == 3, (
- f"Expected only unconsolidated tail to archive, got {archived_count}"
- )
+
+ await loop.close_mcp()
+ assert archived_count == 3
@pytest.mark.asyncio
- async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
- self, tmp_path: Path
- ) -> None:
- """/new should remove lock entry for fully invalidated session key."""
- from nanobot.agent.loop import AgentLoop
+ async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
from nanobot.bus.events import InboundMessage
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.base import LLMResponse
-
- bus = MessageBus()
- provider = MagicMock()
- provider.get_default_model.return_value = "test-model"
- loop = AgentLoop(
- bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
- )
-
- loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
- loop.tools.get_definitions = MagicMock(return_value=[])
+ loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
- # Ensure lock exists before /new.
- _ = loop._get_consolidation_lock(session.key)
- assert session.key in loop._consolidation_locks
-
- async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
+ async def _ok_consolidate(_messages) -> bool:
return True
- loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
+ loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
- assert session.key not in loop._consolidation_locks
+ assert loop.sessions.get_or_create("cli:test").messages == []
+
+ @pytest.mark.asyncio
+ async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
+ """close_mcp waits for background tasks to complete."""
+ from nanobot.bus.events import InboundMessage
+
+ loop = self._make_loop(tmp_path)
+ session = loop.sessions.get_or_create("cli:test")
+ for i in range(3):
+ session.add_message("user", f"msg{i}")
+ session.add_message("assistant", f"resp{i}")
+ loop.sessions.save(session)
+
+ archived = asyncio.Event()
+
+ async def _slow_consolidate(_messages) -> bool:
+ await asyncio.sleep(0.1)
+ archived.set()
+ return True
+
+ loop.memory_consolidator.consolidate_messages = _slow_consolidate # type: ignore[method-assign]
+
+ new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
+ await loop._process_message(new_msg)
+
+ assert not archived.is_set()
+ await loop.close_mcp()
+ assert archived.is_set()
diff --git a/tests/test_context_prompt_cache.py b/tests/test_context_prompt_cache.py
index 8e2333c..6eb4b4f 100644
--- a/tests/test_context_prompt_cache.py
+++ b/tests/test_context_prompt_cache.py
@@ -3,6 +3,7 @@
from __future__ import annotations
from datetime import datetime as real_datetime
+from importlib.resources import files as pkg_files
from pathlib import Path
import datetime as datetime_module
@@ -23,6 +24,13 @@ def _make_workspace(tmp_path: Path) -> Path:
return workspace
+def test_bootstrap_files_are_backed_by_templates() -> None:
+ template_dir = pkg_files("nanobot") / "templates"
+
+ for filename in ContextBuilder.BOOTSTRAP_FILES:
+ assert (template_dir / filename).is_file(), f"missing bootstrap template: {filename}"
+
+
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
"""System prompt should not change just because wall clock minute changes."""
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
@@ -39,8 +47,8 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
assert prompt1 == prompt2
-def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None:
- """Dynamic runtime details should be added at the tail user message, not system."""
+def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
+ """Runtime metadata should be merged with the user message."""
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
@@ -54,10 +62,12 @@ def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None:
assert messages[0]["role"] == "system"
assert "## Current Session" not in messages[0]["content"]
+ # Runtime context is now merged with user message into a single message
assert messages[-1]["role"] == "user"
user_content = messages[-1]["content"]
assert isinstance(user_content, str)
- assert "Return exactly: OK" in user_content
+ assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
assert "Current Time:" in user_content
assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content
+ assert "Return exactly: OK" in user_content
diff --git a/tests/test_cron_commands.py b/tests/test_cron_commands.py
deleted file mode 100644
index bce1ef5..0000000
--- a/tests/test_cron_commands.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from typer.testing import CliRunner
-
-from nanobot.cli.commands import app
-
-runner = CliRunner()
-
-
-def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
- monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
-
- result = runner.invoke(
- app,
- [
- "cron",
- "add",
- "--name",
- "demo",
- "--message",
- "hello",
- "--cron",
- "0 9 * * *",
- "--tz",
- "America/Vancovuer",
- ],
- )
-
- assert result.exit_code == 1
- assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
- assert not (tmp_path / "cron" / "jobs.json").exists()
diff --git a/tests/test_cron_service.py b/tests/test_cron_service.py
index 07e990a..9631da5 100644
--- a/tests/test_cron_service.py
+++ b/tests/test_cron_service.py
@@ -1,3 +1,5 @@
+import asyncio
+
import pytest
from nanobot.cron.service import CronService
@@ -28,3 +30,32 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
assert job.schedule.tz == "America/Vancouver"
assert job.state.next_run_at_ms is not None
+
+
+@pytest.mark.asyncio
+async def test_running_service_honors_external_disable(tmp_path) -> None:
+ store_path = tmp_path / "cron" / "jobs.json"
+ called: list[str] = []
+
+ async def on_job(job) -> None:
+ called.append(job.id)
+
+ service = CronService(store_path, on_job=on_job)
+ job = service.add_job(
+ name="external-disable",
+ schedule=CronSchedule(kind="every", every_ms=200),
+ message="hello",
+ )
+ await service.start()
+ try:
+ # Wait slightly to ensure file mtime is definitively different
+ await asyncio.sleep(0.05)
+ external = CronService(store_path)
+ updated = external.enable_job(job.id, enabled=False)
+ assert updated is not None
+ assert updated.enabled is False
+
+ await asyncio.sleep(0.35)
+ assert called == []
+ finally:
+ service.stop()
diff --git a/tests/test_dingtalk_channel.py b/tests/test_dingtalk_channel.py
new file mode 100644
index 0000000..a0b866f
--- /dev/null
+++ b/tests/test_dingtalk_channel.py
@@ -0,0 +1,213 @@
+import asyncio
+from types import SimpleNamespace
+
+import pytest
+
+from nanobot.bus.queue import MessageBus
+import nanobot.channels.dingtalk as dingtalk_module
+from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
+from nanobot.channels.dingtalk import DingTalkConfig
+
+
+class _FakeResponse:
+ def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
+ self.status_code = status_code
+ self._json_body = json_body or {}
+ self.text = "{}"
+ self.content = b""
+ self.headers = {"content-type": "application/json"}
+
+ def json(self) -> dict:
+ return self._json_body
+
+
+class _FakeHttp:
+ def __init__(self, responses: list[_FakeResponse] | None = None) -> None:
+ self.calls: list[dict] = []
+ self._responses = list(responses) if responses else []
+
+ def _next_response(self) -> _FakeResponse:
+ if self._responses:
+ return self._responses.pop(0)
+ return _FakeResponse()
+
+ async def post(self, url: str, json=None, headers=None, **kwargs):
+ self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
+ return self._next_response()
+
+ async def get(self, url: str, **kwargs):
+ self.calls.append({"method": "GET", "url": url})
+ return self._next_response()
+
+
+@pytest.mark.asyncio
+async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
+ config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
+ bus = MessageBus()
+ channel = DingTalkChannel(config, bus)
+
+ await channel._on_message(
+ "hello",
+ sender_id="user1",
+ sender_name="Alice",
+ conversation_type="2",
+ conversation_id="conv123",
+ )
+
+ msg = await bus.consume_inbound()
+ assert msg.sender_id == "user1"
+ assert msg.chat_id == "group:conv123"
+ assert msg.metadata["conversation_type"] == "2"
+
+
+@pytest.mark.asyncio
+async def test_group_send_uses_group_messages_api() -> None:
+ config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
+ channel = DingTalkChannel(config, MessageBus())
+ channel._http = _FakeHttp()
+
+ ok = await channel._send_batch_message(
+ "token",
+ "group:conv123",
+ "sampleMarkdown",
+ {"text": "hello", "title": "Nanobot Reply"},
+ )
+
+ assert ok is True
+ call = channel._http.calls[0]
+ assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
+ assert call["json"]["openConversationId"] == "conv123"
+ assert call["json"]["msgKey"] == "sampleMarkdown"
+
+
+@pytest.mark.asyncio
+async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
+ bus = MessageBus()
+ channel = DingTalkChannel(
+ DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
+ bus,
+ )
+ handler = NanobotDingTalkHandler(channel)
+
+ class _FakeChatbotMessage:
+ text = None
+ extensions = {"content": {"recognition": "voice transcript"}}
+ sender_staff_id = "user1"
+ sender_id = "fallback-user"
+ sender_nick = "Alice"
+ message_type = "audio"
+
+ @staticmethod
+ def from_dict(_data):
+ return _FakeChatbotMessage()
+
+ monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
+ monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
+
+ status, body = await handler.process(
+ SimpleNamespace(
+ data={
+ "conversationType": "2",
+ "conversationId": "conv123",
+ "text": {"content": ""},
+ }
+ )
+ )
+
+ await asyncio.gather(*list(channel._background_tasks))
+ msg = await bus.consume_inbound()
+
+ assert (status, body) == ("OK", "OK")
+ assert msg.content == "voice transcript"
+ assert msg.sender_id == "user1"
+ assert msg.chat_id == "group:conv123"
+
+
+@pytest.mark.asyncio
+async def test_handler_processes_file_message(monkeypatch) -> None:
+ """Test that file messages are handled and forwarded with downloaded path."""
+ bus = MessageBus()
+ channel = DingTalkChannel(
+ DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
+ bus,
+ )
+ handler = NanobotDingTalkHandler(channel)
+
+ class _FakeFileChatbotMessage:
+ text = None
+ extensions = {}
+ image_content = None
+ rich_text_content = None
+ sender_staff_id = "user1"
+ sender_id = "fallback-user"
+ sender_nick = "Alice"
+ message_type = "file"
+
+ @staticmethod
+ def from_dict(_data):
+ return _FakeFileChatbotMessage()
+
+ async def fake_download(download_code, filename, sender_id):
+ return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}"
+
+ monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage)
+ monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
+ monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download)
+
+ status, body = await handler.process(
+ SimpleNamespace(
+ data={
+ "conversationType": "1",
+ "content": {"downloadCode": "abc123", "fileName": "report.xlsx"},
+ "text": {"content": ""},
+ }
+ )
+ )
+
+ await asyncio.gather(*list(channel._background_tasks))
+ msg = await bus.consume_inbound()
+
+ assert (status, body) == ("OK", "OK")
+ assert "[File]" in msg.content
+ assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content
+
+
+@pytest.mark.asyncio
+async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
+ """Test the two-step file download flow (get URL then download content)."""
+ channel = DingTalkChannel(
+ DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
+ MessageBus(),
+ )
+
+ # Mock access token
+ async def fake_get_token():
+ return "test-token"
+
+ monkeypatch.setattr(channel, "_get_access_token", fake_get_token)
+
+ # Mock HTTP: first POST returns downloadUrl, then GET returns file bytes
+ file_content = b"fake file content"
+ channel._http = _FakeHttp(responses=[
+ _FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}),
+ _FakeResponse(200),
+ ])
+ channel._http._responses[1].content = file_content
+
+ # Redirect media dir to tmp_path
+ monkeypatch.setattr(
+ "nanobot.config.paths.get_media_dir",
+ lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path,
+ )
+
+ result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1")
+
+ assert result is not None
+ assert result.endswith("test.xlsx")
+ assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content
+
+ # Verify API calls
+ assert channel._http.calls[0]["method"] == "POST"
+ assert "messageFiles/download" in channel._http.calls[0]["url"]
+ assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
+ assert channel._http.calls[1]["method"] == "GET"
diff --git a/tests/test_email_channel.py b/tests/test_email_channel.py
index adf35a8..c037ace 100644
--- a/tests/test_email_channel.py
+++ b/tests/test_email_channel.py
@@ -6,7 +6,7 @@ import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.email import EmailChannel
-from nanobot.config.schema import EmailConfig
+from nanobot.channels.email import EmailConfig
def _make_config() -> EmailConfig:
diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py
new file mode 100644
index 0000000..08d068b
--- /dev/null
+++ b/tests/test_evaluator.py
@@ -0,0 +1,63 @@
+import pytest
+
+from nanobot.utils.evaluator import evaluate_response
+from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
+
+
+class DummyProvider(LLMProvider):
+ def __init__(self, responses: list[LLMResponse]):
+ super().__init__()
+ self._responses = list(responses)
+
+ async def chat(self, *args, **kwargs) -> LLMResponse:
+ if self._responses:
+ return self._responses.pop(0)
+ return LLMResponse(content="", tool_calls=[])
+
+ def get_default_model(self) -> str:
+ return "test-model"
+
+
+def _eval_tool_call(should_notify: bool, reason: str = "") -> LLMResponse:
+ return LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="eval_1",
+ name="evaluate_notification",
+ arguments={"should_notify": should_notify, "reason": reason},
+ )
+ ],
+ )
+
+
+@pytest.mark.asyncio
+async def test_should_notify_true() -> None:
+ provider = DummyProvider([_eval_tool_call(True, "user asked to be reminded")])
+ result = await evaluate_response("Task completed with results", "check emails", provider, "m")
+ assert result is True
+
+
+@pytest.mark.asyncio
+async def test_should_notify_false() -> None:
+ provider = DummyProvider([_eval_tool_call(False, "routine check, nothing new")])
+ result = await evaluate_response("All clear, no updates", "check status", provider, "m")
+ assert result is False
+
+
+@pytest.mark.asyncio
+async def test_fallback_on_error() -> None:
+ class FailingProvider(DummyProvider):
+ async def chat(self, *args, **kwargs) -> LLMResponse:
+ raise RuntimeError("provider down")
+
+ provider = FailingProvider([])
+ result = await evaluate_response("some response", "some task", provider, "m")
+ assert result is True
+
+
+@pytest.mark.asyncio
+async def test_no_tool_call_fallback() -> None:
+ provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
+ result = await evaluate_response("some response", "some task", provider, "m")
+ assert result is True
diff --git a/tests/test_exec_security.py b/tests/test_exec_security.py
new file mode 100644
index 0000000..e65d575
--- /dev/null
+++ b/tests/test_exec_security.py
@@ -0,0 +1,69 @@
+"""Tests for exec tool internal URL blocking."""
+
+from __future__ import annotations
+
+import socket
+from unittest.mock import patch
+
+import pytest
+
+from nanobot.agent.tools.shell import ExecTool
+
+
+def _fake_resolve_private(hostname, port, family=0, type_=0):
+ return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
+
+
+def _fake_resolve_localhost(hostname, port, family=0, type_=0):
+ return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
+
+
+def _fake_resolve_public(hostname, port, family=0, type_=0):
+ return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
+
+
+@pytest.mark.asyncio
+async def test_exec_blocks_curl_metadata():
+ tool = ExecTool()
+ with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
+ result = await tool.execute(
+ command='curl -s -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/'
+ )
+ assert "Error" in result
+ assert "internal" in result.lower() or "private" in result.lower()
+
+
+@pytest.mark.asyncio
+async def test_exec_blocks_wget_localhost():
+ tool = ExecTool()
+ with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost):
+ result = await tool.execute(command="wget http://localhost:8080/secret -O /tmp/out")
+ assert "Error" in result
+
+
+@pytest.mark.asyncio
+async def test_exec_allows_normal_commands():
+ tool = ExecTool(timeout=5)
+ result = await tool.execute(command="echo hello")
+ assert "hello" in result
+ assert "Error" not in result.split("\n")[0]
+
+
+@pytest.mark.asyncio
+async def test_exec_allows_curl_to_public_url():
+ """Commands with public URLs should not be blocked by the internal URL check."""
+ tool = ExecTool()
+ with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
+ guard_result = tool._guard_command("curl https://example.com/api", "/tmp")
+ assert guard_result is None
+
+
+@pytest.mark.asyncio
+async def test_exec_blocks_chained_internal_url():
+ """Internal URLs buried in chained commands should still be caught."""
+ tool = ExecTool()
+ with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
+ result = await tool.execute(
+ command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
+ )
+ assert "Error" in result
diff --git a/tests/test_feishu_post_content.py b/tests/test_feishu_post_content.py
new file mode 100644
index 0000000..7b1cb9d
--- /dev/null
+++ b/tests/test_feishu_post_content.py
@@ -0,0 +1,65 @@
+from nanobot.channels.feishu import FeishuChannel, _extract_post_content
+
+
+def test_extract_post_content_supports_post_wrapper_shape() -> None:
+ payload = {
+ "post": {
+ "zh_cn": {
+ "title": "日报",
+ "content": [
+ [
+ {"tag": "text", "text": "完成"},
+ {"tag": "img", "image_key": "img_1"},
+ ]
+ ],
+ }
+ }
+ }
+
+ text, image_keys = _extract_post_content(payload)
+
+ assert text == "日报 完成"
+ assert image_keys == ["img_1"]
+
+
+def test_extract_post_content_keeps_direct_shape_behavior() -> None:
+ payload = {
+ "title": "Daily",
+ "content": [
+ [
+ {"tag": "text", "text": "report"},
+ {"tag": "img", "image_key": "img_a"},
+ {"tag": "img", "image_key": "img_b"},
+ ]
+ ],
+ }
+
+ text, image_keys = _extract_post_content(payload)
+
+ assert text == "Daily report"
+ assert image_keys == ["img_a", "img_b"]
+
+
+def test_register_optional_event_keeps_builder_when_method_missing() -> None:
+ class Builder:
+ pass
+
+ builder = Builder()
+ same = FeishuChannel._register_optional_event(builder, "missing", object())
+ assert same is builder
+
+
+def test_register_optional_event_calls_supported_method() -> None:
+ called = []
+
+ class Builder:
+ def register_event(self, handler):
+ called.append(handler)
+ return self
+
+ builder = Builder()
+ handler = object()
+ same = FeishuChannel._register_optional_event(builder, "register_event", handler)
+
+ assert same is builder
+ assert called == [handler]
diff --git a/tests/test_feishu_reply.py b/tests/test_feishu_reply.py
new file mode 100644
index 0000000..65d7f86
--- /dev/null
+++ b/tests/test_feishu_reply.py
@@ -0,0 +1,392 @@
+"""Tests for Feishu message reply (quote) feature."""
+import asyncio
+import json
+from types import SimpleNamespace
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.feishu import FeishuChannel, FeishuConfig
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
+ config = FeishuConfig(
+ enabled=True,
+ app_id="cli_test",
+ app_secret="secret",
+ allow_from=["*"],
+ reply_to_message=reply_to_message,
+ )
+ channel = FeishuChannel(config, MessageBus())
+ channel._client = MagicMock()
+ # _loop is only used by the WebSocket thread bridge; not needed for unit tests
+ channel._loop = None
+ return channel
+
+
+def _make_feishu_event(
+ *,
+ message_id: str = "om_001",
+ chat_id: str = "oc_abc",
+ chat_type: str = "p2p",
+ msg_type: str = "text",
+ content: str = '{"text": "hello"}',
+ sender_open_id: str = "ou_alice",
+ parent_id: str | None = None,
+ root_id: str | None = None,
+):
+ message = SimpleNamespace(
+ message_id=message_id,
+ chat_id=chat_id,
+ chat_type=chat_type,
+ message_type=msg_type,
+ content=content,
+ parent_id=parent_id,
+ root_id=root_id,
+ mentions=[],
+ )
+ sender = SimpleNamespace(
+ sender_type="user",
+ sender_id=SimpleNamespace(open_id=sender_open_id),
+ )
+ return SimpleNamespace(event=SimpleNamespace(message=message, sender=sender))
+
+
+def _make_get_message_response(text: str, msg_type: str = "text", success: bool = True):
+ """Build a fake im.v1.message.get response object."""
+ body = SimpleNamespace(content=json.dumps({"text": text}))
+ item = SimpleNamespace(msg_type=msg_type, body=body)
+ data = SimpleNamespace(items=[item])
+ resp = MagicMock()
+ resp.success.return_value = success
+ resp.data = data
+ resp.code = 0
+ resp.msg = "ok"
+ return resp
+
+
+# ---------------------------------------------------------------------------
+# Config tests
+# ---------------------------------------------------------------------------
+
+def test_feishu_config_reply_to_message_defaults_false() -> None:
+ assert FeishuConfig().reply_to_message is False
+
+
+def test_feishu_config_reply_to_message_can_be_enabled() -> None:
+ config = FeishuConfig(reply_to_message=True)
+ assert config.reply_to_message is True
+
+
+# ---------------------------------------------------------------------------
+# _get_message_content_sync tests
+# ---------------------------------------------------------------------------
+
+def test_get_message_content_sync_returns_reply_prefix() -> None:
+ channel = _make_feishu_channel()
+ channel._client.im.v1.message.get.return_value = _make_get_message_response("what time is it?")
+
+ result = channel._get_message_content_sync("om_parent")
+
+ assert result == "[Reply to: what time is it?]"
+
+
+def test_get_message_content_sync_truncates_long_text() -> None:
+ channel = _make_feishu_channel()
+ long_text = "x" * (FeishuChannel._REPLY_CONTEXT_MAX_LEN + 50)
+ channel._client.im.v1.message.get.return_value = _make_get_message_response(long_text)
+
+ result = channel._get_message_content_sync("om_parent")
+
+ assert result is not None
+ assert result.endswith("...]")
+ inner = result[len("[Reply to: ") : -1]
+ assert len(inner) == FeishuChannel._REPLY_CONTEXT_MAX_LEN + len("...")
+
+
+def test_get_message_content_sync_returns_none_on_api_failure() -> None:
+ channel = _make_feishu_channel()
+ resp = MagicMock()
+ resp.success.return_value = False
+ resp.code = 230002
+ resp.msg = "bot not in group"
+ channel._client.im.v1.message.get.return_value = resp
+
+ result = channel._get_message_content_sync("om_parent")
+
+ assert result is None
+
+
+def test_get_message_content_sync_returns_none_for_non_text_type() -> None:
+ channel = _make_feishu_channel()
+ body = SimpleNamespace(content=json.dumps({"image_key": "img_1"}))
+ item = SimpleNamespace(msg_type="image", body=body)
+ data = SimpleNamespace(items=[item])
+ resp = MagicMock()
+ resp.success.return_value = True
+ resp.data = data
+ channel._client.im.v1.message.get.return_value = resp
+
+ result = channel._get_message_content_sync("om_parent")
+
+ assert result is None
+
+
+def test_get_message_content_sync_returns_none_when_empty_text() -> None:
+ channel = _make_feishu_channel()
+ channel._client.im.v1.message.get.return_value = _make_get_message_response(" ")
+
+ result = channel._get_message_content_sync("om_parent")
+
+ assert result is None
+
+
+# ---------------------------------------------------------------------------
+# _reply_message_sync tests
+# ---------------------------------------------------------------------------
+
+def test_reply_message_sync_returns_true_on_success() -> None:
+ channel = _make_feishu_channel()
+ resp = MagicMock()
+ resp.success.return_value = True
+ channel._client.im.v1.message.reply.return_value = resp
+
+ ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
+
+ assert ok is True
+ channel._client.im.v1.message.reply.assert_called_once()
+
+
+def test_reply_message_sync_returns_false_on_api_error() -> None:
+ channel = _make_feishu_channel()
+ resp = MagicMock()
+ resp.success.return_value = False
+ resp.code = 400
+ resp.msg = "bad request"
+ resp.get_log_id.return_value = "log_x"
+ channel._client.im.v1.message.reply.return_value = resp
+
+ ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
+
+ assert ok is False
+
+
+def test_reply_message_sync_returns_false_on_exception() -> None:
+ channel = _make_feishu_channel()
+ channel._client.im.v1.message.reply.side_effect = RuntimeError("network error")
+
+ ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
+
+ assert ok is False
+
+
+# ---------------------------------------------------------------------------
+# send() — reply routing tests
+# ---------------------------------------------------------------------------
+
+@pytest.mark.asyncio
+async def test_send_uses_reply_api_when_configured() -> None:
+ channel = _make_feishu_channel(reply_to_message=True)
+
+ reply_resp = MagicMock()
+ reply_resp.success.return_value = True
+ channel._client.im.v1.message.reply.return_value = reply_resp
+
+ await channel.send(OutboundMessage(
+ channel="feishu",
+ chat_id="oc_abc",
+ content="hello",
+ metadata={"message_id": "om_001"},
+ ))
+
+ channel._client.im.v1.message.reply.assert_called_once()
+ channel._client.im.v1.message.create.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_send_uses_create_api_when_reply_disabled() -> None:
+ channel = _make_feishu_channel(reply_to_message=False)
+
+ create_resp = MagicMock()
+ create_resp.success.return_value = True
+ channel._client.im.v1.message.create.return_value = create_resp
+
+ await channel.send(OutboundMessage(
+ channel="feishu",
+ chat_id="oc_abc",
+ content="hello",
+ metadata={"message_id": "om_001"},
+ ))
+
+ channel._client.im.v1.message.create.assert_called_once()
+ channel._client.im.v1.message.reply.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_send_uses_create_api_when_no_message_id() -> None:
+ channel = _make_feishu_channel(reply_to_message=True)
+
+ create_resp = MagicMock()
+ create_resp.success.return_value = True
+ channel._client.im.v1.message.create.return_value = create_resp
+
+ await channel.send(OutboundMessage(
+ channel="feishu",
+ chat_id="oc_abc",
+ content="hello",
+ metadata={},
+ ))
+
+ channel._client.im.v1.message.create.assert_called_once()
+ channel._client.im.v1.message.reply.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_send_skips_reply_for_progress_messages() -> None:
+ channel = _make_feishu_channel(reply_to_message=True)
+
+ create_resp = MagicMock()
+ create_resp.success.return_value = True
+ channel._client.im.v1.message.create.return_value = create_resp
+
+ await channel.send(OutboundMessage(
+ channel="feishu",
+ chat_id="oc_abc",
+ content="thinking...",
+ metadata={"message_id": "om_001", "_progress": True},
+ ))
+
+ channel._client.im.v1.message.create.assert_called_once()
+ channel._client.im.v1.message.reply.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_send_fallback_to_create_when_reply_fails() -> None:
+ channel = _make_feishu_channel(reply_to_message=True)
+
+ reply_resp = MagicMock()
+ reply_resp.success.return_value = False
+ reply_resp.code = 400
+ reply_resp.msg = "error"
+ reply_resp.get_log_id.return_value = "log_x"
+ channel._client.im.v1.message.reply.return_value = reply_resp
+
+ create_resp = MagicMock()
+ create_resp.success.return_value = True
+ channel._client.im.v1.message.create.return_value = create_resp
+
+ await channel.send(OutboundMessage(
+ channel="feishu",
+ chat_id="oc_abc",
+ content="hello",
+ metadata={"message_id": "om_001"},
+ ))
+
+ # reply attempted first, then falls back to create
+ channel._client.im.v1.message.reply.assert_called_once()
+ channel._client.im.v1.message.create.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# _on_message — parent_id / root_id metadata tests
+# ---------------------------------------------------------------------------
+
+@pytest.mark.asyncio
+async def test_on_message_captures_parent_and_root_id_in_metadata() -> None:
+ channel = _make_feishu_channel()
+ channel._processed_message_ids.clear()
+ channel._client.im.v1.message.react.return_value = MagicMock(success=lambda: True)
+
+ captured = []
+
+ async def _capture(**kwargs):
+ captured.append(kwargs)
+
+ channel._handle_message = _capture
+
+ with patch.object(channel, "_add_reaction", return_value=None):
+ await channel._on_message(
+ _make_feishu_event(
+ parent_id="om_parent",
+ root_id="om_root",
+ )
+ )
+
+ assert len(captured) == 1
+ meta = captured[0]["metadata"]
+ assert meta["parent_id"] == "om_parent"
+ assert meta["root_id"] == "om_root"
+ assert meta["message_id"] == "om_001"
+
+
+@pytest.mark.asyncio
+async def test_on_message_parent_and_root_id_none_when_absent() -> None:
+ channel = _make_feishu_channel()
+ channel._processed_message_ids.clear()
+
+ captured = []
+
+ async def _capture(**kwargs):
+ captured.append(kwargs)
+
+ channel._handle_message = _capture
+
+ with patch.object(channel, "_add_reaction", return_value=None):
+ await channel._on_message(_make_feishu_event())
+
+ assert len(captured) == 1
+ meta = captured[0]["metadata"]
+ assert meta["parent_id"] is None
+ assert meta["root_id"] is None
+
+
+@pytest.mark.asyncio
+async def test_on_message_prepends_reply_context_when_parent_id_present() -> None:
+ channel = _make_feishu_channel()
+ channel._processed_message_ids.clear()
+ channel._client.im.v1.message.get.return_value = _make_get_message_response("original question")
+
+ captured = []
+
+ async def _capture(**kwargs):
+ captured.append(kwargs)
+
+ channel._handle_message = _capture
+
+ with patch.object(channel, "_add_reaction", return_value=None):
+ await channel._on_message(
+ _make_feishu_event(
+ content='{"text": "my answer"}',
+ parent_id="om_parent",
+ )
+ )
+
+ assert len(captured) == 1
+ content = captured[0]["content"]
+ assert content.startswith("[Reply to: original question]")
+ assert "my answer" in content
+
+
+@pytest.mark.asyncio
+async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
+ channel = _make_feishu_channel()
+ channel._processed_message_ids.clear()
+
+ captured = []
+
+ async def _capture(**kwargs):
+ captured.append(kwargs)
+
+ channel._handle_message = _capture
+
+ with patch.object(channel, "_add_reaction", return_value=None):
+ await channel._on_message(_make_feishu_event())
+
+ channel._client.im.v1.message.get.assert_not_called()
+ assert len(captured) == 1
diff --git a/tests/test_feishu_table_split.py b/tests/test_feishu_table_split.py
new file mode 100644
index 0000000..af8fa16
--- /dev/null
+++ b/tests/test_feishu_table_split.py
@@ -0,0 +1,104 @@
+"""Tests for FeishuChannel._split_elements_by_table_limit.
+
+Feishu cards reject messages that contain more than one table element
+(API error 11310: card table number over limit). The helper splits a flat
+list of card elements into groups so that each group contains at most one
+table, allowing nanobot to send multiple cards instead of failing.
+"""
+
+from nanobot.channels.feishu import FeishuChannel
+
+
+def _md(text: str) -> dict:
+ return {"tag": "markdown", "content": text}
+
+
+def _table() -> dict:
+ return {
+ "tag": "table",
+ "columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
+ "rows": [{"c0": "v"}],
+ "page_size": 2,
+ }
+
+
+split = FeishuChannel._split_elements_by_table_limit
+
+
+def test_empty_list_returns_single_empty_group() -> None:
+ assert split([]) == [[]]
+
+
+def test_no_tables_returns_single_group() -> None:
+ els = [_md("hello"), _md("world")]
+ result = split(els)
+ assert result == [els]
+
+
+def test_single_table_stays_in_one_group() -> None:
+ els = [_md("intro"), _table(), _md("outro")]
+ result = split(els)
+ assert len(result) == 1
+ assert result[0] == els
+
+
+def test_two_tables_split_into_two_groups() -> None:
+ # Use different row values so the two tables are not equal
+ t1 = {
+ "tag": "table",
+ "columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
+ "rows": [{"c0": "table-one"}],
+ "page_size": 2,
+ }
+ t2 = {
+ "tag": "table",
+ "columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
+ "rows": [{"c0": "table-two"}],
+ "page_size": 2,
+ }
+ els = [_md("before"), t1, _md("between"), t2, _md("after")]
+ result = split(els)
+ assert len(result) == 2
+ # First group: text before table-1 + table-1
+ assert t1 in result[0]
+ assert t2 not in result[0]
+ # Second group: text between tables + table-2 + text after
+ assert t2 in result[1]
+ assert t1 not in result[1]
+
+
+def test_three_tables_split_into_three_groups() -> None:
+ tables = [
+ {"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
+ for i in range(3)
+ ]
+ els = tables[:]
+ result = split(els)
+ assert len(result) == 3
+ for i, group in enumerate(result):
+ assert tables[i] in group
+
+
+def test_leading_markdown_stays_with_first_table() -> None:
+ intro = _md("intro")
+ t = _table()
+ result = split([intro, t])
+ assert len(result) == 1
+ assert result[0] == [intro, t]
+
+
+def test_trailing_markdown_after_second_table() -> None:
+ t1, t2 = _table(), _table()
+ tail = _md("end")
+ result = split([t1, t2, tail])
+ assert len(result) == 2
+ assert result[1] == [t2, tail]
+
+
+def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
+ head = _md("head")
+ t1, t2 = _table(), _table()
+ result = split([head, t1, t2])
+ # head + t1 in group 0; t2 in group 1
+ assert result[0] == [head, t1]
+ assert result[1] == [t2]
diff --git a/tests/test_feishu_tool_hint_code_block.py b/tests/test_feishu_tool_hint_code_block.py
new file mode 100644
index 0000000..2a1b812
--- /dev/null
+++ b/tests/test_feishu_tool_hint_code_block.py
@@ -0,0 +1,138 @@
+"""Tests for FeishuChannel tool hint code block formatting."""
+
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pytest import mark
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.channels.feishu import FeishuChannel
+
+
+@pytest.fixture
+def mock_feishu_channel():
+ """Create a FeishuChannel with mocked client."""
+ config = MagicMock()
+ config.app_id = "test_app_id"
+ config.app_secret = "test_app_secret"
+ config.encrypt_key = None
+ config.verification_token = None
+ bus = MagicMock()
+ channel = FeishuChannel(config, bus)
+ channel._client = MagicMock() # Simulate initialized client
+ return channel
+
+
+@mark.asyncio
+async def test_tool_hint_sends_code_message(mock_feishu_channel):
+ """Tool hint messages should be sent as interactive cards with code blocks."""
+ msg = OutboundMessage(
+ channel="feishu",
+ chat_id="oc_123456",
+ content='web_search("test query")',
+ metadata={"_tool_hint": True}
+ )
+
+ with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
+ await mock_feishu_channel.send(msg)
+
+ # Verify interactive message with card was sent
+ assert mock_send.call_count == 1
+ call_args = mock_send.call_args[0]
+ receive_id_type, receive_id, msg_type, content = call_args
+
+ assert receive_id_type == "chat_id"
+ assert receive_id == "oc_123456"
+ assert msg_type == "interactive"
+
+ # Parse content to verify card structure
+ card = json.loads(content)
+ assert card["config"]["wide_screen_mode"] is True
+ assert len(card["elements"]) == 1
+ assert card["elements"][0]["tag"] == "markdown"
+ # Check that code block is properly formatted with language hint
+ expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
+ assert card["elements"][0]["content"] == expected_md
+
+
+@mark.asyncio
+async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
+ """Empty tool hint messages should not be sent."""
+ msg = OutboundMessage(
+ channel="feishu",
+ chat_id="oc_123456",
+ content=" ", # whitespace only
+ metadata={"_tool_hint": True}
+ )
+
+ with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
+ await mock_feishu_channel.send(msg)
+
+ # Should not send any message
+ mock_send.assert_not_called()
+
+
+@mark.asyncio
+async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
+ """Regular messages without _tool_hint should use normal formatting."""
+ msg = OutboundMessage(
+ channel="feishu",
+ chat_id="oc_123456",
+ content="Hello, world!",
+ metadata={}
+ )
+
+ with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
+ await mock_feishu_channel.send(msg)
+
+ # Should send as text message (detected format)
+ assert mock_send.call_count == 1
+ call_args = mock_send.call_args[0]
+ _, _, msg_type, content = call_args
+ assert msg_type == "text"
+ assert json.loads(content) == {"text": "Hello, world!"}
+
+
+@mark.asyncio
+async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
+ """Multiple tool calls should be displayed each on its own line in a code block."""
+ msg = OutboundMessage(
+ channel="feishu",
+ chat_id="oc_123456",
+ content='web_search("query"), read_file("/path/to/file")',
+ metadata={"_tool_hint": True}
+ )
+
+ with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
+ await mock_feishu_channel.send(msg)
+
+ call_args = mock_send.call_args[0]
+ msg_type = call_args[2]
+ content = json.loads(call_args[3])
+ assert msg_type == "interactive"
+ # Each tool call should be on its own line
+ expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
+ assert content["elements"][0]["content"] == expected_md
+
+
+@mark.asyncio
+async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
+ """Commas inside a single tool argument must not be split onto a new line."""
+ msg = OutboundMessage(
+ channel="feishu",
+ chat_id="oc_123456",
+ content='web_search("foo, bar"), read_file("/path/to/file")',
+ metadata={"_tool_hint": True}
+ )
+
+ with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
+ await mock_feishu_channel.send(msg)
+
+ content = json.loads(mock_send.call_args[0][3])
+ expected_md = (
+ "**Tool Calls**\n\n```text\n"
+ "web_search(\"foo, bar\"),\n"
+ "read_file(\"/path/to/file\")\n```"
+ )
+ assert content["elements"][0]["content"] == expected_md
diff --git a/tests/test_filesystem_tools.py b/tests/test_filesystem_tools.py
new file mode 100644
index 0000000..620aa75
--- /dev/null
+++ b/tests/test_filesystem_tools.py
@@ -0,0 +1,364 @@
+"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
+
+import pytest
+
+from nanobot.agent.tools.filesystem import (
+ EditFileTool,
+ ListDirTool,
+ ReadFileTool,
+ _find_match,
+)
+
+
+# ---------------------------------------------------------------------------
+# ReadFileTool
+# ---------------------------------------------------------------------------
+
+class TestReadFileTool:
+
+ @pytest.fixture()
+ def tool(self, tmp_path):
+ return ReadFileTool(workspace=tmp_path)
+
+ @pytest.fixture()
+ def sample_file(self, tmp_path):
+ f = tmp_path / "sample.txt"
+ f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
+ return f
+
+ @pytest.mark.asyncio
+ async def test_basic_read_has_line_numbers(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file))
+ assert "1| line 1" in result
+ assert "20| line 20" in result
+
+ @pytest.mark.asyncio
+ async def test_offset_and_limit(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file), offset=5, limit=3)
+ assert "5| line 5" in result
+ assert "7| line 7" in result
+ assert "8| line 8" not in result
+ assert "Use offset=8 to continue" in result
+
+ @pytest.mark.asyncio
+ async def test_offset_beyond_end(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file), offset=999)
+ assert "Error" in result
+ assert "beyond end" in result
+
+ @pytest.mark.asyncio
+ async def test_end_of_file_marker(self, tool, sample_file):
+ result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
+ assert "End of file" in result
+
+ @pytest.mark.asyncio
+ async def test_empty_file(self, tool, tmp_path):
+ f = tmp_path / "empty.txt"
+ f.write_text("", encoding="utf-8")
+ result = await tool.execute(path=str(f))
+ assert "Empty file" in result
+
+ @pytest.mark.asyncio
+ async def test_file_not_found(self, tool, tmp_path):
+ result = await tool.execute(path=str(tmp_path / "nope.txt"))
+ assert "Error" in result
+ assert "not found" in result
+
+ @pytest.mark.asyncio
+ async def test_char_budget_trims(self, tool, tmp_path):
+ """When the selected slice exceeds _MAX_CHARS the output is trimmed."""
+ f = tmp_path / "big.txt"
+ # Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
+ f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
+ result = await tool.execute(path=str(f))
+ assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
+ assert "Use offset=" in result
+
+
+# ---------------------------------------------------------------------------
+# _find_match (unit tests for the helper)
+# ---------------------------------------------------------------------------
+
+class TestFindMatch:
+
+ def test_exact_match(self):
+ match, count = _find_match("hello world", "world")
+ assert match == "world"
+ assert count == 1
+
+ def test_exact_no_match(self):
+ match, count = _find_match("hello world", "xyz")
+ assert match is None
+ assert count == 0
+
+ def test_crlf_normalisation(self):
+ # Caller normalises CRLF before calling _find_match, so test with
+ # pre-normalised content to verify exact match still works.
+ content = "line1\nline2\nline3"
+ old_text = "line1\nline2\nline3"
+ match, count = _find_match(content, old_text)
+ assert match is not None
+ assert count == 1
+
+ def test_line_trim_fallback(self):
+ content = " def foo():\n pass\n"
+ old_text = "def foo():\n pass"
+ match, count = _find_match(content, old_text)
+ assert match is not None
+ assert count == 1
+ # The returned match should be the *original* indented text
+ assert " def foo():" in match
+
+ def test_line_trim_multiple_candidates(self):
+ content = " a\n b\n a\n b\n"
+ old_text = "a\nb"
+ match, count = _find_match(content, old_text)
+ assert count == 2
+
+ def test_empty_old_text(self):
+ match, count = _find_match("hello", "")
+ # Empty string is always "in" any string via exact match
+ assert match == ""
+
+
+# ---------------------------------------------------------------------------
+# EditFileTool
+# ---------------------------------------------------------------------------
+
+class TestEditFileTool:
+
+ @pytest.fixture()
+ def tool(self, tmp_path):
+ return EditFileTool(workspace=tmp_path)
+
+ @pytest.mark.asyncio
+ async def test_exact_match(self, tool, tmp_path):
+ f = tmp_path / "a.py"
+ f.write_text("hello world", encoding="utf-8")
+ result = await tool.execute(path=str(f), old_text="world", new_text="earth")
+ assert "Successfully" in result
+ assert f.read_text() == "hello earth"
+
+ @pytest.mark.asyncio
+ async def test_crlf_normalisation(self, tool, tmp_path):
+ f = tmp_path / "crlf.py"
+ f.write_bytes(b"line1\r\nline2\r\nline3")
+ result = await tool.execute(
+ path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
+ )
+ assert "Successfully" in result
+ raw = f.read_bytes()
+ assert b"LINE1" in raw
+ # CRLF line endings should be preserved throughout the file
+ assert b"\r\n" in raw
+
+ @pytest.mark.asyncio
+ async def test_trim_fallback(self, tool, tmp_path):
+ f = tmp_path / "indent.py"
+ f.write_text(" def foo():\n pass\n", encoding="utf-8")
+ result = await tool.execute(
+ path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
+ )
+ assert "Successfully" in result
+ assert "bar" in f.read_text()
+
+ @pytest.mark.asyncio
+ async def test_ambiguous_match(self, tool, tmp_path):
+ f = tmp_path / "dup.py"
+ f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
+ result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
+ assert "appears" in result.lower() or "Warning" in result
+
+ @pytest.mark.asyncio
+ async def test_replace_all(self, tool, tmp_path):
+ f = tmp_path / "multi.py"
+ f.write_text("foo bar foo bar foo", encoding="utf-8")
+ result = await tool.execute(
+ path=str(f), old_text="foo", new_text="baz", replace_all=True,
+ )
+ assert "Successfully" in result
+ assert f.read_text() == "baz bar baz bar baz"
+
+ @pytest.mark.asyncio
+ async def test_not_found(self, tool, tmp_path):
+ f = tmp_path / "nf.py"
+ f.write_text("hello", encoding="utf-8")
+ result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
+ assert "Error" in result
+ assert "not found" in result
+
+
+# ---------------------------------------------------------------------------
+# ListDirTool
+# ---------------------------------------------------------------------------
+
+class TestListDirTool:
+
+ @pytest.fixture()
+ def tool(self, tmp_path):
+ return ListDirTool(workspace=tmp_path)
+
+ @pytest.fixture()
+ def populated_dir(self, tmp_path):
+ (tmp_path / "src").mkdir()
+ (tmp_path / "src" / "main.py").write_text("pass")
+ (tmp_path / "src" / "utils.py").write_text("pass")
+ (tmp_path / "README.md").write_text("hi")
+ (tmp_path / ".git").mkdir()
+ (tmp_path / ".git" / "config").write_text("x")
+ (tmp_path / "node_modules").mkdir()
+ (tmp_path / "node_modules" / "pkg").mkdir()
+ return tmp_path
+
+ @pytest.mark.asyncio
+ async def test_basic_list(self, tool, populated_dir):
+ result = await tool.execute(path=str(populated_dir))
+ assert "README.md" in result
+ assert "src" in result
+ # .git and node_modules should be ignored
+ assert ".git" not in result
+ assert "node_modules" not in result
+
+ @pytest.mark.asyncio
+ async def test_recursive(self, tool, populated_dir):
+ result = await tool.execute(path=str(populated_dir), recursive=True)
+ # Normalize path separators for cross-platform compatibility
+ normalized = result.replace("\\", "/")
+ assert "src/main.py" in normalized
+ assert "src/utils.py" in normalized
+ assert "README.md" in result
+ # Ignored dirs should not appear
+ assert ".git" not in result
+ assert "node_modules" not in result
+
+ @pytest.mark.asyncio
+ async def test_max_entries_truncation(self, tool, tmp_path):
+ for i in range(10):
+ (tmp_path / f"file_{i}.txt").write_text("x")
+ result = await tool.execute(path=str(tmp_path), max_entries=3)
+ assert "truncated" in result
+ assert "3 of 10" in result
+
+ @pytest.mark.asyncio
+ async def test_empty_dir(self, tool, tmp_path):
+ d = tmp_path / "empty"
+ d.mkdir()
+ result = await tool.execute(path=str(d))
+ assert "empty" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_not_found(self, tool, tmp_path):
+ result = await tool.execute(path=str(tmp_path / "nope"))
+ assert "Error" in result
+ assert "not found" in result
+
+
+# ---------------------------------------------------------------------------
+# Workspace restriction + extra_allowed_dirs
+# ---------------------------------------------------------------------------
+
+class TestWorkspaceRestriction:
+
+ @pytest.mark.asyncio
+ async def test_read_blocked_outside_workspace(self, tmp_path):
+ workspace = tmp_path / "ws"
+ workspace.mkdir()
+ outside = tmp_path / "outside"
+ outside.mkdir()
+ secret = outside / "secret.txt"
+ secret.write_text("top secret")
+
+ tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
+ result = await tool.execute(path=str(secret))
+ assert "Error" in result
+ assert "outside" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_read_allowed_with_extra_dir(self, tmp_path):
+ workspace = tmp_path / "ws"
+ workspace.mkdir()
+ skills_dir = tmp_path / "skills"
+ skills_dir.mkdir()
+ skill_file = skills_dir / "test_skill" / "SKILL.md"
+ skill_file.parent.mkdir()
+ skill_file.write_text("# Test Skill\nDo something.")
+
+ tool = ReadFileTool(
+ workspace=workspace, allowed_dir=workspace,
+ extra_allowed_dirs=[skills_dir],
+ )
+ result = await tool.execute(path=str(skill_file))
+ assert "Test Skill" in result
+ assert "Error" not in result
+
+ @pytest.mark.asyncio
+ async def test_extra_dirs_does_not_widen_write(self, tmp_path):
+ from nanobot.agent.tools.filesystem import WriteFileTool
+
+ workspace = tmp_path / "ws"
+ workspace.mkdir()
+ outside = tmp_path / "outside"
+ outside.mkdir()
+
+ tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
+ result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
+ assert "Error" in result
+ assert "outside" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
+ workspace = tmp_path / "ws"
+ workspace.mkdir()
+ skills_dir = tmp_path / "skills"
+ skills_dir.mkdir()
+ unrelated = tmp_path / "other"
+ unrelated.mkdir()
+ secret = unrelated / "secret.txt"
+ secret.write_text("nope")
+
+ tool = ReadFileTool(
+ workspace=workspace, allowed_dir=workspace,
+ extra_allowed_dirs=[skills_dir],
+ )
+ result = await tool.execute(path=str(secret))
+ assert "Error" in result
+ assert "outside" in result.lower()
+
+ @pytest.mark.asyncio
+ async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
+ """Adding extra_allowed_dirs must not break normal workspace reads."""
+ workspace = tmp_path / "ws"
+ workspace.mkdir()
+ ws_file = workspace / "README.md"
+ ws_file.write_text("hello from workspace")
+ skills_dir = tmp_path / "skills"
+ skills_dir.mkdir()
+
+ tool = ReadFileTool(
+ workspace=workspace, allowed_dir=workspace,
+ extra_allowed_dirs=[skills_dir],
+ )
+ result = await tool.execute(path=str(ws_file))
+ assert "hello from workspace" in result
+ assert "Error" not in result
+
+ @pytest.mark.asyncio
+ async def test_edit_blocked_in_extra_dir(self, tmp_path):
+ """edit_file must not be able to modify files in extra_allowed_dirs."""
+ workspace = tmp_path / "ws"
+ workspace.mkdir()
+ skills_dir = tmp_path / "skills"
+ skills_dir.mkdir()
+ skill_file = skills_dir / "weather" / "SKILL.md"
+ skill_file.parent.mkdir()
+ skill_file.write_text("# Weather\nOriginal content.")
+
+ tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
+ result = await tool.execute(
+ path=str(skill_file),
+ old_text="Original content.",
+ new_text="Hacked content.",
+ )
+ assert "Error" in result
+ assert "outside" in result.lower()
+ assert skill_file.read_text() == "# Weather\nOriginal content."
diff --git a/tests/test_gemini_thought_signature.py b/tests/test_gemini_thought_signature.py
new file mode 100644
index 0000000..bc4132c
--- /dev/null
+++ b/tests/test_gemini_thought_signature.py
@@ -0,0 +1,53 @@
+from types import SimpleNamespace
+
+from nanobot.providers.base import ToolCallRequest
+from nanobot.providers.litellm_provider import LiteLLMProvider
+
+
+def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
+ provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
+
+ response = SimpleNamespace(
+ choices=[
+ SimpleNamespace(
+ finish_reason="tool_calls",
+ message=SimpleNamespace(
+ content=None,
+ tool_calls=[
+ SimpleNamespace(
+ id="call_123",
+ function=SimpleNamespace(
+ name="read_file",
+ arguments='{"path":"todo.md"}',
+ provider_specific_fields={"inner": "value"},
+ ),
+ provider_specific_fields={"thought_signature": "signed-token"},
+ )
+ ],
+ ),
+ )
+ ],
+ usage=None,
+ )
+
+ parsed = provider._parse_response(response)
+
+ assert len(parsed.tool_calls) == 1
+ assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
+ assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
+
+
+def test_tool_call_request_serializes_provider_fields() -> None:
+ tool_call = ToolCallRequest(
+ id="abc123xyz",
+ name="read_file",
+ arguments={"path": "todo.md"},
+ provider_specific_fields={"thought_signature": "signed-token"},
+ function_provider_specific_fields={"inner": "value"},
+ )
+
+ message = tool_call.to_openai_tool_call()
+
+ assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
+ assert message["function"]["provider_specific_fields"] == {"inner": "value"}
+ assert message["function"]["arguments"] == '{"path": "todo.md"}'
diff --git a/tests/test_heartbeat_service.py b/tests/test_heartbeat_service.py
index ec91c6b..8f563cf 100644
--- a/tests/test_heartbeat_service.py
+++ b/tests/test_heartbeat_service.py
@@ -2,34 +2,34 @@ import asyncio
import pytest
-from nanobot.heartbeat.service import (
- HEARTBEAT_OK_TOKEN,
- HeartbeatService,
-)
+from nanobot.heartbeat.service import HeartbeatService
+from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
-def test_heartbeat_ok_detection() -> None:
- def is_ok(response: str) -> bool:
- return HEARTBEAT_OK_TOKEN in response.upper()
+class DummyProvider(LLMProvider):
+ def __init__(self, responses: list[LLMResponse]):
+ super().__init__()
+ self._responses = list(responses)
+ self.calls = 0
- assert is_ok("HEARTBEAT_OK")
- assert is_ok("`HEARTBEAT_OK`")
- assert is_ok("**HEARTBEAT_OK**")
- assert is_ok("heartbeat_ok")
- assert is_ok("HEARTBEAT_OK.")
+ async def chat(self, *args, **kwargs) -> LLMResponse:
+ self.calls += 1
+ if self._responses:
+ return self._responses.pop(0)
+ return LLMResponse(content="", tool_calls=[])
- assert not is_ok("HEARTBEAT_NOT_OK")
- assert not is_ok("all good")
+ def get_default_model(self) -> str:
+ return "test-model"
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None:
- async def _on_heartbeat(_: str) -> str:
- return "HEARTBEAT_OK"
+ provider = DummyProvider([])
service = HeartbeatService(
workspace=tmp_path,
- on_heartbeat=_on_heartbeat,
+ provider=provider,
+ model="openai/gpt-4o-mini",
interval_s=9999,
enabled=True,
)
@@ -42,3 +42,248 @@ async def test_start_is_idempotent(tmp_path) -> None:
service.stop()
await asyncio.sleep(0)
+
+
+@pytest.mark.asyncio
+async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
+ provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ )
+
+ action, tasks = await service._decide("heartbeat content")
+ assert action == "skip"
+ assert tasks == ""
+
+
+@pytest.mark.asyncio
+async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
+ (tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
+
+ provider = DummyProvider([
+ LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1",
+ name="heartbeat",
+ arguments={"action": "run", "tasks": "check open tasks"},
+ )
+ ],
+ )
+ ])
+
+ called_with: list[str] = []
+
+ async def _on_execute(tasks: str) -> str:
+ called_with.append(tasks)
+ return "done"
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ on_execute=_on_execute,
+ )
+
+ result = await service.trigger_now()
+ assert result == "done"
+ assert called_with == ["check open tasks"]
+
+
+@pytest.mark.asyncio
+async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
+ (tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
+
+ provider = DummyProvider([
+ LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1",
+ name="heartbeat",
+ arguments={"action": "skip"},
+ )
+ ],
+ )
+ ])
+
+ async def _on_execute(tasks: str) -> str:
+ return tasks
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ on_execute=_on_execute,
+ )
+
+ assert await service.trigger_now() is None
+
+
+@pytest.mark.asyncio
+async def test_tick_notifies_when_evaluator_says_yes(tmp_path, monkeypatch) -> None:
+ """Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=notify -> on_notify called."""
+ (tmp_path / "HEARTBEAT.md").write_text("- [ ] check deployments", encoding="utf-8")
+
+ provider = DummyProvider([
+ LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1",
+ name="heartbeat",
+ arguments={"action": "run", "tasks": "check deployments"},
+ )
+ ],
+ ),
+ ])
+
+ executed: list[str] = []
+ notified: list[str] = []
+
+ async def _on_execute(tasks: str) -> str:
+ executed.append(tasks)
+ return "deployment failed on staging"
+
+ async def _on_notify(response: str) -> None:
+ notified.append(response)
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ on_execute=_on_execute,
+ on_notify=_on_notify,
+ )
+
+ async def _eval_notify(*a, **kw):
+ return True
+
+ monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_notify)
+
+ await service._tick()
+ assert executed == ["check deployments"]
+ assert notified == ["deployment failed on staging"]
+
+
+@pytest.mark.asyncio
+async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) -> None:
+ """Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=silent -> on_notify NOT called."""
+ (tmp_path / "HEARTBEAT.md").write_text("- [ ] check status", encoding="utf-8")
+
+ provider = DummyProvider([
+ LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1",
+ name="heartbeat",
+ arguments={"action": "run", "tasks": "check status"},
+ )
+ ],
+ ),
+ ])
+
+ executed: list[str] = []
+ notified: list[str] = []
+
+ async def _on_execute(tasks: str) -> str:
+ executed.append(tasks)
+ return "everything is fine, no issues"
+
+ async def _on_notify(response: str) -> None:
+ notified.append(response)
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ on_execute=_on_execute,
+ on_notify=_on_notify,
+ )
+
+ async def _eval_silent(*a, **kw):
+ return False
+
+ monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_silent)
+
+ await service._tick()
+ assert executed == ["check status"]
+ assert notified == []
+
+
+@pytest.mark.asyncio
+async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
+ provider = DummyProvider([
+ LLMResponse(content="429 rate limit", finish_reason="error"),
+ LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1",
+ name="heartbeat",
+ arguments={"action": "run", "tasks": "check open tasks"},
+ )
+ ],
+ ),
+ ])
+
+ delays: list[int] = []
+
+ async def _fake_sleep(delay: int) -> None:
+ delays.append(delay)
+
+ monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=provider,
+ model="openai/gpt-4o-mini",
+ )
+
+ action, tasks = await service._decide("heartbeat content")
+
+ assert action == "run"
+ assert tasks == "check open tasks"
+ assert provider.calls == 2
+ assert delays == [1]
+
+
+@pytest.mark.asyncio
+async def test_decide_prompt_includes_current_time(tmp_path) -> None:
+ """Phase 1 user prompt must contain current time so the LLM can judge task urgency."""
+
+ captured_messages: list[dict] = []
+
+ class CapturingProvider(LLMProvider):
+ async def chat(self, *, messages=None, **kwargs) -> LLMResponse:
+ if messages:
+ captured_messages.extend(messages)
+ return LLMResponse(
+ content="",
+ tool_calls=[
+ ToolCallRequest(
+ id="hb_1", name="heartbeat",
+ arguments={"action": "skip"},
+ )
+ ],
+ )
+
+ def get_default_model(self) -> str:
+ return "test-model"
+
+ service = HeartbeatService(
+ workspace=tmp_path,
+ provider=CapturingProvider(),
+ model="test-model",
+ )
+
+ await service._decide("- [ ] check servers at 10:00 UTC")
+
+ user_msg = captured_messages[1]
+ assert user_msg["role"] == "user"
+ assert "Current Time:" in user_msg["content"]
+
diff --git a/tests/test_litellm_kwargs.py b/tests/test_litellm_kwargs.py
new file mode 100644
index 0000000..437f8a5
--- /dev/null
+++ b/tests/test_litellm_kwargs.py
@@ -0,0 +1,161 @@
+"""Regression tests for PR #2026 — litellm_kwargs injection from ProviderSpec.
+
+Validates that:
+- OpenRouter uses litellm_prefix (NOT custom_llm_provider) to avoid LiteLLM double-prefixing.
+- The litellm_kwargs mechanism works correctly for providers that declare it.
+- Non-gateway providers are unaffected.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from typing import Any
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from nanobot.providers.litellm_provider import LiteLLMProvider
+from nanobot.providers.registry import find_by_name
+
+
+def _fake_response(content: str = "ok") -> SimpleNamespace:
+ """Build a minimal acompletion-shaped response object."""
+ message = SimpleNamespace(
+ content=content,
+ tool_calls=None,
+ reasoning_content=None,
+ thinking_blocks=None,
+ )
+ choice = SimpleNamespace(message=message, finish_reason="stop")
+ usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
+ return SimpleNamespace(choices=[choice], usage=usage)
+
+
+def test_openrouter_spec_uses_prefix_not_custom_llm_provider() -> None:
+ """OpenRouter must rely on litellm_prefix, not custom_llm_provider kwarg.
+
+ LiteLLM internally adds a provider/ prefix when custom_llm_provider is set,
+ which double-prefixes models (openrouter/anthropic/model) and breaks the API.
+ """
+ spec = find_by_name("openrouter")
+ assert spec is not None
+ assert spec.litellm_prefix == "openrouter"
+ assert "custom_llm_provider" not in spec.litellm_kwargs, (
+ "custom_llm_provider causes LiteLLM to double-prefix the model name"
+ )
+
+
+@pytest.mark.asyncio
+async def test_openrouter_prefixes_model_correctly() -> None:
+ """OpenRouter should prefix model as openrouter/vendor/model for LiteLLM routing."""
+ mock_acompletion = AsyncMock(return_value=_fake_response())
+
+ with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
+ provider = LiteLLMProvider(
+ api_key="sk-or-test-key",
+ api_base="https://openrouter.ai/api/v1",
+ default_model="anthropic/claude-sonnet-4-5",
+ provider_name="openrouter",
+ )
+ await provider.chat(
+ messages=[{"role": "user", "content": "hello"}],
+ model="anthropic/claude-sonnet-4-5",
+ )
+
+ call_kwargs = mock_acompletion.call_args.kwargs
+ assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
+ "LiteLLM needs openrouter/ prefix to detect the provider and strip it before API call"
+ )
+ assert "custom_llm_provider" not in call_kwargs
+
+
+@pytest.mark.asyncio
+async def test_non_gateway_provider_no_extra_kwargs() -> None:
+ """Standard (non-gateway) providers must NOT inject any litellm_kwargs."""
+ mock_acompletion = AsyncMock(return_value=_fake_response())
+
+ with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
+ provider = LiteLLMProvider(
+ api_key="sk-ant-test-key",
+ default_model="claude-sonnet-4-5",
+ )
+ await provider.chat(
+ messages=[{"role": "user", "content": "hello"}],
+ model="claude-sonnet-4-5",
+ )
+
+ call_kwargs = mock_acompletion.call_args.kwargs
+ assert "custom_llm_provider" not in call_kwargs, (
+ "Standard Anthropic provider should NOT inject custom_llm_provider"
+ )
+
+
+@pytest.mark.asyncio
+async def test_gateway_without_litellm_kwargs_injects_nothing_extra() -> None:
+ """Gateways without litellm_kwargs (e.g. AiHubMix) must not add extra keys."""
+ mock_acompletion = AsyncMock(return_value=_fake_response())
+
+ with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
+ provider = LiteLLMProvider(
+ api_key="sk-aihub-test-key",
+ api_base="https://aihubmix.com/v1",
+ default_model="claude-sonnet-4-5",
+ provider_name="aihubmix",
+ )
+ await provider.chat(
+ messages=[{"role": "user", "content": "hello"}],
+ model="claude-sonnet-4-5",
+ )
+
+ call_kwargs = mock_acompletion.call_args.kwargs
+ assert "custom_llm_provider" not in call_kwargs
+
+
+@pytest.mark.asyncio
+async def test_openrouter_autodetect_by_key_prefix() -> None:
+ """OpenRouter should be auto-detected by sk-or- key prefix even without explicit provider_name."""
+ mock_acompletion = AsyncMock(return_value=_fake_response())
+
+ with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
+ provider = LiteLLMProvider(
+ api_key="sk-or-auto-detect-key",
+ default_model="anthropic/claude-sonnet-4-5",
+ )
+ await provider.chat(
+ messages=[{"role": "user", "content": "hello"}],
+ model="anthropic/claude-sonnet-4-5",
+ )
+
+ call_kwargs = mock_acompletion.call_args.kwargs
+ assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
+ "Auto-detected OpenRouter should prefix model for LiteLLM routing"
+ )
+
+
+@pytest.mark.asyncio
+async def test_openrouter_native_model_id_gets_double_prefixed() -> None:
+ """Models like openrouter/free must be double-prefixed so LiteLLM strips one layer.
+
+ openrouter/free is an actual OpenRouter model ID. LiteLLM strips the first
+ openrouter/ for routing, so we must send openrouter/openrouter/free to ensure
+ the API receives openrouter/free.
+ """
+ mock_acompletion = AsyncMock(return_value=_fake_response())
+
+ with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
+ provider = LiteLLMProvider(
+ api_key="sk-or-test-key",
+ api_base="https://openrouter.ai/api/v1",
+ default_model="openrouter/free",
+ provider_name="openrouter",
+ )
+ await provider.chat(
+ messages=[{"role": "user", "content": "hello"}],
+ model="openrouter/free",
+ )
+
+ call_kwargs = mock_acompletion.call_args.kwargs
+ assert call_kwargs["model"] == "openrouter/openrouter/free", (
+ "openrouter/free must become openrouter/openrouter/free — "
+ "LiteLLM strips one layer so the API receives openrouter/free"
+ )
diff --git a/tests/test_loop_consolidation_tokens.py b/tests/test_loop_consolidation_tokens.py
new file mode 100644
index 0000000..b0f3dda
--- /dev/null
+++ b/tests/test_loop_consolidation_tokens.py
@@ -0,0 +1,190 @@
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
+from nanobot.agent.loop import AgentLoop
+import nanobot.agent.memory as memory_module
+from nanobot.bus.queue import MessageBus
+from nanobot.providers.base import LLMResponse
+
+
+def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
+ provider = MagicMock()
+ provider.get_default_model.return_value = "test-model"
+ provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
+ provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
+
+ loop = AgentLoop(
+ bus=MessageBus(),
+ provider=provider,
+ workspace=tmp_path,
+ model="test-model",
+ context_window_tokens=context_window_tokens,
+ )
+ loop.tools.get_definitions = MagicMock(return_value=[])
+ return loop
+
+
+@pytest.mark.asyncio
+async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
+ loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
+ loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
+
+ await loop.process_direct("hello", session_key="cli:test")
+
+ loop.memory_consolidator.consolidate_messages.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
+ loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
+ loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
+ session = loop.sessions.get_or_create("cli:test")
+ session.messages = [
+ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
+ {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
+ {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
+ ]
+ loop.sessions.save(session)
+ monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
+
+ await loop.process_direct("hello", session_key="cli:test")
+
+ assert loop.memory_consolidator.consolidate_messages.await_count >= 1
+
+
+@pytest.mark.asyncio
+async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
+ loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
+ loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
+
+ session = loop.sessions.get_or_create("cli:test")
+ session.messages = [
+ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
+ {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
+ {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
+ {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
+ {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
+ ]
+ loop.sessions.save(session)
+
+ token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
+ monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
+
+ await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
+
+ archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
+ assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
+ assert session.last_consolidated == 4
+
+
+@pytest.mark.asyncio
+async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
+ """Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
+ loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
+ loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
+
+ session = loop.sessions.get_or_create("cli:test")
+ session.messages = [
+ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
+ {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
+ {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
+ {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
+ {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
+ {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
+ {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
+ ]
+ loop.sessions.save(session)
+
+ call_count = [0]
+ def mock_estimate(_session):
+ call_count[0] += 1
+ if call_count[0] == 1:
+ return (500, "test")
+ if call_count[0] == 2:
+ return (300, "test")
+ return (80, "test")
+
+ loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
+ monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
+
+ await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
+
+ assert loop.memory_consolidator.consolidate_messages.await_count == 2
+ assert session.last_consolidated == 6
+
+
+@pytest.mark.asyncio
+async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
+ """Once triggered, consolidation should continue until it drops below half threshold."""
+ loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
+ loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
+
+ session = loop.sessions.get_or_create("cli:test")
+ session.messages = [
+ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
+ {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
+ {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
+ {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
+ {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
+ {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
+ {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
+ ]
+ loop.sessions.save(session)
+
+ call_count = [0]
+
+ def mock_estimate(_session):
+ call_count[0] += 1
+ if call_count[0] == 1:
+ return (500, "test")
+ if call_count[0] == 2:
+ return (150, "test")
+ return (80, "test")
+
+ loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
+ monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
+
+ await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
+
+ assert loop.memory_consolidator.consolidate_messages.await_count == 2
+ assert session.last_consolidated == 6
+
+
+@pytest.mark.asyncio
+async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
+ """Verify preflight consolidation runs before the LLM call in process_direct."""
+ order: list[str] = []
+
+ loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
+
+ async def track_consolidate(messages):
+ order.append("consolidate")
+ return True
+ loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
+
+ async def track_llm(*args, **kwargs):
+ order.append("llm")
+ return LLMResponse(content="ok", tool_calls=[])
+ loop.provider.chat_with_retry = track_llm
+
+ session = loop.sessions.get_or_create("cli:test")
+ session.messages = [
+ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
+ {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
+ {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
+ ]
+ loop.sessions.save(session)
+ monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
+
+ call_count = [0]
+ def mock_estimate(_session):
+ call_count[0] += 1
+ return (1000 if call_count[0] <= 1 else 80, "test")
+ loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
+
+ await loop.process_direct("hello", session_key="cli:test")
+
+ assert "consolidate" in order
+ assert "llm" in order
+ assert order.index("consolidate") < order.index("llm")
diff --git a/tests/test_loop_save_turn.py b/tests/test_loop_save_turn.py
new file mode 100644
index 0000000..25ba88b
--- /dev/null
+++ b/tests/test_loop_save_turn.py
@@ -0,0 +1,55 @@
+from nanobot.agent.context import ContextBuilder
+from nanobot.agent.loop import AgentLoop
+from nanobot.session.manager import Session
+
+
+def _mk_loop() -> AgentLoop:
+ loop = AgentLoop.__new__(AgentLoop)
+ loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
+ return loop
+
+
+def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
+ loop = _mk_loop()
+ session = Session(key="test:runtime-only")
+ runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
+
+ loop._save_turn(
+ session,
+ [{"role": "user", "content": [{"type": "text", "text": runtime}]}],
+ skip=0,
+ )
+ assert session.messages == []
+
+
+def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
+ loop = _mk_loop()
+ session = Session(key="test:image")
+ runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
+
+ loop._save_turn(
+ session,
+ [{
+ "role": "user",
+ "content": [
+ {"type": "text", "text": runtime},
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
+ ],
+ }],
+ skip=0,
+ )
+ assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
+
+
+def test_save_turn_keeps_tool_results_under_16k() -> None:
+ loop = _mk_loop()
+ session = Session(key="test:tool-result")
+ content = "x" * 12_000
+
+ loop._save_turn(
+ session,
+ [{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
+ skip=0,
+ )
+
+ assert session.messages[0]["content"] == content
diff --git a/tests/test_matrix_channel.py b/tests/test_matrix_channel.py
new file mode 100644
index 0000000..1f3b69c
--- /dev/null
+++ b/tests/test_matrix_channel.py
@@ -0,0 +1,1318 @@
+import asyncio
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+
+import nanobot.channels.matrix as matrix_module
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.matrix import (
+ MATRIX_HTML_FORMAT,
+ TYPING_NOTICE_TIMEOUT_MS,
+ MatrixChannel,
+)
+from nanobot.channels.matrix import MatrixConfig
+
+_ROOM_SEND_UNSET = object()
+
+
+class _DummyTask:
+ def __init__(self) -> None:
+ self.cancelled = False
+
+ def cancel(self) -> None:
+ self.cancelled = True
+
+ def __await__(self):
+ async def _done():
+ return None
+
+ return _done().__await__()
+
+
+class _FakeAsyncClient:
+ def __init__(self, homeserver, user, store_path, config) -> None:
+ self.homeserver = homeserver
+ self.user = user
+ self.store_path = store_path
+ self.config = config
+ self.user_id: str | None = None
+ self.access_token: str | None = None
+ self.device_id: str | None = None
+ self.load_store_called = False
+ self.stop_sync_forever_called = False
+ self.join_calls: list[str] = []
+ self.callbacks: list[tuple[object, object]] = []
+ self.response_callbacks: list[tuple[object, object]] = []
+ self.rooms: dict[str, object] = {}
+ self.room_send_calls: list[dict[str, object]] = []
+ self.typing_calls: list[tuple[str, bool, int]] = []
+ self.download_calls: list[dict[str, object]] = []
+ self.upload_calls: list[dict[str, object]] = []
+ self.download_response: object | None = None
+ self.download_bytes: bytes = b"media"
+ self.download_content_type: str = "application/octet-stream"
+ self.download_filename: str | None = None
+ self.upload_response: object | None = None
+ self.content_repository_config_response: object = SimpleNamespace(upload_size=None)
+ self.raise_on_send = False
+ self.raise_on_typing = False
+ self.raise_on_upload = False
+
+ def add_event_callback(self, callback, event_type) -> None:
+ self.callbacks.append((callback, event_type))
+
+ def add_response_callback(self, callback, response_type) -> None:
+ self.response_callbacks.append((callback, response_type))
+
+ def load_store(self) -> None:
+ self.load_store_called = True
+
+ def stop_sync_forever(self) -> None:
+ self.stop_sync_forever_called = True
+
+ async def join(self, room_id: str) -> None:
+ self.join_calls.append(room_id)
+
+ async def room_send(
+ self,
+ room_id: str,
+ message_type: str,
+ content: dict[str, object],
+ ignore_unverified_devices: object = _ROOM_SEND_UNSET,
+ ) -> None:
+ call: dict[str, object] = {
+ "room_id": room_id,
+ "message_type": message_type,
+ "content": content,
+ }
+ if ignore_unverified_devices is not _ROOM_SEND_UNSET:
+ call["ignore_unverified_devices"] = ignore_unverified_devices
+ self.room_send_calls.append(call)
+ if self.raise_on_send:
+ raise RuntimeError("send failed")
+
+ async def room_typing(
+ self,
+ room_id: str,
+ typing_state: bool = True,
+ timeout: int = 30_000,
+ ) -> None:
+ self.typing_calls.append((room_id, typing_state, timeout))
+ if self.raise_on_typing:
+ raise RuntimeError("typing failed")
+
+ async def download(self, **kwargs):
+ self.download_calls.append(kwargs)
+ if self.download_response is not None:
+ return self.download_response
+ return matrix_module.MemoryDownloadResponse(
+ body=self.download_bytes,
+ content_type=self.download_content_type,
+ filename=self.download_filename,
+ )
+
+ async def upload(
+ self,
+ data_provider,
+ content_type: str | None = None,
+ filename: str | None = None,
+ filesize: int | None = None,
+ encrypt: bool = False,
+ ):
+ if self.raise_on_upload:
+ raise RuntimeError("upload failed")
+ if isinstance(data_provider, (bytes, bytearray)):
+ raise TypeError(
+ f"data_provider type {type(data_provider)!r} is not of a usable type "
+ "(Callable, IOBase)"
+ )
+ self.upload_calls.append(
+ {
+ "data_provider": data_provider,
+ "content_type": content_type,
+ "filename": filename,
+ "filesize": filesize,
+ "encrypt": encrypt,
+ }
+ )
+ if self.upload_response is not None:
+ return self.upload_response
+ if encrypt:
+ return (
+ SimpleNamespace(content_uri="mxc://example.org/uploaded"),
+ {
+ "v": "v2",
+ "iv": "iv",
+ "hashes": {"sha256": "hash"},
+ "key": {"alg": "A256CTR", "k": "key"},
+ },
+ )
+ return SimpleNamespace(content_uri="mxc://example.org/uploaded"), None
+
+ async def content_repository_config(self):
+ return self.content_repository_config_response
+
+ async def close(self) -> None:
+ return None
+
+
+def _make_config(**kwargs) -> MatrixConfig:
+ kwargs.setdefault("allow_from", ["*"])
+ return MatrixConfig(
+ enabled=True,
+ homeserver="https://matrix.org",
+ access_token="token",
+ user_id="@bot:matrix.org",
+ **kwargs,
+ )
+
+
+@pytest.mark.asyncio
+async def test_start_skips_load_store_when_device_id_missing(
+ monkeypatch, tmp_path
+) -> None:
+ clients: list[_FakeAsyncClient] = []
+
+ def _fake_client(*args, **kwargs) -> _FakeAsyncClient:
+ client = _FakeAsyncClient(*args, **kwargs)
+ clients.append(client)
+ return client
+
+ def _fake_create_task(coro):
+ coro.close()
+ return _DummyTask()
+
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.AsyncClientConfig",
+ lambda **kwargs: SimpleNamespace(**kwargs),
+ )
+ monkeypatch.setattr("nanobot.channels.matrix.AsyncClient", _fake_client)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.asyncio.create_task", _fake_create_task
+ )
+
+ channel = MatrixChannel(_make_config(device_id=""), MessageBus())
+ await channel.start()
+
+ assert len(clients) == 1
+ assert clients[0].config.encryption_enabled is True
+ assert clients[0].load_store_called is False
+ assert len(clients[0].callbacks) == 3
+ assert len(clients[0].response_callbacks) == 3
+
+ await channel.stop()
+
+
+@pytest.mark.asyncio
+async def test_register_event_callbacks_uses_media_base_filter() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ channel._register_event_callbacks()
+
+ assert len(client.callbacks) == 3
+ assert client.callbacks[1][0] == channel._on_media_message
+ assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
+
+
+def test_media_event_filter_does_not_match_text_events() -> None:
+ assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
+
+
+@pytest.mark.asyncio
+async def test_start_disables_e2ee_when_configured(
+ monkeypatch, tmp_path
+) -> None:
+ clients: list[_FakeAsyncClient] = []
+
+ def _fake_client(*args, **kwargs) -> _FakeAsyncClient:
+ client = _FakeAsyncClient(*args, **kwargs)
+ clients.append(client)
+ return client
+
+ def _fake_create_task(coro):
+ coro.close()
+ return _DummyTask()
+
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.AsyncClientConfig",
+ lambda **kwargs: SimpleNamespace(**kwargs),
+ )
+ monkeypatch.setattr("nanobot.channels.matrix.AsyncClient", _fake_client)
+ monkeypatch.setattr(
+ "nanobot.channels.matrix.asyncio.create_task", _fake_create_task
+ )
+
+ channel = MatrixChannel(_make_config(device_id="", e2ee_enabled=False), MessageBus())
+ await channel.start()
+
+ assert len(clients) == 1
+ assert clients[0].config.encryption_enabled is False
+
+ await channel.stop()
+
+
+@pytest.mark.asyncio
+async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None:
+ channel = MatrixChannel(_make_config(device_id="DEVICE"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ task = _DummyTask()
+
+ channel.client = client
+ channel._sync_task = task
+ channel._running = True
+
+ await channel.stop()
+
+ assert channel._running is False
+ assert client.stop_sync_forever_called is True
+ assert task.cancelled is False
+
+
+@pytest.mark.asyncio
+async def test_room_invite_ignores_when_allow_list_is_empty() -> None:
+ channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org")
+ event = SimpleNamespace(sender="@alice:matrix.org")
+
+ await channel._on_room_invite(room, event)
+
+ assert client.join_calls == []
+
+
+@pytest.mark.asyncio
+async def test_room_invite_joins_when_sender_allowed() -> None:
+ channel = MatrixChannel(_make_config(allow_from=["@alice:matrix.org"]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org")
+ event = SimpleNamespace(sender="@alice:matrix.org")
+
+ await channel._on_room_invite(room, event)
+
+ assert client.join_calls == ["!room:matrix.org"]
+
+@pytest.mark.asyncio
+async def test_room_invite_respects_allow_list_when_configured() -> None:
+ channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org")
+ event = SimpleNamespace(sender="@alice:matrix.org")
+
+ await channel._on_room_invite(room, event)
+
+ assert client.join_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_sets_typing_for_allowed_sender() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={})
+
+ await channel._on_message(room, event)
+
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [
+ ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS),
+ ]
+
+
+@pytest.mark.asyncio
+async def test_typing_keepalive_refreshes_periodically(monkeypatch) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._running = True
+
+ monkeypatch.setattr(matrix_module, "TYPING_KEEPALIVE_INTERVAL_MS", 10)
+
+ await channel._start_typing_keepalive("!room:matrix.org")
+ await asyncio.sleep(0.03)
+ await channel._stop_typing_keepalive("!room:matrix.org", clear_typing=True)
+
+ true_updates = [call for call in client.typing_calls if call[1] is True]
+ assert len(true_updates) >= 2
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_on_message_skips_typing_for_self_message() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
+ event = SimpleNamespace(sender="@bot:matrix.org", body="Hello", source={})
+
+ await channel._on_message(room, event)
+
+ assert client.typing_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_skips_typing_for_denied_sender() -> None:
+ channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={})
+
+ await channel._on_message(room, event)
+
+ assert handled == []
+ assert client.typing_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_mention_policy_requires_mx_mentions() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}})
+
+ await channel._on_message(room, event)
+
+ assert handled == []
+ assert client.typing_calls == []
+
+
+@pytest.mark.asyncio
+async def test_on_message_mention_policy_accepts_bot_user_mentions() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="Hello",
+ source={"content": {"m.mentions": {"user_ids": ["@bot:matrix.org"]}}},
+ )
+
+ await channel._on_message(room, event)
+
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_mention_policy_allows_direct_room_without_mentions() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!dm:matrix.org", display_name="DM", member_count=2)
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}})
+
+ await channel._on_message(room, event)
+
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [("!dm:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_allowlist_policy_requires_room_id() -> None:
+ channel = MatrixChannel(
+ _make_config(group_policy="allowlist", group_allow_from=["!allowed:matrix.org"]),
+ MessageBus(),
+ )
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["chat_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ denied_room = SimpleNamespace(room_id="!denied:matrix.org", display_name="Denied", member_count=3)
+ event = SimpleNamespace(sender="@alice:matrix.org", body="Hello", source={"content": {}})
+ await channel._on_message(denied_room, event)
+
+ allowed_room = SimpleNamespace(
+ room_id="!allowed:matrix.org",
+ display_name="Allowed",
+ member_count=3,
+ )
+ await channel._on_message(allowed_room, event)
+
+ assert handled == ["!allowed:matrix.org"]
+ assert client.typing_calls == [("!allowed:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_room_mention_requires_opt_in() -> None:
+ channel = MatrixChannel(_make_config(group_policy="mention"), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[str] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs["sender_id"])
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ room_mention_event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="Hello everyone",
+ source={"content": {"m.mentions": {"room": True}}},
+ )
+
+ await channel._on_message(room, room_mention_event)
+ assert handled == []
+ assert client.typing_calls == []
+
+ channel.config.allow_room_mentions = True
+ await channel._on_message(room, room_mention_event)
+ assert handled == ["@alice:matrix.org"]
+ assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+
+@pytest.mark.asyncio
+async def test_on_message_sets_thread_metadata_when_threaded_event() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=3)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="Hello",
+ event_id="$reply1",
+ source={
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ }
+ }
+ },
+ )
+
+ await channel._on_message(room, event)
+
+ assert len(handled) == 1
+ metadata = handled[0]["metadata"]
+ assert metadata["thread_root_event_id"] == "$root1"
+ assert metadata["thread_reply_to_event_id"] == "$reply1"
+ assert metadata["event_id"] == "$reply1"
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_downloads_attachment_and_sets_metadata(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"image"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="photo.png",
+ url="mxc://example.org/mediaid",
+ event_id="$event1",
+ source={
+ "content": {
+ "msgtype": "m.image",
+ "info": {"mimetype": "image/png", "size": 5},
+ }
+ },
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(client.download_calls) == 1
+ assert len(handled) == 1
+ assert client.typing_calls == [("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)]
+
+ media_paths = handled[0]["media"]
+ assert isinstance(media_paths, list) and len(media_paths) == 1
+ media_path = Path(media_paths[0])
+ assert media_path.is_file()
+ assert media_path.read_bytes() == b"image"
+
+ metadata = handled[0]["metadata"]
+ attachments = metadata["attachments"]
+ assert isinstance(attachments, list) and len(attachments) == 1
+ assert attachments[0]["type"] == "image"
+ assert attachments[0]["mxc_url"] == "mxc://example.org/mediaid"
+ assert attachments[0]["path"] == str(media_path)
+ assert "[attachment: " in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_sets_thread_metadata_when_threaded_event(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"image"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="photo.png",
+ url="mxc://example.org/mediaid",
+ event_id="$event1",
+ source={
+ "content": {
+ "msgtype": "m.image",
+ "info": {"mimetype": "image/png", "size": 5},
+ "m.relates_to": {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ },
+ }
+ },
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(handled) == 1
+ metadata = handled[0]["metadata"]
+ assert metadata["thread_root_event_id"] == "$root1"
+ assert metadata["thread_reply_to_event_id"] == "$event1"
+ assert metadata["event_id"] == "$event1"
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_respects_declared_size_limit(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(max_media_bytes=3), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="large.bin",
+ url="mxc://example.org/large",
+ event_id="$event2",
+ source={"content": {"msgtype": "m.file", "info": {"size": 10}}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert client.download_calls == []
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: large.bin - too large]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_uses_server_limit_when_smaller_than_local_limit(
+ monkeypatch, tmp_path
+) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(max_media_bytes=10), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.content_repository_config_response = SimpleNamespace(upload_size=3)
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="large.bin",
+ url="mxc://example.org/large",
+ event_id="$event2_server",
+ source={"content": {"msgtype": "m.file", "info": {"size": 5}}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert client.download_calls == []
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: large.bin - too large]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_response = matrix_module.DownloadError("download failed")
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="photo.png",
+ url="mxc://example.org/mediaid",
+ event_id="$event3",
+ source={"content": {"msgtype": "m.image"}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(client.download_calls) == 1
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: photo.png - download failed]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_decrypts_encrypted_media(monkeypatch, tmp_path) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+ monkeypatch.setattr(
+ matrix_module,
+ "decrypt_attachment",
+ lambda ciphertext, key, sha256, iv: b"plain",
+ )
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"cipher"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="secret.txt",
+ url="mxc://example.org/encrypted",
+ event_id="$event4",
+ key={"k": "key"},
+ hashes={"sha256": "hash"},
+ iv="iv",
+ source={"content": {"msgtype": "m.file", "info": {"size": 6}}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(handled) == 1
+ media_path = Path(handled[0]["media"][0])
+ assert media_path.read_bytes() == b"plain"
+ attachment = handled[0]["metadata"]["attachments"][0]
+ assert attachment["encrypted"] is True
+ assert attachment["size_bytes"] == 5
+
+
+@pytest.mark.asyncio
+async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) -> None:
+ monkeypatch.setattr("nanobot.channels.matrix.get_data_dir", lambda: tmp_path)
+
+ def _raise(*args, **kwargs):
+ raise matrix_module.EncryptionError("boom")
+
+ monkeypatch.setattr(matrix_module, "decrypt_attachment", _raise)
+
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.download_bytes = b"cipher"
+ channel.client = client
+
+ handled: list[dict[str, object]] = []
+
+ async def _fake_handle_message(**kwargs) -> None:
+ handled.append(kwargs)
+
+ channel._handle_message = _fake_handle_message # type: ignore[method-assign]
+
+ room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room", member_count=2)
+ event = SimpleNamespace(
+ sender="@alice:matrix.org",
+ body="secret.txt",
+ url="mxc://example.org/encrypted",
+ event_id="$event5",
+ key={"k": "key"},
+ hashes={"sha256": "hash"},
+ iv="iv",
+ source={"content": {"msgtype": "m.file"}},
+ )
+
+ await channel._on_media_message(room, event)
+
+ assert len(handled) == 1
+ assert handled[0]["media"] == []
+ assert handled[0]["metadata"]["attachments"] == []
+ assert "[attachment: secret.txt - download failed]" in handled[0]["content"]
+
+
+@pytest.mark.asyncio
+async def test_send_clears_typing_after_send() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"] == {
+ "msgtype": "m.text",
+ "body": "Hi",
+ "m.mentions": {},
+ }
+ assert client.room_send_calls[0]["ignore_unverified_devices"] is True
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_send_uploads_media_and_sends_file_event(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ file_path = tmp_path / "test.txt"
+ file_path.write_text("hello", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Please review.",
+ media=[str(file_path)],
+ )
+ )
+
+ assert len(client.upload_calls) == 1
+ assert not isinstance(client.upload_calls[0]["data_provider"], (bytes, bytearray))
+ assert hasattr(client.upload_calls[0]["data_provider"], "read")
+ assert client.upload_calls[0]["filename"] == "test.txt"
+ assert client.upload_calls[0]["filesize"] == 5
+ assert len(client.room_send_calls) == 2
+ assert client.room_send_calls[0]["content"]["msgtype"] == "m.file"
+ assert client.room_send_calls[0]["content"]["url"] == "mxc://example.org/uploaded"
+ assert client.room_send_calls[1]["content"]["body"] == "Please review."
+
+
+@pytest.mark.asyncio
+async def test_send_adds_thread_relates_to_for_thread_metadata() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ metadata = {
+ "thread_root_event_id": "$root1",
+ "thread_reply_to_event_id": "$reply1",
+ }
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Hi",
+ metadata=metadata,
+ )
+ )
+
+ content = client.room_send_calls[0]["content"]
+ assert content["m.relates_to"] == {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ "m.in_reply_to": {"event_id": "$reply1"},
+ "is_falling_back": True,
+ }
+
+
+@pytest.mark.asyncio
+async def test_send_uses_encrypted_media_payload_in_encrypted_room(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(e2ee_enabled=True), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.rooms["!encrypted:matrix.org"] = SimpleNamespace(encrypted=True)
+ channel.client = client
+
+ file_path = tmp_path / "secret.txt"
+ file_path.write_text("topsecret", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!encrypted:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert len(client.upload_calls) == 1
+ assert client.upload_calls[0]["encrypt"] is True
+ assert len(client.room_send_calls) == 1
+ content = client.room_send_calls[0]["content"]
+ assert content["msgtype"] == "m.file"
+ assert "file" in content
+ assert "url" not in content
+ assert content["file"]["url"] == "mxc://example.org/uploaded"
+ assert content["file"]["hashes"]["sha256"] == "hash"
+
+
+@pytest.mark.asyncio
+async def test_send_does_not_parse_attachment_marker_without_media(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ missing_path = tmp_path / "missing.txt"
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content=f"[attachment: {missing_path}]",
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == f"[attachment: {missing_path}]"
+
+
+@pytest.mark.asyncio
+async def test_send_passes_thread_relates_to_to_attachment_upload(monkeypatch) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._server_upload_limit_checked = True
+ channel._server_upload_limit_bytes = None
+
+ captured: dict[str, object] = {}
+
+ async def _fake_upload_and_send_attachment(
+ *,
+ room_id: str,
+ path: Path,
+ limit_bytes: int,
+ relates_to: dict[str, object] | None = None,
+ ) -> str | None:
+ captured["relates_to"] = relates_to
+ return None
+
+ monkeypatch.setattr(channel, "_upload_and_send_attachment", _fake_upload_and_send_attachment)
+
+ metadata = {
+ "thread_root_event_id": "$root1",
+ "thread_reply_to_event_id": "$reply1",
+ }
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Hi",
+ media=["/tmp/fake.txt"],
+ metadata=metadata,
+ )
+ )
+
+ assert captured["relates_to"] == {
+ "rel_type": "m.thread",
+ "event_id": "$root1",
+ "m.in_reply_to": {"event_id": "$reply1"},
+ "is_falling_back": True,
+ }
+
+
+@pytest.mark.asyncio
+async def test_send_workspace_restriction_blocks_external_attachment(tmp_path) -> None:
+ workspace = tmp_path / "workspace"
+ workspace.mkdir()
+ file_path = tmp_path / "external.txt"
+ file_path.write_text("outside", encoding="utf-8")
+
+ channel = MatrixChannel(
+ _make_config(),
+ MessageBus(),
+ restrict_to_workspace=True,
+ workspace=workspace,
+ )
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == "[attachment: external.txt - upload failed]"
+
+
+@pytest.mark.asyncio
+async def test_send_handles_upload_exception_and_reports_failure(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.raise_on_upload = True
+ channel.client = client
+
+ file_path = tmp_path / "broken.txt"
+ file_path.write_text("hello", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="Please review.",
+ media=[str(file_path)],
+ )
+ )
+
+ assert len(client.upload_calls) == 0
+ assert len(client.room_send_calls) == 1
+ assert (
+ client.room_send_calls[0]["content"]["body"]
+ == "Please review.\n[attachment: broken.txt - upload failed]"
+ )
+
+
+@pytest.mark.asyncio
+async def test_send_uses_server_upload_limit_when_smaller_than_local_limit(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(max_media_bytes=10), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.content_repository_config_response = SimpleNamespace(upload_size=3)
+ channel.client = client
+
+ file_path = tmp_path / "tiny.txt"
+ file_path.write_text("hello", encoding="utf-8")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == "[attachment: tiny.txt - too large]"
+
+
+@pytest.mark.asyncio
+async def test_send_blocks_all_outbound_media_when_limit_is_zero(tmp_path) -> None:
+ channel = MatrixChannel(_make_config(max_media_bytes=0), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ file_path = tmp_path / "empty.txt"
+ file_path.write_bytes(b"")
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="",
+ media=[str(file_path)],
+ )
+ )
+
+ assert client.upload_calls == []
+ assert len(client.room_send_calls) == 1
+ assert client.room_send_calls[0]["content"]["body"] == "[attachment: empty.txt - too large]"
+
+
+@pytest.mark.asyncio
+async def test_send_omits_ignore_unverified_devices_when_e2ee_disabled() -> None:
+ channel = MatrixChannel(_make_config(e2ee_enabled=False), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert len(client.room_send_calls) == 1
+ assert "ignore_unverified_devices" not in client.room_send_calls[0]
+
+
+@pytest.mark.asyncio
+async def test_send_stops_typing_keepalive_task() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._running = True
+
+ await channel._start_typing_keepalive("!room:matrix.org")
+ assert "!room:matrix.org" in channel._typing_tasks
+
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert "!room:matrix.org" not in channel._typing_tasks
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_send_progress_keeps_typing_keepalive_running() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+ channel._running = True
+
+ await channel._start_typing_keepalive("!room:matrix.org")
+ assert "!room:matrix.org" in channel._typing_tasks
+
+ await channel.send(
+ OutboundMessage(
+ channel="matrix",
+ chat_id="!room:matrix.org",
+ content="working...",
+ metadata={"_progress": True, "_progress_kind": "reasoning"},
+ )
+ )
+
+ assert "!room:matrix.org" in channel._typing_tasks
+ assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
+
+ await channel.stop()
+
+
+@pytest.mark.asyncio
+async def test_send_clears_typing_when_send_fails() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ client.raise_on_send = True
+ channel.client = client
+
+ with pytest.raises(RuntimeError, match="send failed"):
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content="Hi")
+ )
+
+ assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
+
+
+@pytest.mark.asyncio
+async def test_send_adds_formatted_body_for_markdown() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ markdown_text = "# Headline\n\n- [x] done\n\n| A | B |\n| - | - |\n| 1 | 2 |"
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text)
+ )
+
+ content = client.room_send_calls[0]["content"]
+ assert content["msgtype"] == "m.text"
+ assert content["body"] == markdown_text
+ assert content["m.mentions"] == {}
+ assert content["format"] == MATRIX_HTML_FORMAT
+ assert "Headline
" in str(content["formatted_body"])
+ assert "" in str(content["formatted_body"])
+ assert "[x] done" in str(content["formatted_body"])
+
+
+@pytest.mark.asyncio
+async def test_send_adds_formatted_body_for_inline_url_superscript_subscript() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ markdown_text = "Visit https://example.com and x^2^ plus H~2~O."
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text)
+ )
+
+ content = client.room_send_calls[0]["content"]
+ assert content["msgtype"] == "m.text"
+ assert content["body"] == markdown_text
+ assert content["m.mentions"] == {}
+ assert content["format"] == MATRIX_HTML_FORMAT
+ assert '' in str(
+ content["formatted_body"]
+ )
+ assert "2" in str(content["formatted_body"])
+ assert "2" in str(content["formatted_body"])
+
+
+@pytest.mark.asyncio
+async def test_send_sanitizes_disallowed_link_scheme() -> None:
+ channel = MatrixChannel(_make_config(), MessageBus())
+ client = _FakeAsyncClient("", "", "", None)
+ channel.client = client
+
+ markdown_text = "[click](javascript:alert(1))"
+ await channel.send(
+ OutboundMessage(channel="matrix", chat_id="!room:matrix.org", content=markdown_text)
+ )
+
+ formatted_body = str(client.room_send_calls[0]["content"]["formatted_body"])
+ assert "javascript:" not in formatted_body
+ assert " None:
+ dirty_html = 'x'
+ cleaned_html = matrix_module.MATRIX_HTML_CLEANER.clean(dirty_html)
+
+ assert "