fix(memory): Enforce memory consolidation schema with a tool call

This commit is contained in:
Rudolfs Tilgass
2026-02-19 21:02:52 +01:00
parent d22929305f
commit afca0278ad

View File

@@ -3,7 +3,6 @@
import asyncio import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
import json import json
import json_repair
from pathlib import Path from pathlib import Path
import re import re
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable
@@ -84,13 +83,13 @@ class AgentLoop:
exec_config=self.exec_config, exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace, restrict_to_workspace=restrict_to_workspace,
) )
self._running = False self._running = False
self._mcp_servers = mcp_servers or {} self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False self._mcp_connected = False
self._register_default_tools() self._register_default_tools()
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
"""Register the default set of tools.""" """Register the default set of tools."""
# File tools (restrict to workspace if configured) # File tools (restrict to workspace if configured)
@@ -99,30 +98,30 @@ class AgentLoop:
self.tools.register(WriteFileTool(allowed_dir=allowed_dir)) self.tools.register(WriteFileTool(allowed_dir=allowed_dir))
self.tools.register(EditFileTool(allowed_dir=allowed_dir)) self.tools.register(EditFileTool(allowed_dir=allowed_dir))
self.tools.register(ListDirTool(allowed_dir=allowed_dir)) self.tools.register(ListDirTool(allowed_dir=allowed_dir))
# Shell tool # Shell tool
self.tools.register(ExecTool( self.tools.register(ExecTool(
working_dir=str(self.workspace), working_dir=str(self.workspace),
timeout=self.exec_config.timeout, timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace, restrict_to_workspace=self.restrict_to_workspace,
)) ))
# Web tools # Web tools
self.tools.register(WebSearchTool(api_key=self.brave_api_key)) self.tools.register(WebSearchTool(api_key=self.brave_api_key))
self.tools.register(WebFetchTool()) self.tools.register(WebFetchTool())
# Message tool # Message tool
message_tool = MessageTool(send_callback=self.bus.publish_outbound) message_tool = MessageTool(send_callback=self.bus.publish_outbound)
self.tools.register(message_tool) self.tools.register(message_tool)
# Spawn tool (for subagents) # Spawn tool (for subagents)
spawn_tool = SpawnTool(manager=self.subagents) spawn_tool = SpawnTool(manager=self.subagents)
self.tools.register(spawn_tool) self.tools.register(spawn_tool)
# Cron tool (for scheduling) # Cron tool (for scheduling)
if self.cron_service: if self.cron_service:
self.tools.register(CronTool(self.cron_service)) self.tools.register(CronTool(self.cron_service))
async def _connect_mcp(self) -> None: async def _connect_mcp(self) -> None:
"""Connect to configured MCP servers (one-time, lazy).""" """Connect to configured MCP servers (one-time, lazy)."""
if self._mcp_connected or not self._mcp_servers: if self._mcp_connected or not self._mcp_servers:
@@ -255,7 +254,7 @@ class AgentLoop:
)) ))
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
"""Close MCP connections.""" """Close MCP connections."""
if self._mcp_stack: if self._mcp_stack:
@@ -269,7 +268,7 @@ class AgentLoop:
"""Stop the agent loop.""" """Stop the agent loop."""
self._running = False self._running = False
logger.info("Agent loop stopping") logger.info("Agent loop stopping")
async def _process_message( async def _process_message(
self, self,
msg: InboundMessage, msg: InboundMessage,
@@ -278,25 +277,25 @@ class AgentLoop:
) -> OutboundMessage | None: ) -> OutboundMessage | None:
""" """
Process a single inbound message. Process a single inbound message.
Args: Args:
msg: The inbound message to process. msg: The inbound message to process.
session_key: Override session key (used by process_direct). session_key: Override session key (used by process_direct).
on_progress: Optional callback for intermediate output (defaults to bus publish). on_progress: Optional callback for intermediate output (defaults to bus publish).
Returns: Returns:
The response message, or None if no response needed. The response message, or None if no response needed.
""" """
# System messages route back via chat_id ("channel:chat_id") # System messages route back via chat_id ("channel:chat_id")
if msg.channel == "system": if msg.channel == "system":
return await self._process_system_message(msg) return await self._process_system_message(msg)
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}") logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
key = session_key or msg.session_key key = session_key or msg.session_key
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
# Handle slash commands # Handle slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
@@ -317,7 +316,7 @@ class AgentLoop:
if cmd == "/help": if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands") content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
if len(session.messages) > self.memory_window: if len(session.messages) > self.memory_window:
asyncio.create_task(self._consolidate_memory(session)) asyncio.create_task(self._consolidate_memory(session))
@@ -342,31 +341,31 @@ class AgentLoop:
if final_content is None: if final_content is None:
final_content = "I've completed processing but have no response to give." final_content = "I've completed processing but have no response to give."
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}") logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
session.add_message("user", msg.content) session.add_message("user", msg.content)
session.add_message("assistant", final_content, session.add_message("assistant", final_content,
tools_used=tools_used if tools_used else None) tools_used=tools_used if tools_used else None)
self.sessions.save(session) self.sessions.save(session)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
content=final_content, content=final_content,
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts) metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
) )
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None: async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
""" """
Process a system message (e.g., subagent announce). Process a system message (e.g., subagent announce).
The chat_id field contains "original_channel:original_chat_id" to route The chat_id field contains "original_channel:original_chat_id" to route
the response back to the correct destination. the response back to the correct destination.
""" """
logger.info(f"Processing system message from {msg.sender_id}") logger.info(f"Processing system message from {msg.sender_id}")
# Parse origin from chat_id (format: "channel:chat_id") # Parse origin from chat_id (format: "channel:chat_id")
if ":" in msg.chat_id: if ":" in msg.chat_id:
parts = msg.chat_id.split(":", 1) parts = msg.chat_id.split(":", 1)
@@ -376,7 +375,7 @@ class AgentLoop:
# Fallback # Fallback
origin_channel = "cli" origin_channel = "cli"
origin_chat_id = msg.chat_id origin_chat_id = msg.chat_id
session_key = f"{origin_channel}:{origin_chat_id}" session_key = f"{origin_channel}:{origin_chat_id}"
session = self.sessions.get_or_create(session_key) session = self.sessions.get_or_create(session_key)
self._set_tool_context(origin_channel, origin_chat_id) self._set_tool_context(origin_channel, origin_chat_id)
@@ -390,17 +389,17 @@ class AgentLoop:
if final_content is None: if final_content is None:
final_content = "Background task completed." final_content = "Background task completed."
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}") session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
session.add_message("assistant", final_content) session.add_message("assistant", final_content)
self.sessions.save(session) self.sessions.save(session)
return OutboundMessage( return OutboundMessage(
channel=origin_channel, channel=origin_channel,
chat_id=origin_chat_id, chat_id=origin_chat_id,
content=final_content content=final_content
) )
async def _consolidate_memory(self, session, archive_all: bool = False) -> None: async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
"""Consolidate old messages into MEMORY.md + HISTORY.md. """Consolidate old messages into MEMORY.md + HISTORY.md.
@@ -439,42 +438,56 @@ class AgentLoop:
conversation = "\n".join(lines) conversation = "\n".join(lines)
current_memory = memory.read_long_term() current_memory = memory.read_long_term()
prompt = f"""You are a memory consolidation agent. Process this conversation and return a JSON object with exactly two keys: prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
1. "history_entry": A paragraph (2-5 sentences) summarizing the key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later.
2. "memory_update": The updated long-term memory content. Add any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.
## Current Long-term Memory ## Current Long-term Memory
{current_memory or "(empty)"} {current_memory or "(empty)"}
## Conversation to Process ## Conversation to Process
{conversation} {conversation}"""
Respond with ONLY valid JSON, no markdown fences.""" save_memory_tool = [
{
"type": "function",
"function": {
"name": "save_memory",
"description": "Save the memory consolidation result to persistent storage.",
"parameters": {
"type": "object",
"properties": {
"history_entry": {
"type": "string",
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later.",
},
"memory_update": {
"type": "string",
"description": "The full updated long-term memory content as a markdown string. Include all existing facts plus any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.",
},
},
"required": ["history_entry", "memory_update"],
},
},
}
]
try: try:
response = await self.provider.chat( response = await self.provider.chat(
messages=[ messages=[
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."}, {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
], ],
tools=save_memory_tool,
model=self.model, model=self.model,
) )
text = (response.content or "").strip()
if not text: if not response.has_tool_calls:
logger.warning("Memory consolidation: LLM returned empty response, skipping") logger.warning("Memory consolidation: LLM did not call save_memory tool, skipping")
return
if text.startswith("```"):
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
result = json_repair.loads(text)
if not isinstance(result, dict):
logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}")
return return
if entry := result.get("history_entry"): args = response.tool_calls[0].arguments
if entry := args.get("history_entry"):
memory.append_history(entry) memory.append_history(entry)
if update := result.get("memory_update"): if update := args.get("memory_update"):
if update != current_memory: if update != current_memory:
memory.write_long_term(update) memory.write_long_term(update)
@@ -496,14 +509,14 @@ Respond with ONLY valid JSON, no markdown fences."""
) -> str: ) -> str:
""" """
Process a message directly (for CLI or cron usage). Process a message directly (for CLI or cron usage).
Args: Args:
content: The message content. content: The message content.
session_key: Session identifier (overrides channel:chat_id for session lookup). session_key: Session identifier (overrides channel:chat_id for session lookup).
channel: Source channel (for tool context routing). channel: Source channel (for tool context routing).
chat_id: Source chat ID (for tool context routing). chat_id: Source chat ID (for tool context routing).
on_progress: Optional callback for intermediate output. on_progress: Optional callback for intermediate output.
Returns: Returns:
The agent's response. The agent's response.
""" """
@@ -514,6 +527,6 @@ Respond with ONLY valid JSON, no markdown fences."""
chat_id=chat_id, chat_id=chat_id,
content=content content=content
) )
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress) response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
return response.content if response else "" return response.content if response else ""