diff --git a/README.md b/README.md index 3d5fb63..5be0ce5 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ ## ๐Ÿ“ข News +- **2026-03-08** ๐Ÿš€ Released **v0.1.4.post4** โ€” a reliability-packed release with safer defaults, better multi-instance support, sturdier MCP, and major channel and provider improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post4) for details. - **2026-03-07** ๐Ÿš€ Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish. - **2026-03-06** ๐Ÿช„ Lighter providers, smarter media handling, and sturdier memory and CLI compatibility. - **2026-03-05** โšก๏ธ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes. diff --git a/core_agent_lines.sh b/core_agent_lines.sh index 3f5301a..df32394 100755 --- a/core_agent_lines.sh +++ b/core_agent_lines.sh @@ -15,7 +15,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l) printf " %-16s %5s lines\n" "(root)" "$root" echo "" -total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" | xargs cat | wc -l) +total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l) echo " Core total: $total lines" echo "" -echo " (excludes: channels/, cli/, providers/)" +echo " (excludes: channels/, cli/, providers/, skills/)" diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 2c648eb..e47fcb8 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -10,7 +10,7 @@ from typing import Any from nanobot.agent.memory import MemoryStore from nanobot.agent.skills import SkillsLoader -from nanobot.utils.helpers import detect_image_mime +from nanobot.utils.helpers import build_assistant_message, detect_image_mime class ContextBuilder: @@ -182,12 +182,10 @@ Reply directly with text for conversations. Only use the 'message' tool to send thinking_blocks: list[dict] | None = None, ) -> list[dict[str, Any]]: """Add an assistant message to the message list.""" - msg: dict[str, Any] = {"role": "assistant", "content": content} - if tool_calls: - msg["tool_calls"] = tool_calls - if reasoning_content is not None: - msg["reasoning_content"] = reasoning_content - if thinking_blocks: - msg["thinking_blocks"] = thinking_blocks - messages.append(msg) + messages.append(build_assistant_message( + content, + tool_calls=tool_calls, + reasoning_content=reasoning_content, + thinking_blocks=thinking_blocks, + )) return messages diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index ca9a06e..8605a09 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -5,7 +5,6 @@ from __future__ import annotations import asyncio import json import re -import weakref from contextlib import AsyncExitStack from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable @@ -13,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger from nanobot.agent.context import ContextBuilder -from nanobot.agent.memory import MemoryStore +from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool @@ -55,8 +54,8 @@ class AgentLoop: max_iterations: int = 40, temperature: float = 0.1, max_tokens: int = 4096, - memory_window: int = 100, reasoning_effort: str | None = None, + context_window_tokens: int = 65_536, brave_api_key: str | None = None, web_proxy: str | None = None, exec_config: ExecToolConfig | None = None, @@ -75,8 +74,8 @@ class AgentLoop: self.max_iterations = max_iterations self.temperature = temperature self.max_tokens = max_tokens - self.memory_window = memory_window self.reasoning_effort = reasoning_effort + self.context_window_tokens = context_window_tokens self.brave_api_key = brave_api_key self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -105,11 +104,17 @@ class AgentLoop: self._mcp_stack: AsyncExitStack | None = None self._mcp_connected = False self._mcp_connecting = False - self._consolidating: set[str] = set() # Session keys with consolidation in progress - self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks - self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._processing_lock = asyncio.Lock() + self.memory_consolidator = MemoryConsolidator( + workspace=workspace, + provider=provider, + model=self.model, + sessions=self.sessions, + context_window_tokens=context_window_tokens, + build_messages=self.context.build_messages, + get_tool_definitions=self.tools.get_definitions, + ) self._register_default_tools() def _register_default_tools(self) -> None: @@ -182,7 +187,7 @@ class AgentLoop: initial_messages: list[dict], on_progress: Callable[..., Awaitable[None]] | None = None, ) -> tuple[str | None, list[str], list[dict]]: - """Run the agent iteration loop. Returns (final_content, tools_used, messages).""" + """Run the agent iteration loop.""" messages = initial_messages iteration = 0 final_content = None @@ -191,9 +196,11 @@ class AgentLoop: while iteration < self.max_iterations: iteration += 1 - response = await self.provider.chat( + tool_defs = self.tools.get_definitions() + + response = await self.provider.chat_with_retry( messages=messages, - tools=self.tools.get_definitions(), + tools=tool_defs, model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, @@ -341,8 +348,9 @@ class AgentLoop: logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) + await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) - history = session.get_history(max_messages=self.memory_window) + history = session.get_history(max_messages=0) messages = self.context.build_messages( history=history, current_message=msg.content, channel=channel, chat_id=chat_id, @@ -350,6 +358,7 @@ class AgentLoop: final_content, _, all_msgs = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + await self.memory_consolidator.maybe_consolidate_by_tokens(session) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -362,27 +371,20 @@ class AgentLoop: # Slash commands cmd = msg.content.strip().lower() if cmd == "/new": - lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) - self._consolidating.add(session.key) try: - async with lock: - snapshot = session.messages[session.last_consolidated:] - if snapshot: - temp = Session(key=session.key) - temp.messages = list(snapshot) - if not await self._consolidate_memory(temp, archive_all=True): - return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="Memory archival failed, session not cleared. Please try again.", - ) + if not await self.memory_consolidator.archive_unconsolidated(session): + return OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="Memory archival failed, session not cleared. Please try again.", + ) except Exception: logger.exception("/new archival failed for {}", session.key) return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, content="Memory archival failed, session not cleared. Please try again.", ) - finally: - self._consolidating.discard(session.key) session.clear() self.sessions.save(session) @@ -393,30 +395,14 @@ class AgentLoop: return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="๐Ÿˆ nanobot commands:\n/new โ€” Start a new conversation\n/stop โ€” Stop the current task\n/help โ€” Show available commands") - unconsolidated = len(session.messages) - session.last_consolidated - if (unconsolidated >= self.memory_window and session.key not in self._consolidating): - self._consolidating.add(session.key) - lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) - - async def _consolidate_and_unlock(): - try: - async with lock: - await self._consolidate_memory(session) - finally: - self._consolidating.discard(session.key) - _task = asyncio.current_task() - if _task is not None: - self._consolidation_tasks.discard(_task) - - _task = asyncio.create_task(_consolidate_and_unlock()) - self._consolidation_tasks.add(_task) + await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): if isinstance(message_tool, MessageTool): message_tool.start_turn() - history = session.get_history(max_messages=self.memory_window) + history = session.get_history(max_messages=0) initial_messages = self.context.build_messages( history=history, current_message=msg.content, @@ -441,6 +427,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + await self.memory_consolidator.maybe_consolidate_by_tokens(session) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None @@ -487,13 +474,6 @@ class AgentLoop: session.messages.append(entry) session.updated_at = datetime.now() - async def _consolidate_memory(self, session, archive_all: bool = False) -> bool: - """Delegate to MemoryStore.consolidate(). Returns True on success.""" - return await MemoryStore(self.workspace).consolidate( - session, self.provider, self.model, - archive_all=archive_all, memory_window=self.memory_window, - ) - async def process_direct( self, content: str, diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 21fe77d..cd5f54f 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -2,17 +2,19 @@ from __future__ import annotations +import asyncio import json +import weakref from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable from loguru import logger -from nanobot.utils.helpers import ensure_dir +from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain if TYPE_CHECKING: from nanobot.providers.base import LLMProvider - from nanobot.session.manager import Session + from nanobot.session.manager import Session, SessionManager _SAVE_MEMORY_TOOL = [ @@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [ "properties": { "history_entry": { "type": "string", - "description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. " + "description": "A paragraph summarizing key events/decisions/topics. " "Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.", }, "memory_update": { @@ -42,6 +44,20 @@ _SAVE_MEMORY_TOOL = [ ] +def _ensure_text(value: Any) -> str: + """Normalize tool-call payload values to text for file storage.""" + return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False) + + +def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None: + """Normalize provider tool-call arguments to the expected dict shape.""" + if isinstance(args, str): + args = json.loads(args) + if isinstance(args, list): + return args[0] if args and isinstance(args[0], dict) else None + return args if isinstance(args, dict) else None + + class MemoryStore: """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" @@ -66,40 +82,27 @@ class MemoryStore: long_term = self.read_long_term() return f"## Long-term Memory\n{long_term}" if long_term else "" + @staticmethod + def _format_messages(messages: list[dict]) -> str: + lines = [] + for message in messages: + if not message.get("content"): + continue + tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else "" + lines.append( + f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}" + ) + return "\n".join(lines) + async def consolidate( self, - session: Session, + messages: list[dict], provider: LLMProvider, model: str, - *, - archive_all: bool = False, - memory_window: int = 50, ) -> bool: - """Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call. - - Returns True on success (including no-op), False on failure. - """ - if archive_all: - old_messages = session.messages - keep_count = 0 - logger.info("Memory consolidation (archive_all): {} messages", len(session.messages)) - else: - keep_count = memory_window // 2 - if len(session.messages) <= keep_count: - return True - if len(session.messages) - session.last_consolidated <= 0: - return True - old_messages = session.messages[session.last_consolidated:-keep_count] - if not old_messages: - return True - logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count) - - lines = [] - for m in old_messages: - if not m.get("content"): - continue - tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" - lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}") + """Consolidate the provided message chunk into MEMORY.md + HISTORY.md.""" + if not messages: + return True current_memory = self.read_long_term() prompt = f"""Process this conversation and call the save_memory tool with your consolidation. @@ -108,10 +111,10 @@ class MemoryStore: {current_memory or "(empty)"} ## Conversation to Process -{chr(10).join(lines)}""" +{self._format_messages(messages)}""" try: - response = await provider.chat( + response = await provider.chat_with_retry( messages=[ {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, {"role": "user", "content": prompt}, @@ -124,34 +127,158 @@ class MemoryStore: logger.warning("Memory consolidation: LLM did not call save_memory, skipping") return False - args = response.tool_calls[0].arguments - # Some providers return arguments as a JSON string instead of dict - if isinstance(args, str): - args = json.loads(args) - # Some providers return arguments as a list (handle edge case) - if isinstance(args, list): - if args and isinstance(args[0], dict): - args = args[0] - else: - logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list") - return False - if not isinstance(args, dict): - logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) + args = _normalize_save_memory_args(response.tool_calls[0].arguments) + if args is None: + logger.warning("Memory consolidation: unexpected save_memory arguments") return False if entry := args.get("history_entry"): - if not isinstance(entry, str): - entry = json.dumps(entry, ensure_ascii=False) - self.append_history(entry) + self.append_history(_ensure_text(entry)) if update := args.get("memory_update"): - if not isinstance(update, str): - update = json.dumps(update, ensure_ascii=False) + update = _ensure_text(update) if update != current_memory: self.write_long_term(update) - session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count - logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) + logger.info("Memory consolidation done for {} messages", len(messages)) return True except Exception: logger.exception("Memory consolidation failed") return False + + +class MemoryConsolidator: + """Owns consolidation policy, locking, and session offset updates.""" + + _MAX_CONSOLIDATION_ROUNDS = 5 + + def __init__( + self, + workspace: Path, + provider: LLMProvider, + model: str, + sessions: SessionManager, + context_window_tokens: int, + build_messages: Callable[..., list[dict[str, Any]]], + get_tool_definitions: Callable[[], list[dict[str, Any]]], + ): + self.store = MemoryStore(workspace) + self.provider = provider + self.model = model + self.sessions = sessions + self.context_window_tokens = context_window_tokens + self._build_messages = build_messages + self._get_tool_definitions = get_tool_definitions + self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() + + def get_lock(self, session_key: str) -> asyncio.Lock: + """Return the shared consolidation lock for one session.""" + return self._locks.setdefault(session_key, asyncio.Lock()) + + async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: + """Archive a selected message chunk into persistent memory.""" + return await self.store.consolidate(messages, self.provider, self.model) + + def pick_consolidation_boundary( + self, + session: Session, + tokens_to_remove: int, + ) -> tuple[int, int] | None: + """Pick a user-turn boundary that removes enough old prompt tokens.""" + start = session.last_consolidated + if start >= len(session.messages) or tokens_to_remove <= 0: + return None + + removed_tokens = 0 + last_boundary: tuple[int, int] | None = None + for idx in range(start, len(session.messages)): + message = session.messages[idx] + if idx > start and message.get("role") == "user": + last_boundary = (idx, removed_tokens) + if removed_tokens >= tokens_to_remove: + return last_boundary + removed_tokens += estimate_message_tokens(message) + + return last_boundary + + def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]: + """Estimate current prompt size for the normal session history view.""" + history = session.get_history(max_messages=0) + channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) + probe_messages = self._build_messages( + history=history, + current_message="[token-probe]", + channel=channel, + chat_id=chat_id, + ) + return estimate_prompt_tokens_chain( + self.provider, + self.model, + probe_messages, + self._get_tool_definitions(), + ) + + async def archive_unconsolidated(self, session: Session) -> bool: + """Archive the full unconsolidated tail for /new-style session rollover.""" + lock = self.get_lock(session.key) + async with lock: + snapshot = session.messages[session.last_consolidated:] + if not snapshot: + return True + return await self.consolidate_messages(snapshot) + + async def maybe_consolidate_by_tokens(self, session: Session) -> None: + """Loop: archive old messages until prompt fits within half the context window.""" + if not session.messages or self.context_window_tokens <= 0: + return + + lock = self.get_lock(session.key) + async with lock: + target = self.context_window_tokens // 2 + estimated, source = self.estimate_session_prompt_tokens(session) + if estimated <= 0: + return + if estimated < self.context_window_tokens: + logger.debug( + "Token consolidation idle {}: {}/{} via {}", + session.key, + estimated, + self.context_window_tokens, + source, + ) + return + + for round_num in range(self._MAX_CONSOLIDATION_ROUNDS): + if estimated <= target: + return + + boundary = self.pick_consolidation_boundary(session, max(1, estimated - target)) + if boundary is None: + logger.debug( + "Token consolidation: no safe boundary for {} (round {})", + session.key, + round_num, + ) + return + + end_idx = boundary[0] + chunk = session.messages[session.last_consolidated:end_idx] + if not chunk: + return + + logger.info( + "Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs", + round_num, + session.key, + estimated, + self.context_window_tokens, + source, + len(chunk), + ) + if not await self.consolidate_messages(chunk): + return + session.last_consolidated = end_idx + self.sessions.save(session) + + estimated, source = self.estimate_session_prompt_tokens(session) + if estimated <= 0: + return diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index f2d6ee5..eff0b4f 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -16,6 +16,7 @@ from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.config.schema import ExecToolConfig from nanobot.providers.base import LLMProvider +from nanobot.utils.helpers import build_assistant_message class SubagentManager: @@ -123,7 +124,7 @@ class SubagentManager: while iteration < max_iterations: iteration += 1 - response = await self.provider.chat( + response = await self.provider.chat_with_retry( messages=messages, tools=tools.get_definitions(), model=self.model, @@ -133,7 +134,6 @@ class SubagentManager: ) if response.has_tool_calls: - # Add assistant message with tool calls tool_call_dicts = [ { "id": tc.id, @@ -145,11 +145,12 @@ class SubagentManager: } for tc in response.tool_calls ] - messages.append({ - "role": "assistant", - "content": response.content or "", - "tool_calls": tool_call_dicts, - }) + messages.append(build_assistant_message( + response.content or "", + tool_calls=tool_call_dicts, + reasoning_content=response.reasoning_content, + thinking_blocks=response.thinking_blocks, + )) # Execute tools for tool_call in response.tool_calls: diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 3c301a9..cdcba57 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -57,6 +57,8 @@ class NanobotDingTalkHandler(CallbackHandler): content = "" if chatbot_msg.text: content = chatbot_msg.text.content.strip() + elif chatbot_msg.extensions.get("content", {}).get("recognition"): + content = chatbot_msg.extensions["content"]["recognition"].strip() if not content: content = message.data.get("text", {}).get("content", "").strip() diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index a637025..0409c32 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -753,8 +753,9 @@ class FeishuChannel(BaseChannel): None, self._download_file_sync, message_id, file_key, msg_type ) if not filename: - ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "") - filename = f"{file_key[:16]}{ext}" + filename = file_key[:16] + if msg_type == "audio" and not filename.endswith(".opus"): + filename = f"{filename}.opus" if data and filename: file_path = media_dir / filename diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index a4e7324..0384d8d 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -81,8 +81,8 @@ class SlackChannel(BaseChannel): slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {} thread_ts = slack_meta.get("thread_ts") channel_type = slack_meta.get("channel_type") - # Only reply in thread for channel/group messages; DMs don't use threads - thread_ts_param = thread_ts if use_thread else None + # Slack DMs don't use threads; channel/group replies may keep thread_ts. + thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None # Slack rejects empty text payloads. Keep media-only messages media-only, # but send a single blank message when the bot has no text or files to send. @@ -278,4 +278,3 @@ class SlackChannel(BaseChannel): if parts: rows.append(" ยท ".join(parts)) return "\n".join(rows) - diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index ecb1440..5b294cc 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -179,6 +179,8 @@ class TelegramChannel(BaseChannel): self._media_group_buffers: dict[str, dict] = {} self._media_group_tasks: dict[str, asyncio.Task] = {} self._message_threads: dict[tuple[str, int], int] = {} + self._bot_user_id: int | None = None + self._bot_username: str | None = None def is_allowed(self, sender_id: str) -> bool: """Preserve Telegram's legacy id|username allowlist matching.""" @@ -242,6 +244,8 @@ class TelegramChannel(BaseChannel): # Get bot info and register command menu bot_info = await self._app.bot.get_me() + self._bot_user_id = getattr(bot_info, "id", None) + self._bot_username = getattr(bot_info, "username", None) logger.info("Telegram bot @{} connected", bot_info.username) try: @@ -462,6 +466,70 @@ class TelegramChannel(BaseChannel): "is_forum": bool(getattr(message.chat, "is_forum", False)), } + async def _ensure_bot_identity(self) -> tuple[int | None, str | None]: + """Load bot identity once and reuse it for mention/reply checks.""" + if self._bot_user_id is not None or self._bot_username is not None: + return self._bot_user_id, self._bot_username + if not self._app: + return None, None + bot_info = await self._app.bot.get_me() + self._bot_user_id = getattr(bot_info, "id", None) + self._bot_username = getattr(bot_info, "username", None) + return self._bot_user_id, self._bot_username + + @staticmethod + def _has_mention_entity( + text: str, + entities, + bot_username: str, + bot_id: int | None, + ) -> bool: + """Check Telegram mention entities against the bot username.""" + handle = f"@{bot_username}".lower() + for entity in entities or []: + entity_type = getattr(entity, "type", None) + if entity_type == "text_mention": + user = getattr(entity, "user", None) + if user is not None and bot_id is not None and getattr(user, "id", None) == bot_id: + return True + continue + if entity_type != "mention": + continue + offset = getattr(entity, "offset", None) + length = getattr(entity, "length", None) + if offset is None or length is None: + continue + if text[offset : offset + length].lower() == handle: + return True + return handle in text.lower() + + async def _is_group_message_for_bot(self, message) -> bool: + """Allow group messages when policy is open, @mentioned, or replying to the bot.""" + if message.chat.type == "private" or self.config.group_policy == "open": + return True + + bot_id, bot_username = await self._ensure_bot_identity() + if bot_username: + text = message.text or "" + caption = message.caption or "" + if self._has_mention_entity( + text, + getattr(message, "entities", None), + bot_username, + bot_id, + ): + return True + if self._has_mention_entity( + caption, + getattr(message, "caption_entities", None), + bot_username, + bot_id, + ): + return True + + reply_user = getattr(getattr(message, "reply_to_message", None), "from_user", None) + return bool(bot_id and reply_user and reply_user.id == bot_id) + def _remember_thread_context(self, message) -> None: """Cache topic thread id by chat/message id for follow-up replies.""" message_thread_id = getattr(message, "message_thread_id", None) @@ -501,6 +569,9 @@ class TelegramChannel(BaseChannel): # Store chat_id for replies self._chat_ids[sender_id] = chat_id + if not await self._is_group_message_for_bot(message): + return + # Build content from text and/or media content_parts = [] media_paths = [] diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 2c8d6d3..cf69450 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -191,6 +191,8 @@ def onboard(): save_config(Config()) console.print(f"[green]โœ“[/green] Created config at {config_path}") + console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]") + # Create workspace workspace = get_workspace_path() @@ -283,6 +285,16 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None return loaded +def _print_deprecated_memory_window_notice(config: Config) -> None: + """Warn when running with old memoryWindow-only config.""" + if config.agents.defaults.should_warn_deprecated_memory_window: + console.print( + "[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without " + "`contextWindowTokens`. `memoryWindow` is ignored; run " + "[cyan]nanobot onboard[/cyan] to refresh your config template." + ) + + # ============================================================================ # Gateway / Server # ============================================================================ @@ -290,7 +302,7 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None @app.command() def gateway( - port: int = typer.Option(18790, "--port", "-p", help="Gateway port"), + port: int | None = typer.Option(None, "--port", "-p", help="Gateway port"), workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"), config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), @@ -310,6 +322,8 @@ def gateway( logging.basicConfig(level=logging.DEBUG) config = _load_runtime_config(config, workspace) + _print_deprecated_memory_window_notice(config) + port = port if port is not None else config.gateway.port console.print(f"{__logo__} Starting nanobot gateway on port {port}...") sync_workspace_templates(config.workspace_path) @@ -330,8 +344,8 @@ def gateway( temperature=config.agents.defaults.temperature, max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, - memory_window=config.agents.defaults.memory_window, reasoning_effort=config.agents.defaults.reasoning_effort, + context_window_tokens=config.agents.defaults.context_window_tokens, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, @@ -493,6 +507,7 @@ def agent( from nanobot.cron.service import CronService config = _load_runtime_config(config, workspace) + _print_deprecated_memory_window_notice(config) sync_workspace_templates(config.workspace_path) bus = MessageBus() @@ -515,8 +530,8 @@ def agent( temperature=config.agents.defaults.temperature, max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, - memory_window=config.agents.defaults.memory_window, reasoning_effort=config.agents.defaults.reasoning_effort, + context_window_tokens=config.agents.defaults.context_window_tokens, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 63eae48..b772d18 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -33,6 +33,7 @@ class TelegramConfig(Base): None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080" ) reply_to_message: bool = False # If true, bot replies quote the original message + group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned or replied to, "open" responds to all class FeishuConfig(Base): @@ -236,11 +237,18 @@ class AgentDefaults(Base): "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection ) max_tokens: int = 8192 + context_window_tokens: int = 65_536 temperature: float = 0.1 max_tool_iterations: int = 40 - memory_window: int = 100 + # Deprecated compatibility field: accepted from old configs but ignored at runtime. + memory_window: int | None = Field(default=None, exclude=True) reasoning_effort: str | None = None # low / medium / high โ€” enables LLM thinking mode + @property + def should_warn_deprecated_memory_window(self) -> bool: + """Return True when old memoryWindow is present without contextWindowTokens.""" + return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set + class AgentsConfig(Base): """Agent configuration.""" diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py index e534017..831ae85 100644 --- a/nanobot/heartbeat/service.py +++ b/nanobot/heartbeat/service.py @@ -87,7 +87,7 @@ class HeartbeatService: Returns (action, tasks) where action is 'skip' or 'run'. """ - response = await self.provider.chat( + response = await self.provider.chat_with_retry( messages=[ {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "user", "content": ( diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 0f73544..a3b6c47 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -1,9 +1,12 @@ """Base LLM provider interface.""" +import asyncio from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any +from loguru import logger + @dataclass class ToolCallRequest: @@ -37,6 +40,22 @@ class LLMProvider(ABC): while maintaining a consistent interface. """ + _CHAT_RETRY_DELAYS = (1, 2, 4) + _TRANSIENT_ERROR_MARKERS = ( + "429", + "rate limit", + "500", + "502", + "503", + "504", + "overloaded", + "timeout", + "timed out", + "connection", + "server error", + "temporarily unavailable", + ) + def __init__(self, api_key: str | None = None, api_base: str | None = None): self.api_key = api_key self.api_base = api_base @@ -126,6 +145,71 @@ class LLMProvider(ABC): """ pass + @classmethod + def _is_transient_error(cls, content: str | None) -> bool: + err = (content or "").lower() + return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) + + async def chat_with_retry( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + ) -> LLMResponse: + """Call chat() with retry on transient provider failures.""" + for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): + try: + response = await self.chat( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) + except asyncio.CancelledError: + raise + except Exception as exc: + response = LLMResponse( + content=f"Error calling LLM: {exc}", + finish_reason="error", + ) + + if response.finish_reason != "error": + return response + if not self._is_transient_error(response.content): + return response + + err = (response.content or "").lower() + logger.warning( + "LLM transient error (attempt {}/{}), retrying in {}s: {}", + attempt, + len(self._CHAT_RETRY_DELAYS), + delay, + err[:120], + ) + await asyncio.sleep(delay) + + try: + return await self.chat( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) + except asyncio.CancelledError: + raise + except Exception as exc: + return LLMResponse( + content=f"Error calling LLM: {exc}", + finish_reason="error", + ) + @abstractmethod def get_default_model(self) -> str: """Get the default model for this provider.""" diff --git a/nanobot/skills/skill-creator/SKILL.md b/nanobot/skills/skill-creator/SKILL.md index 9b5eb6f..ea53abe 100644 --- a/nanobot/skills/skill-creator/SKILL.md +++ b/nanobot/skills/skill-creator/SKILL.md @@ -268,6 +268,8 @@ Skip this step only if the skill being developed already exists, and iteration o When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable. +For `nanobot`, custom skills should live under the active workspace `skills/` directory so they can be discovered automatically at runtime (for example, `/skills/my-skill/SKILL.md`). + Usage: ```bash @@ -277,9 +279,9 @@ scripts/init_skill.py --path [--resources script Examples: ```bash -scripts/init_skill.py my-skill --path skills/public -scripts/init_skill.py my-skill --path skills/public --resources scripts,references -scripts/init_skill.py my-skill --path skills/public --resources scripts --examples +scripts/init_skill.py my-skill --path ./workspace/skills +scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts,references +scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts --examples ``` The script: @@ -326,7 +328,7 @@ Write the YAML frontmatter with `name` and `description`: - Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent. - Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks" -Do not include any other fields in YAML frontmatter. +Keep frontmatter minimal. In `nanobot`, `metadata` and `always` are also supported when needed, but avoid adding extra fields unless they are actually required. ##### Body @@ -349,7 +351,6 @@ scripts/package_skill.py ./dist The packaging script will: 1. **Validate** the skill automatically, checking: - - YAML frontmatter format and required fields - Skill naming conventions and directory structure - Description completeness and quality @@ -357,6 +358,8 @@ The packaging script will: 2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension. + Security restriction: symlinks are rejected and packaging fails when any symlink is present. + If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again. ### Step 6: Iterate diff --git a/nanobot/skills/skill-creator/scripts/init_skill.py b/nanobot/skills/skill-creator/scripts/init_skill.py new file mode 100755 index 0000000..8633fe9 --- /dev/null +++ b/nanobot/skills/skill-creator/scripts/init_skill.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +""" +Skill Initializer - Creates a new skill from template + +Usage: + init_skill.py --path [--resources scripts,references,assets] [--examples] + +Examples: + init_skill.py my-new-skill --path skills/public + init_skill.py my-new-skill --path skills/public --resources scripts,references + init_skill.py my-api-helper --path skills/private --resources scripts --examples + init_skill.py custom-skill --path /custom/location +""" + +import argparse +import re +import sys +from pathlib import Path + +MAX_SKILL_NAME_LENGTH = 64 +ALLOWED_RESOURCES = {"scripts", "references", "assets"} + +SKILL_TEMPLATE = """--- +name: {skill_name} +description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.] +--- + +# {skill_title} + +## Overview + +[TODO: 1-2 sentences explaining what this skill enables] + +## Structuring This Skill + +[TODO: Choose the structure that best fits this skill's purpose. Common patterns: + +**1. Workflow-Based** (best for sequential processes) +- Works well when there are clear step-by-step procedures +- Example: DOCX skill with "Workflow Decision Tree" -> "Reading" -> "Creating" -> "Editing" +- Structure: ## Overview -> ## Workflow Decision Tree -> ## Step 1 -> ## Step 2... + +**2. Task-Based** (best for tool collections) +- Works well when the skill offers different operations/capabilities +- Example: PDF skill with "Quick Start" -> "Merge PDFs" -> "Split PDFs" -> "Extract Text" +- Structure: ## Overview -> ## Quick Start -> ## Task Category 1 -> ## Task Category 2... + +**3. Reference/Guidelines** (best for standards or specifications) +- Works well for brand guidelines, coding standards, or requirements +- Example: Brand styling with "Brand Guidelines" -> "Colors" -> "Typography" -> "Features" +- Structure: ## Overview -> ## Guidelines -> ## Specifications -> ## Usage... + +**4. Capabilities-Based** (best for integrated systems) +- Works well when the skill provides multiple interrelated features +- Example: Product Management with "Core Capabilities" -> numbered capability list +- Structure: ## Overview -> ## Core Capabilities -> ### 1. Feature -> ### 2. Feature... + +Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations). + +Delete this entire "Structuring This Skill" section when done - it's just guidance.] + +## [TODO: Replace with the first main section based on chosen structure] + +[TODO: Add content here. See examples in existing skills: +- Code samples for technical skills +- Decision trees for complex workflows +- Concrete examples with realistic user requests +- References to scripts/templates/references as needed] + +## Resources (optional) + +Create only the resource directories this skill actually needs. Delete this section if no resources are required. + +### scripts/ +Executable code (Python/Bash/etc.) that can be run directly to perform specific operations. + +**Examples from other skills:** +- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation +- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing + +**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations. + +**Note:** Scripts may be executed without loading into context, but can still be read by Codex for patching or environment adjustments. + +### references/ +Documentation and reference material intended to be loaded into context to inform Codex's process and thinking. + +**Examples from other skills:** +- Product management: `communication.md`, `context_building.md` - detailed workflow guides +- BigQuery: API reference documentation and query examples +- Finance: Schema documentation, company policies + +**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Codex should reference while working. + +### assets/ +Files not intended to be loaded into context, but rather used within the output Codex produces. + +**Examples from other skills:** +- Brand styling: PowerPoint template files (.pptx), logo files +- Frontend builder: HTML/React boilerplate project directories +- Typography: Font files (.ttf, .woff2) + +**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output. + +--- + +**Not every skill requires all three types of resources.** +""" + +EXAMPLE_SCRIPT = '''#!/usr/bin/env python3 +""" +Example helper script for {skill_name} + +This is a placeholder script that can be executed directly. +Replace with actual implementation or delete if not needed. + +Example real scripts from other skills: +- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields +- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images +""" + +def main(): + print("This is an example script for {skill_name}") + # TODO: Add actual script logic here + # This could be data processing, file conversion, API calls, etc. + +if __name__ == "__main__": + main() +''' + +EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title} + +This is a placeholder for detailed reference documentation. +Replace with actual reference content or delete if not needed. + +Example real reference docs from other skills: +- product-management/references/communication.md - Comprehensive guide for status updates +- product-management/references/context_building.md - Deep-dive on gathering context +- bigquery/references/ - API references and query examples + +## When Reference Docs Are Useful + +Reference docs are ideal for: +- Comprehensive API documentation +- Detailed workflow guides +- Complex multi-step processes +- Information too lengthy for main SKILL.md +- Content that's only needed for specific use cases + +## Structure Suggestions + +### API Reference Example +- Overview +- Authentication +- Endpoints with examples +- Error codes +- Rate limits + +### Workflow Guide Example +- Prerequisites +- Step-by-step instructions +- Common patterns +- Troubleshooting +- Best practices +""" + +EXAMPLE_ASSET = """# Example Asset File + +This placeholder represents where asset files would be stored. +Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed. + +Asset files are NOT intended to be loaded into context, but rather used within +the output Codex produces. + +Example asset files from other skills: +- Brand guidelines: logo.png, slides_template.pptx +- Frontend builder: hello-world/ directory with HTML/React boilerplate +- Typography: custom-font.ttf, font-family.woff2 +- Data: sample_data.csv, test_dataset.json + +## Common Asset Types + +- Templates: .pptx, .docx, boilerplate directories +- Images: .png, .jpg, .svg, .gif +- Fonts: .ttf, .otf, .woff, .woff2 +- Boilerplate code: Project directories, starter files +- Icons: .ico, .svg +- Data files: .csv, .json, .xml, .yaml + +Note: This is a text placeholder. Actual assets can be any file type. +""" + + +def normalize_skill_name(skill_name): + """Normalize a skill name to lowercase hyphen-case.""" + normalized = skill_name.strip().lower() + normalized = re.sub(r"[^a-z0-9]+", "-", normalized) + normalized = normalized.strip("-") + normalized = re.sub(r"-{2,}", "-", normalized) + return normalized + + +def title_case_skill_name(skill_name): + """Convert hyphenated skill name to Title Case for display.""" + return " ".join(word.capitalize() for word in skill_name.split("-")) + + +def parse_resources(raw_resources): + if not raw_resources: + return [] + resources = [item.strip() for item in raw_resources.split(",") if item.strip()] + invalid = sorted({item for item in resources if item not in ALLOWED_RESOURCES}) + if invalid: + allowed = ", ".join(sorted(ALLOWED_RESOURCES)) + print(f"[ERROR] Unknown resource type(s): {', '.join(invalid)}") + print(f" Allowed: {allowed}") + sys.exit(1) + deduped = [] + seen = set() + for resource in resources: + if resource not in seen: + deduped.append(resource) + seen.add(resource) + return deduped + + +def create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples): + for resource in resources: + resource_dir = skill_dir / resource + resource_dir.mkdir(exist_ok=True) + if resource == "scripts": + if include_examples: + example_script = resource_dir / "example.py" + example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name)) + example_script.chmod(0o755) + print("[OK] Created scripts/example.py") + else: + print("[OK] Created scripts/") + elif resource == "references": + if include_examples: + example_reference = resource_dir / "api_reference.md" + example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title)) + print("[OK] Created references/api_reference.md") + else: + print("[OK] Created references/") + elif resource == "assets": + if include_examples: + example_asset = resource_dir / "example_asset.txt" + example_asset.write_text(EXAMPLE_ASSET) + print("[OK] Created assets/example_asset.txt") + else: + print("[OK] Created assets/") + + +def init_skill(skill_name, path, resources, include_examples): + """ + Initialize a new skill directory with template SKILL.md. + + Args: + skill_name: Name of the skill + path: Path where the skill directory should be created + resources: Resource directories to create + include_examples: Whether to create example files in resource directories + + Returns: + Path to created skill directory, or None if error + """ + # Determine skill directory path + skill_dir = Path(path).resolve() / skill_name + + # Check if directory already exists + if skill_dir.exists(): + print(f"[ERROR] Skill directory already exists: {skill_dir}") + return None + + # Create skill directory + try: + skill_dir.mkdir(parents=True, exist_ok=False) + print(f"[OK] Created skill directory: {skill_dir}") + except Exception as e: + print(f"[ERROR] Error creating directory: {e}") + return None + + # Create SKILL.md from template + skill_title = title_case_skill_name(skill_name) + skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title) + + skill_md_path = skill_dir / "SKILL.md" + try: + skill_md_path.write_text(skill_content) + print("[OK] Created SKILL.md") + except Exception as e: + print(f"[ERROR] Error creating SKILL.md: {e}") + return None + + # Create resource directories if requested + if resources: + try: + create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples) + except Exception as e: + print(f"[ERROR] Error creating resource directories: {e}") + return None + + # Print next steps + print(f"\n[OK] Skill '{skill_name}' initialized successfully at {skill_dir}") + print("\nNext steps:") + print("1. Edit SKILL.md to complete the TODO items and update the description") + if resources: + if include_examples: + print("2. Customize or delete the example files in scripts/, references/, and assets/") + else: + print("2. Add resources to scripts/, references/, and assets/ as needed") + else: + print("2. Create resource directories only if needed (scripts/, references/, assets/)") + print("3. Run the validator when ready to check the skill structure") + + return skill_dir + + +def main(): + parser = argparse.ArgumentParser( + description="Create a new skill directory with a SKILL.md template.", + ) + parser.add_argument("skill_name", help="Skill name (normalized to hyphen-case)") + parser.add_argument("--path", required=True, help="Output directory for the skill") + parser.add_argument( + "--resources", + default="", + help="Comma-separated list: scripts,references,assets", + ) + parser.add_argument( + "--examples", + action="store_true", + help="Create example files inside the selected resource directories", + ) + args = parser.parse_args() + + raw_skill_name = args.skill_name + skill_name = normalize_skill_name(raw_skill_name) + if not skill_name: + print("[ERROR] Skill name must include at least one letter or digit.") + sys.exit(1) + if len(skill_name) > MAX_SKILL_NAME_LENGTH: + print( + f"[ERROR] Skill name '{skill_name}' is too long ({len(skill_name)} characters). " + f"Maximum is {MAX_SKILL_NAME_LENGTH} characters." + ) + sys.exit(1) + if skill_name != raw_skill_name: + print(f"Note: Normalized skill name from '{raw_skill_name}' to '{skill_name}'.") + + resources = parse_resources(args.resources) + if args.examples and not resources: + print("[ERROR] --examples requires --resources to be set.") + sys.exit(1) + + path = args.path + + print(f"Initializing skill: {skill_name}") + print(f" Location: {path}") + if resources: + print(f" Resources: {', '.join(resources)}") + if args.examples: + print(" Examples: enabled") + else: + print(" Resources: none (create as needed)") + print() + + result = init_skill(skill_name, path, resources, args.examples) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/nanobot/skills/skill-creator/scripts/package_skill.py b/nanobot/skills/skill-creator/scripts/package_skill.py new file mode 100755 index 0000000..48fcbbe --- /dev/null +++ b/nanobot/skills/skill-creator/scripts/package_skill.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Skill Packager - Creates a distributable .skill file of a skill folder + +Usage: + python package_skill.py [output-directory] + +Example: + python package_skill.py skills/public/my-skill + python package_skill.py skills/public/my-skill ./dist +""" + +import sys +import zipfile +from pathlib import Path + +from quick_validate import validate_skill + + +def _is_within(path: Path, root: Path) -> bool: + try: + path.relative_to(root) + return True + except ValueError: + return False + + +def _cleanup_partial_archive(skill_filename: Path) -> None: + try: + if skill_filename.exists(): + skill_filename.unlink() + except OSError: + pass + + +def package_skill(skill_path, output_dir=None): + """ + Package a skill folder into a .skill file. + + Args: + skill_path: Path to the skill folder + output_dir: Optional output directory for the .skill file (defaults to current directory) + + Returns: + Path to the created .skill file, or None if error + """ + skill_path = Path(skill_path).resolve() + + # Validate skill folder exists + if not skill_path.exists(): + print(f"[ERROR] Skill folder not found: {skill_path}") + return None + + if not skill_path.is_dir(): + print(f"[ERROR] Path is not a directory: {skill_path}") + return None + + # Validate SKILL.md exists + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + print(f"[ERROR] SKILL.md not found in {skill_path}") + return None + + # Run validation before packaging + print("Validating skill...") + valid, message = validate_skill(skill_path) + if not valid: + print(f"[ERROR] Validation failed: {message}") + print(" Please fix the validation errors before packaging.") + return None + print(f"[OK] {message}\n") + + # Determine output location + skill_name = skill_path.name + if output_dir: + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + else: + output_path = Path.cwd() + + skill_filename = output_path / f"{skill_name}.skill" + + EXCLUDED_DIRS = {".git", ".svn", ".hg", "__pycache__", "node_modules"} + + files_to_package = [] + resolved_archive = skill_filename.resolve() + + for file_path in skill_path.rglob("*"): + # Fail closed on symlinks so the packaged contents are explicit and predictable. + if file_path.is_symlink(): + print(f"[ERROR] Symlink not allowed in packaged skill: {file_path}") + _cleanup_partial_archive(skill_filename) + return None + + rel_parts = file_path.relative_to(skill_path).parts + if any(part in EXCLUDED_DIRS for part in rel_parts): + continue + + if file_path.is_file(): + resolved_file = file_path.resolve() + if not _is_within(resolved_file, skill_path): + print(f"[ERROR] File escapes skill root: {file_path}") + _cleanup_partial_archive(skill_filename) + return None + # If output lives under skill_path, avoid writing archive into itself. + if resolved_file == resolved_archive: + print(f"[WARN] Skipping output archive: {file_path}") + continue + files_to_package.append(file_path) + + # Create the .skill file (zip format) + try: + with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf: + for file_path in files_to_package: + # Calculate the relative path within the zip. + arcname = Path(skill_name) / file_path.relative_to(skill_path) + zipf.write(file_path, arcname) + print(f" Added: {arcname}") + + print(f"\n[OK] Successfully packaged skill to: {skill_filename}") + return skill_filename + + except Exception as e: + _cleanup_partial_archive(skill_filename) + print(f"[ERROR] Error creating .skill file: {e}") + return None + + +def main(): + if len(sys.argv) < 2: + print("Usage: python package_skill.py [output-directory]") + print("\nExample:") + print(" python package_skill.py skills/public/my-skill") + print(" python package_skill.py skills/public/my-skill ./dist") + sys.exit(1) + + skill_path = sys.argv[1] + output_dir = sys.argv[2] if len(sys.argv) > 2 else None + + print(f"Packaging skill: {skill_path}") + if output_dir: + print(f" Output directory: {output_dir}") + print() + + result = package_skill(skill_path, output_dir) + + if result: + sys.exit(0) + else: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/nanobot/skills/skill-creator/scripts/quick_validate.py b/nanobot/skills/skill-creator/scripts/quick_validate.py new file mode 100644 index 0000000..03d246d --- /dev/null +++ b/nanobot/skills/skill-creator/scripts/quick_validate.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +Minimal validator for nanobot skill folders. +""" + +import re +import sys +from pathlib import Path +from typing import Optional + +try: + import yaml +except ModuleNotFoundError: + yaml = None + +MAX_SKILL_NAME_LENGTH = 64 +ALLOWED_FRONTMATTER_KEYS = { + "name", + "description", + "metadata", + "always", + "license", + "allowed-tools", +} +ALLOWED_RESOURCE_DIRS = {"scripts", "references", "assets"} +PLACEHOLDER_MARKERS = ("[todo", "todo:") + + +def _extract_frontmatter(content: str) -> Optional[str]: + lines = content.splitlines() + if not lines or lines[0].strip() != "---": + return None + for i in range(1, len(lines)): + if lines[i].strip() == "---": + return "\n".join(lines[1:i]) + return None + + +def _parse_simple_frontmatter(frontmatter_text: str) -> Optional[dict[str, str]]: + """Fallback parser for simple frontmatter when PyYAML is unavailable.""" + parsed: dict[str, str] = {} + current_key: Optional[str] = None + multiline_key: Optional[str] = None + + for raw_line in frontmatter_text.splitlines(): + stripped = raw_line.strip() + if not stripped or stripped.startswith("#"): + continue + + is_indented = raw_line[:1].isspace() + if is_indented: + if current_key is None: + return None + current_value = parsed[current_key] + parsed[current_key] = f"{current_value}\n{stripped}" if current_value else stripped + continue + + if ":" not in stripped: + return None + + key, value = stripped.split(":", 1) + key = key.strip() + value = value.strip() + if not key: + return None + + if value in {"|", ">"}: + parsed[key] = "" + current_key = key + multiline_key = key + continue + + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + value = value[1:-1] + parsed[key] = value + current_key = key + multiline_key = None + + if multiline_key is not None and multiline_key not in parsed: + return None + return parsed + + +def _load_frontmatter(frontmatter_text: str) -> tuple[Optional[dict], Optional[str]]: + if yaml is not None: + try: + frontmatter = yaml.safe_load(frontmatter_text) + except yaml.YAMLError as exc: + return None, f"Invalid YAML in frontmatter: {exc}" + if not isinstance(frontmatter, dict): + return None, "Frontmatter must be a YAML dictionary" + return frontmatter, None + + frontmatter = _parse_simple_frontmatter(frontmatter_text) + if frontmatter is None: + return None, "Invalid YAML in frontmatter: unsupported syntax without PyYAML installed" + return frontmatter, None + + +def _validate_skill_name(name: str, folder_name: str) -> Optional[str]: + if not re.fullmatch(r"[a-z0-9]+(?:-[a-z0-9]+)*", name): + return ( + f"Name '{name}' should be hyphen-case " + "(lowercase letters, digits, and single hyphens only)" + ) + if len(name) > MAX_SKILL_NAME_LENGTH: + return ( + f"Name is too long ({len(name)} characters). " + f"Maximum is {MAX_SKILL_NAME_LENGTH} characters." + ) + if name != folder_name: + return f"Skill name '{name}' must match directory name '{folder_name}'" + return None + + +def _validate_description(description: str) -> Optional[str]: + trimmed = description.strip() + if not trimmed: + return "Description cannot be empty" + lowered = trimmed.lower() + if any(marker in lowered for marker in PLACEHOLDER_MARKERS): + return "Description still contains TODO placeholder text" + if "<" in trimmed or ">" in trimmed: + return "Description cannot contain angle brackets (< or >)" + if len(trimmed) > 1024: + return f"Description is too long ({len(trimmed)} characters). Maximum is 1024 characters." + return None + + +def validate_skill(skill_path): + """Validate a skill folder structure and required frontmatter.""" + skill_path = Path(skill_path).resolve() + + if not skill_path.exists(): + return False, f"Skill folder not found: {skill_path}" + if not skill_path.is_dir(): + return False, f"Path is not a directory: {skill_path}" + + skill_md = skill_path / "SKILL.md" + if not skill_md.exists(): + return False, "SKILL.md not found" + + try: + content = skill_md.read_text(encoding="utf-8") + except OSError as exc: + return False, f"Could not read SKILL.md: {exc}" + + frontmatter_text = _extract_frontmatter(content) + if frontmatter_text is None: + return False, "Invalid frontmatter format" + + frontmatter, error = _load_frontmatter(frontmatter_text) + if error: + return False, error + + unexpected_keys = sorted(set(frontmatter.keys()) - ALLOWED_FRONTMATTER_KEYS) + if unexpected_keys: + allowed = ", ".join(sorted(ALLOWED_FRONTMATTER_KEYS)) + unexpected = ", ".join(unexpected_keys) + return ( + False, + f"Unexpected key(s) in SKILL.md frontmatter: {unexpected}. Allowed properties are: {allowed}", + ) + + if "name" not in frontmatter: + return False, "Missing 'name' in frontmatter" + if "description" not in frontmatter: + return False, "Missing 'description' in frontmatter" + + name = frontmatter["name"] + if not isinstance(name, str): + return False, f"Name must be a string, got {type(name).__name__}" + name_error = _validate_skill_name(name.strip(), skill_path.name) + if name_error: + return False, name_error + + description = frontmatter["description"] + if not isinstance(description, str): + return False, f"Description must be a string, got {type(description).__name__}" + description_error = _validate_description(description) + if description_error: + return False, description_error + + always = frontmatter.get("always") + if always is not None and not isinstance(always, bool): + return False, f"'always' must be a boolean, got {type(always).__name__}" + + for child in skill_path.iterdir(): + if child.name == "SKILL.md": + continue + if child.is_dir() and child.name in ALLOWED_RESOURCE_DIRS: + continue + if child.is_symlink(): + continue + return ( + False, + f"Unexpected file or directory in skill root: {child.name}. " + "Only SKILL.md, scripts/, references/, and assets/ are allowed.", + ) + + return True, "Skill is valid!" + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python quick_validate.py ") + sys.exit(1) + + valid, message = validate_skill(sys.argv[1]) + print(message) + sys.exit(0 if valid else 1) diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 57c60dc..5ca06f4 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -1,8 +1,12 @@ """Utility functions for nanobot.""" +import json import re from datetime import datetime from pathlib import Path +from typing import Any + +import tiktoken def detect_image_mime(data: bytes) -> str | None: @@ -68,6 +72,104 @@ def split_message(content: str, max_len: int = 2000) -> list[str]: return chunks +def build_assistant_message( + content: str | None, + tool_calls: list[dict[str, Any]] | None = None, + reasoning_content: str | None = None, + thinking_blocks: list[dict] | None = None, +) -> dict[str, Any]: + """Build a provider-safe assistant message with optional reasoning fields.""" + msg: dict[str, Any] = {"role": "assistant", "content": content} + if tool_calls: + msg["tool_calls"] = tool_calls + if reasoning_content is not None: + msg["reasoning_content"] = reasoning_content + if thinking_blocks: + msg["thinking_blocks"] = thinking_blocks + return msg + + +def estimate_prompt_tokens( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, +) -> int: + """Estimate prompt tokens with tiktoken.""" + try: + enc = tiktoken.get_encoding("cl100k_base") + parts: list[str] = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + txt = part.get("text", "") + if txt: + parts.append(txt) + if tools: + parts.append(json.dumps(tools, ensure_ascii=False)) + return len(enc.encode("\n".join(parts))) + except Exception: + return 0 + + +def estimate_message_tokens(message: dict[str, Any]) -> int: + """Estimate prompt tokens contributed by one persisted message.""" + content = message.get("content") + parts: list[str] = [] + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text", "") + if text: + parts.append(text) + else: + parts.append(json.dumps(part, ensure_ascii=False)) + elif content is not None: + parts.append(json.dumps(content, ensure_ascii=False)) + + for key in ("name", "tool_call_id"): + value = message.get(key) + if isinstance(value, str) and value: + parts.append(value) + if message.get("tool_calls"): + parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) + + payload = "\n".join(parts) + if not payload: + return 1 + try: + enc = tiktoken.get_encoding("cl100k_base") + return max(1, len(enc.encode(payload))) + except Exception: + return max(1, len(payload) // 4) + + +def estimate_prompt_tokens_chain( + provider: Any, + model: str | None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, +) -> tuple[int, str]: + """Estimate prompt tokens via provider counter first, then tiktoken fallback.""" + provider_counter = getattr(provider, "estimate_prompt_tokens", None) + if callable(provider_counter): + try: + tokens, source = provider_counter(messages, tools, model) + if isinstance(tokens, (int, float)) and tokens > 0: + return int(tokens), str(source or "provider_counter") + except Exception: + pass + + estimated = estimate_prompt_tokens(messages, tools) + if estimated > 0: + return int(estimated), "tiktoken" + return 0, "none" + + def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: """Sync bundled templates to workspace. Only creates missing files.""" from importlib.resources import files as pkg_files @@ -88,7 +190,7 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str] added.append(str(dest.relative_to(workspace))) for item in tpl.iterdir(): - if item.name.endswith(".md"): + if item.name.endswith(".md") and not item.name.startswith("."): _write(item, workspace / item.name) _write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md") _write(None, workspace / "memory" / "HISTORY.md") diff --git a/pyproject.toml b/pyproject.toml index fac53ce..0582be6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ dependencies = [ "typer>=0.20.0,<1.0.0", - "litellm>=1.81.5,<2.0.0", + "litellm>=1.82.1,<2.0.0", "pydantic>=2.12.0,<3.0.0", "pydantic-settings>=2.12.0,<3.0.0", "websockets>=16.0,<17.0", @@ -45,6 +45,7 @@ dependencies = [ "chardet>=3.0.2,<6.0.0", "openai>=2.8.0", "wecom-aibot-sdk-python>=0.1.2", + "tiktoken>=0.12.0,<1.0.0", ] [project.optional-dependencies] diff --git a/tests/test_commands.py b/tests/test_commands.py index 19c1998..1375a3a 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -267,6 +267,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path +def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime): + mock_agent_runtime["config"].agents.defaults.memory_window = 100 + + result = runner.invoke(app, ["agent", "-m", "hello"]) + + assert result.exit_code == 0 + assert "memoryWindow" in result.stdout + assert "contextWindowTokens" in result.stdout + + def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) @@ -328,6 +338,28 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) assert config.workspace_path == override +def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.memory_window = 100 + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr( + "nanobot.cli.commands._make_provider", + lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + ) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + + assert isinstance(result.exception, _StopGateway) + assert "memoryWindow" in result.stdout + assert "contextWindowTokens" in result.stdout + def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) @@ -356,3 +388,47 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat assert isinstance(result.exception, _StopGateway) assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json" + + +def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.gateway.port = 18791 + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr( + "nanobot.cli.commands._make_provider", + lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + ) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + + assert isinstance(result.exception, _StopGateway) + assert "port 18791" in result.stdout + + +def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.gateway.port = 18791 + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr( + "nanobot.cli.commands._make_provider", + lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + ) + + result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"]) + + assert isinstance(result.exception, _StopGateway) + assert "port 18792" in result.stdout diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py new file mode 100644 index 0000000..62e601e --- /dev/null +++ b/tests/test_config_migration.py @@ -0,0 +1,88 @@ +import json + +from typer.testing import CliRunner + +from nanobot.cli.commands import app +from nanobot.config.loader import load_config, save_config + +runner = CliRunner() + + +def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "agents": { + "defaults": { + "maxTokens": 1234, + "memoryWindow": 42, + } + } + } + ), + encoding="utf-8", + ) + + config = load_config(config_path) + + assert config.agents.defaults.max_tokens == 1234 + assert config.agents.defaults.context_window_tokens == 65_536 + assert config.agents.defaults.should_warn_deprecated_memory_window is True + + +def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "agents": { + "defaults": { + "maxTokens": 2222, + "memoryWindow": 30, + } + } + } + ), + encoding="utf-8", + ) + + config = load_config(config_path) + save_config(config, config_path) + saved = json.loads(config_path.read_text(encoding="utf-8")) + defaults = saved["agents"]["defaults"] + + assert defaults["maxTokens"] == 2222 + assert defaults["contextWindowTokens"] == 65_536 + assert "memoryWindow" not in defaults + + +def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None: + config_path = tmp_path / "config.json" + workspace = tmp_path / "workspace" + config_path.write_text( + json.dumps( + { + "agents": { + "defaults": { + "maxTokens": 3333, + "memoryWindow": 50, + } + } + } + ), + encoding="utf-8", + ) + + monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) + monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace) + + result = runner.invoke(app, ["onboard"], input="n\n") + + assert result.exit_code == 0 + assert "contextWindowTokens" in result.stdout + saved = json.loads(config_path.read_text(encoding="utf-8")) + defaults = saved["agents"]["defaults"] + assert defaults["maxTokens"] == 3333 + assert defaults["contextWindowTokens"] == 65_536 + assert "memoryWindow" not in defaults diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index a3213dd..7d12338 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions: assert_messages_content(old_messages, 10, 34) -class TestConsolidationDeduplicationGuard: - """Test that consolidation tasks are deduplicated and serialized.""" +class TestNewCommandArchival: + """Test /new archival behavior with the simplified consolidation flow.""" - @pytest.mark.asyncio - async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None: - """Concurrent messages above memory_window spawn only one consolidation task.""" + @staticmethod + def _make_loop(tmp_path: Path): from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" + provider.estimate_prompt_tokens.return_value = (10_000, "test") loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + bus=bus, + provider=provider, + workspace=tmp_path, + model="test-model", + context_window_tokens=1, ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - consolidation_calls = 0 - - async def _fake_consolidate(_session, archive_all: bool = False) -> None: - nonlocal consolidation_calls - consolidation_calls += 1 - await asyncio.sleep(0.05) - - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - await loop._process_message(msg) - await asyncio.sleep(0.1) - - assert consolidation_calls == 1, ( - f"Expected exactly 1 consolidation, got {consolidation_calls}" - ) - - @pytest.mark.asyncio - async def test_new_command_guard_prevents_concurrent_consolidation( - self, tmp_path: Path - ) -> None: - """/new command does not run consolidation concurrently with in-flight consolidation.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - consolidation_calls = 0 - active = 0 - max_active = 0 - - async def _fake_consolidate(_session, archive_all: bool = False) -> None: - nonlocal consolidation_calls, active, max_active - consolidation_calls += 1 - active += 1 - max_active = max(max_active, active) - await asyncio.sleep(0.05) - active -= 1 - - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - - new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") - await loop._process_message(new_msg) - await asyncio.sleep(0.1) - - assert consolidation_calls == 2, ( - f"Expected normal + /new consolidations, got {consolidation_calls}" - ) - assert max_active == 1, ( - f"Expected serialized consolidation, observed concurrency={max_active}" - ) - - @pytest.mark.asyncio - async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None: - """create_task results are tracked in _consolidation_tasks while in flight.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - started = asyncio.Event() - - async def _slow_consolidate(_session, archive_all: bool = False) -> None: - started.set() - await asyncio.sleep(0.1) - - loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - - await started.wait() - assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight" - - await asyncio.sleep(0.15) - assert len(loop._consolidation_tasks) == 0, ( - "Task reference must be removed after completion" - ) - - @pytest.mark.asyncio - async def test_new_waits_for_inflight_consolidation_and_preserves_messages( - self, tmp_path: Path - ) -> None: - """/new waits for in-flight consolidation and archives before clear.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - started = asyncio.Event() - release = asyncio.Event() - archived_count = 0 - - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: - nonlocal archived_count - if archive_all: - archived_count = len(sess.messages) - return True - started.set() - await release.wait() - return True - - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - await started.wait() - - new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") - pending_new = asyncio.create_task(loop._process_message(new_msg)) - - await asyncio.sleep(0.02) - assert not pending_new.done(), "/new should wait while consolidation is in-flight" - - release.set() - response = await pending_new - assert response is not None - assert "new session started" in response.content.lower() - assert archived_count > 0, "Expected /new archival to process a non-empty snapshot" - - session_after = loop.sessions.get_or_create("cli:test") - assert session_after.messages == [], "Session should be cleared after successful archival" + return loop @pytest.mark.asyncio async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None: - """/new must keep session data if archive step reports failure.""" - from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) + loop = self._make_loop(tmp_path) session = loop.sessions.get_or_create("cli:test") for i in range(5): session.add_message("user", f"msg{i}") @@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard: loop.sessions.save(session) before_count = len(session.messages) - async def _failing_consolidate(sess, archive_all: bool = False) -> bool: - if archive_all: - return False - return True + async def _failing_consolidate(_messages) -> bool: + return False - loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign] + loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") response = await loop._process_message(new_msg) assert response is not None assert "failed" in response.content.lower() - session_after = loop.sessions.get_or_create("cli:test") - assert len(session_after.messages) == before_count, ( - "Session must remain intact when /new archival fails" - ) + assert len(loop.sessions.get_or_create("cli:test").messages) == before_count @pytest.mark.asyncio - async def test_new_archives_only_unconsolidated_messages_after_inflight_task( - self, tmp_path: Path - ) -> None: - """/new should archive only messages not yet consolidated by prior task.""" - from nanobot.agent.loop import AgentLoop + async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None: from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) + loop = self._make_loop(tmp_path) session = loop.sessions.get_or_create("cli:test") for i in range(15): session.add_message("user", f"msg{i}") session.add_message("assistant", f"resp{i}") + session.last_consolidated = len(session.messages) - 3 loop.sessions.save(session) - started = asyncio.Event() - release = asyncio.Event() archived_count = -1 - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: + async def _fake_consolidate(messages) -> bool: nonlocal archived_count - if archive_all: - archived_count = len(sess.messages) - return True - - started.set() - await release.wait() - sess.last_consolidated = len(sess.messages) - 3 + archived_count = len(messages) return True - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - await started.wait() + loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") - pending_new = asyncio.create_task(loop._process_message(new_msg)) - await asyncio.sleep(0.02) - assert not pending_new.done() - - release.set() - response = await pending_new + response = await loop._process_message(new_msg) assert response is not None assert "new session started" in response.content.lower() - assert archived_count == 3, ( - f"Expected only unconsolidated tail to archive, got {archived_count}" - ) + assert archived_count == 3 @pytest.mark.asyncio async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None: - """/new clears session and returns confirmation.""" - from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) + loop = self._make_loop(tmp_path) session = loop.sessions.get_or_create("cli:test") for i in range(3): session.add_message("user", f"msg{i}") session.add_message("assistant", f"resp{i}") loop.sessions.save(session) - async def _ok_consolidate(sess, archive_all: bool = False) -> bool: + async def _ok_consolidate(_messages) -> bool: return True - loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign] + loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") response = await loop._process_message(new_msg) diff --git a/tests/test_dingtalk_channel.py b/tests/test_dingtalk_channel.py index 7595a33..6051014 100644 --- a/tests/test_dingtalk_channel.py +++ b/tests/test_dingtalk_channel.py @@ -1,9 +1,11 @@ +import asyncio from types import SimpleNamespace import pytest from nanobot.bus.queue import MessageBus -from nanobot.channels.dingtalk import DingTalkChannel +import nanobot.channels.dingtalk as dingtalk_module +from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler from nanobot.config.schema import DingTalkConfig @@ -64,3 +66,46 @@ async def test_group_send_uses_group_messages_api() -> None: assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send" assert call["json"]["openConversationId"] == "conv123" assert call["json"]["msgKey"] == "sampleMarkdown" + + +@pytest.mark.asyncio +async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None: + bus = MessageBus() + channel = DingTalkChannel( + DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]), + bus, + ) + handler = NanobotDingTalkHandler(channel) + + class _FakeChatbotMessage: + text = None + extensions = {"content": {"recognition": "voice transcript"}} + sender_staff_id = "user1" + sender_id = "fallback-user" + sender_nick = "Alice" + message_type = "audio" + + @staticmethod + def from_dict(_data): + return _FakeChatbotMessage() + + monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage) + monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK")) + + status, body = await handler.process( + SimpleNamespace( + data={ + "conversationType": "2", + "conversationId": "conv123", + "text": {"content": ""}, + } + ) + ) + + await asyncio.gather(*list(channel._background_tasks)) + msg = await bus.consume_inbound() + + assert (status, body) == ("OK", "OK") + assert msg.content == "voice transcript" + assert msg.sender_id == "user1" + assert msg.chat_id == "group:conv123" diff --git a/tests/test_heartbeat_service.py b/tests/test_heartbeat_service.py index c5478af..9ce8912 100644 --- a/tests/test_heartbeat_service.py +++ b/tests/test_heartbeat_service.py @@ -3,18 +3,24 @@ import asyncio import pytest from nanobot.heartbeat.service import HeartbeatService -from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -class DummyProvider: +class DummyProvider(LLMProvider): def __init__(self, responses: list[LLMResponse]): + super().__init__() self._responses = list(responses) + self.calls = 0 async def chat(self, *args, **kwargs) -> LLMResponse: + self.calls += 1 if self._responses: return self._responses.pop(0) return LLMResponse(content="", tool_calls=[]) + def get_default_model(self) -> str: + return "test-model" + @pytest.mark.asyncio async def test_start_is_idempotent(tmp_path) -> None: @@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None: ) assert await service.trigger_now() is None + + +@pytest.mark.asyncio +async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None: + provider = DummyProvider([ + LLMResponse(content="429 rate limit", finish_reason="error"), + LLMResponse( + content="", + tool_calls=[ + ToolCallRequest( + id="hb_1", + name="heartbeat", + arguments={"action": "run", "tasks": "check open tasks"}, + ) + ], + ), + ]) + + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr(asyncio, "sleep", _fake_sleep) + + service = HeartbeatService( + workspace=tmp_path, + provider=provider, + model="openai/gpt-4o-mini", + ) + + action, tasks = await service._decide("heartbeat content") + + assert action == "run" + assert tasks == "check open tasks" + assert provider.calls == 2 + assert delays == [1] diff --git a/tests/test_loop_consolidation_tokens.py b/tests/test_loop_consolidation_tokens.py new file mode 100644 index 0000000..b0f3dda --- /dev/null +++ b/tests/test_loop_consolidation_tokens.py @@ -0,0 +1,190 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +import nanobot.agent.memory as memory_module +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMResponse + + +def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop: + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") + provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) + + loop = AgentLoop( + bus=MessageBus(), + provider=provider, + workspace=tmp_path, + model="test-model", + context_window_tokens=context_window_tokens, + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + return loop + + +@pytest.mark.asyncio +async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None: + loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + await loop.process_direct("hello", session_key="cli:test") + + loop.memory_consolidator.consolidate_messages.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None: + loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + ] + loop.sessions.save(session) + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500) + + await loop.process_direct("hello", session_key="cli:test") + + assert loop.memory_consolidator.consolidate_messages.await_count >= 1 + + +@pytest.mark.asyncio +async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None: + loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, + {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, + ] + loop.sessions.save(session) + + token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120} + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]]) + + await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + + archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0] + assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"] + assert session.last_consolidated == 4 + + +@pytest.mark.asyncio +async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None: + """Verify maybe_consolidate_by_tokens keeps looping until under threshold.""" + loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, + {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, + {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"}, + {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"}, + ] + loop.sessions.save(session) + + call_count = [0] + def mock_estimate(_session): + call_count[0] += 1 + if call_count[0] == 1: + return (500, "test") + if call_count[0] == 2: + return (300, "test") + return (80, "test") + + loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) + + await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + + assert loop.memory_consolidator.consolidate_messages.await_count == 2 + assert session.last_consolidated == 6 + + +@pytest.mark.asyncio +async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None: + """Once triggered, consolidation should continue until it drops below half threshold.""" + loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, + {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, + {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"}, + {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"}, + ] + loop.sessions.save(session) + + call_count = [0] + + def mock_estimate(_session): + call_count[0] += 1 + if call_count[0] == 1: + return (500, "test") + if call_count[0] == 2: + return (150, "test") + return (80, "test") + + loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) + + await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + + assert loop.memory_consolidator.consolidate_messages.await_count == 2 + assert session.last_consolidated == 6 + + +@pytest.mark.asyncio +async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None: + """Verify preflight consolidation runs before the LLM call in process_direct.""" + order: list[str] = [] + + loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) + + async def track_consolidate(messages): + order.append("consolidate") + return True + loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign] + + async def track_llm(*args, **kwargs): + order.append("llm") + return LLMResponse(content="ok", tool_calls=[]) + loop.provider.chat_with_retry = track_llm + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + ] + loop.sessions.save(session) + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500) + + call_count = [0] + def mock_estimate(_session): + call_count[0] += 1 + return (1000 if call_count[0] <= 1 else 80, "test") + loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + + await loop.process_direct("hello", session_key="cli:test") + + assert "consolidate" in order + assert "llm" in order + assert order.index("consolidate") < order.index("llm") diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py index ff15584..0263f01 100644 --- a/tests/test_memory_consolidation_types.py +++ b/tests/test_memory_consolidation_types.py @@ -7,23 +7,20 @@ tool call response, it should serialize them to JSON instead of raising TypeErro import json from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock import pytest from nanobot.agent.memory import MemoryStore -from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -def _make_session(message_count: int = 30, memory_window: int = 50): - """Create a mock session with messages.""" - session = MagicMock() - session.messages = [ +def _make_messages(message_count: int = 30): + """Create a list of mock messages.""" + return [ {"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"} for i in range(message_count) ] - session.last_consolidated = 0 - return session def _make_tool_response(history_entry, memory_update): @@ -43,6 +40,22 @@ def _make_tool_response(history_entry, memory_update): ) +class ScriptedProvider(LLMProvider): + def __init__(self, responses: list[LLMResponse]): + super().__init__() + self._responses = list(responses) + self.calls = 0 + + async def chat(self, *args, **kwargs) -> LLMResponse: + self.calls += 1 + if self._responses: + return self._responses.pop(0) + return LLMResponse(content="", tool_calls=[]) + + def get_default_model(self) -> str: + return "test-model" + + class TestMemoryConsolidationTypeHandling: """Test that consolidation handles various argument types correctly.""" @@ -57,9 +70,10 @@ class TestMemoryConsolidationTypeHandling: memory_update="# Memory\nUser likes testing.", ) ) - session = _make_session(message_count=60) + provider.chat_with_retry = provider.chat + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert store.history_file.exists() @@ -77,9 +91,10 @@ class TestMemoryConsolidationTypeHandling: memory_update={"facts": ["User likes testing"], "topics": ["testing"]}, ) ) - session = _make_session(message_count=60) + provider.chat_with_retry = provider.chat + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert store.history_file.exists() @@ -112,9 +127,10 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) - session = _make_session(message_count=60) + provider.chat_with_retry = provider.chat + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert "User discussed testing." in store.history_file.read_text() @@ -127,21 +143,23 @@ class TestMemoryConsolidationTypeHandling: provider.chat = AsyncMock( return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[]) ) - session = _make_session(message_count=60) + provider.chat_with_retry = provider.chat + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is False assert not store.history_file.exists() @pytest.mark.asyncio - async def test_skips_when_few_messages(self, tmp_path: Path) -> None: - """Consolidation should be a no-op when messages < keep_count.""" + async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None: + """Consolidation should be a no-op when the selected chunk is empty.""" store = MemoryStore(tmp_path) provider = AsyncMock() - session = _make_session(message_count=10) + provider.chat_with_retry = provider.chat + messages: list[dict] = [] - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True provider.chat.assert_not_called() @@ -167,9 +185,10 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) - session = _make_session(message_count=60) + provider.chat_with_retry = provider.chat + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert "User discussed testing." in store.history_file.read_text() @@ -192,9 +211,10 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) - session = _make_session(message_count=60) + provider.chat_with_retry = provider.chat + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is False @@ -215,8 +235,33 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) - session = _make_session(message_count=60) + provider.chat_with_retry = provider.chat + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is False + + @pytest.mark.asyncio + async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None: + store = MemoryStore(tmp_path) + provider = ScriptedProvider([ + LLMResponse(content="503 server error", finish_reason="error"), + _make_tool_response( + history_entry="[2026-01-01] User discussed testing.", + memory_update="# Memory\nUser likes testing.", + ), + ]) + messages = _make_messages(message_count=60) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + result = await store.consolidate(messages, provider, "test-model") + + assert result is True + assert provider.calls == 2 + assert delays == [1] diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py index 63b0fd1..1091de4 100644 --- a/tests/test_message_tool_suppress.py +++ b/tests/test_message_tool_suppress.py @@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop: bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" - return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) + return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") class TestMessageToolSuppressLogic: @@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic: LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="Done", tool_calls=[]), ]) - loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) sent: list[OutboundMessage] = [] @@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic: LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="I've sent the email.", tool_calls=[]), ]) - loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) sent: list[OutboundMessage] = [] @@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic: @pytest.mark.asyncio async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: loop = _make_loop(tmp_path) - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) loop.tools.get_definitions = MagicMock(return_value=[]) msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") @@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic: ), LLMResponse(content="Done", tool_calls=[]), ]) - loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.execute = AsyncMock(return_value="ok") diff --git a/tests/test_provider_retry.py b/tests/test_provider_retry.py new file mode 100644 index 0000000..751ecc3 --- /dev/null +++ b/tests/test_provider_retry.py @@ -0,0 +1,92 @@ +import asyncio + +import pytest + +from nanobot.providers.base import LLMProvider, LLMResponse + + +class ScriptedProvider(LLMProvider): + def __init__(self, responses): + super().__init__() + self._responses = list(responses) + self.calls = 0 + + async def chat(self, *args, **kwargs) -> LLMResponse: + self.calls += 1 + response = self._responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + def get_default_model(self) -> str: + return "test-model" + + +@pytest.mark.asyncio +async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="429 rate limit", finish_reason="error"), + LLMResponse(content="ok"), + ]) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.finish_reason == "stop" + assert response.content == "ok" + assert provider.calls == 2 + assert delays == [1] + + +@pytest.mark.asyncio +async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="401 unauthorized", finish_reason="error"), + ]) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "401 unauthorized" + assert provider.calls == 1 + assert delays == [] + + +@pytest.mark.asyncio +async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="429 rate limit a", finish_reason="error"), + LLMResponse(content="429 rate limit b", finish_reason="error"), + LLMResponse(content="429 rate limit c", finish_reason="error"), + LLMResponse(content="503 final server error", finish_reason="error"), + ]) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "503 final server error" + assert provider.calls == 4 + assert delays == [1, 2, 4] + + +@pytest.mark.asyncio +async def test_chat_with_retry_preserves_cancelled_error() -> None: + provider = ScriptedProvider([asyncio.CancelledError()]) + + with pytest.raises(asyncio.CancelledError): + await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) diff --git a/tests/test_skill_creator_scripts.py b/tests/test_skill_creator_scripts.py new file mode 100644 index 0000000..4207c6f --- /dev/null +++ b/tests/test_skill_creator_scripts.py @@ -0,0 +1,127 @@ +import importlib +import shutil +import sys +import zipfile +from pathlib import Path + + +SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve() +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + +init_skill = importlib.import_module("init_skill") +package_skill = importlib.import_module("package_skill") +quick_validate = importlib.import_module("quick_validate") + + +def test_init_skill_creates_expected_files(tmp_path: Path) -> None: + skill_dir = init_skill.init_skill( + "demo-skill", + tmp_path, + ["scripts", "references", "assets"], + include_examples=True, + ) + + assert skill_dir == tmp_path / "demo-skill" + assert (skill_dir / "SKILL.md").exists() + assert (skill_dir / "scripts" / "example.py").exists() + assert (skill_dir / "references" / "api_reference.md").exists() + assert (skill_dir / "assets" / "example_asset.txt").exists() + + +def test_validate_skill_accepts_existing_skill_creator() -> None: + valid, message = quick_validate.validate_skill( + Path("nanobot/skills/skill-creator").resolve() + ) + + assert valid, message + + +def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None: + skill_dir = tmp_path / "placeholder-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\n" + "name: placeholder-skill\n" + 'description: "[TODO: fill me in]"\n' + "---\n" + "# Placeholder\n", + encoding="utf-8", + ) + + valid, message = quick_validate.validate_skill(skill_dir) + + assert not valid + assert "TODO placeholder" in message + + +def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None: + skill_dir = tmp_path / "bad-root-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\n" + "name: bad-root-skill\n" + "description: Valid description\n" + "---\n" + "# Skill\n", + encoding="utf-8", + ) + (skill_dir / "README.md").write_text("extra\n", encoding="utf-8") + + valid, message = quick_validate.validate_skill(skill_dir) + + assert not valid + assert "Unexpected file or directory in skill root" in message + + +def test_package_skill_creates_archive(tmp_path: Path) -> None: + skill_dir = tmp_path / "package-me" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\n" + "name: package-me\n" + "description: Package this skill.\n" + "---\n" + "# Skill\n", + encoding="utf-8", + ) + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir() + (scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8") + + archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist") + + assert archive_path == (tmp_path / "dist" / "package-me.skill") + assert archive_path.exists() + with zipfile.ZipFile(archive_path, "r") as archive: + names = set(archive.namelist()) + assert "package-me/SKILL.md" in names + assert "package-me/scripts/helper.py" in names + + +def test_package_skill_rejects_symlink(tmp_path: Path) -> None: + skill_dir = tmp_path / "symlink-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text( + "---\n" + "name: symlink-skill\n" + "description: Reject symlinks during packaging.\n" + "---\n" + "# Skill\n", + encoding="utf-8", + ) + scripts_dir = skill_dir / "scripts" + scripts_dir.mkdir() + target = tmp_path / "outside.txt" + target.write_text("secret\n", encoding="utf-8") + link = scripts_dir / "outside.txt" + + try: + link.symlink_to(target) + except (OSError, NotImplementedError): + return + + archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist") + + assert archive_path is None + assert not (tmp_path / "dist" / "symlink-skill.skill").exists() diff --git a/tests/test_slack_channel.py b/tests/test_slack_channel.py new file mode 100644 index 0000000..891f86a --- /dev/null +++ b/tests/test_slack_channel.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.slack import SlackChannel +from nanobot.config.schema import SlackConfig + + +class _FakeAsyncWebClient: + def __init__(self) -> None: + self.chat_post_calls: list[dict[str, object | None]] = [] + self.file_upload_calls: list[dict[str, object | None]] = [] + + async def chat_postMessage( + self, + *, + channel: str, + text: str, + thread_ts: str | None = None, + ) -> None: + self.chat_post_calls.append( + { + "channel": channel, + "text": text, + "thread_ts": thread_ts, + } + ) + + async def files_upload_v2( + self, + *, + channel: str, + file: str, + thread_ts: str | None = None, + ) -> None: + self.file_upload_calls.append( + { + "channel": channel, + "file": file, + "thread_ts": thread_ts, + } + ) + + +@pytest.mark.asyncio +async def test_send_uses_thread_for_channel_messages() -> None: + channel = SlackChannel(SlackConfig(enabled=True), MessageBus()) + fake_web = _FakeAsyncWebClient() + channel._web_client = fake_web + + await channel.send( + OutboundMessage( + channel="slack", + chat_id="C123", + content="hello", + media=["/tmp/demo.txt"], + metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}}, + ) + ) + + assert len(fake_web.chat_post_calls) == 1 + assert fake_web.chat_post_calls[0]["text"] == "hello\n" + assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100" + assert len(fake_web.file_upload_calls) == 1 + assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100" + + +@pytest.mark.asyncio +async def test_send_omits_thread_for_dm_messages() -> None: + channel = SlackChannel(SlackConfig(enabled=True), MessageBus()) + fake_web = _FakeAsyncWebClient() + channel._web_client = fake_web + + await channel.send( + OutboundMessage( + channel="slack", + chat_id="D123", + content="hello", + media=["/tmp/demo.txt"], + metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}}, + ) + ) + + assert len(fake_web.chat_post_calls) == 1 + assert fake_web.chat_post_calls[0]["text"] == "hello\n" + assert fake_web.chat_post_calls[0]["thread_ts"] is None + assert len(fake_web.file_upload_calls) == 1 + assert fake_web.file_upload_calls[0]["thread_ts"] is None diff --git a/tests/test_task_cancel.py b/tests/test_task_cancel.py index 27a2d73..62ab2cc 100644 --- a/tests/test_task_cancel.py +++ b/tests/test_task_cancel.py @@ -165,3 +165,46 @@ class TestSubagentCancellation: provider.get_default_model.return_value = "test-model" mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) assert await mgr.cancel_by_session("nonexistent") == 0 + + @pytest.mark.asyncio + async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path): + from nanobot.agent.subagent import SubagentManager + from nanobot.bus.queue import MessageBus + from nanobot.providers.base import LLMResponse, ToolCallRequest + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + + captured_second_call: list[dict] = [] + + call_count = {"n": 0} + + async def scripted_chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + reasoning_content="hidden reasoning", + thinking_blocks=[{"type": "thinking", "thinking": "step"}], + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[]) + provider.chat_with_retry = scripted_chat_with_retry + mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + + async def fake_execute(self, name, arguments): + return "tool result" + + monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + + await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) + + assistant_messages = [ + msg for msg in captured_second_call + if msg.get("role") == "assistant" and msg.get("tool_calls") + ] + assert len(assistant_messages) == 1 + assert assistant_messages[0]["reasoning_content"] == "hidden reasoning" + assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}] diff --git a/tests/test_telegram_channel.py b/tests/test_telegram_channel.py index 88c3f54..678512d 100644 --- a/tests/test_telegram_channel.py +++ b/tests/test_telegram_channel.py @@ -27,9 +27,11 @@ class _FakeUpdater: class _FakeBot: def __init__(self) -> None: self.sent_messages: list[dict] = [] + self.get_me_calls = 0 async def get_me(self): - return SimpleNamespace(username="nanobot_test") + self.get_me_calls += 1 + return SimpleNamespace(id=999, username="nanobot_test") async def set_my_commands(self, commands) -> None: self.commands = commands @@ -37,6 +39,9 @@ class _FakeBot: async def send_message(self, **kwargs) -> None: self.sent_messages.append(kwargs) + async def send_chat_action(self, **kwargs) -> None: + pass + class _FakeApp: def __init__(self, on_start_polling) -> None: @@ -87,6 +92,35 @@ class _FakeBuilder: return self.app +def _make_telegram_update( + *, + chat_type: str = "group", + text: str | None = None, + caption: str | None = None, + entities=None, + caption_entities=None, + reply_to_message=None, +): + user = SimpleNamespace(id=12345, username="alice", first_name="Alice") + message = SimpleNamespace( + chat=SimpleNamespace(type=chat_type, is_forum=False), + chat_id=-100123, + text=text, + caption=caption, + entities=entities or [], + caption_entities=caption_entities or [], + reply_to_message=reply_to_message, + photo=None, + voice=None, + audio=None, + document=None, + media_group_id=None, + message_thread_id=None, + message_id=1, + ) + return SimpleNamespace(message=message, effective_user=user) + + @pytest.mark.asyncio async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None: config = TelegramConfig( @@ -131,6 +165,10 @@ def test_get_extension_falls_back_to_original_filename() -> None: assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz" +def test_telegram_group_policy_defaults_to_mention() -> None: + assert TelegramConfig().group_policy == "mention" + + def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None: channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus()) @@ -182,3 +220,119 @@ async def test_send_reply_infers_topic_from_message_id_cache() -> None: assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42 assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10 + + +@pytest.mark.asyncio +async def test_group_policy_mention_ignores_unmentioned_group_message() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + handled = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle + channel._start_typing = lambda _chat_id: None + + await channel._on_message(_make_telegram_update(text="hello everyone"), None) + + assert handled == [] + assert channel._app.bot.get_me_calls == 1 + + +@pytest.mark.asyncio +async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + handled = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle + channel._start_typing = lambda _chat_id: None + + mention = SimpleNamespace(type="mention", offset=0, length=13) + await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None) + await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None) + + assert len(handled) == 2 + assert channel._app.bot.get_me_calls == 1 + + +@pytest.mark.asyncio +async def test_group_policy_mention_accepts_caption_mention() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + handled = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle + channel._start_typing = lambda _chat_id: None + + mention = SimpleNamespace(type="mention", offset=0, length=13) + await channel._on_message( + _make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]), + None, + ) + + assert len(handled) == 1 + assert handled[0]["content"] == "@nanobot_test photo" + + +@pytest.mark.asyncio +async def test_group_policy_mention_accepts_reply_to_bot() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + handled = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle + channel._start_typing = lambda _chat_id: None + + reply = SimpleNamespace(from_user=SimpleNamespace(id=999)) + await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None) + + assert len(handled) == 1 + + +@pytest.mark.asyncio +async def test_group_policy_open_accepts_plain_group_message() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + handled = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle + channel._start_typing = lambda _chat_id: None + + await channel._on_message(_make_telegram_update(text="hello group"), None) + + assert len(handled) == 1 + assert channel._app.bot.get_me_calls == 0