Merge branch 'main' into pr-1083
This commit is contained in:
@@ -16,10 +16,13 @@
|
|||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||||
|
|
||||||
📏 Real-time line count: **3,897 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: **3,966 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
||||||
|
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
||||||
|
- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements.
|
||||||
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
||||||
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
||||||
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
nanobot - A lightweight AI agent framework
|
nanobot - A lightweight AI agent framework
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.4"
|
__version__ = "0.1.4.post2"
|
||||||
__logo__ = "🐈"
|
__logo__ = "🐈"
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import platform
|
import platform
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -72,10 +74,6 @@ Skills with available="false" need dependencies installed first - you can try in
|
|||||||
|
|
||||||
def _get_identity(self) -> str:
|
def _get_identity(self) -> str:
|
||||||
"""Get the core identity section."""
|
"""Get the core identity section."""
|
||||||
from datetime import datetime
|
|
||||||
import time as _time
|
|
||||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
|
||||||
tz = _time.strftime("%Z") or "UTC"
|
|
||||||
workspace_path = str(self.workspace.expanduser().resolve())
|
workspace_path = str(self.workspace.expanduser().resolve())
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||||
@@ -84,9 +82,6 @@ Skills with available="false" need dependencies installed first - you can try in
|
|||||||
|
|
||||||
You are nanobot, a helpful AI assistant.
|
You are nanobot, a helpful AI assistant.
|
||||||
|
|
||||||
## Current Time
|
|
||||||
{now} ({tz})
|
|
||||||
|
|
||||||
## Runtime
|
## Runtime
|
||||||
{runtime}
|
{runtime}
|
||||||
|
|
||||||
@@ -108,6 +103,23 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
## Memory
|
## Memory
|
||||||
- Remember important facts: write to {workspace_path}/memory/MEMORY.md
|
- Remember important facts: write to {workspace_path}/memory/MEMORY.md
|
||||||
- Recall past events: grep {workspace_path}/memory/HISTORY.md"""
|
- Recall past events: grep {workspace_path}/memory/HISTORY.md"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _inject_runtime_context(
|
||||||
|
user_content: str | list[dict[str, Any]],
|
||||||
|
channel: str | None,
|
||||||
|
chat_id: str | None,
|
||||||
|
) -> str | list[dict[str, Any]]:
|
||||||
|
"""Append dynamic runtime context to the tail of the user message."""
|
||||||
|
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||||
|
tz = time.strftime("%Z") or "UTC"
|
||||||
|
lines = [f"Current Time: {now} ({tz})"]
|
||||||
|
if channel and chat_id:
|
||||||
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
|
block = "[Runtime Context]\n" + "\n".join(lines)
|
||||||
|
if isinstance(user_content, str):
|
||||||
|
return f"{user_content}\n\n{block}"
|
||||||
|
return [*user_content, {"type": "text", "text": block}]
|
||||||
|
|
||||||
def _load_bootstrap_files(self) -> str:
|
def _load_bootstrap_files(self) -> str:
|
||||||
"""Load all bootstrap files from workspace."""
|
"""Load all bootstrap files from workspace."""
|
||||||
@@ -148,8 +160,6 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
|
|
||||||
# System prompt
|
# System prompt
|
||||||
system_prompt = self.build_system_prompt(skill_names)
|
system_prompt = self.build_system_prompt(skill_names)
|
||||||
if channel and chat_id:
|
|
||||||
system_prompt += f"\n\n## Current Session\nChannel: {channel}\nChat ID: {chat_id}"
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
# History
|
# History
|
||||||
@@ -157,6 +167,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
|
|
||||||
# Current message (with optional image attachments)
|
# Current message (with optional image attachments)
|
||||||
user_content = self._build_user_content(current_message, media)
|
user_content = self._build_user_content(current_message, media)
|
||||||
|
user_content = self._inject_runtime_context(user_content, channel, chat_id)
|
||||||
messages.append({"role": "user", "content": user_content})
|
messages.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@@ -125,6 +125,13 @@ class MemoryStore:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
args = response.tool_calls[0].arguments
|
args = response.tool_calls[0].arguments
|
||||||
|
# Some providers return arguments as a JSON string instead of dict
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json.loads(args)
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||||
|
return False
|
||||||
|
|
||||||
if entry := args.get("history_entry"):
|
if entry := args.get("history_entry"):
|
||||||
if not isinstance(entry, str):
|
if not isinstance(entry, str):
|
||||||
entry = json.dumps(entry, ensure_ascii=False)
|
entry = json.dumps(entry, ensure_ascii=False)
|
||||||
|
|||||||
@@ -69,20 +69,18 @@ async def connect_mcp_servers(
|
|||||||
read, write = await stack.enter_async_context(stdio_client(params))
|
read, write = await stack.enter_async_context(stdio_client(params))
|
||||||
elif cfg.url:
|
elif cfg.url:
|
||||||
from mcp.client.streamable_http import streamable_http_client
|
from mcp.client.streamable_http import streamable_http_client
|
||||||
if cfg.headers:
|
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||||
http_client = await stack.enter_async_context(
|
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||||
httpx.AsyncClient(
|
http_client = await stack.enter_async_context(
|
||||||
headers=cfg.headers,
|
httpx.AsyncClient(
|
||||||
follow_redirects=True
|
headers=cfg.headers or None,
|
||||||
)
|
follow_redirects=True,
|
||||||
)
|
timeout=None,
|
||||||
read, write, _ = await stack.enter_async_context(
|
|
||||||
streamable_http_client(cfg.url, http_client=http_client)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
read, write, _ = await stack.enter_async_context(
|
|
||||||
streamable_http_client(cfg.url)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
read, write, _ = await stack.enter_async_context(
|
||||||
|
streamable_http_client(cfg.url, http_client=http_client)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -36,19 +36,7 @@ class ToolRegistry:
|
|||||||
return [tool.to_schema() for tool in self._tools.values()]
|
return [tool.to_schema() for tool in self._tools.values()]
|
||||||
|
|
||||||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||||||
"""
|
"""Execute a tool by name with given parameters."""
|
||||||
Execute a tool by name with given parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Tool name.
|
|
||||||
params: Tool parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tool execution result as string.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If tool not found.
|
|
||||||
"""
|
|
||||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
|
|
||||||
tool = self._tools.get(name)
|
tool = self._tools.get(name)
|
||||||
|
|||||||
@@ -9,12 +9,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class SpawnTool(Tool):
|
class SpawnTool(Tool):
|
||||||
"""
|
"""Tool to spawn a subagent for background task execution."""
|
||||||
Tool to spawn a subagent for background task execution.
|
|
||||||
|
|
||||||
The subagent runs asynchronously and announces its result back
|
|
||||||
to the main agent when complete.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, manager: "SubagentManager"):
|
def __init__(self, manager: "SubagentManager"):
|
||||||
self._manager = manager
|
self._manager = manager
|
||||||
|
|||||||
@@ -58,12 +58,21 @@ class WebSearchTool(Tool):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
||||||
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "")
|
self._init_api_key = api_key
|
||||||
self.max_results = max_results
|
self.max_results = max_results
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_key(self) -> str:
|
||||||
|
"""Resolve API key at call time so env/config changes are picked up."""
|
||||||
|
return self._init_api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||||
|
|
||||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
return "Error: BRAVE_API_KEY not configured"
|
return (
|
||||||
|
"Error: Brave Search API key not configured. "
|
||||||
|
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey "
|
||||||
|
"(or export BRAVE_API_KEY), then restart the gateway."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
n = min(max(count or self.max_results, 1), 10)
|
n = min(max(count or self.max_results, 1), 10)
|
||||||
@@ -71,7 +80,7 @@ class WebSearchTool(Tool):
|
|||||||
r = await client.get(
|
r = await client.get(
|
||||||
"https://api.search.brave.com/res/v1/web/search",
|
"https://api.search.brave.com/res/v1/web/search",
|
||||||
params={"q": query, "count": n},
|
params={"q": query, "count": n},
|
||||||
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
headers={"Accept": "application/json", "X-Subscription-Token": api_key},
|
||||||
timeout=10.0
|
timeout=10.0
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|||||||
@@ -108,11 +108,6 @@ class EmailChannel(BaseChannel):
|
|||||||
logger.warning("Skip email send: consent_granted is false")
|
logger.warning("Skip email send: consent_granted is false")
|
||||||
return
|
return
|
||||||
|
|
||||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
|
||||||
if not self.config.auto_reply_enabled and not force_send:
|
|
||||||
logger.info("Skip automatic email reply: auto_reply_enabled is false")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.config.smtp_host:
|
if not self.config.smtp_host:
|
||||||
logger.warning("Email channel SMTP host not configured")
|
logger.warning("Email channel SMTP host not configured")
|
||||||
return
|
return
|
||||||
@@ -122,6 +117,15 @@ class EmailChannel(BaseChannel):
|
|||||||
logger.warning("Email channel missing recipient address")
|
logger.warning("Email channel missing recipient address")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Determine if this is a reply (recipient has sent us an email before)
|
||||||
|
is_reply = to_addr in self._last_subject_by_chat
|
||||||
|
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||||
|
|
||||||
|
# autoReplyEnabled only controls automatic replies, not proactive sends
|
||||||
|
if is_reply and not self.config.auto_reply_enabled and not force_send:
|
||||||
|
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
|
||||||
|
return
|
||||||
|
|
||||||
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
||||||
subject = self._reply_subject(base_subject)
|
subject = self._reply_subject(base_subject)
|
||||||
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
||||||
|
|||||||
@@ -180,21 +180,25 @@ def _extract_element_content(element: dict) -> list[str]:
|
|||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def _extract_post_text(content_json: dict) -> str:
|
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||||
"""Extract plain text from Feishu post (rich text) message content.
|
"""Extract text and image keys from Feishu post (rich text) message content.
|
||||||
|
|
||||||
Supports two formats:
|
Supports two formats:
|
||||||
1. Direct format: {"title": "...", "content": [...]}
|
1. Direct format: {"title": "...", "content": [...]}
|
||||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(text, image_keys) - extracted text and list of image keys
|
||||||
"""
|
"""
|
||||||
def extract_from_lang(lang_content: dict) -> str | None:
|
def extract_from_lang(lang_content: dict) -> tuple[str | None, list[str]]:
|
||||||
if not isinstance(lang_content, dict):
|
if not isinstance(lang_content, dict):
|
||||||
return None
|
return None, []
|
||||||
title = lang_content.get("title", "")
|
title = lang_content.get("title", "")
|
||||||
content_blocks = lang_content.get("content", [])
|
content_blocks = lang_content.get("content", [])
|
||||||
if not isinstance(content_blocks, list):
|
if not isinstance(content_blocks, list):
|
||||||
return None
|
return None, []
|
||||||
text_parts = []
|
text_parts = []
|
||||||
|
image_keys = []
|
||||||
if title:
|
if title:
|
||||||
text_parts.append(title)
|
text_parts.append(title)
|
||||||
for block in content_blocks:
|
for block in content_blocks:
|
||||||
@@ -209,22 +213,36 @@ def _extract_post_text(content_json: dict) -> str:
|
|||||||
text_parts.append(element.get("text", ""))
|
text_parts.append(element.get("text", ""))
|
||||||
elif tag == "at":
|
elif tag == "at":
|
||||||
text_parts.append(f"@{element.get('user_name', 'user')}")
|
text_parts.append(f"@{element.get('user_name', 'user')}")
|
||||||
return " ".join(text_parts).strip() if text_parts else None
|
elif tag == "img":
|
||||||
|
img_key = element.get("image_key")
|
||||||
|
if img_key:
|
||||||
|
image_keys.append(img_key)
|
||||||
|
text = " ".join(text_parts).strip() if text_parts else None
|
||||||
|
return text, image_keys
|
||||||
|
|
||||||
# Try direct format first
|
# Try direct format first
|
||||||
if "content" in content_json:
|
if "content" in content_json:
|
||||||
result = extract_from_lang(content_json)
|
text, images = extract_from_lang(content_json)
|
||||||
if result:
|
if text or images:
|
||||||
return result
|
return text or "", images
|
||||||
|
|
||||||
# Try localized format
|
# Try localized format
|
||||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||||
lang_content = content_json.get(lang_key)
|
lang_content = content_json.get(lang_key)
|
||||||
result = extract_from_lang(lang_content)
|
text, images = extract_from_lang(lang_content)
|
||||||
if result:
|
if text or images:
|
||||||
return result
|
return text or "", images
|
||||||
|
|
||||||
return ""
|
return "", []
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_post_text(content_json: dict) -> str:
|
||||||
|
"""Extract plain text from Feishu post (rich text) message content.
|
||||||
|
|
||||||
|
Legacy wrapper for _extract_post_content, returns only text.
|
||||||
|
"""
|
||||||
|
text, _ = _extract_post_content(content_json)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class FeishuChannel(BaseChannel):
|
class FeishuChannel(BaseChannel):
|
||||||
@@ -691,9 +709,17 @@ class FeishuChannel(BaseChannel):
|
|||||||
content_parts.append(text)
|
content_parts.append(text)
|
||||||
|
|
||||||
elif msg_type == "post":
|
elif msg_type == "post":
|
||||||
text = _extract_post_text(content_json)
|
text, image_keys = _extract_post_content(content_json)
|
||||||
if text:
|
if text:
|
||||||
content_parts.append(text)
|
content_parts.append(text)
|
||||||
|
# Download images embedded in post
|
||||||
|
for img_key in image_keys:
|
||||||
|
file_path, content_text = await self._download_and_save_media(
|
||||||
|
"image", {"image_key": img_key}, message_id
|
||||||
|
)
|
||||||
|
if file_path:
|
||||||
|
media_paths.append(file_path)
|
||||||
|
content_parts.append(content_text)
|
||||||
|
|
||||||
elif msg_type in ("image", "audio", "file", "media"):
|
elif msg_type in ("image", "audio", "file", "media"):
|
||||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||||
|
|||||||
@@ -229,6 +229,11 @@ class SlackChannel(BaseChannel):
|
|||||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||||
|
|
||||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||||
|
_CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
|
||||||
|
_INLINE_CODE_RE = re.compile(r"`[^`]+`")
|
||||||
|
_LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
|
||||||
|
_LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||||
|
_BARE_URL_RE = re.compile(r"(?<![|<])(https?://\S+)")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _to_mrkdwn(cls, text: str) -> str:
|
def _to_mrkdwn(cls, text: str) -> str:
|
||||||
@@ -236,7 +241,26 @@ class SlackChannel(BaseChannel):
|
|||||||
if not text:
|
if not text:
|
||||||
return ""
|
return ""
|
||||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||||
return slackify_markdown(text)
|
return cls._fixup_mrkdwn(slackify_markdown(text))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fixup_mrkdwn(cls, text: str) -> str:
|
||||||
|
"""Fix markdown artifacts that slackify_markdown misses."""
|
||||||
|
code_blocks: list[str] = []
|
||||||
|
|
||||||
|
def _save_code(m: re.Match) -> str:
|
||||||
|
code_blocks.append(m.group(0))
|
||||||
|
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||||
|
|
||||||
|
text = cls._CODE_FENCE_RE.sub(_save_code, text)
|
||||||
|
text = cls._INLINE_CODE_RE.sub(_save_code, text)
|
||||||
|
text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
|
||||||
|
text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
|
||||||
|
text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
|
||||||
|
|
||||||
|
for i, block in enumerate(code_blocks):
|
||||||
|
text = text.replace(f"\x00CB{i}\x00", block)
|
||||||
|
return text
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_table(match: re.Match) -> str:
|
def _convert_table(match: re.Match) -> str:
|
||||||
|
|||||||
@@ -360,19 +360,19 @@ def gateway(
|
|||||||
return "cli", "direct"
|
return "cli", "direct"
|
||||||
|
|
||||||
# Create heartbeat service
|
# Create heartbeat service
|
||||||
async def on_heartbeat(prompt: str) -> str:
|
async def on_heartbeat_execute(tasks: str) -> str:
|
||||||
"""Execute heartbeat through the agent."""
|
"""Phase 2: execute heartbeat tasks through the full agent loop."""
|
||||||
channel, chat_id = _pick_heartbeat_target()
|
channel, chat_id = _pick_heartbeat_target()
|
||||||
|
|
||||||
async def _silent(*_args, **_kwargs):
|
async def _silent(*_args, **_kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return await agent.process_direct(
|
return await agent.process_direct(
|
||||||
prompt,
|
tasks,
|
||||||
session_key="heartbeat",
|
session_key="heartbeat",
|
||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
on_progress=_silent, # suppress: heartbeat should not push progress to external channels
|
on_progress=_silent,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_heartbeat_notify(response: str) -> None:
|
async def on_heartbeat_notify(response: str) -> None:
|
||||||
@@ -383,12 +383,15 @@ def gateway(
|
|||||||
return # No external channel available to deliver to
|
return # No external channel available to deliver to
|
||||||
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
||||||
|
|
||||||
|
hb_cfg = config.gateway.heartbeat
|
||||||
heartbeat = HeartbeatService(
|
heartbeat = HeartbeatService(
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
on_heartbeat=on_heartbeat,
|
provider=provider,
|
||||||
|
model=agent.model,
|
||||||
|
on_execute=on_heartbeat_execute,
|
||||||
on_notify=on_heartbeat_notify,
|
on_notify=on_heartbeat_notify,
|
||||||
interval_s=30 * 60, # 30 minutes
|
interval_s=hb_cfg.interval_s,
|
||||||
enabled=True
|
enabled=hb_cfg.enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
if channels.enabled_channels:
|
if channels.enabled_channels:
|
||||||
@@ -400,7 +403,7 @@ def gateway(
|
|||||||
if cron_status["jobs"] > 0:
|
if cron_status["jobs"] > 0:
|
||||||
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Heartbeat: every 30m")
|
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||||
|
|
||||||
async def run():
|
async def run():
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -228,11 +228,19 @@ class ProvidersConfig(Base):
|
|||||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||||
|
|
||||||
|
|
||||||
|
class HeartbeatConfig(Base):
|
||||||
|
"""Heartbeat service configuration."""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
interval_s: int = 30 * 60 # 30 minutes
|
||||||
|
|
||||||
|
|
||||||
class GatewayConfig(Base):
|
class GatewayConfig(Base):
|
||||||
"""Gateway/server configuration."""
|
"""Gateway/server configuration."""
|
||||||
|
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
port: int = 18790
|
port: int = 18790
|
||||||
|
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
|
||||||
|
|
||||||
|
|
||||||
class WebSearchConfig(Base):
|
class WebSearchConfig(Base):
|
||||||
|
|||||||
@@ -1,80 +1,110 @@
|
|||||||
"""Heartbeat service - periodic agent wake-up to check for tasks."""
|
"""Heartbeat service - periodic agent wake-up to check for tasks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# Default interval: 30 minutes
|
if TYPE_CHECKING:
|
||||||
DEFAULT_HEARTBEAT_INTERVAL_S = 30 * 60
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
# Token the agent replies with when there is nothing to report
|
_HEARTBEAT_TOOL = [
|
||||||
HEARTBEAT_OK_TOKEN = "HEARTBEAT_OK"
|
{
|
||||||
|
"type": "function",
|
||||||
# The prompt sent to agent during heartbeat
|
"function": {
|
||||||
HEARTBEAT_PROMPT = (
|
"name": "heartbeat",
|
||||||
"Read HEARTBEAT.md in your workspace and follow any instructions listed there. "
|
"description": "Report heartbeat decision after reviewing tasks.",
|
||||||
f"If nothing needs attention, reply with exactly: {HEARTBEAT_OK_TOKEN}"
|
"parameters": {
|
||||||
)
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {
|
||||||
def _is_heartbeat_empty(content: str | None) -> bool:
|
"type": "string",
|
||||||
"""Check if HEARTBEAT.md has no actionable content."""
|
"enum": ["skip", "run"],
|
||||||
if not content:
|
"description": "skip = nothing to do, run = has active tasks",
|
||||||
return True
|
},
|
||||||
|
"tasks": {
|
||||||
# Lines to skip: empty, headers, HTML comments, empty checkboxes
|
"type": "string",
|
||||||
skip_patterns = {"- [ ]", "* [ ]", "- [x]", "* [x]"}
|
"description": "Natural-language summary of active tasks (required for run)",
|
||||||
|
},
|
||||||
for line in content.split("\n"):
|
},
|
||||||
line = line.strip()
|
"required": ["action"],
|
||||||
if not line or line.startswith("#") or line.startswith("<!--") or line in skip_patterns:
|
},
|
||||||
continue
|
},
|
||||||
return False # Found actionable content
|
}
|
||||||
|
]
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class HeartbeatService:
|
class HeartbeatService:
|
||||||
"""
|
"""
|
||||||
Periodic heartbeat service that wakes the agent to check for tasks.
|
Periodic heartbeat service that wakes the agent to check for tasks.
|
||||||
|
|
||||||
The agent reads HEARTBEAT.md from the workspace and executes any tasks
|
Phase 1 (decision): reads HEARTBEAT.md and asks the LLM — via a virtual
|
||||||
listed there. If it has something to report, the response is forwarded
|
tool call — whether there are active tasks. This avoids free-text parsing
|
||||||
to the user via on_notify. If nothing needs attention, the agent replies
|
and the unreliable HEARTBEAT_OK token.
|
||||||
HEARTBEAT_OK and the response is silently dropped.
|
|
||||||
|
Phase 2 (execution): only triggered when Phase 1 returns ``run``. The
|
||||||
|
``on_execute`` callback runs the task through the full agent loop and
|
||||||
|
returns the result to deliver.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
on_heartbeat: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
provider: LLMProvider,
|
||||||
|
model: str,
|
||||||
|
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
||||||
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
||||||
interval_s: int = DEFAULT_HEARTBEAT_INTERVAL_S,
|
interval_s: int = 30 * 60,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
):
|
):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.on_heartbeat = on_heartbeat
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.on_execute = on_execute
|
||||||
self.on_notify = on_notify
|
self.on_notify = on_notify
|
||||||
self.interval_s = interval_s
|
self.interval_s = interval_s
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task: asyncio.Task | None = None
|
self._task: asyncio.Task | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def heartbeat_file(self) -> Path:
|
def heartbeat_file(self) -> Path:
|
||||||
return self.workspace / "HEARTBEAT.md"
|
return self.workspace / "HEARTBEAT.md"
|
||||||
|
|
||||||
def _read_heartbeat_file(self) -> str | None:
|
def _read_heartbeat_file(self) -> str | None:
|
||||||
"""Read HEARTBEAT.md content."""
|
|
||||||
if self.heartbeat_file.exists():
|
if self.heartbeat_file.exists():
|
||||||
try:
|
try:
|
||||||
return self.heartbeat_file.read_text(encoding="utf-8")
|
return self.heartbeat_file.read_text(encoding="utf-8")
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _decide(self, content: str) -> tuple[str, str]:
|
||||||
|
"""Phase 1: ask LLM to decide skip/run via virtual tool call.
|
||||||
|
|
||||||
|
Returns (action, tasks) where action is 'skip' or 'run'.
|
||||||
|
"""
|
||||||
|
response = await self.provider.chat(
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
||||||
|
{"role": "user", "content": (
|
||||||
|
"Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n"
|
||||||
|
f"{content}"
|
||||||
|
)},
|
||||||
|
],
|
||||||
|
tools=_HEARTBEAT_TOOL,
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.has_tool_calls:
|
||||||
|
return "skip", ""
|
||||||
|
|
||||||
|
args = response.tool_calls[0].arguments
|
||||||
|
return args.get("action", "skip"), args.get("tasks", "")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the heartbeat service."""
|
"""Start the heartbeat service."""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
@@ -83,18 +113,18 @@ class HeartbeatService:
|
|||||||
if self._running:
|
if self._running:
|
||||||
logger.warning("Heartbeat already running")
|
logger.warning("Heartbeat already running")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
self._task = asyncio.create_task(self._run_loop())
|
||||||
logger.info("Heartbeat started (every {}s)", self.interval_s)
|
logger.info("Heartbeat started (every {}s)", self.interval_s)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the heartbeat service."""
|
"""Stop the heartbeat service."""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._task:
|
if self._task:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
self._task = None
|
self._task = None
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
async def _run_loop(self) -> None:
|
||||||
"""Main heartbeat loop."""
|
"""Main heartbeat loop."""
|
||||||
while self._running:
|
while self._running:
|
||||||
@@ -106,32 +136,38 @@ class HeartbeatService:
|
|||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Heartbeat error: {}", e)
|
logger.error("Heartbeat error: {}", e)
|
||||||
|
|
||||||
async def _tick(self) -> None:
|
async def _tick(self) -> None:
|
||||||
"""Execute a single heartbeat tick."""
|
"""Execute a single heartbeat tick."""
|
||||||
content = self._read_heartbeat_file()
|
content = self._read_heartbeat_file()
|
||||||
|
if not content:
|
||||||
# Skip if HEARTBEAT.md is empty or doesn't exist
|
logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
|
||||||
if _is_heartbeat_empty(content):
|
|
||||||
logger.debug("Heartbeat: no tasks (HEARTBEAT.md empty)")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Heartbeat: checking for tasks...")
|
logger.info("Heartbeat: checking for tasks...")
|
||||||
|
|
||||||
if self.on_heartbeat:
|
try:
|
||||||
try:
|
action, tasks = await self._decide(content)
|
||||||
response = await self.on_heartbeat(HEARTBEAT_PROMPT)
|
|
||||||
if HEARTBEAT_OK_TOKEN in response.upper():
|
if action != "run":
|
||||||
logger.info("Heartbeat: OK (nothing to report)")
|
logger.info("Heartbeat: OK (nothing to report)")
|
||||||
else:
|
return
|
||||||
|
|
||||||
|
logger.info("Heartbeat: tasks found, executing...")
|
||||||
|
if self.on_execute:
|
||||||
|
response = await self.on_execute(tasks)
|
||||||
|
if response and self.on_notify:
|
||||||
logger.info("Heartbeat: completed, delivering response")
|
logger.info("Heartbeat: completed, delivering response")
|
||||||
if self.on_notify:
|
await self.on_notify(response)
|
||||||
await self.on_notify(response)
|
except Exception:
|
||||||
except Exception:
|
logger.exception("Heartbeat execution failed")
|
||||||
logger.exception("Heartbeat execution failed")
|
|
||||||
|
|
||||||
async def trigger_now(self) -> str | None:
|
async def trigger_now(self) -> str | None:
|
||||||
"""Manually trigger a heartbeat."""
|
"""Manually trigger a heartbeat."""
|
||||||
if self.on_heartbeat:
|
content = self._read_heartbeat_file()
|
||||||
return await self.on_heartbeat(HEARTBEAT_PROMPT)
|
if not content:
|
||||||
return None
|
return None
|
||||||
|
action, tasks = await self._decide(content)
|
||||||
|
if action != "run" or not self.on_execute:
|
||||||
|
return None
|
||||||
|
return await self.on_execute(tasks)
|
||||||
|
|||||||
@@ -12,8 +12,9 @@ from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
from nanobot.providers.registry import find_by_model, find_gateway
|
||||||
|
|
||||||
|
|
||||||
# Standard OpenAI chat-completion message keys; extras (e.g. reasoning_content) are stripped for strict providers.
|
# Standard OpenAI chat-completion message keys plus reasoning_content for
|
||||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
|
||||||
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(LLMProvider):
|
class LiteLLMProvider(LLMProvider):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "nanobot-ai"
|
name = "nanobot-ai"
|
||||||
version = "0.1.4.post1"
|
version = "0.1.4.post2"
|
||||||
description = "A lightweight personal AI assistant framework"
|
description = "A lightweight personal AI assistant framework"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
|
|||||||
63
tests/test_context_prompt_cache.py
Normal file
63
tests/test_context_prompt_cache.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Tests for cache-friendly prompt construction."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime as real_datetime
|
||||||
|
from pathlib import Path
|
||||||
|
import datetime as datetime_module
|
||||||
|
|
||||||
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeDatetime(real_datetime):
|
||||||
|
current = real_datetime(2026, 2, 24, 13, 59)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def now(cls, tz=None): # type: ignore[override]
|
||||||
|
return cls.current
|
||||||
|
|
||||||
|
|
||||||
|
def _make_workspace(tmp_path: Path) -> Path:
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir(parents=True)
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
|
||||||
|
"""System prompt should not change just because wall clock minute changes."""
|
||||||
|
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
|
||||||
|
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
|
||||||
|
prompt1 = builder.build_system_prompt()
|
||||||
|
|
||||||
|
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
|
||||||
|
prompt2 = builder.build_system_prompt()
|
||||||
|
|
||||||
|
assert prompt1 == prompt2
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_context_is_appended_to_current_user_message(tmp_path) -> None:
|
||||||
|
"""Dynamic runtime details should be added at the tail user message, not system."""
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
messages = builder.build_messages(
|
||||||
|
history=[],
|
||||||
|
current_message="Return exactly: OK",
|
||||||
|
channel="cli",
|
||||||
|
chat_id="direct",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert messages[0]["role"] == "system"
|
||||||
|
assert "## Current Session" not in messages[0]["content"]
|
||||||
|
|
||||||
|
assert messages[-1]["role"] == "user"
|
||||||
|
user_content = messages[-1]["content"]
|
||||||
|
assert isinstance(user_content, str)
|
||||||
|
assert "Return exactly: OK" in user_content
|
||||||
|
assert "Current Time:" in user_content
|
||||||
|
assert "Channel: cli" in user_content
|
||||||
|
assert "Chat ID: direct" in user_content
|
||||||
@@ -169,7 +169,8 @@ async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
|
||||||
|
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
|
||||||
class FakeSMTP:
|
class FakeSMTP:
|
||||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||||
self.sent_messages: list[EmailMessage] = []
|
self.sent_messages: list[EmailMessage] = []
|
||||||
@@ -201,6 +202,11 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
|||||||
cfg = _make_config()
|
cfg = _make_config()
|
||||||
cfg.auto_reply_enabled = False
|
cfg.auto_reply_enabled = False
|
||||||
channel = EmailChannel(cfg, MessageBus())
|
channel = EmailChannel(cfg, MessageBus())
|
||||||
|
|
||||||
|
# Mark alice as someone who sent us an email (making this a "reply")
|
||||||
|
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
|
||||||
|
|
||||||
|
# Reply should be skipped (auto_reply_enabled=False)
|
||||||
await channel.send(
|
await channel.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
channel="email",
|
channel="email",
|
||||||
@@ -210,6 +216,7 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
|||||||
)
|
)
|
||||||
assert fake_instances == []
|
assert fake_instances == []
|
||||||
|
|
||||||
|
# Reply with force_send=True should be sent
|
||||||
await channel.send(
|
await channel.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
channel="email",
|
channel="email",
|
||||||
@@ -222,6 +229,56 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
|||||||
assert len(fake_instances[0].sent_messages) == 1
|
assert len(fake_instances[0].sent_messages) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
|
||||||
|
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
|
||||||
|
class FakeSMTP:
|
||||||
|
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||||
|
self.sent_messages: list[EmailMessage] = []
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def starttls(self, context=None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def login(self, _user: str, _pw: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send_message(self, msg: EmailMessage):
|
||||||
|
self.sent_messages.append(msg)
|
||||||
|
|
||||||
|
fake_instances: list[FakeSMTP] = []
|
||||||
|
|
||||||
|
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||||
|
instance = FakeSMTP(host, port, timeout=timeout)
|
||||||
|
fake_instances.append(instance)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||||
|
|
||||||
|
cfg = _make_config()
|
||||||
|
cfg.auto_reply_enabled = False
|
||||||
|
channel = EmailChannel(cfg, MessageBus())
|
||||||
|
|
||||||
|
# bob@example.com has never sent us an email (proactive send)
|
||||||
|
# This should be sent even with auto_reply_enabled=False
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="email",
|
||||||
|
chat_id="bob@example.com",
|
||||||
|
content="Hello, this is a proactive email.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(fake_instances) == 1
|
||||||
|
assert len(fake_instances[0].sent_messages) == 1
|
||||||
|
sent = fake_instances[0].sent_messages[0]
|
||||||
|
assert sent["To"] == "bob@example.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
||||||
class FakeSMTP:
|
class FakeSMTP:
|
||||||
|
|||||||
147
tests/test_memory_consolidation_types.py
Normal file
147
tests/test_memory_consolidation_types.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
|
||||||
|
|
||||||
|
Regression test for https://github.com/HKUDS/nanobot/issues/1042
|
||||||
|
When memory consolidation receives dict values instead of strings from the LLM
|
||||||
|
tool call response, it should serialize them to JSON instead of raising TypeError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.memory import MemoryStore
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session(message_count: int = 30, memory_window: int = 50):
|
||||||
|
"""Create a mock session with messages."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.messages = [
|
||||||
|
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||||
|
for i in range(message_count)
|
||||||
|
]
|
||||||
|
session.last_consolidated = 0
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool_response(history_entry, memory_update):
|
||||||
|
"""Create an LLMResponse with a save_memory tool call."""
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={
|
||||||
|
"history_entry": history_entry,
|
||||||
|
"memory_update": memory_update,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryConsolidationTypeHandling:
|
||||||
|
"""Test that consolidation handles various argument types correctly."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_string_arguments_work(self, tmp_path: Path) -> None:
|
||||||
|
"""Normal case: LLM returns string arguments."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry="[2026-01-01] User discussed testing.",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert store.history_file.exists()
|
||||||
|
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
|
||||||
|
assert "User likes testing." in store.memory_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
|
||||||
|
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
|
||||||
|
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert store.history_file.exists()
|
||||||
|
history_content = store.history_file.read_text()
|
||||||
|
parsed = json.loads(history_content.strip())
|
||||||
|
assert parsed["summary"] == "User discussed testing."
|
||||||
|
|
||||||
|
memory_content = store.memory_file.read_text()
|
||||||
|
parsed_mem = json.loads(memory_content)
|
||||||
|
assert "User likes testing" in parsed_mem["facts"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
|
||||||
|
"""Some providers return arguments as a JSON string instead of parsed dict."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
# Simulate arguments being a JSON string (not yet parsed)
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=json.dumps({
|
||||||
|
"history_entry": "[2026-01-01] User discussed testing.",
|
||||||
|
"memory_update": "# Memory\nUser likes testing.",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
|
||||||
|
"""When LLM doesn't use the save_memory tool, return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
||||||
|
"""Consolidation should be a no-op when messages < keep_count."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
session = _make_session(message_count=10)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
provider.chat.assert_not_called()
|
||||||
Reference in New Issue
Block a user