Merge remote-tracking branch 'origin/main' into pr-1062

This commit is contained in:
Re-bin
2026-02-24 12:14:00 +00:00
10 changed files with 391 additions and 119 deletions

View File

@@ -125,6 +125,13 @@ class MemoryStore:
return False return False
args = response.tool_calls[0].arguments args = response.tool_calls[0].arguments
# Some providers return arguments as a JSON string instead of dict
if isinstance(args, str):
args = json.loads(args)
if not isinstance(args, dict):
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False
if entry := args.get("history_entry"): if entry := args.get("history_entry"):
if not isinstance(entry, str): if not isinstance(entry, str):
entry = json.dumps(entry, ensure_ascii=False) entry = json.dumps(entry, ensure_ascii=False)

View File

@@ -58,12 +58,17 @@ class WebSearchTool(Tool):
} }
def __init__(self, api_key: str | None = None, max_results: int = 5): def __init__(self, api_key: str | None = None, max_results: int = 5):
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "") self.api_key = api_key
self.max_results = max_results self.max_results = max_results
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
if not self.api_key: api_key = self.api_key or os.environ.get("BRAVE_API_KEY", "")
return "Error: BRAVE_API_KEY not configured" if not api_key:
return (
"Error: Brave Search API key not configured. "
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey "
"(or export BRAVE_API_KEY), then restart the gateway."
)
try: try:
n = min(max(count or self.max_results, 1), 10) n = min(max(count or self.max_results, 1), 10)
@@ -71,7 +76,7 @@ class WebSearchTool(Tool):
r = await client.get( r = await client.get(
"https://api.search.brave.com/res/v1/web/search", "https://api.search.brave.com/res/v1/web/search",
params={"q": query, "count": n}, params={"q": query, "count": n},
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key}, headers={"Accept": "application/json", "X-Subscription-Token": api_key},
timeout=10.0 timeout=10.0
) )
r.raise_for_status() r.raise_for_status()

View File

@@ -108,11 +108,6 @@ class EmailChannel(BaseChannel):
logger.warning("Skip email send: consent_granted is false") logger.warning("Skip email send: consent_granted is false")
return return
force_send = bool((msg.metadata or {}).get("force_send"))
if not self.config.auto_reply_enabled and not force_send:
logger.info("Skip automatic email reply: auto_reply_enabled is false")
return
if not self.config.smtp_host: if not self.config.smtp_host:
logger.warning("Email channel SMTP host not configured") logger.warning("Email channel SMTP host not configured")
return return
@@ -122,6 +117,15 @@ class EmailChannel(BaseChannel):
logger.warning("Email channel missing recipient address") logger.warning("Email channel missing recipient address")
return return
# Determine if this is a reply (recipient has sent us an email before)
is_reply = to_addr in self._last_subject_by_chat
force_send = bool((msg.metadata or {}).get("force_send"))
# autoReplyEnabled only controls automatic replies, not proactive sends
if is_reply and not self.config.auto_reply_enabled and not force_send:
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
return
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply") base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
subject = self._reply_subject(base_subject) subject = self._reply_subject(base_subject)
if msg.metadata and isinstance(msg.metadata.get("subject"), str): if msg.metadata and isinstance(msg.metadata.get("subject"), str):

View File

@@ -180,21 +180,25 @@ def _extract_element_content(element: dict) -> list[str]:
return parts return parts
def _extract_post_text(content_json: dict) -> str: def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
"""Extract plain text from Feishu post (rich text) message content. """Extract text and image keys from Feishu post (rich text) message content.
Supports two formats: Supports two formats:
1. Direct format: {"title": "...", "content": [...]} 1. Direct format: {"title": "...", "content": [...]}
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}} 2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
Returns:
(text, image_keys) - extracted text and list of image keys
""" """
def extract_from_lang(lang_content: dict) -> str | None: def extract_from_lang(lang_content: dict) -> tuple[str | None, list[str]]:
if not isinstance(lang_content, dict): if not isinstance(lang_content, dict):
return None return None, []
title = lang_content.get("title", "") title = lang_content.get("title", "")
content_blocks = lang_content.get("content", []) content_blocks = lang_content.get("content", [])
if not isinstance(content_blocks, list): if not isinstance(content_blocks, list):
return None return None, []
text_parts = [] text_parts = []
image_keys = []
if title: if title:
text_parts.append(title) text_parts.append(title)
for block in content_blocks: for block in content_blocks:
@@ -209,22 +213,36 @@ def _extract_post_text(content_json: dict) -> str:
text_parts.append(element.get("text", "")) text_parts.append(element.get("text", ""))
elif tag == "at": elif tag == "at":
text_parts.append(f"@{element.get('user_name', 'user')}") text_parts.append(f"@{element.get('user_name', 'user')}")
return " ".join(text_parts).strip() if text_parts else None elif tag == "img":
img_key = element.get("image_key")
if img_key:
image_keys.append(img_key)
text = " ".join(text_parts).strip() if text_parts else None
return text, image_keys
# Try direct format first # Try direct format first
if "content" in content_json: if "content" in content_json:
result = extract_from_lang(content_json) text, images = extract_from_lang(content_json)
if result: if text or images:
return result return text or "", images
# Try localized format # Try localized format
for lang_key in ("zh_cn", "en_us", "ja_jp"): for lang_key in ("zh_cn", "en_us", "ja_jp"):
lang_content = content_json.get(lang_key) lang_content = content_json.get(lang_key)
result = extract_from_lang(lang_content) text, images = extract_from_lang(lang_content)
if result: if text or images:
return result return text or "", images
return "" return "", []
def _extract_post_text(content_json: dict) -> str:
"""Extract plain text from Feishu post (rich text) message content.
Legacy wrapper for _extract_post_content, returns only text.
"""
text, _ = _extract_post_content(content_json)
return text
class FeishuChannel(BaseChannel): class FeishuChannel(BaseChannel):
@@ -691,9 +709,17 @@ class FeishuChannel(BaseChannel):
content_parts.append(text) content_parts.append(text)
elif msg_type == "post": elif msg_type == "post":
text = _extract_post_text(content_json) text, image_keys = _extract_post_content(content_json)
if text: if text:
content_parts.append(text) content_parts.append(text)
# Download images embedded in post
for img_key in image_keys:
file_path, content_text = await self._download_and_save_media(
"image", {"image_key": img_key}, message_id
)
if file_path:
media_paths.append(file_path)
content_parts.append(content_text)
elif msg_type in ("image", "audio", "file", "media"): elif msg_type in ("image", "audio", "file", "media"):
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id) file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)

View File

@@ -360,19 +360,19 @@ def gateway(
return "cli", "direct" return "cli", "direct"
# Create heartbeat service # Create heartbeat service
async def on_heartbeat(prompt: str) -> str: async def on_heartbeat_execute(tasks: str) -> str:
"""Execute heartbeat through the agent.""" """Phase 2: execute heartbeat tasks through the full agent loop."""
channel, chat_id = _pick_heartbeat_target() channel, chat_id = _pick_heartbeat_target()
async def _silent(*_args, **_kwargs): async def _silent(*_args, **_kwargs):
pass pass
return await agent.process_direct( return await agent.process_direct(
prompt, tasks,
session_key="heartbeat", session_key="heartbeat",
channel=channel, channel=channel,
chat_id=chat_id, chat_id=chat_id,
on_progress=_silent, # suppress: heartbeat should not push progress to external channels on_progress=_silent,
) )
async def on_heartbeat_notify(response: str) -> None: async def on_heartbeat_notify(response: str) -> None:
@@ -383,12 +383,15 @@ def gateway(
return # No external channel available to deliver to return # No external channel available to deliver to
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response)) await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
hb_cfg = config.gateway.heartbeat
heartbeat = HeartbeatService( heartbeat = HeartbeatService(
workspace=config.workspace_path, workspace=config.workspace_path,
on_heartbeat=on_heartbeat, provider=provider,
model=agent.model,
on_execute=on_heartbeat_execute,
on_notify=on_heartbeat_notify, on_notify=on_heartbeat_notify,
interval_s=30 * 60, # 30 minutes interval_s=hb_cfg.interval_s,
enabled=True enabled=hb_cfg.enabled,
) )
if channels.enabled_channels: if channels.enabled_channels:
@@ -400,7 +403,7 @@ def gateway(
if cron_status["jobs"] > 0: if cron_status["jobs"] > 0:
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs") console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
console.print(f"[green]✓[/green] Heartbeat: every 30m") console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
async def run(): async def run():
try: try:

View File

@@ -228,11 +228,19 @@ class ProvidersConfig(Base):
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth) github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
class HeartbeatConfig(Base):
"""Heartbeat service configuration."""
enabled: bool = True
interval_s: int = 30 * 60 # 30 minutes
class GatewayConfig(Base): class GatewayConfig(Base):
"""Gateway/server configuration.""" """Gateway/server configuration."""
host: str = "0.0.0.0" host: str = "0.0.0.0"
port: int = 18790 port: int = 18790
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
class WebSearchConfig(Base): class WebSearchConfig(Base):

View File

@@ -1,61 +1,69 @@
"""Heartbeat service - periodic agent wake-up to check for tasks.""" """Heartbeat service - periodic agent wake-up to check for tasks."""
from __future__ import annotations
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Coroutine from typing import TYPE_CHECKING, Any, Callable, Coroutine
from loguru import logger from loguru import logger
# Default interval: 30 minutes if TYPE_CHECKING:
DEFAULT_HEARTBEAT_INTERVAL_S = 30 * 60 from nanobot.providers.base import LLMProvider
# Token the agent replies with when there is nothing to report _HEARTBEAT_TOOL = [
HEARTBEAT_OK_TOKEN = "HEARTBEAT_OK" {
"type": "function",
# The prompt sent to agent during heartbeat "function": {
HEARTBEAT_PROMPT = ( "name": "heartbeat",
"Read HEARTBEAT.md in your workspace and follow any instructions listed there. " "description": "Report heartbeat decision after reviewing tasks.",
f"If nothing needs attention, reply with exactly: {HEARTBEAT_OK_TOKEN}" "parameters": {
) "type": "object",
"properties": {
"action": {
def _is_heartbeat_empty(content: str | None) -> bool: "type": "string",
"""Check if HEARTBEAT.md has no actionable content.""" "enum": ["skip", "run"],
if not content: "description": "skip = nothing to do, run = has active tasks",
return True },
"tasks": {
# Lines to skip: empty, headers, HTML comments, empty checkboxes "type": "string",
skip_patterns = {"- [ ]", "* [ ]", "- [x]", "* [x]"} "description": "Natural-language summary of active tasks (required for run)",
},
for line in content.split("\n"): },
line = line.strip() "required": ["action"],
if not line or line.startswith("#") or line.startswith("<!--") or line in skip_patterns: },
continue },
return False # Found actionable content }
]
return True
class HeartbeatService: class HeartbeatService:
""" """
Periodic heartbeat service that wakes the agent to check for tasks. Periodic heartbeat service that wakes the agent to check for tasks.
The agent reads HEARTBEAT.md from the workspace and executes any tasks Phase 1 (decision): reads HEARTBEAT.md and asks the LLM — via a virtual
listed there. If it has something to report, the response is forwarded tool call — whether there are active tasks. This avoids free-text parsing
to the user via on_notify. If nothing needs attention, the agent replies and the unreliable HEARTBEAT_OK token.
HEARTBEAT_OK and the response is silently dropped.
Phase 2 (execution): only triggered when Phase 1 returns ``run``. The
``on_execute`` callback runs the task through the full agent loop and
returns the result to deliver.
""" """
def __init__( def __init__(
self, self,
workspace: Path, workspace: Path,
on_heartbeat: Callable[[str], Coroutine[Any, Any, str]] | None = None, provider: LLMProvider,
model: str,
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None, on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
interval_s: int = DEFAULT_HEARTBEAT_INTERVAL_S, interval_s: int = 30 * 60,
enabled: bool = True, enabled: bool = True,
): ):
self.workspace = workspace self.workspace = workspace
self.on_heartbeat = on_heartbeat self.provider = provider
self.model = model
self.on_execute = on_execute
self.on_notify = on_notify self.on_notify = on_notify
self.interval_s = interval_s self.interval_s = interval_s
self.enabled = enabled self.enabled = enabled
@@ -67,7 +75,6 @@ class HeartbeatService:
return self.workspace / "HEARTBEAT.md" return self.workspace / "HEARTBEAT.md"
def _read_heartbeat_file(self) -> str | None: def _read_heartbeat_file(self) -> str | None:
"""Read HEARTBEAT.md content."""
if self.heartbeat_file.exists(): if self.heartbeat_file.exists():
try: try:
return self.heartbeat_file.read_text(encoding="utf-8") return self.heartbeat_file.read_text(encoding="utf-8")
@@ -75,6 +82,29 @@ class HeartbeatService:
return None return None
return None return None
async def _decide(self, content: str) -> tuple[str, str]:
"""Phase 1: ask LLM to decide skip/run via virtual tool call.
Returns (action, tasks) where action is 'skip' or 'run'.
"""
response = await self.provider.chat(
messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": (
"Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n"
f"{content}"
)},
],
tools=_HEARTBEAT_TOOL,
model=self.model,
)
if not response.has_tool_calls:
return "skip", ""
args = response.tool_calls[0].arguments
return args.get("action", "skip"), args.get("tasks", "")
async def start(self) -> None: async def start(self) -> None:
"""Start the heartbeat service.""" """Start the heartbeat service."""
if not self.enabled: if not self.enabled:
@@ -110,28 +140,34 @@ class HeartbeatService:
async def _tick(self) -> None: async def _tick(self) -> None:
"""Execute a single heartbeat tick.""" """Execute a single heartbeat tick."""
content = self._read_heartbeat_file() content = self._read_heartbeat_file()
if not content:
# Skip if HEARTBEAT.md is empty or doesn't exist logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
if _is_heartbeat_empty(content):
logger.debug("Heartbeat: no tasks (HEARTBEAT.md empty)")
return return
logger.info("Heartbeat: checking for tasks...") logger.info("Heartbeat: checking for tasks...")
if self.on_heartbeat: try:
try: action, tasks = await self._decide(content)
response = await self.on_heartbeat(HEARTBEAT_PROMPT)
if HEARTBEAT_OK_TOKEN in response.upper(): if action != "run":
logger.info("Heartbeat: OK (nothing to report)") logger.info("Heartbeat: OK (nothing to report)")
else: return
logger.info("Heartbeat: tasks found, executing...")
if self.on_execute:
response = await self.on_execute(tasks)
if response and self.on_notify:
logger.info("Heartbeat: completed, delivering response") logger.info("Heartbeat: completed, delivering response")
if self.on_notify: await self.on_notify(response)
await self.on_notify(response) except Exception:
except Exception: logger.exception("Heartbeat execution failed")
logger.exception("Heartbeat execution failed")
async def trigger_now(self) -> str | None: async def trigger_now(self) -> str | None:
"""Manually trigger a heartbeat.""" """Manually trigger a heartbeat."""
if self.on_heartbeat: content = self._read_heartbeat_file()
return await self.on_heartbeat(HEARTBEAT_PROMPT) if not content:
return None return None
action, tasks = await self._decide(content)
if action != "run" or not self.on_execute:
return None
return await self.on_execute(tasks)

View File

@@ -10,27 +10,6 @@ This file documents non-obvious constraints and usage patterns.
- Output is truncated at 10,000 characters - Output is truncated at 10,000 characters
- `restrictToWorkspace` config can limit file access to the workspace - `restrictToWorkspace` config can limit file access to the workspace
## Cron — Scheduled Reminders ## cron — Scheduled Reminders
Use `exec` to create scheduled reminders: - Please refer to cron skill for usage.
```bash
# Recurring: every day at 9am
nanobot cron add --name "morning" --message "Good morning!" --cron "0 9 * * *"
# With timezone (--tz only works with --cron)
nanobot cron add --name "standup" --message "Standup time!" --cron "0 10 * * 1-5" --tz "Asia/Shanghai"
# Recurring: every 2 hours
nanobot cron add --name "water" --message "Drink water!" --every 7200
# One-time: specific ISO time
nanobot cron add --name "meeting" --message "Meeting starts now!" --at "2025-01-31T15:00:00"
# Deliver to a specific channel/user
nanobot cron add --name "reminder" --message "Check email" --at "2025-01-31T09:00:00" --deliver --to "USER_ID" --channel "CHANNEL"
# Manage jobs
nanobot cron list
nanobot cron remove <job_id>
```

View File

@@ -169,7 +169,8 @@ async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None: async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
class FakeSMTP: class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None: def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.sent_messages: list[EmailMessage] = [] self.sent_messages: list[EmailMessage] = []
@@ -201,6 +202,11 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
cfg = _make_config() cfg = _make_config()
cfg.auto_reply_enabled = False cfg.auto_reply_enabled = False
channel = EmailChannel(cfg, MessageBus()) channel = EmailChannel(cfg, MessageBus())
# Mark alice as someone who sent us an email (making this a "reply")
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
# Reply should be skipped (auto_reply_enabled=False)
await channel.send( await channel.send(
OutboundMessage( OutboundMessage(
channel="email", channel="email",
@@ -210,6 +216,7 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
) )
assert fake_instances == [] assert fake_instances == []
# Reply with force_send=True should be sent
await channel.send( await channel.send(
OutboundMessage( OutboundMessage(
channel="email", channel="email",
@@ -222,6 +229,56 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
assert len(fake_instances[0].sent_messages) == 1 assert len(fake_instances[0].sent_messages) == 1
@pytest.mark.asyncio
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
class FakeSMTP:
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
self.sent_messages: list[EmailMessage] = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def starttls(self, context=None):
return None
def login(self, _user: str, _pw: str):
return None
def send_message(self, msg: EmailMessage):
self.sent_messages.append(msg)
fake_instances: list[FakeSMTP] = []
def _smtp_factory(host: str, port: int, timeout: int = 30):
instance = FakeSMTP(host, port, timeout=timeout)
fake_instances.append(instance)
return instance
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
cfg = _make_config()
cfg.auto_reply_enabled = False
channel = EmailChannel(cfg, MessageBus())
# bob@example.com has never sent us an email (proactive send)
# This should be sent even with auto_reply_enabled=False
await channel.send(
OutboundMessage(
channel="email",
chat_id="bob@example.com",
content="Hello, this is a proactive email.",
)
)
assert len(fake_instances) == 1
assert len(fake_instances[0].sent_messages) == 1
sent = fake_instances[0].sent_messages[0]
assert sent["To"] == "bob@example.com"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None: async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
class FakeSMTP: class FakeSMTP:

View File

@@ -0,0 +1,147 @@
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
Regression test for https://github.com/HKUDS/nanobot/issues/1042
When memory consolidation receives dict values instead of strings from the LLM
tool call response, it should serialize them to JSON instead of raising TypeError.
"""
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50):
"""Create a mock session with messages."""
session = MagicMock()
session.messages = [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count)
]
session.last_consolidated = 0
return session
def _make_tool_response(history_entry, memory_update):
"""Create an LLMResponse with a save_memory tool call."""
return LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={
"history_entry": history_entry,
"memory_update": memory_update,
},
)
],
)
class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly."""
@pytest.mark.asyncio
async def test_string_arguments_work(self, tmp_path: Path) -> None:
"""Normal case: LLM returns string arguments."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert store.history_file.exists()
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
assert "User likes testing." in store.memory_file.read_text()
@pytest.mark.asyncio
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert store.history_file.exists()
history_content = store.history_file.read_text()
parsed = json.loads(history_content.strip())
assert parsed["summary"] == "User discussed testing."
memory_content = store.memory_file.read_text()
parsed_mem = json.loads(memory_content)
assert "User likes testing" in parsed_mem["facts"]
@pytest.mark.asyncio
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
"""Some providers return arguments as a JSON string instead of parsed dict."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments=json.dumps({
"history_entry": "[2026-01-01] User discussed testing.",
"memory_update": "# Memory\nUser likes testing.",
}),
)
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
"""When LLM doesn't use the save_memory tool, return False."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when messages < keep_count."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
session = _make_session(message_count=10)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
provider.chat.assert_not_called()