merge origin/main into pr-1327
Made-with: Cursor
This commit is contained in:
@@ -20,6 +20,7 @@
|
|||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
- **2026-03-08** 🚀 Released **v0.1.4.post4** — a reliability-packed release with safer defaults, better multi-instance support, sturdier MCP, and major channel and provider improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post4) for details.
|
||||||
- **2026-03-07** 🚀 Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish.
|
- **2026-03-07** 🚀 Azure OpenAI provider, WhatsApp media, QQ group chats, and more Telegram/Feishu polish.
|
||||||
- **2026-03-06** 🪄 Lighter providers, smarter media handling, and sturdier memory and CLI compatibility.
|
- **2026-03-06** 🪄 Lighter providers, smarter media handling, and sturdier memory and CLI compatibility.
|
||||||
- **2026-03-05** ⚡️ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes.
|
- **2026-03-05** ⚡️ Telegram draft streaming, MCP SSE support, and broader channel reliability fixes.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
|||||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" | xargs cat | wc -l)
|
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
|
||||||
echo " Core total: $total lines"
|
echo " Core total: $total lines"
|
||||||
echo ""
|
echo ""
|
||||||
echo " (excludes: channels/, cli/, providers/)"
|
echo " (excludes: channels/, cli/, providers/, skills/)"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
from nanobot.utils.helpers import detect_image_mime
|
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
|
||||||
|
|
||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
@@ -182,12 +182,10 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
thinking_blocks: list[dict] | None = None,
|
thinking_blocks: list[dict] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Add an assistant message to the message list."""
|
"""Add an assistant message to the message list."""
|
||||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
messages.append(build_assistant_message(
|
||||||
if tool_calls:
|
content,
|
||||||
msg["tool_calls"] = tool_calls
|
tool_calls=tool_calls,
|
||||||
if reasoning_content is not None:
|
reasoning_content=reasoning_content,
|
||||||
msg["reasoning_content"] = reasoning_content
|
thinking_blocks=thinking_blocks,
|
||||||
if thinking_blocks:
|
))
|
||||||
msg["thinking_blocks"] = thinking_blocks
|
|
||||||
messages.append(msg)
|
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import weakref
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
@@ -13,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryConsolidator
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
@@ -55,8 +54,8 @@ class AgentLoop:
|
|||||||
max_iterations: int = 40,
|
max_iterations: int = 40,
|
||||||
temperature: float = 0.1,
|
temperature: float = 0.1,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
memory_window: int = 100,
|
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
context_window_tokens: int = 65_536,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
@@ -75,8 +74,8 @@ class AgentLoop:
|
|||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.memory_window = memory_window
|
|
||||||
self.reasoning_effort = reasoning_effort
|
self.reasoning_effort = reasoning_effort
|
||||||
|
self.context_window_tokens = context_window_tokens
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@@ -105,11 +104,17 @@ class AgentLoop:
|
|||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stack: AsyncExitStack | None = None
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
|
||||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
|
||||||
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._processing_lock = asyncio.Lock()
|
self._processing_lock = asyncio.Lock()
|
||||||
|
self.memory_consolidator = MemoryConsolidator(
|
||||||
|
workspace=workspace,
|
||||||
|
provider=provider,
|
||||||
|
model=self.model,
|
||||||
|
sessions=self.sessions,
|
||||||
|
context_window_tokens=context_window_tokens,
|
||||||
|
build_messages=self.context.build_messages,
|
||||||
|
get_tool_definitions=self.tools.get_definitions,
|
||||||
|
)
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@@ -182,7 +187,7 @@ class AgentLoop:
|
|||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
) -> tuple[str | None, list[str], list[dict]]:
|
||||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
"""Run the agent iteration loop."""
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_content = None
|
final_content = None
|
||||||
@@ -191,9 +196,11 @@ class AgentLoop:
|
|||||||
while iteration < self.max_iterations:
|
while iteration < self.max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
response = await self.provider.chat(
|
tool_defs = self.tools.get_definitions()
|
||||||
|
|
||||||
|
response = await self.provider.chat_with_retry(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tools.get_definitions(),
|
tools=tool_defs,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
@@ -341,8 +348,9 @@ class AgentLoop:
|
|||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=self.memory_window)
|
history = session.get_history(max_messages=0)
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
@@ -350,6 +358,7 @@ class AgentLoop:
|
|||||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
content=final_content or "Background task completed.")
|
content=final_content or "Background task completed.")
|
||||||
|
|
||||||
@@ -362,27 +371,20 @@ class AgentLoop:
|
|||||||
# Slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
|
||||||
self._consolidating.add(session.key)
|
|
||||||
try:
|
try:
|
||||||
async with lock:
|
if not await self.memory_consolidator.archive_unconsolidated(session):
|
||||||
snapshot = session.messages[session.last_consolidated:]
|
return OutboundMessage(
|
||||||
if snapshot:
|
channel=msg.channel,
|
||||||
temp = Session(key=session.key)
|
chat_id=msg.chat_id,
|
||||||
temp.messages = list(snapshot)
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
if not await self._consolidate_memory(temp, archive_all=True):
|
)
|
||||||
return OutboundMessage(
|
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
|
||||||
content="Memory archival failed, session not cleared. Please try again.",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("/new archival failed for {}", session.key)
|
logger.exception("/new archival failed for {}", session.key)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
content="Memory archival failed, session not cleared. Please try again.",
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self._consolidating.discard(session.key)
|
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
@@ -393,30 +395,14 @@ class AgentLoop:
|
|||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||||
|
|
||||||
unconsolidated = len(session.messages) - session.last_consolidated
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
|
||||||
self._consolidating.add(session.key)
|
|
||||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
|
||||||
|
|
||||||
async def _consolidate_and_unlock():
|
|
||||||
try:
|
|
||||||
async with lock:
|
|
||||||
await self._consolidate_memory(session)
|
|
||||||
finally:
|
|
||||||
self._consolidating.discard(session.key)
|
|
||||||
_task = asyncio.current_task()
|
|
||||||
if _task is not None:
|
|
||||||
self._consolidation_tasks.discard(_task)
|
|
||||||
|
|
||||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
|
||||||
self._consolidation_tasks.add(_task)
|
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
if isinstance(message_tool, MessageTool):
|
if isinstance(message_tool, MessageTool):
|
||||||
message_tool.start_turn()
|
message_tool.start_turn()
|
||||||
|
|
||||||
history = session.get_history(max_messages=self.memory_window)
|
history = session.get_history(max_messages=0)
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
@@ -441,6 +427,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
return None
|
return None
|
||||||
@@ -487,13 +474,6 @@ class AgentLoop:
|
|||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
|
||||||
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
|
||||||
return await MemoryStore(self.workspace).consolidate(
|
|
||||||
session, self.provider, self.model,
|
|
||||||
archive_all=archive_all, memory_window=self.memory_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|||||||
@@ -2,17 +2,19 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import weakref
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.utils.helpers import ensure_dir
|
from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.session.manager import Session
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
|
||||||
_SAVE_MEMORY_TOOL = [
|
_SAVE_MEMORY_TOOL = [
|
||||||
@@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [
|
|||||||
"properties": {
|
"properties": {
|
||||||
"history_entry": {
|
"history_entry": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
|
"description": "A paragraph summarizing key events/decisions/topics. "
|
||||||
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
||||||
},
|
},
|
||||||
"memory_update": {
|
"memory_update": {
|
||||||
@@ -42,6 +44,20 @@ _SAVE_MEMORY_TOOL = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_text(value: Any) -> str:
|
||||||
|
"""Normalize tool-call payload values to text for file storage."""
|
||||||
|
return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
|
||||||
|
"""Normalize provider tool-call arguments to the expected dict shape."""
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args)
|
||||||
|
if isinstance(args, list):
|
||||||
|
return args[0] if args and isinstance(args[0], dict) else None
|
||||||
|
return args if isinstance(args, dict) else None
|
||||||
|
|
||||||
|
|
||||||
class MemoryStore:
|
class MemoryStore:
|
||||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||||
|
|
||||||
@@ -66,40 +82,27 @@ class MemoryStore:
|
|||||||
long_term = self.read_long_term()
|
long_term = self.read_long_term()
|
||||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_messages(messages: list[dict]) -> str:
|
||||||
|
lines = []
|
||||||
|
for message in messages:
|
||||||
|
if not message.get("content"):
|
||||||
|
continue
|
||||||
|
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
|
||||||
|
lines.append(
|
||||||
|
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
async def consolidate(
|
async def consolidate(
|
||||||
self,
|
self,
|
||||||
session: Session,
|
messages: list[dict],
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
|
||||||
archive_all: bool = False,
|
|
||||||
memory_window: int = 50,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
"""Consolidate the provided message chunk into MEMORY.md + HISTORY.md."""
|
||||||
|
if not messages:
|
||||||
Returns True on success (including no-op), False on failure.
|
return True
|
||||||
"""
|
|
||||||
if archive_all:
|
|
||||||
old_messages = session.messages
|
|
||||||
keep_count = 0
|
|
||||||
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
|
||||||
else:
|
|
||||||
keep_count = memory_window // 2
|
|
||||||
if len(session.messages) <= keep_count:
|
|
||||||
return True
|
|
||||||
if len(session.messages) - session.last_consolidated <= 0:
|
|
||||||
return True
|
|
||||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
|
||||||
if not old_messages:
|
|
||||||
return True
|
|
||||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for m in old_messages:
|
|
||||||
if not m.get("content"):
|
|
||||||
continue
|
|
||||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
|
||||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
|
||||||
|
|
||||||
current_memory = self.read_long_term()
|
current_memory = self.read_long_term()
|
||||||
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||||
@@ -108,10 +111,10 @@ class MemoryStore:
|
|||||||
{current_memory or "(empty)"}
|
{current_memory or "(empty)"}
|
||||||
|
|
||||||
## Conversation to Process
|
## Conversation to Process
|
||||||
{chr(10).join(lines)}"""
|
{self._format_messages(messages)}"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await provider.chat(
|
response = await provider.chat_with_retry(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
@@ -124,34 +127,158 @@ class MemoryStore:
|
|||||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
args = response.tool_calls[0].arguments
|
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
|
||||||
# Some providers return arguments as a JSON string instead of dict
|
if args is None:
|
||||||
if isinstance(args, str):
|
logger.warning("Memory consolidation: unexpected save_memory arguments")
|
||||||
args = json.loads(args)
|
|
||||||
# Some providers return arguments as a list (handle edge case)
|
|
||||||
if isinstance(args, list):
|
|
||||||
if args and isinstance(args[0], dict):
|
|
||||||
args = args[0]
|
|
||||||
else:
|
|
||||||
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
|
||||||
return False
|
|
||||||
if not isinstance(args, dict):
|
|
||||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if entry := args.get("history_entry"):
|
if entry := args.get("history_entry"):
|
||||||
if not isinstance(entry, str):
|
self.append_history(_ensure_text(entry))
|
||||||
entry = json.dumps(entry, ensure_ascii=False)
|
|
||||||
self.append_history(entry)
|
|
||||||
if update := args.get("memory_update"):
|
if update := args.get("memory_update"):
|
||||||
if not isinstance(update, str):
|
update = _ensure_text(update)
|
||||||
update = json.dumps(update, ensure_ascii=False)
|
|
||||||
if update != current_memory:
|
if update != current_memory:
|
||||||
self.write_long_term(update)
|
self.write_long_term(update)
|
||||||
|
|
||||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Memory consolidation failed")
|
logger.exception("Memory consolidation failed")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConsolidator:
|
||||||
|
"""Owns consolidation policy, locking, and session offset updates."""
|
||||||
|
|
||||||
|
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace: Path,
|
||||||
|
provider: LLMProvider,
|
||||||
|
model: str,
|
||||||
|
sessions: SessionManager,
|
||||||
|
context_window_tokens: int,
|
||||||
|
build_messages: Callable[..., list[dict[str, Any]]],
|
||||||
|
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||||
|
):
|
||||||
|
self.store = MemoryStore(workspace)
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.sessions = sessions
|
||||||
|
self.context_window_tokens = context_window_tokens
|
||||||
|
self._build_messages = build_messages
|
||||||
|
self._get_tool_definitions = get_tool_definitions
|
||||||
|
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||||
|
|
||||||
|
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||||
|
"""Return the shared consolidation lock for one session."""
|
||||||
|
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||||
|
|
||||||
|
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||||
|
"""Archive a selected message chunk into persistent memory."""
|
||||||
|
return await self.store.consolidate(messages, self.provider, self.model)
|
||||||
|
|
||||||
|
def pick_consolidation_boundary(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
tokens_to_remove: int,
|
||||||
|
) -> tuple[int, int] | None:
|
||||||
|
"""Pick a user-turn boundary that removes enough old prompt tokens."""
|
||||||
|
start = session.last_consolidated
|
||||||
|
if start >= len(session.messages) or tokens_to_remove <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
removed_tokens = 0
|
||||||
|
last_boundary: tuple[int, int] | None = None
|
||||||
|
for idx in range(start, len(session.messages)):
|
||||||
|
message = session.messages[idx]
|
||||||
|
if idx > start and message.get("role") == "user":
|
||||||
|
last_boundary = (idx, removed_tokens)
|
||||||
|
if removed_tokens >= tokens_to_remove:
|
||||||
|
return last_boundary
|
||||||
|
removed_tokens += estimate_message_tokens(message)
|
||||||
|
|
||||||
|
return last_boundary
|
||||||
|
|
||||||
|
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||||
|
"""Estimate current prompt size for the normal session history view."""
|
||||||
|
history = session.get_history(max_messages=0)
|
||||||
|
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||||
|
probe_messages = self._build_messages(
|
||||||
|
history=history,
|
||||||
|
current_message="[token-probe]",
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
)
|
||||||
|
return estimate_prompt_tokens_chain(
|
||||||
|
self.provider,
|
||||||
|
self.model,
|
||||||
|
probe_messages,
|
||||||
|
self._get_tool_definitions(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def archive_unconsolidated(self, session: Session) -> bool:
|
||||||
|
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||||||
|
lock = self.get_lock(session.key)
|
||||||
|
async with lock:
|
||||||
|
snapshot = session.messages[session.last_consolidated:]
|
||||||
|
if not snapshot:
|
||||||
|
return True
|
||||||
|
return await self.consolidate_messages(snapshot)
|
||||||
|
|
||||||
|
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||||
|
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||||
|
if not session.messages or self.context_window_tokens <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
lock = self.get_lock(session.key)
|
||||||
|
async with lock:
|
||||||
|
target = self.context_window_tokens // 2
|
||||||
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
|
if estimated <= 0:
|
||||||
|
return
|
||||||
|
if estimated < self.context_window_tokens:
|
||||||
|
logger.debug(
|
||||||
|
"Token consolidation idle {}: {}/{} via {}",
|
||||||
|
session.key,
|
||||||
|
estimated,
|
||||||
|
self.context_window_tokens,
|
||||||
|
source,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
|
||||||
|
if estimated <= target:
|
||||||
|
return
|
||||||
|
|
||||||
|
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
|
||||||
|
if boundary is None:
|
||||||
|
logger.debug(
|
||||||
|
"Token consolidation: no safe boundary for {} (round {})",
|
||||||
|
session.key,
|
||||||
|
round_num,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
end_idx = boundary[0]
|
||||||
|
chunk = session.messages[session.last_consolidated:end_idx]
|
||||||
|
if not chunk:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
|
||||||
|
round_num,
|
||||||
|
session.key,
|
||||||
|
estimated,
|
||||||
|
self.context_window_tokens,
|
||||||
|
source,
|
||||||
|
len(chunk),
|
||||||
|
)
|
||||||
|
if not await self.consolidate_messages(chunk):
|
||||||
|
return
|
||||||
|
session.last_consolidated = end_idx
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
|
if estimated <= 0:
|
||||||
|
return
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from nanobot.bus.events import InboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
|
from nanobot.utils.helpers import build_assistant_message
|
||||||
|
|
||||||
|
|
||||||
class SubagentManager:
|
class SubagentManager:
|
||||||
@@ -123,7 +124,7 @@ class SubagentManager:
|
|||||||
while iteration < max_iterations:
|
while iteration < max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
response = await self.provider.chat(
|
response = await self.provider.chat_with_retry(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools.get_definitions(),
|
tools=tools.get_definitions(),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -133,7 +134,6 @@ class SubagentManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
# Add assistant message with tool calls
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
{
|
{
|
||||||
"id": tc.id,
|
"id": tc.id,
|
||||||
@@ -145,11 +145,12 @@ class SubagentManager:
|
|||||||
}
|
}
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
]
|
]
|
||||||
messages.append({
|
messages.append(build_assistant_message(
|
||||||
"role": "assistant",
|
response.content or "",
|
||||||
"content": response.content or "",
|
tool_calls=tool_call_dicts,
|
||||||
"tool_calls": tool_call_dicts,
|
reasoning_content=response.reasoning_content,
|
||||||
})
|
thinking_blocks=response.thinking_blocks,
|
||||||
|
))
|
||||||
|
|
||||||
# Execute tools
|
# Execute tools
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ class NanobotDingTalkHandler(CallbackHandler):
|
|||||||
content = ""
|
content = ""
|
||||||
if chatbot_msg.text:
|
if chatbot_msg.text:
|
||||||
content = chatbot_msg.text.content.strip()
|
content = chatbot_msg.text.content.strip()
|
||||||
|
elif chatbot_msg.extensions.get("content", {}).get("recognition"):
|
||||||
|
content = chatbot_msg.extensions["content"]["recognition"].strip()
|
||||||
if not content:
|
if not content:
|
||||||
content = message.data.get("text", {}).get("content", "").strip()
|
content = message.data.get("text", {}).get("content", "").strip()
|
||||||
|
|
||||||
|
|||||||
@@ -753,8 +753,9 @@ class FeishuChannel(BaseChannel):
|
|||||||
None, self._download_file_sync, message_id, file_key, msg_type
|
None, self._download_file_sync, message_id, file_key, msg_type
|
||||||
)
|
)
|
||||||
if not filename:
|
if not filename:
|
||||||
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
|
filename = file_key[:16]
|
||||||
filename = f"{file_key[:16]}{ext}"
|
if msg_type == "audio" and not filename.endswith(".opus"):
|
||||||
|
filename = f"{filename}.opus"
|
||||||
|
|
||||||
if data and filename:
|
if data and filename:
|
||||||
file_path = media_dir / filename
|
file_path = media_dir / filename
|
||||||
|
|||||||
@@ -81,8 +81,8 @@ class SlackChannel(BaseChannel):
|
|||||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||||
thread_ts = slack_meta.get("thread_ts")
|
thread_ts = slack_meta.get("thread_ts")
|
||||||
channel_type = slack_meta.get("channel_type")
|
channel_type = slack_meta.get("channel_type")
|
||||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
|
||||||
thread_ts_param = thread_ts if use_thread else None
|
thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
|
||||||
|
|
||||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||||
# but send a single blank message when the bot has no text or files to send.
|
# but send a single blank message when the bot has no text or files to send.
|
||||||
@@ -278,4 +278,3 @@ class SlackChannel(BaseChannel):
|
|||||||
if parts:
|
if parts:
|
||||||
rows.append(" · ".join(parts))
|
rows.append(" · ".join(parts))
|
||||||
return "\n".join(rows)
|
return "\n".join(rows)
|
||||||
|
|
||||||
|
|||||||
@@ -179,6 +179,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._media_group_buffers: dict[str, dict] = {}
|
self._media_group_buffers: dict[str, dict] = {}
|
||||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||||
self._message_threads: dict[tuple[str, int], int] = {}
|
self._message_threads: dict[tuple[str, int], int] = {}
|
||||||
|
self._bot_user_id: int | None = None
|
||||||
|
self._bot_username: str | None = None
|
||||||
|
|
||||||
def is_allowed(self, sender_id: str) -> bool:
|
def is_allowed(self, sender_id: str) -> bool:
|
||||||
"""Preserve Telegram's legacy id|username allowlist matching."""
|
"""Preserve Telegram's legacy id|username allowlist matching."""
|
||||||
@@ -242,6 +244,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
# Get bot info and register command menu
|
# Get bot info and register command menu
|
||||||
bot_info = await self._app.bot.get_me()
|
bot_info = await self._app.bot.get_me()
|
||||||
|
self._bot_user_id = getattr(bot_info, "id", None)
|
||||||
|
self._bot_username = getattr(bot_info, "username", None)
|
||||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -462,6 +466,70 @@ class TelegramChannel(BaseChannel):
|
|||||||
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
|
||||||
|
"""Load bot identity once and reuse it for mention/reply checks."""
|
||||||
|
if self._bot_user_id is not None or self._bot_username is not None:
|
||||||
|
return self._bot_user_id, self._bot_username
|
||||||
|
if not self._app:
|
||||||
|
return None, None
|
||||||
|
bot_info = await self._app.bot.get_me()
|
||||||
|
self._bot_user_id = getattr(bot_info, "id", None)
|
||||||
|
self._bot_username = getattr(bot_info, "username", None)
|
||||||
|
return self._bot_user_id, self._bot_username
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_mention_entity(
|
||||||
|
text: str,
|
||||||
|
entities,
|
||||||
|
bot_username: str,
|
||||||
|
bot_id: int | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Check Telegram mention entities against the bot username."""
|
||||||
|
handle = f"@{bot_username}".lower()
|
||||||
|
for entity in entities or []:
|
||||||
|
entity_type = getattr(entity, "type", None)
|
||||||
|
if entity_type == "text_mention":
|
||||||
|
user = getattr(entity, "user", None)
|
||||||
|
if user is not None and bot_id is not None and getattr(user, "id", None) == bot_id:
|
||||||
|
return True
|
||||||
|
continue
|
||||||
|
if entity_type != "mention":
|
||||||
|
continue
|
||||||
|
offset = getattr(entity, "offset", None)
|
||||||
|
length = getattr(entity, "length", None)
|
||||||
|
if offset is None or length is None:
|
||||||
|
continue
|
||||||
|
if text[offset : offset + length].lower() == handle:
|
||||||
|
return True
|
||||||
|
return handle in text.lower()
|
||||||
|
|
||||||
|
async def _is_group_message_for_bot(self, message) -> bool:
|
||||||
|
"""Allow group messages when policy is open, @mentioned, or replying to the bot."""
|
||||||
|
if message.chat.type == "private" or self.config.group_policy == "open":
|
||||||
|
return True
|
||||||
|
|
||||||
|
bot_id, bot_username = await self._ensure_bot_identity()
|
||||||
|
if bot_username:
|
||||||
|
text = message.text or ""
|
||||||
|
caption = message.caption or ""
|
||||||
|
if self._has_mention_entity(
|
||||||
|
text,
|
||||||
|
getattr(message, "entities", None),
|
||||||
|
bot_username,
|
||||||
|
bot_id,
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if self._has_mention_entity(
|
||||||
|
caption,
|
||||||
|
getattr(message, "caption_entities", None),
|
||||||
|
bot_username,
|
||||||
|
bot_id,
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
reply_user = getattr(getattr(message, "reply_to_message", None), "from_user", None)
|
||||||
|
return bool(bot_id and reply_user and reply_user.id == bot_id)
|
||||||
|
|
||||||
def _remember_thread_context(self, message) -> None:
|
def _remember_thread_context(self, message) -> None:
|
||||||
"""Cache topic thread id by chat/message id for follow-up replies."""
|
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||||
message_thread_id = getattr(message, "message_thread_id", None)
|
message_thread_id = getattr(message, "message_thread_id", None)
|
||||||
@@ -501,6 +569,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
# Store chat_id for replies
|
# Store chat_id for replies
|
||||||
self._chat_ids[sender_id] = chat_id
|
self._chat_ids[sender_id] = chat_id
|
||||||
|
|
||||||
|
if not await self._is_group_message_for_bot(message):
|
||||||
|
return
|
||||||
|
|
||||||
# Build content from text and/or media
|
# Build content from text and/or media
|
||||||
content_parts = []
|
content_parts = []
|
||||||
media_paths = []
|
media_paths = []
|
||||||
|
|||||||
@@ -191,6 +191,8 @@ def onboard():
|
|||||||
save_config(Config())
|
save_config(Config())
|
||||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||||
|
|
||||||
|
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
|
||||||
|
|
||||||
# Create workspace
|
# Create workspace
|
||||||
workspace = get_workspace_path()
|
workspace = get_workspace_path()
|
||||||
|
|
||||||
@@ -283,6 +285,16 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
|
|||||||
return loaded
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
|
def _print_deprecated_memory_window_notice(config: Config) -> None:
|
||||||
|
"""Warn when running with old memoryWindow-only config."""
|
||||||
|
if config.agents.defaults.should_warn_deprecated_memory_window:
|
||||||
|
console.print(
|
||||||
|
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
|
||||||
|
"`contextWindowTokens`. `memoryWindow` is ignored; run "
|
||||||
|
"[cyan]nanobot onboard[/cyan] to refresh your config template."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Gateway / Server
|
# Gateway / Server
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -290,7 +302,7 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
|
|||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def gateway(
|
def gateway(
|
||||||
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
|
port: int | None = typer.Option(None, "--port", "-p", help="Gateway port"),
|
||||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
||||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||||
@@ -310,6 +322,8 @@ def gateway(
|
|||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
config = _load_runtime_config(config, workspace)
|
config = _load_runtime_config(config, workspace)
|
||||||
|
_print_deprecated_memory_window_notice(config)
|
||||||
|
port = port if port is not None else config.gateway.port
|
||||||
|
|
||||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
@@ -330,8 +344,8 @@ def gateway(
|
|||||||
temperature=config.agents.defaults.temperature,
|
temperature=config.agents.defaults.temperature,
|
||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
|
||||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
@@ -493,6 +507,7 @@ def agent(
|
|||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
config = _load_runtime_config(config, workspace)
|
config = _load_runtime_config(config, workspace)
|
||||||
|
_print_deprecated_memory_window_notice(config)
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
@@ -515,8 +530,8 @@ def agent(
|
|||||||
temperature=config.agents.defaults.temperature,
|
temperature=config.agents.defaults.temperature,
|
||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
|
||||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class TelegramConfig(Base):
|
|||||||
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||||
)
|
)
|
||||||
reply_to_message: bool = False # If true, bot replies quote the original message
|
reply_to_message: bool = False # If true, bot replies quote the original message
|
||||||
|
group_policy: Literal["open", "mention"] = "mention" # "mention" responds when @mentioned or replied to, "open" responds to all
|
||||||
|
|
||||||
|
|
||||||
class FeishuConfig(Base):
|
class FeishuConfig(Base):
|
||||||
@@ -236,11 +237,18 @@ class AgentDefaults(Base):
|
|||||||
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||||
)
|
)
|
||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
|
context_window_tokens: int = 65_536
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 40
|
max_tool_iterations: int = 40
|
||||||
memory_window: int = 100
|
# Deprecated compatibility field: accepted from old configs but ignored at runtime.
|
||||||
|
memory_window: int | None = Field(default=None, exclude=True)
|
||||||
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_warn_deprecated_memory_window(self) -> bool:
|
||||||
|
"""Return True when old memoryWindow is present without contextWindowTokens."""
|
||||||
|
return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
|
||||||
|
|
||||||
|
|
||||||
class AgentsConfig(Base):
|
class AgentsConfig(Base):
|
||||||
"""Agent configuration."""
|
"""Agent configuration."""
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class HeartbeatService:
|
|||||||
|
|
||||||
Returns (action, tasks) where action is 'skip' or 'run'.
|
Returns (action, tasks) where action is 'skip' or 'run'.
|
||||||
"""
|
"""
|
||||||
response = await self.provider.chat(
|
response = await self.provider.chat_with_retry(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
||||||
{"role": "user", "content": (
|
{"role": "user", "content": (
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
"""Base LLM provider interface."""
|
"""Base LLM provider interface."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallRequest:
|
class ToolCallRequest:
|
||||||
@@ -37,6 +40,22 @@ class LLMProvider(ABC):
|
|||||||
while maintaining a consistent interface.
|
while maintaining a consistent interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||||
|
_TRANSIENT_ERROR_MARKERS = (
|
||||||
|
"429",
|
||||||
|
"rate limit",
|
||||||
|
"500",
|
||||||
|
"502",
|
||||||
|
"503",
|
||||||
|
"504",
|
||||||
|
"overloaded",
|
||||||
|
"timeout",
|
||||||
|
"timed out",
|
||||||
|
"connection",
|
||||||
|
"server error",
|
||||||
|
"temporarily unavailable",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
@@ -126,6 +145,71 @@ class LLMProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_transient_error(cls, content: str | None) -> bool:
|
||||||
|
err = (content or "").lower()
|
||||||
|
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||||
|
|
||||||
|
async def chat_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""Call chat() with retry on transient provider failures."""
|
||||||
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||||
|
try:
|
||||||
|
response = await self.chat(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
response = LLMResponse(
|
||||||
|
content=f"Error calling LLM: {exc}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
return response
|
||||||
|
if not self._is_transient_error(response.content):
|
||||||
|
return response
|
||||||
|
|
||||||
|
err = (response.content or "").lower()
|
||||||
|
logger.warning(
|
||||||
|
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||||
|
attempt,
|
||||||
|
len(self._CHAT_RETRY_DELAYS),
|
||||||
|
delay,
|
||||||
|
err[:120],
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.chat(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error calling LLM: {exc}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
"""Get the default model for this provider."""
|
"""Get the default model for this provider."""
|
||||||
|
|||||||
@@ -268,6 +268,8 @@ Skip this step only if the skill being developed already exists, and iteration o
|
|||||||
|
|
||||||
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
|
When creating a new skill from scratch, always run the `init_skill.py` script. The script conveniently generates a new template skill directory that automatically includes everything a skill requires, making the skill creation process much more efficient and reliable.
|
||||||
|
|
||||||
|
For `nanobot`, custom skills should live under the active workspace `skills/` directory so they can be discovered automatically at runtime (for example, `<workspace>/skills/my-skill/SKILL.md`).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -277,9 +279,9 @@ scripts/init_skill.py <skill-name> --path <output-directory> [--resources script
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
scripts/init_skill.py my-skill --path skills/public
|
scripts/init_skill.py my-skill --path ./workspace/skills
|
||||||
scripts/init_skill.py my-skill --path skills/public --resources scripts,references
|
scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts,references
|
||||||
scripts/init_skill.py my-skill --path skills/public --resources scripts --examples
|
scripts/init_skill.py my-skill --path ./workspace/skills --resources scripts --examples
|
||||||
```
|
```
|
||||||
|
|
||||||
The script:
|
The script:
|
||||||
@@ -326,7 +328,7 @@ Write the YAML frontmatter with `name` and `description`:
|
|||||||
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent.
|
- Include all "when to use" information here - Not in the body. The body is only loaded after triggering, so "When to Use This Skill" sections in the body are not helpful to the agent.
|
||||||
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
|
- Example description for a `docx` skill: "Comprehensive document creation, editing, and analysis with support for tracked changes, comments, formatting preservation, and text extraction. Use when the agent needs to work with professional documents (.docx files) for: (1) Creating new documents, (2) Modifying or editing content, (3) Working with tracked changes, (4) Adding comments, or any other document tasks"
|
||||||
|
|
||||||
Do not include any other fields in YAML frontmatter.
|
Keep frontmatter minimal. In `nanobot`, `metadata` and `always` are also supported when needed, but avoid adding extra fields unless they are actually required.
|
||||||
|
|
||||||
##### Body
|
##### Body
|
||||||
|
|
||||||
@@ -349,7 +351,6 @@ scripts/package_skill.py <path/to/skill-folder> ./dist
|
|||||||
The packaging script will:
|
The packaging script will:
|
||||||
|
|
||||||
1. **Validate** the skill automatically, checking:
|
1. **Validate** the skill automatically, checking:
|
||||||
|
|
||||||
- YAML frontmatter format and required fields
|
- YAML frontmatter format and required fields
|
||||||
- Skill naming conventions and directory structure
|
- Skill naming conventions and directory structure
|
||||||
- Description completeness and quality
|
- Description completeness and quality
|
||||||
@@ -357,6 +358,8 @@ The packaging script will:
|
|||||||
|
|
||||||
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
|
2. **Package** the skill if validation passes, creating a .skill file named after the skill (e.g., `my-skill.skill`) that includes all files and maintains the proper directory structure for distribution. The .skill file is a zip file with a .skill extension.
|
||||||
|
|
||||||
|
Security restriction: symlinks are rejected and packaging fails when any symlink is present.
|
||||||
|
|
||||||
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
|
If validation fails, the script will report the errors and exit without creating a package. Fix any validation errors and run the packaging command again.
|
||||||
|
|
||||||
### Step 6: Iterate
|
### Step 6: Iterate
|
||||||
|
|||||||
378
nanobot/skills/skill-creator/scripts/init_skill.py
Executable file
378
nanobot/skills/skill-creator/scripts/init_skill.py
Executable file
@@ -0,0 +1,378 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Skill Initializer - Creates a new skill from template
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
init_skill.py <skill-name> --path <path> [--resources scripts,references,assets] [--examples]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
init_skill.py my-new-skill --path skills/public
|
||||||
|
init_skill.py my-new-skill --path skills/public --resources scripts,references
|
||||||
|
init_skill.py my-api-helper --path skills/private --resources scripts --examples
|
||||||
|
init_skill.py custom-skill --path /custom/location
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
MAX_SKILL_NAME_LENGTH = 64
|
||||||
|
ALLOWED_RESOURCES = {"scripts", "references", "assets"}
|
||||||
|
|
||||||
|
SKILL_TEMPLATE = """---
|
||||||
|
name: {skill_name}
|
||||||
|
description: [TODO: Complete and informative explanation of what the skill does and when to use it. Include WHEN to use this skill - specific scenarios, file types, or tasks that trigger it.]
|
||||||
|
---
|
||||||
|
|
||||||
|
# {skill_title}
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
[TODO: 1-2 sentences explaining what this skill enables]
|
||||||
|
|
||||||
|
## Structuring This Skill
|
||||||
|
|
||||||
|
[TODO: Choose the structure that best fits this skill's purpose. Common patterns:
|
||||||
|
|
||||||
|
**1. Workflow-Based** (best for sequential processes)
|
||||||
|
- Works well when there are clear step-by-step procedures
|
||||||
|
- Example: DOCX skill with "Workflow Decision Tree" -> "Reading" -> "Creating" -> "Editing"
|
||||||
|
- Structure: ## Overview -> ## Workflow Decision Tree -> ## Step 1 -> ## Step 2...
|
||||||
|
|
||||||
|
**2. Task-Based** (best for tool collections)
|
||||||
|
- Works well when the skill offers different operations/capabilities
|
||||||
|
- Example: PDF skill with "Quick Start" -> "Merge PDFs" -> "Split PDFs" -> "Extract Text"
|
||||||
|
- Structure: ## Overview -> ## Quick Start -> ## Task Category 1 -> ## Task Category 2...
|
||||||
|
|
||||||
|
**3. Reference/Guidelines** (best for standards or specifications)
|
||||||
|
- Works well for brand guidelines, coding standards, or requirements
|
||||||
|
- Example: Brand styling with "Brand Guidelines" -> "Colors" -> "Typography" -> "Features"
|
||||||
|
- Structure: ## Overview -> ## Guidelines -> ## Specifications -> ## Usage...
|
||||||
|
|
||||||
|
**4. Capabilities-Based** (best for integrated systems)
|
||||||
|
- Works well when the skill provides multiple interrelated features
|
||||||
|
- Example: Product Management with "Core Capabilities" -> numbered capability list
|
||||||
|
- Structure: ## Overview -> ## Core Capabilities -> ### 1. Feature -> ### 2. Feature...
|
||||||
|
|
||||||
|
Patterns can be mixed and matched as needed. Most skills combine patterns (e.g., start with task-based, add workflow for complex operations).
|
||||||
|
|
||||||
|
Delete this entire "Structuring This Skill" section when done - it's just guidance.]
|
||||||
|
|
||||||
|
## [TODO: Replace with the first main section based on chosen structure]
|
||||||
|
|
||||||
|
[TODO: Add content here. See examples in existing skills:
|
||||||
|
- Code samples for technical skills
|
||||||
|
- Decision trees for complex workflows
|
||||||
|
- Concrete examples with realistic user requests
|
||||||
|
- References to scripts/templates/references as needed]
|
||||||
|
|
||||||
|
## Resources (optional)
|
||||||
|
|
||||||
|
Create only the resource directories this skill actually needs. Delete this section if no resources are required.
|
||||||
|
|
||||||
|
### scripts/
|
||||||
|
Executable code (Python/Bash/etc.) that can be run directly to perform specific operations.
|
||||||
|
|
||||||
|
**Examples from other skills:**
|
||||||
|
- PDF skill: `fill_fillable_fields.py`, `extract_form_field_info.py` - utilities for PDF manipulation
|
||||||
|
- DOCX skill: `document.py`, `utilities.py` - Python modules for document processing
|
||||||
|
|
||||||
|
**Appropriate for:** Python scripts, shell scripts, or any executable code that performs automation, data processing, or specific operations.
|
||||||
|
|
||||||
|
**Note:** Scripts may be executed without loading into context, but can still be read by Codex for patching or environment adjustments.
|
||||||
|
|
||||||
|
### references/
|
||||||
|
Documentation and reference material intended to be loaded into context to inform Codex's process and thinking.
|
||||||
|
|
||||||
|
**Examples from other skills:**
|
||||||
|
- Product management: `communication.md`, `context_building.md` - detailed workflow guides
|
||||||
|
- BigQuery: API reference documentation and query examples
|
||||||
|
- Finance: Schema documentation, company policies
|
||||||
|
|
||||||
|
**Appropriate for:** In-depth documentation, API references, database schemas, comprehensive guides, or any detailed information that Codex should reference while working.
|
||||||
|
|
||||||
|
### assets/
|
||||||
|
Files not intended to be loaded into context, but rather used within the output Codex produces.
|
||||||
|
|
||||||
|
**Examples from other skills:**
|
||||||
|
- Brand styling: PowerPoint template files (.pptx), logo files
|
||||||
|
- Frontend builder: HTML/React boilerplate project directories
|
||||||
|
- Typography: Font files (.ttf, .woff2)
|
||||||
|
|
||||||
|
**Appropriate for:** Templates, boilerplate code, document templates, images, icons, fonts, or any files meant to be copied or used in the final output.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Not every skill requires all three types of resources.**
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXAMPLE_SCRIPT = '''#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Example helper script for {skill_name}
|
||||||
|
|
||||||
|
This is a placeholder script that can be executed directly.
|
||||||
|
Replace with actual implementation or delete if not needed.
|
||||||
|
|
||||||
|
Example real scripts from other skills:
|
||||||
|
- pdf/scripts/fill_fillable_fields.py - Fills PDF form fields
|
||||||
|
- pdf/scripts/convert_pdf_to_images.py - Converts PDF pages to images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("This is an example script for {skill_name}")
|
||||||
|
# TODO: Add actual script logic here
|
||||||
|
# This could be data processing, file conversion, API calls, etc.
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
'''
|
||||||
|
|
||||||
|
EXAMPLE_REFERENCE = """# Reference Documentation for {skill_title}
|
||||||
|
|
||||||
|
This is a placeholder for detailed reference documentation.
|
||||||
|
Replace with actual reference content or delete if not needed.
|
||||||
|
|
||||||
|
Example real reference docs from other skills:
|
||||||
|
- product-management/references/communication.md - Comprehensive guide for status updates
|
||||||
|
- product-management/references/context_building.md - Deep-dive on gathering context
|
||||||
|
- bigquery/references/ - API references and query examples
|
||||||
|
|
||||||
|
## When Reference Docs Are Useful
|
||||||
|
|
||||||
|
Reference docs are ideal for:
|
||||||
|
- Comprehensive API documentation
|
||||||
|
- Detailed workflow guides
|
||||||
|
- Complex multi-step processes
|
||||||
|
- Information too lengthy for main SKILL.md
|
||||||
|
- Content that's only needed for specific use cases
|
||||||
|
|
||||||
|
## Structure Suggestions
|
||||||
|
|
||||||
|
### API Reference Example
|
||||||
|
- Overview
|
||||||
|
- Authentication
|
||||||
|
- Endpoints with examples
|
||||||
|
- Error codes
|
||||||
|
- Rate limits
|
||||||
|
|
||||||
|
### Workflow Guide Example
|
||||||
|
- Prerequisites
|
||||||
|
- Step-by-step instructions
|
||||||
|
- Common patterns
|
||||||
|
- Troubleshooting
|
||||||
|
- Best practices
|
||||||
|
"""
|
||||||
|
|
||||||
|
EXAMPLE_ASSET = """# Example Asset File
|
||||||
|
|
||||||
|
This placeholder represents where asset files would be stored.
|
||||||
|
Replace with actual asset files (templates, images, fonts, etc.) or delete if not needed.
|
||||||
|
|
||||||
|
Asset files are NOT intended to be loaded into context, but rather used within
|
||||||
|
the output Codex produces.
|
||||||
|
|
||||||
|
Example asset files from other skills:
|
||||||
|
- Brand guidelines: logo.png, slides_template.pptx
|
||||||
|
- Frontend builder: hello-world/ directory with HTML/React boilerplate
|
||||||
|
- Typography: custom-font.ttf, font-family.woff2
|
||||||
|
- Data: sample_data.csv, test_dataset.json
|
||||||
|
|
||||||
|
## Common Asset Types
|
||||||
|
|
||||||
|
- Templates: .pptx, .docx, boilerplate directories
|
||||||
|
- Images: .png, .jpg, .svg, .gif
|
||||||
|
- Fonts: .ttf, .otf, .woff, .woff2
|
||||||
|
- Boilerplate code: Project directories, starter files
|
||||||
|
- Icons: .ico, .svg
|
||||||
|
- Data files: .csv, .json, .xml, .yaml
|
||||||
|
|
||||||
|
Note: This is a text placeholder. Actual assets can be any file type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_skill_name(skill_name):
|
||||||
|
"""Normalize a skill name to lowercase hyphen-case."""
|
||||||
|
normalized = skill_name.strip().lower()
|
||||||
|
normalized = re.sub(r"[^a-z0-9]+", "-", normalized)
|
||||||
|
normalized = normalized.strip("-")
|
||||||
|
normalized = re.sub(r"-{2,}", "-", normalized)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def title_case_skill_name(skill_name):
|
||||||
|
"""Convert hyphenated skill name to Title Case for display."""
|
||||||
|
return " ".join(word.capitalize() for word in skill_name.split("-"))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_resources(raw_resources):
|
||||||
|
if not raw_resources:
|
||||||
|
return []
|
||||||
|
resources = [item.strip() for item in raw_resources.split(",") if item.strip()]
|
||||||
|
invalid = sorted({item for item in resources if item not in ALLOWED_RESOURCES})
|
||||||
|
if invalid:
|
||||||
|
allowed = ", ".join(sorted(ALLOWED_RESOURCES))
|
||||||
|
print(f"[ERROR] Unknown resource type(s): {', '.join(invalid)}")
|
||||||
|
print(f" Allowed: {allowed}")
|
||||||
|
sys.exit(1)
|
||||||
|
deduped = []
|
||||||
|
seen = set()
|
||||||
|
for resource in resources:
|
||||||
|
if resource not in seen:
|
||||||
|
deduped.append(resource)
|
||||||
|
seen.add(resource)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
def create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples):
|
||||||
|
for resource in resources:
|
||||||
|
resource_dir = skill_dir / resource
|
||||||
|
resource_dir.mkdir(exist_ok=True)
|
||||||
|
if resource == "scripts":
|
||||||
|
if include_examples:
|
||||||
|
example_script = resource_dir / "example.py"
|
||||||
|
example_script.write_text(EXAMPLE_SCRIPT.format(skill_name=skill_name))
|
||||||
|
example_script.chmod(0o755)
|
||||||
|
print("[OK] Created scripts/example.py")
|
||||||
|
else:
|
||||||
|
print("[OK] Created scripts/")
|
||||||
|
elif resource == "references":
|
||||||
|
if include_examples:
|
||||||
|
example_reference = resource_dir / "api_reference.md"
|
||||||
|
example_reference.write_text(EXAMPLE_REFERENCE.format(skill_title=skill_title))
|
||||||
|
print("[OK] Created references/api_reference.md")
|
||||||
|
else:
|
||||||
|
print("[OK] Created references/")
|
||||||
|
elif resource == "assets":
|
||||||
|
if include_examples:
|
||||||
|
example_asset = resource_dir / "example_asset.txt"
|
||||||
|
example_asset.write_text(EXAMPLE_ASSET)
|
||||||
|
print("[OK] Created assets/example_asset.txt")
|
||||||
|
else:
|
||||||
|
print("[OK] Created assets/")
|
||||||
|
|
||||||
|
|
||||||
|
def init_skill(skill_name, path, resources, include_examples):
|
||||||
|
"""
|
||||||
|
Initialize a new skill directory with template SKILL.md.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill
|
||||||
|
path: Path where the skill directory should be created
|
||||||
|
resources: Resource directories to create
|
||||||
|
include_examples: Whether to create example files in resource directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to created skill directory, or None if error
|
||||||
|
"""
|
||||||
|
# Determine skill directory path
|
||||||
|
skill_dir = Path(path).resolve() / skill_name
|
||||||
|
|
||||||
|
# Check if directory already exists
|
||||||
|
if skill_dir.exists():
|
||||||
|
print(f"[ERROR] Skill directory already exists: {skill_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create skill directory
|
||||||
|
try:
|
||||||
|
skill_dir.mkdir(parents=True, exist_ok=False)
|
||||||
|
print(f"[OK] Created skill directory: {skill_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Error creating directory: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create SKILL.md from template
|
||||||
|
skill_title = title_case_skill_name(skill_name)
|
||||||
|
skill_content = SKILL_TEMPLATE.format(skill_name=skill_name, skill_title=skill_title)
|
||||||
|
|
||||||
|
skill_md_path = skill_dir / "SKILL.md"
|
||||||
|
try:
|
||||||
|
skill_md_path.write_text(skill_content)
|
||||||
|
print("[OK] Created SKILL.md")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Error creating SKILL.md: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create resource directories if requested
|
||||||
|
if resources:
|
||||||
|
try:
|
||||||
|
create_resource_dirs(skill_dir, skill_name, skill_title, resources, include_examples)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Error creating resource directories: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Print next steps
|
||||||
|
print(f"\n[OK] Skill '{skill_name}' initialized successfully at {skill_dir}")
|
||||||
|
print("\nNext steps:")
|
||||||
|
print("1. Edit SKILL.md to complete the TODO items and update the description")
|
||||||
|
if resources:
|
||||||
|
if include_examples:
|
||||||
|
print("2. Customize or delete the example files in scripts/, references/, and assets/")
|
||||||
|
else:
|
||||||
|
print("2. Add resources to scripts/, references/, and assets/ as needed")
|
||||||
|
else:
|
||||||
|
print("2. Create resource directories only if needed (scripts/, references/, assets/)")
|
||||||
|
print("3. Run the validator when ready to check the skill structure")
|
||||||
|
|
||||||
|
return skill_dir
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Create a new skill directory with a SKILL.md template.",
|
||||||
|
)
|
||||||
|
parser.add_argument("skill_name", help="Skill name (normalized to hyphen-case)")
|
||||||
|
parser.add_argument("--path", required=True, help="Output directory for the skill")
|
||||||
|
parser.add_argument(
|
||||||
|
"--resources",
|
||||||
|
default="",
|
||||||
|
help="Comma-separated list: scripts,references,assets",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--examples",
|
||||||
|
action="store_true",
|
||||||
|
help="Create example files inside the selected resource directories",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
raw_skill_name = args.skill_name
|
||||||
|
skill_name = normalize_skill_name(raw_skill_name)
|
||||||
|
if not skill_name:
|
||||||
|
print("[ERROR] Skill name must include at least one letter or digit.")
|
||||||
|
sys.exit(1)
|
||||||
|
if len(skill_name) > MAX_SKILL_NAME_LENGTH:
|
||||||
|
print(
|
||||||
|
f"[ERROR] Skill name '{skill_name}' is too long ({len(skill_name)} characters). "
|
||||||
|
f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if skill_name != raw_skill_name:
|
||||||
|
print(f"Note: Normalized skill name from '{raw_skill_name}' to '{skill_name}'.")
|
||||||
|
|
||||||
|
resources = parse_resources(args.resources)
|
||||||
|
if args.examples and not resources:
|
||||||
|
print("[ERROR] --examples requires --resources to be set.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
path = args.path
|
||||||
|
|
||||||
|
print(f"Initializing skill: {skill_name}")
|
||||||
|
print(f" Location: {path}")
|
||||||
|
if resources:
|
||||||
|
print(f" Resources: {', '.join(resources)}")
|
||||||
|
if args.examples:
|
||||||
|
print(" Examples: enabled")
|
||||||
|
else:
|
||||||
|
print(" Resources: none (create as needed)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
result = init_skill(skill_name, path, resources, args.examples)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
154
nanobot/skills/skill-creator/scripts/package_skill.py
Executable file
154
nanobot/skills/skill-creator/scripts/package_skill.py
Executable file
@@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Skill Packager - Creates a distributable .skill file of a skill folder
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python package_skill.py <path/to/skill-folder> [output-directory]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python package_skill.py skills/public/my-skill
|
||||||
|
python package_skill.py skills/public/my-skill ./dist
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from quick_validate import validate_skill
|
||||||
|
|
||||||
|
|
||||||
|
def _is_within(path: Path, root: Path) -> bool:
|
||||||
|
try:
|
||||||
|
path.relative_to(root)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_partial_archive(skill_filename: Path) -> None:
|
||||||
|
try:
|
||||||
|
if skill_filename.exists():
|
||||||
|
skill_filename.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def package_skill(skill_path, output_dir=None):
|
||||||
|
"""
|
||||||
|
Package a skill folder into a .skill file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_path: Path to the skill folder
|
||||||
|
output_dir: Optional output directory for the .skill file (defaults to current directory)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the created .skill file, or None if error
|
||||||
|
"""
|
||||||
|
skill_path = Path(skill_path).resolve()
|
||||||
|
|
||||||
|
# Validate skill folder exists
|
||||||
|
if not skill_path.exists():
|
||||||
|
print(f"[ERROR] Skill folder not found: {skill_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not skill_path.is_dir():
|
||||||
|
print(f"[ERROR] Path is not a directory: {skill_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate SKILL.md exists
|
||||||
|
skill_md = skill_path / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
print(f"[ERROR] SKILL.md not found in {skill_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Run validation before packaging
|
||||||
|
print("Validating skill...")
|
||||||
|
valid, message = validate_skill(skill_path)
|
||||||
|
if not valid:
|
||||||
|
print(f"[ERROR] Validation failed: {message}")
|
||||||
|
print(" Please fix the validation errors before packaging.")
|
||||||
|
return None
|
||||||
|
print(f"[OK] {message}\n")
|
||||||
|
|
||||||
|
# Determine output location
|
||||||
|
skill_name = skill_path.name
|
||||||
|
if output_dir:
|
||||||
|
output_path = Path(output_dir).resolve()
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
output_path = Path.cwd()
|
||||||
|
|
||||||
|
skill_filename = output_path / f"{skill_name}.skill"
|
||||||
|
|
||||||
|
EXCLUDED_DIRS = {".git", ".svn", ".hg", "__pycache__", "node_modules"}
|
||||||
|
|
||||||
|
files_to_package = []
|
||||||
|
resolved_archive = skill_filename.resolve()
|
||||||
|
|
||||||
|
for file_path in skill_path.rglob("*"):
|
||||||
|
# Fail closed on symlinks so the packaged contents are explicit and predictable.
|
||||||
|
if file_path.is_symlink():
|
||||||
|
print(f"[ERROR] Symlink not allowed in packaged skill: {file_path}")
|
||||||
|
_cleanup_partial_archive(skill_filename)
|
||||||
|
return None
|
||||||
|
|
||||||
|
rel_parts = file_path.relative_to(skill_path).parts
|
||||||
|
if any(part in EXCLUDED_DIRS for part in rel_parts):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if file_path.is_file():
|
||||||
|
resolved_file = file_path.resolve()
|
||||||
|
if not _is_within(resolved_file, skill_path):
|
||||||
|
print(f"[ERROR] File escapes skill root: {file_path}")
|
||||||
|
_cleanup_partial_archive(skill_filename)
|
||||||
|
return None
|
||||||
|
# If output lives under skill_path, avoid writing archive into itself.
|
||||||
|
if resolved_file == resolved_archive:
|
||||||
|
print(f"[WARN] Skipping output archive: {file_path}")
|
||||||
|
continue
|
||||||
|
files_to_package.append(file_path)
|
||||||
|
|
||||||
|
# Create the .skill file (zip format)
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(skill_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||||
|
for file_path in files_to_package:
|
||||||
|
# Calculate the relative path within the zip.
|
||||||
|
arcname = Path(skill_name) / file_path.relative_to(skill_path)
|
||||||
|
zipf.write(file_path, arcname)
|
||||||
|
print(f" Added: {arcname}")
|
||||||
|
|
||||||
|
print(f"\n[OK] Successfully packaged skill to: {skill_filename}")
|
||||||
|
return skill_filename
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
_cleanup_partial_archive(skill_filename)
|
||||||
|
print(f"[ERROR] Error creating .skill file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: python package_skill.py <path/to/skill-folder> [output-directory]")
|
||||||
|
print("\nExample:")
|
||||||
|
print(" python package_skill.py skills/public/my-skill")
|
||||||
|
print(" python package_skill.py skills/public/my-skill ./dist")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
skill_path = sys.argv[1]
|
||||||
|
output_dir = sys.argv[2] if len(sys.argv) > 2 else None
|
||||||
|
|
||||||
|
print(f"Packaging skill: {skill_path}")
|
||||||
|
if output_dir:
|
||||||
|
print(f" Output directory: {output_dir}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
result = package_skill(skill_path, output_dir)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
213
nanobot/skills/skill-creator/scripts/quick_validate.py
Normal file
213
nanobot/skills/skill-creator/scripts/quick_validate.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Minimal validator for nanobot skill folders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
yaml = None
|
||||||
|
|
||||||
|
MAX_SKILL_NAME_LENGTH = 64
|
||||||
|
ALLOWED_FRONTMATTER_KEYS = {
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"metadata",
|
||||||
|
"always",
|
||||||
|
"license",
|
||||||
|
"allowed-tools",
|
||||||
|
}
|
||||||
|
ALLOWED_RESOURCE_DIRS = {"scripts", "references", "assets"}
|
||||||
|
PLACEHOLDER_MARKERS = ("[todo", "todo:")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_frontmatter(content: str) -> Optional[str]:
|
||||||
|
lines = content.splitlines()
|
||||||
|
if not lines or lines[0].strip() != "---":
|
||||||
|
return None
|
||||||
|
for i in range(1, len(lines)):
|
||||||
|
if lines[i].strip() == "---":
|
||||||
|
return "\n".join(lines[1:i])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_simple_frontmatter(frontmatter_text: str) -> Optional[dict[str, str]]:
|
||||||
|
"""Fallback parser for simple frontmatter when PyYAML is unavailable."""
|
||||||
|
parsed: dict[str, str] = {}
|
||||||
|
current_key: Optional[str] = None
|
||||||
|
multiline_key: Optional[str] = None
|
||||||
|
|
||||||
|
for raw_line in frontmatter_text.splitlines():
|
||||||
|
stripped = raw_line.strip()
|
||||||
|
if not stripped or stripped.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_indented = raw_line[:1].isspace()
|
||||||
|
if is_indented:
|
||||||
|
if current_key is None:
|
||||||
|
return None
|
||||||
|
current_value = parsed[current_key]
|
||||||
|
parsed[current_key] = f"{current_value}\n{stripped}" if current_value else stripped
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ":" not in stripped:
|
||||||
|
return None
|
||||||
|
|
||||||
|
key, value = stripped.split(":", 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
if not key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if value in {"|", ">"}:
|
||||||
|
parsed[key] = ""
|
||||||
|
current_key = key
|
||||||
|
multiline_key = key
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (value.startswith('"') and value.endswith('"')) or (
|
||||||
|
value.startswith("'") and value.endswith("'")
|
||||||
|
):
|
||||||
|
value = value[1:-1]
|
||||||
|
parsed[key] = value
|
||||||
|
current_key = key
|
||||||
|
multiline_key = None
|
||||||
|
|
||||||
|
if multiline_key is not None and multiline_key not in parsed:
|
||||||
|
return None
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _load_frontmatter(frontmatter_text: str) -> tuple[Optional[dict], Optional[str]]:
|
||||||
|
if yaml is not None:
|
||||||
|
try:
|
||||||
|
frontmatter = yaml.safe_load(frontmatter_text)
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
return None, f"Invalid YAML in frontmatter: {exc}"
|
||||||
|
if not isinstance(frontmatter, dict):
|
||||||
|
return None, "Frontmatter must be a YAML dictionary"
|
||||||
|
return frontmatter, None
|
||||||
|
|
||||||
|
frontmatter = _parse_simple_frontmatter(frontmatter_text)
|
||||||
|
if frontmatter is None:
|
||||||
|
return None, "Invalid YAML in frontmatter: unsupported syntax without PyYAML installed"
|
||||||
|
return frontmatter, None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_skill_name(name: str, folder_name: str) -> Optional[str]:
|
||||||
|
if not re.fullmatch(r"[a-z0-9]+(?:-[a-z0-9]+)*", name):
|
||||||
|
return (
|
||||||
|
f"Name '{name}' should be hyphen-case "
|
||||||
|
"(lowercase letters, digits, and single hyphens only)"
|
||||||
|
)
|
||||||
|
if len(name) > MAX_SKILL_NAME_LENGTH:
|
||||||
|
return (
|
||||||
|
f"Name is too long ({len(name)} characters). "
|
||||||
|
f"Maximum is {MAX_SKILL_NAME_LENGTH} characters."
|
||||||
|
)
|
||||||
|
if name != folder_name:
|
||||||
|
return f"Skill name '{name}' must match directory name '{folder_name}'"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description(description: str) -> Optional[str]:
|
||||||
|
trimmed = description.strip()
|
||||||
|
if not trimmed:
|
||||||
|
return "Description cannot be empty"
|
||||||
|
lowered = trimmed.lower()
|
||||||
|
if any(marker in lowered for marker in PLACEHOLDER_MARKERS):
|
||||||
|
return "Description still contains TODO placeholder text"
|
||||||
|
if "<" in trimmed or ">" in trimmed:
|
||||||
|
return "Description cannot contain angle brackets (< or >)"
|
||||||
|
if len(trimmed) > 1024:
|
||||||
|
return f"Description is too long ({len(trimmed)} characters). Maximum is 1024 characters."
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_skill(skill_path):
|
||||||
|
"""Validate a skill folder structure and required frontmatter."""
|
||||||
|
skill_path = Path(skill_path).resolve()
|
||||||
|
|
||||||
|
if not skill_path.exists():
|
||||||
|
return False, f"Skill folder not found: {skill_path}"
|
||||||
|
if not skill_path.is_dir():
|
||||||
|
return False, f"Path is not a directory: {skill_path}"
|
||||||
|
|
||||||
|
skill_md = skill_path / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
return False, "SKILL.md not found"
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = skill_md.read_text(encoding="utf-8")
|
||||||
|
except OSError as exc:
|
||||||
|
return False, f"Could not read SKILL.md: {exc}"
|
||||||
|
|
||||||
|
frontmatter_text = _extract_frontmatter(content)
|
||||||
|
if frontmatter_text is None:
|
||||||
|
return False, "Invalid frontmatter format"
|
||||||
|
|
||||||
|
frontmatter, error = _load_frontmatter(frontmatter_text)
|
||||||
|
if error:
|
||||||
|
return False, error
|
||||||
|
|
||||||
|
unexpected_keys = sorted(set(frontmatter.keys()) - ALLOWED_FRONTMATTER_KEYS)
|
||||||
|
if unexpected_keys:
|
||||||
|
allowed = ", ".join(sorted(ALLOWED_FRONTMATTER_KEYS))
|
||||||
|
unexpected = ", ".join(unexpected_keys)
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Unexpected key(s) in SKILL.md frontmatter: {unexpected}. Allowed properties are: {allowed}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if "name" not in frontmatter:
|
||||||
|
return False, "Missing 'name' in frontmatter"
|
||||||
|
if "description" not in frontmatter:
|
||||||
|
return False, "Missing 'description' in frontmatter"
|
||||||
|
|
||||||
|
name = frontmatter["name"]
|
||||||
|
if not isinstance(name, str):
|
||||||
|
return False, f"Name must be a string, got {type(name).__name__}"
|
||||||
|
name_error = _validate_skill_name(name.strip(), skill_path.name)
|
||||||
|
if name_error:
|
||||||
|
return False, name_error
|
||||||
|
|
||||||
|
description = frontmatter["description"]
|
||||||
|
if not isinstance(description, str):
|
||||||
|
return False, f"Description must be a string, got {type(description).__name__}"
|
||||||
|
description_error = _validate_description(description)
|
||||||
|
if description_error:
|
||||||
|
return False, description_error
|
||||||
|
|
||||||
|
always = frontmatter.get("always")
|
||||||
|
if always is not None and not isinstance(always, bool):
|
||||||
|
return False, f"'always' must be a boolean, got {type(always).__name__}"
|
||||||
|
|
||||||
|
for child in skill_path.iterdir():
|
||||||
|
if child.name == "SKILL.md":
|
||||||
|
continue
|
||||||
|
if child.is_dir() and child.name in ALLOWED_RESOURCE_DIRS:
|
||||||
|
continue
|
||||||
|
if child.is_symlink():
|
||||||
|
continue
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Unexpected file or directory in skill root: {child.name}. "
|
||||||
|
"Only SKILL.md, scripts/, references/, and assets/ are allowed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, "Skill is valid!"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python quick_validate.py <skill_directory>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
valid, message = validate_skill(sys.argv[1])
|
||||||
|
print(message)
|
||||||
|
sys.exit(0 if valid else 1)
|
||||||
@@ -1,8 +1,12 @@
|
|||||||
"""Utility functions for nanobot."""
|
"""Utility functions for nanobot."""
|
||||||
|
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
def detect_image_mime(data: bytes) -> str | None:
|
def detect_image_mime(data: bytes) -> str | None:
|
||||||
@@ -68,6 +72,104 @@ def split_message(content: str, max_len: int = 2000) -> list[str]:
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_assistant_message(
|
||||||
|
content: str | None,
|
||||||
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
|
reasoning_content: str | None = None,
|
||||||
|
thinking_blocks: list[dict] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build a provider-safe assistant message with optional reasoning fields."""
|
||||||
|
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||||
|
if tool_calls:
|
||||||
|
msg["tool_calls"] = tool_calls
|
||||||
|
if reasoning_content is not None:
|
||||||
|
msg["reasoning_content"] = reasoning_content
|
||||||
|
if thinking_blocks:
|
||||||
|
msg["thinking_blocks"] = thinking_blocks
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_prompt_tokens(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Estimate prompt tokens with tiktoken."""
|
||||||
|
try:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
parts: list[str] = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
txt = part.get("text", "")
|
||||||
|
if txt:
|
||||||
|
parts.append(txt)
|
||||||
|
if tools:
|
||||||
|
parts.append(json.dumps(tools, ensure_ascii=False))
|
||||||
|
return len(enc.encode("\n".join(parts)))
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||||
|
"""Estimate prompt tokens contributed by one persisted message."""
|
||||||
|
content = message.get("content")
|
||||||
|
parts: list[str] = []
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
text = part.get("text", "")
|
||||||
|
if text:
|
||||||
|
parts.append(text)
|
||||||
|
else:
|
||||||
|
parts.append(json.dumps(part, ensure_ascii=False))
|
||||||
|
elif content is not None:
|
||||||
|
parts.append(json.dumps(content, ensure_ascii=False))
|
||||||
|
|
||||||
|
for key in ("name", "tool_call_id"):
|
||||||
|
value = message.get(key)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
parts.append(value)
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||||
|
|
||||||
|
payload = "\n".join(parts)
|
||||||
|
if not payload:
|
||||||
|
return 1
|
||||||
|
try:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return max(1, len(enc.encode(payload)))
|
||||||
|
except Exception:
|
||||||
|
return max(1, len(payload) // 4)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_prompt_tokens_chain(
|
||||||
|
provider: Any,
|
||||||
|
model: str | None,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> tuple[int, str]:
|
||||||
|
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
|
||||||
|
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
|
||||||
|
if callable(provider_counter):
|
||||||
|
try:
|
||||||
|
tokens, source = provider_counter(messages, tools, model)
|
||||||
|
if isinstance(tokens, (int, float)) and tokens > 0:
|
||||||
|
return int(tokens), str(source or "provider_counter")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
estimated = estimate_prompt_tokens(messages, tools)
|
||||||
|
if estimated > 0:
|
||||||
|
return int(estimated), "tiktoken"
|
||||||
|
return 0, "none"
|
||||||
|
|
||||||
|
|
||||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||||
from importlib.resources import files as pkg_files
|
from importlib.resources import files as pkg_files
|
||||||
@@ -88,7 +190,7 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
|
|||||||
added.append(str(dest.relative_to(workspace)))
|
added.append(str(dest.relative_to(workspace)))
|
||||||
|
|
||||||
for item in tpl.iterdir():
|
for item in tpl.iterdir():
|
||||||
if item.name.endswith(".md"):
|
if item.name.endswith(".md") and not item.name.startswith("."):
|
||||||
_write(item, workspace / item.name)
|
_write(item, workspace / item.name)
|
||||||
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
|
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
|
||||||
_write(None, workspace / "memory" / "HISTORY.md")
|
_write(None, workspace / "memory" / "HISTORY.md")
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ classifiers = [
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"typer>=0.20.0,<1.0.0",
|
"typer>=0.20.0,<1.0.0",
|
||||||
"litellm>=1.81.5,<2.0.0",
|
"litellm>=1.82.1,<2.0.0",
|
||||||
"pydantic>=2.12.0,<3.0.0",
|
"pydantic>=2.12.0,<3.0.0",
|
||||||
"pydantic-settings>=2.12.0,<3.0.0",
|
"pydantic-settings>=2.12.0,<3.0.0",
|
||||||
"websockets>=16.0,<17.0",
|
"websockets>=16.0,<17.0",
|
||||||
@@ -45,6 +45,7 @@ dependencies = [
|
|||||||
"chardet>=3.0.2,<6.0.0",
|
"chardet>=3.0.2,<6.0.0",
|
||||||
"openai>=2.8.0",
|
"openai>=2.8.0",
|
||||||
"wecom-aibot-sdk-python>=0.1.2",
|
"wecom-aibot-sdk-python>=0.1.2",
|
||||||
|
"tiktoken>=0.12.0,<1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -267,6 +267,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
|
|||||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
|
||||||
|
mock_agent_runtime["config"].agents.defaults.memory_window = 100
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "memoryWindow" in result.stdout
|
||||||
|
assert "contextWindowTokens" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||||
config_file = tmp_path / "instance" / "config.json"
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
config_file.parent.mkdir(parents=True)
|
config_file.parent.mkdir(parents=True)
|
||||||
@@ -328,6 +338,28 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
|||||||
assert config.workspace_path == override
|
assert config.workspace_path == override
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
|
config_file.parent.mkdir(parents=True)
|
||||||
|
config_file.write_text("{}")
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.agents.defaults.memory_window = 100
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.cli.commands._make_provider",
|
||||||
|
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||||
|
|
||||||
|
assert isinstance(result.exception, _StopGateway)
|
||||||
|
assert "memoryWindow" in result.stdout
|
||||||
|
assert "contextWindowTokens" in result.stdout
|
||||||
|
|
||||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||||
config_file = tmp_path / "instance" / "config.json"
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
config_file.parent.mkdir(parents=True)
|
config_file.parent.mkdir(parents=True)
|
||||||
@@ -356,3 +388,47 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat
|
|||||||
|
|
||||||
assert isinstance(result.exception, _StopGateway)
|
assert isinstance(result.exception, _StopGateway)
|
||||||
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
|
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
|
config_file.parent.mkdir(parents=True)
|
||||||
|
config_file.write_text("{}")
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.gateway.port = 18791
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.cli.commands._make_provider",
|
||||||
|
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||||
|
|
||||||
|
assert isinstance(result.exception, _StopGateway)
|
||||||
|
assert "port 18791" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
|
config_file.parent.mkdir(parents=True)
|
||||||
|
config_file.write_text("{}")
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config.gateway.port = 18791
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.cli.commands._make_provider",
|
||||||
|
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||||
|
|
||||||
|
assert isinstance(result.exception, _StopGateway)
|
||||||
|
assert "port 18792" in result.stdout
|
||||||
|
|||||||
88
tests/test_config_migration.py
Normal file
88
tests/test_config_migration.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from nanobot.cli.commands import app
|
||||||
|
from nanobot.config.loader import load_config, save_config
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"maxTokens": 1234,
|
||||||
|
"memoryWindow": 42,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
|
||||||
|
assert config.agents.defaults.max_tokens == 1234
|
||||||
|
assert config.agents.defaults.context_window_tokens == 65_536
|
||||||
|
assert config.agents.defaults.should_warn_deprecated_memory_window is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"maxTokens": 2222,
|
||||||
|
"memoryWindow": 30,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path)
|
||||||
|
save_config(config, config_path)
|
||||||
|
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
defaults = saved["agents"]["defaults"]
|
||||||
|
|
||||||
|
assert defaults["maxTokens"] == 2222
|
||||||
|
assert defaults["contextWindowTokens"] == 65_536
|
||||||
|
assert "memoryWindow" not in defaults
|
||||||
|
|
||||||
|
|
||||||
|
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"maxTokens": 3333,
|
||||||
|
"memoryWindow": 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||||
|
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "contextWindowTokens" in result.stdout
|
||||||
|
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
defaults = saved["agents"]["defaults"]
|
||||||
|
assert defaults["maxTokens"] == 3333
|
||||||
|
assert defaults["contextWindowTokens"] == 65_536
|
||||||
|
assert "memoryWindow" not in defaults
|
||||||
@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
|
|||||||
assert_messages_content(old_messages, 10, 34)
|
assert_messages_content(old_messages, 10, 34)
|
||||||
|
|
||||||
|
|
||||||
class TestConsolidationDeduplicationGuard:
|
class TestNewCommandArchival:
|
||||||
"""Test that consolidation tasks are deduplicated and serialized."""
|
"""Test /new archival behavior with the simplified consolidation flow."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@staticmethod
|
||||||
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
def _make_loop(tmp_path: Path):
|
||||||
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
context_window_tokens=1,
|
||||||
)
|
)
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
return loop
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
consolidation_calls = 0
|
|
||||||
|
|
||||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
|
||||||
nonlocal consolidation_calls
|
|
||||||
consolidation_calls += 1
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert consolidation_calls == 1, (
|
|
||||||
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_new_command_guard_prevents_concurrent_consolidation(
|
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
consolidation_calls = 0
|
|
||||||
active = 0
|
|
||||||
max_active = 0
|
|
||||||
|
|
||||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
|
||||||
nonlocal consolidation_calls, active, max_active
|
|
||||||
consolidation_calls += 1
|
|
||||||
active += 1
|
|
||||||
max_active = max(max_active, active)
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
active -= 1
|
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
|
||||||
await loop._process_message(new_msg)
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert consolidation_calls == 2, (
|
|
||||||
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
|
||||||
)
|
|
||||||
assert max_active == 1, (
|
|
||||||
f"Expected serialized consolidation, observed concurrency={max_active}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
|
||||||
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
started = asyncio.Event()
|
|
||||||
|
|
||||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
|
||||||
started.set()
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
|
|
||||||
await started.wait()
|
|
||||||
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
|
||||||
|
|
||||||
await asyncio.sleep(0.15)
|
|
||||||
assert len(loop._consolidation_tasks) == 0, (
|
|
||||||
"Task reference must be removed after completion"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""/new waits for in-flight consolidation and archives before clear."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
|
||||||
for i in range(15):
|
|
||||||
session.add_message("user", f"msg{i}")
|
|
||||||
session.add_message("assistant", f"resp{i}")
|
|
||||||
loop.sessions.save(session)
|
|
||||||
|
|
||||||
started = asyncio.Event()
|
|
||||||
release = asyncio.Event()
|
|
||||||
archived_count = 0
|
|
||||||
|
|
||||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
|
||||||
nonlocal archived_count
|
|
||||||
if archive_all:
|
|
||||||
archived_count = len(sess.messages)
|
|
||||||
return True
|
|
||||||
started.set()
|
|
||||||
await release.wait()
|
|
||||||
return True
|
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await started.wait()
|
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
|
||||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
|
||||||
|
|
||||||
await asyncio.sleep(0.02)
|
|
||||||
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
|
||||||
|
|
||||||
release.set()
|
|
||||||
response = await pending_new
|
|
||||||
assert response is not None
|
|
||||||
assert "new session started" in response.content.lower()
|
|
||||||
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
|
||||||
|
|
||||||
session_after = loop.sessions.get_or_create("cli:test")
|
|
||||||
assert session_after.messages == [], "Session should be cleared after successful archival"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||||
"""/new must keep session data if archive step reports failure."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
|
loop = self._make_loop(tmp_path)
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
session.add_message("user", f"msg{i}")
|
session.add_message("user", f"msg{i}")
|
||||||
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
before_count = len(session.messages)
|
before_count = len(session.messages)
|
||||||
|
|
||||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _failing_consolidate(_messages) -> bool:
|
||||||
if archive_all:
|
return False
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
response = await loop._process_message(new_msg)
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert "failed" in response.content.lower()
|
assert "failed" in response.content.lower()
|
||||||
session_after = loop.sessions.get_or_create("cli:test")
|
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
|
||||||
assert len(session_after.messages) == before_count, (
|
|
||||||
"Session must remain intact when /new archival fails"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""/new should archive only messages not yet consolidated by prior task."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
|
loop = self._make_loop(tmp_path)
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
session.add_message("user", f"msg{i}")
|
session.add_message("user", f"msg{i}")
|
||||||
session.add_message("assistant", f"resp{i}")
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
session.last_consolidated = len(session.messages) - 3
|
||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
started = asyncio.Event()
|
|
||||||
release = asyncio.Event()
|
|
||||||
archived_count = -1
|
archived_count = -1
|
||||||
|
|
||||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _fake_consolidate(messages) -> bool:
|
||||||
nonlocal archived_count
|
nonlocal archived_count
|
||||||
if archive_all:
|
archived_count = len(messages)
|
||||||
archived_count = len(sess.messages)
|
|
||||||
return True
|
|
||||||
|
|
||||||
started.set()
|
|
||||||
await release.wait()
|
|
||||||
sess.last_consolidated = len(sess.messages) - 3
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
|
||||||
await loop._process_message(msg)
|
|
||||||
await started.wait()
|
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
response = await loop._process_message(new_msg)
|
||||||
await asyncio.sleep(0.02)
|
|
||||||
assert not pending_new.done()
|
|
||||||
|
|
||||||
release.set()
|
|
||||||
response = await pending_new
|
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert "new session started" in response.content.lower()
|
assert "new session started" in response.content.lower()
|
||||||
assert archived_count == 3, (
|
assert archived_count == 3
|
||||||
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||||
"""/new clears session and returns confirmation."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
|
||||||
)
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
|
loop = self._make_loop(tmp_path)
|
||||||
session = loop.sessions.get_or_create("cli:test")
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
session.add_message("user", f"msg{i}")
|
session.add_message("user", f"msg{i}")
|
||||||
session.add_message("assistant", f"resp{i}")
|
session.add_message("assistant", f"resp{i}")
|
||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _ok_consolidate(_messages) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
response = await loop._process_message(new_msg)
|
response = await loop._process_message(new_msg)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.dingtalk import DingTalkChannel
|
import nanobot.channels.dingtalk as dingtalk_module
|
||||||
|
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
||||||
from nanobot.config.schema import DingTalkConfig
|
from nanobot.config.schema import DingTalkConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -64,3 +66,46 @@ async def test_group_send_uses_group_messages_api() -> None:
|
|||||||
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||||
assert call["json"]["openConversationId"] == "conv123"
|
assert call["json"]["openConversationId"] == "conv123"
|
||||||
assert call["json"]["msgKey"] == "sampleMarkdown"
|
assert call["json"]["msgKey"] == "sampleMarkdown"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
handler = NanobotDingTalkHandler(channel)
|
||||||
|
|
||||||
|
class _FakeChatbotMessage:
|
||||||
|
text = None
|
||||||
|
extensions = {"content": {"recognition": "voice transcript"}}
|
||||||
|
sender_staff_id = "user1"
|
||||||
|
sender_id = "fallback-user"
|
||||||
|
sender_nick = "Alice"
|
||||||
|
message_type = "audio"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(_data):
|
||||||
|
return _FakeChatbotMessage()
|
||||||
|
|
||||||
|
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
|
||||||
|
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||||
|
|
||||||
|
status, body = await handler.process(
|
||||||
|
SimpleNamespace(
|
||||||
|
data={
|
||||||
|
"conversationType": "2",
|
||||||
|
"conversationId": "conv123",
|
||||||
|
"text": {"content": ""},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(*list(channel._background_tasks))
|
||||||
|
msg = await bus.consume_inbound()
|
||||||
|
|
||||||
|
assert (status, body) == ("OK", "OK")
|
||||||
|
assert msg.content == "voice transcript"
|
||||||
|
assert msg.sender_id == "user1"
|
||||||
|
assert msg.chat_id == "group:conv123"
|
||||||
|
|||||||
@@ -3,18 +3,24 @@ import asyncio
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.heartbeat.service import HeartbeatService
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
class DummyProvider:
|
class DummyProvider(LLMProvider):
|
||||||
def __init__(self, responses: list[LLMResponse]):
|
def __init__(self, responses: list[LLMResponse]):
|
||||||
|
super().__init__()
|
||||||
self._responses = list(responses)
|
self._responses = list(responses)
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
self.calls += 1
|
||||||
if self._responses:
|
if self._responses:
|
||||||
return self._responses.pop(0)
|
return self._responses.pop(0)
|
||||||
return LLMResponse(content="", tool_calls=[])
|
return LLMResponse(content="", tool_calls=[])
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return "test-model"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_is_idempotent(tmp_path) -> None:
|
async def test_start_is_idempotent(tmp_path) -> None:
|
||||||
@@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert await service.trigger_now() is None
|
assert await service.trigger_now() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||||
|
provider = DummyProvider([
|
||||||
|
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||||
|
LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="hb_1",
|
||||||
|
name="heartbeat",
|
||||||
|
arguments={"action": "run", "tasks": "check open tasks"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
|
||||||
|
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
provider=provider,
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
action, tasks = await service._decide("heartbeat content")
|
||||||
|
|
||||||
|
assert action == "run"
|
||||||
|
assert tasks == "check open tasks"
|
||||||
|
assert provider.calls == 2
|
||||||
|
assert delays == [1]
|
||||||
|
|||||||
190
tests/test_loop_consolidation_tokens.py
Normal file
190
tests/test_loop_consolidation_tokens.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
import nanobot.agent.memory as memory_module
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
context_window_tokens=context_window_tokens,
|
||||||
|
)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop.process_direct("hello", session_key="cli:test")
|
||||||
|
|
||||||
|
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
|
||||||
|
|
||||||
|
await loop.process_direct("hello", session_key="cli:test")
|
||||||
|
|
||||||
|
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||||
|
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
|
||||||
|
|
||||||
|
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
|
||||||
|
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
|
||||||
|
assert session.last_consolidated == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||||
|
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||||
|
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||||
|
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
def mock_estimate(_session):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
return (500, "test")
|
||||||
|
if call_count[0] == 2:
|
||||||
|
return (300, "test")
|
||||||
|
return (80, "test")
|
||||||
|
|
||||||
|
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||||
|
|
||||||
|
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||||
|
assert session.last_consolidated == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Once triggered, consolidation should continue until it drops below half threshold."""
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||||
|
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||||
|
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||||
|
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def mock_estimate(_session):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
return (500, "test")
|
||||||
|
if call_count[0] == 2:
|
||||||
|
return (150, "test")
|
||||||
|
return (80, "test")
|
||||||
|
|
||||||
|
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||||
|
|
||||||
|
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||||
|
assert session.last_consolidated == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Verify preflight consolidation runs before the LLM call in process_direct."""
|
||||||
|
order: list[str] = []
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
|
||||||
|
async def track_consolidate(messages):
|
||||||
|
order.append("consolidate")
|
||||||
|
return True
|
||||||
|
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
async def track_llm(*args, **kwargs):
|
||||||
|
order.append("llm")
|
||||||
|
return LLMResponse(content="ok", tool_calls=[])
|
||||||
|
loop.provider.chat_with_retry = track_llm
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||||
|
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||||
|
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||||
|
]
|
||||||
|
loop.sessions.save(session)
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
def mock_estimate(_session):
|
||||||
|
call_count[0] += 1
|
||||||
|
return (1000 if call_count[0] <= 1 else 80, "test")
|
||||||
|
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop.process_direct("hello", session_key="cli:test")
|
||||||
|
|
||||||
|
assert "consolidate" in order
|
||||||
|
assert "llm" in order
|
||||||
|
assert order.index("consolidate") < order.index("llm")
|
||||||
@@ -7,23 +7,20 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
def _make_session(message_count: int = 30, memory_window: int = 50):
|
def _make_messages(message_count: int = 30):
|
||||||
"""Create a mock session with messages."""
|
"""Create a list of mock messages."""
|
||||||
session = MagicMock()
|
return [
|
||||||
session.messages = [
|
|
||||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||||
for i in range(message_count)
|
for i in range(message_count)
|
||||||
]
|
]
|
||||||
session.last_consolidated = 0
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tool_response(history_entry, memory_update):
|
def _make_tool_response(history_entry, memory_update):
|
||||||
@@ -43,6 +40,22 @@ def _make_tool_response(history_entry, memory_update):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptedProvider(LLMProvider):
|
||||||
|
def __init__(self, responses: list[LLMResponse]):
|
||||||
|
super().__init__()
|
||||||
|
self._responses = list(responses)
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
self.calls += 1
|
||||||
|
if self._responses:
|
||||||
|
return self._responses.pop(0)
|
||||||
|
return LLMResponse(content="", tool_calls=[])
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return "test-model"
|
||||||
|
|
||||||
|
|
||||||
class TestMemoryConsolidationTypeHandling:
|
class TestMemoryConsolidationTypeHandling:
|
||||||
"""Test that consolidation handles various argument types correctly."""
|
"""Test that consolidation handles various argument types correctly."""
|
||||||
|
|
||||||
@@ -57,9 +70,10 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
memory_update="# Memory\nUser likes testing.",
|
memory_update="# Memory\nUser likes testing.",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
session = _make_session(message_count=60)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert store.history_file.exists()
|
assert store.history_file.exists()
|
||||||
@@ -77,9 +91,10 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
session = _make_session(message_count=60)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert store.history_file.exists()
|
assert store.history_file.exists()
|
||||||
@@ -112,9 +127,10 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
session = _make_session(message_count=60)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert "User discussed testing." in store.history_file.read_text()
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
@@ -127,21 +143,23 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
provider.chat = AsyncMock(
|
provider.chat = AsyncMock(
|
||||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||||
)
|
)
|
||||||
session = _make_session(message_count=60)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
assert not store.history_file.exists()
|
assert not store.history_file.exists()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
|
||||||
"""Consolidation should be a no-op when messages < keep_count."""
|
"""Consolidation should be a no-op when the selected chunk is empty."""
|
||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
session = _make_session(message_count=10)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages: list[dict] = []
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
provider.chat.assert_not_called()
|
provider.chat.assert_not_called()
|
||||||
@@ -167,9 +185,10 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
session = _make_session(message_count=60)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
assert "User discussed testing." in store.history_file.read_text()
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
@@ -192,9 +211,10 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
session = _make_session(message_count=60)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
@@ -215,8 +235,33 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
provider.chat = AsyncMock(return_value=response)
|
provider.chat = AsyncMock(return_value=response)
|
||||||
session = _make_session(message_count=60)
|
provider.chat_with_retry = provider.chat
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="503 server error", finish_reason="error"),
|
||||||
|
_make_tool_response(
|
||||||
|
history_entry="[2026-01-01] User discussed testing.",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
),
|
||||||
|
])
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert provider.calls == 2
|
||||||
|
assert delays == [1]
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
|
|||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
|
||||||
|
|
||||||
class TestMessageToolSuppressLogic:
|
class TestMessageToolSuppressLogic:
|
||||||
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
LLMResponse(content="", tool_calls=[tool_call]),
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
LLMResponse(content="Done", tool_calls=[]),
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
])
|
])
|
||||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
sent: list[OutboundMessage] = []
|
sent: list[OutboundMessage] = []
|
||||||
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
LLMResponse(content="", tool_calls=[tool_call]),
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||||
])
|
])
|
||||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
sent: list[OutboundMessage] = []
|
sent: list[OutboundMessage] = []
|
||||||
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||||
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
),
|
),
|
||||||
LLMResponse(content="Done", tool_calls=[]),
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
])
|
])
|
||||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
loop.tools.execute = AsyncMock(return_value="ok")
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
|
|
||||||
|
|||||||
92
tests/test_provider_retry.py
Normal file
92
tests/test_provider_retry.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptedProvider(LLMProvider):
|
||||||
|
def __init__(self, responses):
|
||||||
|
super().__init__()
|
||||||
|
self._responses = list(responses)
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
self.calls += 1
|
||||||
|
response = self._responses.pop(0)
|
||||||
|
if isinstance(response, BaseException):
|
||||||
|
raise response
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return "test-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="429 rate limit", finish_reason="error"),
|
||||||
|
LLMResponse(content="ok"),
|
||||||
|
])
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert response.finish_reason == "stop"
|
||||||
|
assert response.content == "ok"
|
||||||
|
assert provider.calls == 2
|
||||||
|
assert delays == [1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||||
|
])
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert response.content == "401 unauthorized"
|
||||||
|
assert provider.calls == 1
|
||||||
|
assert delays == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="429 rate limit a", finish_reason="error"),
|
||||||
|
LLMResponse(content="429 rate limit b", finish_reason="error"),
|
||||||
|
LLMResponse(content="429 rate limit c", finish_reason="error"),
|
||||||
|
LLMResponse(content="503 final server error", finish_reason="error"),
|
||||||
|
])
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert response.content == "503 final server error"
|
||||||
|
assert provider.calls == 4
|
||||||
|
assert delays == [1, 2, 4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
||||||
|
provider = ScriptedProvider([asyncio.CancelledError()])
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
127
tests/test_skill_creator_scripts.py
Normal file
127
tests/test_skill_creator_scripts.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import importlib
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
|
||||||
|
if str(SCRIPT_DIR) not in sys.path:
|
||||||
|
sys.path.insert(0, str(SCRIPT_DIR))
|
||||||
|
|
||||||
|
init_skill = importlib.import_module("init_skill")
|
||||||
|
package_skill = importlib.import_module("package_skill")
|
||||||
|
quick_validate = importlib.import_module("quick_validate")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
|
||||||
|
skill_dir = init_skill.init_skill(
|
||||||
|
"demo-skill",
|
||||||
|
tmp_path,
|
||||||
|
["scripts", "references", "assets"],
|
||||||
|
include_examples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert skill_dir == tmp_path / "demo-skill"
|
||||||
|
assert (skill_dir / "SKILL.md").exists()
|
||||||
|
assert (skill_dir / "scripts" / "example.py").exists()
|
||||||
|
assert (skill_dir / "references" / "api_reference.md").exists()
|
||||||
|
assert (skill_dir / "assets" / "example_asset.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_skill_accepts_existing_skill_creator() -> None:
|
||||||
|
valid, message = quick_validate.validate_skill(
|
||||||
|
Path("nanobot/skills/skill-creator").resolve()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert valid, message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
|
||||||
|
skill_dir = tmp_path / "placeholder-skill"
|
||||||
|
skill_dir.mkdir()
|
||||||
|
(skill_dir / "SKILL.md").write_text(
|
||||||
|
"---\n"
|
||||||
|
"name: placeholder-skill\n"
|
||||||
|
'description: "[TODO: fill me in]"\n'
|
||||||
|
"---\n"
|
||||||
|
"# Placeholder\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
valid, message = quick_validate.validate_skill(skill_dir)
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "TODO placeholder" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
|
||||||
|
skill_dir = tmp_path / "bad-root-skill"
|
||||||
|
skill_dir.mkdir()
|
||||||
|
(skill_dir / "SKILL.md").write_text(
|
||||||
|
"---\n"
|
||||||
|
"name: bad-root-skill\n"
|
||||||
|
"description: Valid description\n"
|
||||||
|
"---\n"
|
||||||
|
"# Skill\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
|
||||||
|
|
||||||
|
valid, message = quick_validate.validate_skill(skill_dir)
|
||||||
|
|
||||||
|
assert not valid
|
||||||
|
assert "Unexpected file or directory in skill root" in message
|
||||||
|
|
||||||
|
|
||||||
|
def test_package_skill_creates_archive(tmp_path: Path) -> None:
|
||||||
|
skill_dir = tmp_path / "package-me"
|
||||||
|
skill_dir.mkdir()
|
||||||
|
(skill_dir / "SKILL.md").write_text(
|
||||||
|
"---\n"
|
||||||
|
"name: package-me\n"
|
||||||
|
"description: Package this skill.\n"
|
||||||
|
"---\n"
|
||||||
|
"# Skill\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
scripts_dir = skill_dir / "scripts"
|
||||||
|
scripts_dir.mkdir()
|
||||||
|
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
|
||||||
|
|
||||||
|
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||||
|
|
||||||
|
assert archive_path == (tmp_path / "dist" / "package-me.skill")
|
||||||
|
assert archive_path.exists()
|
||||||
|
with zipfile.ZipFile(archive_path, "r") as archive:
|
||||||
|
names = set(archive.namelist())
|
||||||
|
assert "package-me/SKILL.md" in names
|
||||||
|
assert "package-me/scripts/helper.py" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
|
||||||
|
skill_dir = tmp_path / "symlink-skill"
|
||||||
|
skill_dir.mkdir()
|
||||||
|
(skill_dir / "SKILL.md").write_text(
|
||||||
|
"---\n"
|
||||||
|
"name: symlink-skill\n"
|
||||||
|
"description: Reject symlinks during packaging.\n"
|
||||||
|
"---\n"
|
||||||
|
"# Skill\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
scripts_dir = skill_dir / "scripts"
|
||||||
|
scripts_dir.mkdir()
|
||||||
|
target = tmp_path / "outside.txt"
|
||||||
|
target.write_text("secret\n", encoding="utf-8")
|
||||||
|
link = scripts_dir / "outside.txt"
|
||||||
|
|
||||||
|
try:
|
||||||
|
link.symlink_to(target)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
return
|
||||||
|
|
||||||
|
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
|
||||||
|
|
||||||
|
assert archive_path is None
|
||||||
|
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()
|
||||||
90
tests/test_slack_channel.py
Normal file
90
tests/test_slack_channel.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.slack import SlackChannel
|
||||||
|
from nanobot.config.schema import SlackConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAsyncWebClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.chat_post_calls: list[dict[str, object | None]] = []
|
||||||
|
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||||
|
|
||||||
|
async def chat_postMessage(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel: str,
|
||||||
|
text: str,
|
||||||
|
thread_ts: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.chat_post_calls.append(
|
||||||
|
{
|
||||||
|
"channel": channel,
|
||||||
|
"text": text,
|
||||||
|
"thread_ts": thread_ts,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def files_upload_v2(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel: str,
|
||||||
|
file: str,
|
||||||
|
thread_ts: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.file_upload_calls.append(
|
||||||
|
{
|
||||||
|
"channel": channel,
|
||||||
|
"file": file,
|
||||||
|
"thread_ts": thread_ts,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||||
|
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||||
|
fake_web = _FakeAsyncWebClient()
|
||||||
|
channel._web_client = fake_web
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="slack",
|
||||||
|
chat_id="C123",
|
||||||
|
content="hello",
|
||||||
|
media=["/tmp/demo.txt"],
|
||||||
|
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(fake_web.chat_post_calls) == 1
|
||||||
|
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||||
|
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
|
||||||
|
assert len(fake_web.file_upload_calls) == 1
|
||||||
|
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_omits_thread_for_dm_messages() -> None:
|
||||||
|
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||||
|
fake_web = _FakeAsyncWebClient()
|
||||||
|
channel._web_client = fake_web
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="slack",
|
||||||
|
chat_id="D123",
|
||||||
|
content="hello",
|
||||||
|
media=["/tmp/demo.txt"],
|
||||||
|
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(fake_web.chat_post_calls) == 1
|
||||||
|
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
|
||||||
|
assert fake_web.chat_post_calls[0]["thread_ts"] is None
|
||||||
|
assert len(fake_web.file_upload_calls) == 1
|
||||||
|
assert fake_web.file_upload_calls[0]["thread_ts"] is None
|
||||||
@@ -165,3 +165,46 @@ class TestSubagentCancellation:
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def scripted_chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
reasoning_content="hidden reasoning",
|
||||||
|
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[])
|
||||||
|
provider.chat_with_retry = scripted_chat_with_retry
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||||
|
|
||||||
|
async def fake_execute(self, name, arguments):
|
||||||
|
return "tool result"
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||||
|
|
||||||
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
|
assistant_messages = [
|
||||||
|
msg for msg in captured_second_call
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||||
|
]
|
||||||
|
assert len(assistant_messages) == 1
|
||||||
|
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||||
|
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||||
|
|||||||
@@ -27,9 +27,11 @@ class _FakeUpdater:
|
|||||||
class _FakeBot:
|
class _FakeBot:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.sent_messages: list[dict] = []
|
self.sent_messages: list[dict] = []
|
||||||
|
self.get_me_calls = 0
|
||||||
|
|
||||||
async def get_me(self):
|
async def get_me(self):
|
||||||
return SimpleNamespace(username="nanobot_test")
|
self.get_me_calls += 1
|
||||||
|
return SimpleNamespace(id=999, username="nanobot_test")
|
||||||
|
|
||||||
async def set_my_commands(self, commands) -> None:
|
async def set_my_commands(self, commands) -> None:
|
||||||
self.commands = commands
|
self.commands = commands
|
||||||
@@ -37,6 +39,9 @@ class _FakeBot:
|
|||||||
async def send_message(self, **kwargs) -> None:
|
async def send_message(self, **kwargs) -> None:
|
||||||
self.sent_messages.append(kwargs)
|
self.sent_messages.append(kwargs)
|
||||||
|
|
||||||
|
async def send_chat_action(self, **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _FakeApp:
|
class _FakeApp:
|
||||||
def __init__(self, on_start_polling) -> None:
|
def __init__(self, on_start_polling) -> None:
|
||||||
@@ -87,6 +92,35 @@ class _FakeBuilder:
|
|||||||
return self.app
|
return self.app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_telegram_update(
|
||||||
|
*,
|
||||||
|
chat_type: str = "group",
|
||||||
|
text: str | None = None,
|
||||||
|
caption: str | None = None,
|
||||||
|
entities=None,
|
||||||
|
caption_entities=None,
|
||||||
|
reply_to_message=None,
|
||||||
|
):
|
||||||
|
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
|
||||||
|
message = SimpleNamespace(
|
||||||
|
chat=SimpleNamespace(type=chat_type, is_forum=False),
|
||||||
|
chat_id=-100123,
|
||||||
|
text=text,
|
||||||
|
caption=caption,
|
||||||
|
entities=entities or [],
|
||||||
|
caption_entities=caption_entities or [],
|
||||||
|
reply_to_message=reply_to_message,
|
||||||
|
photo=None,
|
||||||
|
voice=None,
|
||||||
|
audio=None,
|
||||||
|
document=None,
|
||||||
|
media_group_id=None,
|
||||||
|
message_thread_id=None,
|
||||||
|
message_id=1,
|
||||||
|
)
|
||||||
|
return SimpleNamespace(message=message, effective_user=user)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
||||||
config = TelegramConfig(
|
config = TelegramConfig(
|
||||||
@@ -131,6 +165,10 @@ def test_get_extension_falls_back_to_original_filename() -> None:
|
|||||||
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_telegram_group_policy_defaults_to_mention() -> None:
|
||||||
|
assert TelegramConfig().group_policy == "mention"
|
||||||
|
|
||||||
|
|
||||||
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
|
||||||
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
|
||||||
|
|
||||||
@@ -182,3 +220,119 @@ async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
|||||||
|
|
||||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
handled = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
|
||||||
|
|
||||||
|
assert handled == []
|
||||||
|
assert channel._app.bot.get_me_calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
handled = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||||
|
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
|
||||||
|
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
|
||||||
|
|
||||||
|
assert len(handled) == 2
|
||||||
|
assert channel._app.bot.get_me_calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_policy_mention_accepts_caption_mention() -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
handled = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
mention = SimpleNamespace(type="mention", offset=0, length=13)
|
||||||
|
await channel._on_message(
|
||||||
|
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert handled[0]["content"] == "@nanobot_test photo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
handled = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
|
||||||
|
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_policy_open_accepts_plain_group_message() -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
handled = []
|
||||||
|
|
||||||
|
async def capture_handle(**kwargs) -> None:
|
||||||
|
handled.append(kwargs)
|
||||||
|
|
||||||
|
channel._handle_message = capture_handle
|
||||||
|
channel._start_typing = lambda _chat_id: None
|
||||||
|
|
||||||
|
await channel._on_message(_make_telegram_update(text="hello group"), None)
|
||||||
|
|
||||||
|
assert len(handled) == 1
|
||||||
|
assert channel._app.bot.get_me_calls == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user