fix(memory): Enforce memory consolidation schema with a tool call
This commit is contained in:
@@ -3,7 +3,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
import json
|
import json
|
||||||
import json_repair
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Any, Awaitable, Callable
|
||||||
@@ -84,13 +83,13 @@ class AgentLoop:
|
|||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
restrict_to_workspace=restrict_to_workspace,
|
restrict_to_workspace=restrict_to_workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._mcp_servers = mcp_servers or {}
|
self._mcp_servers = mcp_servers or {}
|
||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stack: AsyncExitStack | None = None
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
"""Register the default set of tools."""
|
"""Register the default set of tools."""
|
||||||
# File tools (restrict to workspace if configured)
|
# File tools (restrict to workspace if configured)
|
||||||
@@ -99,30 +98,30 @@ class AgentLoop:
|
|||||||
self.tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
self.tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
||||||
self.tools.register(EditFileTool(allowed_dir=allowed_dir))
|
self.tools.register(EditFileTool(allowed_dir=allowed_dir))
|
||||||
self.tools.register(ListDirTool(allowed_dir=allowed_dir))
|
self.tools.register(ListDirTool(allowed_dir=allowed_dir))
|
||||||
|
|
||||||
# Shell tool
|
# Shell tool
|
||||||
self.tools.register(ExecTool(
|
self.tools.register(ExecTool(
|
||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
timeout=self.exec_config.timeout,
|
timeout=self.exec_config.timeout,
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Web tools
|
# Web tools
|
||||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
self.tools.register(WebFetchTool())
|
self.tools.register(WebFetchTool())
|
||||||
|
|
||||||
# Message tool
|
# Message tool
|
||||||
message_tool = MessageTool(send_callback=self.bus.publish_outbound)
|
message_tool = MessageTool(send_callback=self.bus.publish_outbound)
|
||||||
self.tools.register(message_tool)
|
self.tools.register(message_tool)
|
||||||
|
|
||||||
# Spawn tool (for subagents)
|
# Spawn tool (for subagents)
|
||||||
spawn_tool = SpawnTool(manager=self.subagents)
|
spawn_tool = SpawnTool(manager=self.subagents)
|
||||||
self.tools.register(spawn_tool)
|
self.tools.register(spawn_tool)
|
||||||
|
|
||||||
# Cron tool (for scheduling)
|
# Cron tool (for scheduling)
|
||||||
if self.cron_service:
|
if self.cron_service:
|
||||||
self.tools.register(CronTool(self.cron_service))
|
self.tools.register(CronTool(self.cron_service))
|
||||||
|
|
||||||
async def _connect_mcp(self) -> None:
|
async def _connect_mcp(self) -> None:
|
||||||
"""Connect to configured MCP servers (one-time, lazy)."""
|
"""Connect to configured MCP servers (one-time, lazy)."""
|
||||||
if self._mcp_connected or not self._mcp_servers:
|
if self._mcp_connected or not self._mcp_servers:
|
||||||
@@ -255,7 +254,7 @@ class AgentLoop:
|
|||||||
))
|
))
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Close MCP connections."""
|
"""Close MCP connections."""
|
||||||
if self._mcp_stack:
|
if self._mcp_stack:
|
||||||
@@ -269,7 +268,7 @@ class AgentLoop:
|
|||||||
"""Stop the agent loop."""
|
"""Stop the agent loop."""
|
||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Agent loop stopping")
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
@@ -278,25 +277,25 @@ class AgentLoop:
|
|||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""
|
"""
|
||||||
Process a single inbound message.
|
Process a single inbound message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The inbound message to process.
|
msg: The inbound message to process.
|
||||||
session_key: Override session key (used by process_direct).
|
session_key: Override session key (used by process_direct).
|
||||||
on_progress: Optional callback for intermediate output (defaults to bus publish).
|
on_progress: Optional callback for intermediate output (defaults to bus publish).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response message, or None if no response needed.
|
The response message, or None if no response needed.
|
||||||
"""
|
"""
|
||||||
# System messages route back via chat_id ("channel:chat_id")
|
# System messages route back via chat_id ("channel:chat_id")
|
||||||
if msg.channel == "system":
|
if msg.channel == "system":
|
||||||
return await self._process_system_message(msg)
|
return await self._process_system_message(msg)
|
||||||
|
|
||||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||||
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
||||||
|
|
||||||
key = session_key or msg.session_key
|
key = session_key or msg.session_key
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
|
||||||
# Handle slash commands
|
# Handle slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
@@ -317,7 +316,7 @@ class AgentLoop:
|
|||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
||||||
|
|
||||||
if len(session.messages) > self.memory_window:
|
if len(session.messages) > self.memory_window:
|
||||||
asyncio.create_task(self._consolidate_memory(session))
|
asyncio.create_task(self._consolidate_memory(session))
|
||||||
|
|
||||||
@@ -342,31 +341,31 @@ class AgentLoop:
|
|||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
|
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
|
||||||
|
|
||||||
session.add_message("user", msg.content)
|
session.add_message("user", msg.content)
|
||||||
session.add_message("assistant", final_content,
|
session.add_message("assistant", final_content,
|
||||||
tools_used=tools_used if tools_used else None)
|
tools_used=tools_used if tools_used else None)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
content=final_content,
|
content=final_content,
|
||||||
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
|
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
||||||
"""
|
"""
|
||||||
Process a system message (e.g., subagent announce).
|
Process a system message (e.g., subagent announce).
|
||||||
|
|
||||||
The chat_id field contains "original_channel:original_chat_id" to route
|
The chat_id field contains "original_channel:original_chat_id" to route
|
||||||
the response back to the correct destination.
|
the response back to the correct destination.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Processing system message from {msg.sender_id}")
|
logger.info(f"Processing system message from {msg.sender_id}")
|
||||||
|
|
||||||
# Parse origin from chat_id (format: "channel:chat_id")
|
# Parse origin from chat_id (format: "channel:chat_id")
|
||||||
if ":" in msg.chat_id:
|
if ":" in msg.chat_id:
|
||||||
parts = msg.chat_id.split(":", 1)
|
parts = msg.chat_id.split(":", 1)
|
||||||
@@ -376,7 +375,7 @@ class AgentLoop:
|
|||||||
# Fallback
|
# Fallback
|
||||||
origin_channel = "cli"
|
origin_channel = "cli"
|
||||||
origin_chat_id = msg.chat_id
|
origin_chat_id = msg.chat_id
|
||||||
|
|
||||||
session_key = f"{origin_channel}:{origin_chat_id}"
|
session_key = f"{origin_channel}:{origin_chat_id}"
|
||||||
session = self.sessions.get_or_create(session_key)
|
session = self.sessions.get_or_create(session_key)
|
||||||
self._set_tool_context(origin_channel, origin_chat_id)
|
self._set_tool_context(origin_channel, origin_chat_id)
|
||||||
@@ -390,17 +389,17 @@ class AgentLoop:
|
|||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "Background task completed."
|
final_content = "Background task completed."
|
||||||
|
|
||||||
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
|
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
|
||||||
session.add_message("assistant", final_content)
|
session.add_message("assistant", final_content)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=origin_channel,
|
channel=origin_channel,
|
||||||
chat_id=origin_chat_id,
|
chat_id=origin_chat_id,
|
||||||
content=final_content
|
content=final_content
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
||||||
"""Consolidate old messages into MEMORY.md + HISTORY.md.
|
"""Consolidate old messages into MEMORY.md + HISTORY.md.
|
||||||
|
|
||||||
@@ -439,42 +438,56 @@ class AgentLoop:
|
|||||||
conversation = "\n".join(lines)
|
conversation = "\n".join(lines)
|
||||||
current_memory = memory.read_long_term()
|
current_memory = memory.read_long_term()
|
||||||
|
|
||||||
prompt = f"""You are a memory consolidation agent. Process this conversation and return a JSON object with exactly two keys:
|
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||||
|
|
||||||
1. "history_entry": A paragraph (2-5 sentences) summarizing the key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later.
|
|
||||||
|
|
||||||
2. "memory_update": The updated long-term memory content. Add any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.
|
|
||||||
|
|
||||||
## Current Long-term Memory
|
## Current Long-term Memory
|
||||||
{current_memory or "(empty)"}
|
{current_memory or "(empty)"}
|
||||||
|
|
||||||
## Conversation to Process
|
## Conversation to Process
|
||||||
{conversation}
|
{conversation}"""
|
||||||
|
|
||||||
Respond with ONLY valid JSON, no markdown fences."""
|
save_memory_tool = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "save_memory",
|
||||||
|
"description": "Save the memory consolidation result to persistent storage.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"history_entry": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. Start with a timestamp like [YYYY-MM-DD HH:MM]. Include enough detail to be useful when found by grep search later.",
|
||||||
|
},
|
||||||
|
"memory_update": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The full updated long-term memory content as a markdown string. Include all existing facts plus any new facts: user location, preferences, personal info, habits, project context, technical decisions, tools/services used. If nothing new, return the existing content unchanged.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["history_entry", "memory_update"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.provider.chat(
|
response = await self.provider.chat(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."},
|
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
|
tools=save_memory_tool,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
)
|
)
|
||||||
text = (response.content or "").strip()
|
|
||||||
if not text:
|
if not response.has_tool_calls:
|
||||||
logger.warning("Memory consolidation: LLM returned empty response, skipping")
|
logger.warning("Memory consolidation: LLM did not call save_memory tool, skipping")
|
||||||
return
|
|
||||||
if text.startswith("```"):
|
|
||||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
|
||||||
result = json_repair.loads(text)
|
|
||||||
if not isinstance(result, dict):
|
|
||||||
logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if entry := result.get("history_entry"):
|
args = response.tool_calls[0].arguments
|
||||||
|
if entry := args.get("history_entry"):
|
||||||
memory.append_history(entry)
|
memory.append_history(entry)
|
||||||
if update := result.get("memory_update"):
|
if update := args.get("memory_update"):
|
||||||
if update != current_memory:
|
if update != current_memory:
|
||||||
memory.write_long_term(update)
|
memory.write_long_term(update)
|
||||||
|
|
||||||
@@ -496,14 +509,14 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process a message directly (for CLI or cron usage).
|
Process a message directly (for CLI or cron usage).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The message content.
|
content: The message content.
|
||||||
session_key: Session identifier (overrides channel:chat_id for session lookup).
|
session_key: Session identifier (overrides channel:chat_id for session lookup).
|
||||||
channel: Source channel (for tool context routing).
|
channel: Source channel (for tool context routing).
|
||||||
chat_id: Source chat ID (for tool context routing).
|
chat_id: Source chat ID (for tool context routing).
|
||||||
on_progress: Optional callback for intermediate output.
|
on_progress: Optional callback for intermediate output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The agent's response.
|
The agent's response.
|
||||||
"""
|
"""
|
||||||
@@ -514,6 +527,6 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
content=content
|
content=content
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||||
return response.content if response else ""
|
return response.content if response else ""
|
||||||
|
|||||||
Reference in New Issue
Block a user