fix(loop): require successful archival before /new clear

This commit is contained in:
Alexander Minges
2026-02-20 12:48:54 +01:00
parent 5f9eca4664
commit 9ada8e6854
2 changed files with 151 additions and 91 deletions

View File

@@ -56,6 +56,7 @@ class AgentLoop:
): ):
from nanobot.config.schema import ExecToolConfig from nanobot.config.schema import ExecToolConfig
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
self.bus = bus self.bus = bus
self.provider = provider self.provider = provider
self.workspace = workspace self.workspace = workspace
@@ -83,7 +84,7 @@ class AgentLoop:
exec_config=self.exec_config, exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace, restrict_to_workspace=restrict_to_workspace,
) )
self._running = False self._running = False
self._mcp_servers = mcp_servers or {} self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None self._mcp_stack: AsyncExitStack | None = None
@@ -92,7 +93,7 @@ class AgentLoop:
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: dict[str, asyncio.Lock] = {} self._consolidation_locks: dict[str, asyncio.Lock] = {}
self._register_default_tools() self._register_default_tools()
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
"""Register the default set of tools.""" """Register the default set of tools."""
# File tools (workspace for relative paths, restrict if configured) # File tools (workspace for relative paths, restrict if configured)
@@ -101,36 +102,39 @@ class AgentLoop:
self.tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
# Shell tool # Shell tool
self.tools.register(ExecTool( self.tools.register(
working_dir=str(self.workspace), ExecTool(
timeout=self.exec_config.timeout, working_dir=str(self.workspace),
restrict_to_workspace=self.restrict_to_workspace, timeout=self.exec_config.timeout,
)) restrict_to_workspace=self.restrict_to_workspace,
)
)
# Web tools # Web tools
self.tools.register(WebSearchTool(api_key=self.brave_api_key)) self.tools.register(WebSearchTool(api_key=self.brave_api_key))
self.tools.register(WebFetchTool()) self.tools.register(WebFetchTool())
# Message tool # Message tool
message_tool = MessageTool(send_callback=self.bus.publish_outbound) message_tool = MessageTool(send_callback=self.bus.publish_outbound)
self.tools.register(message_tool) self.tools.register(message_tool)
# Spawn tool (for subagents) # Spawn tool (for subagents)
spawn_tool = SpawnTool(manager=self.subagents) spawn_tool = SpawnTool(manager=self.subagents)
self.tools.register(spawn_tool) self.tools.register(spawn_tool)
# Cron tool (for scheduling) # Cron tool (for scheduling)
if self.cron_service: if self.cron_service:
self.tools.register(CronTool(self.cron_service)) self.tools.register(CronTool(self.cron_service))
async def _connect_mcp(self) -> None: async def _connect_mcp(self) -> None:
"""Connect to configured MCP servers (one-time, lazy).""" """Connect to configured MCP servers (one-time, lazy)."""
if self._mcp_connected or not self._mcp_servers: if self._mcp_connected or not self._mcp_servers:
return return
self._mcp_connected = True self._mcp_connected = True
from nanobot.agent.tools.mcp import connect_mcp_servers from nanobot.agent.tools.mcp import connect_mcp_servers
self._mcp_stack = AsyncExitStack() self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__() await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
@@ -159,11 +163,13 @@ class AgentLoop:
@staticmethod @staticmethod
def _tool_hint(tool_calls: list) -> str: def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'.""" """Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc): def _fmt(tc):
val = next(iter(tc.arguments.values()), None) if tc.arguments else None val = next(iter(tc.arguments.values()), None) if tc.arguments else None
if not isinstance(val, str): if not isinstance(val, str):
return tc.name return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")' return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
return ", ".join(_fmt(tc) for tc in tool_calls) return ", ".join(_fmt(tc) for tc in tool_calls)
async def _run_agent_loop( async def _run_agent_loop(
@@ -211,13 +217,15 @@ class AgentLoop:
"type": "function", "type": "function",
"function": { "function": {
"name": tc.name, "name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False) "arguments": json.dumps(tc.arguments, ensure_ascii=False),
} },
} }
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
messages, response.content, tool_call_dicts, messages,
response.content,
tool_call_dicts,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
) )
@@ -235,9 +243,13 @@ class AgentLoop:
# Give them one retry; don't forward the text to avoid duplicates. # Give them one retry; don't forward the text to avoid duplicates.
if not tools_used and not text_only_retried and final_content: if not tools_used and not text_only_retried and final_content:
text_only_retried = True text_only_retried = True
logger.debug("Interim text response (no tools used yet), retrying: {}", final_content[:80]) logger.debug(
"Interim text response (no tools used yet), retrying: {}",
final_content[:80],
)
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
messages, response.content, messages,
response.content,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
) )
final_content = None final_content = None
@@ -254,24 +266,23 @@ class AgentLoop:
while self._running: while self._running:
try: try:
msg = await asyncio.wait_for( msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
self.bus.consume_inbound(),
timeout=1.0
)
try: try:
response = await self._process_message(msg) response = await self._process_message(msg)
if response: if response:
await self.bus.publish_outbound(response) await self.bus.publish_outbound(response)
except Exception as e: except Exception as e:
logger.error("Error processing message: {}", e) logger.error("Error processing message: {}", e)
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, OutboundMessage(
chat_id=msg.chat_id, channel=msg.channel,
content=f"Sorry, I encountered an error: {str(e)}" chat_id=msg.chat_id,
)) content=f"Sorry, I encountered an error: {str(e)}",
)
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
"""Close MCP connections.""" """Close MCP connections."""
if self._mcp_stack: if self._mcp_stack:
@@ -292,7 +303,7 @@ class AgentLoop:
lock = asyncio.Lock() lock = asyncio.Lock()
self._consolidation_locks[session_key] = lock self._consolidation_locks[session_key] = lock
return lock return lock
async def _process_message( async def _process_message(
self, self,
msg: InboundMessage, msg: InboundMessage,
@@ -301,25 +312,25 @@ class AgentLoop:
) -> OutboundMessage | None: ) -> OutboundMessage | None:
""" """
Process a single inbound message. Process a single inbound message.
Args: Args:
msg: The inbound message to process. msg: The inbound message to process.
session_key: Override session key (used by process_direct). session_key: Override session key (used by process_direct).
on_progress: Optional callback for intermediate output (defaults to bus publish). on_progress: Optional callback for intermediate output (defaults to bus publish).
Returns: Returns:
The response message, or None if no response needed. The response message, or None if no response needed.
""" """
# System messages route back via chat_id ("channel:chat_id") # System messages route back via chat_id ("channel:chat_id")
if msg.channel == "system": if msg.channel == "system":
return await self._process_system_message(msg) return await self._process_system_message(msg)
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
key = session_key or msg.session_key key = session_key or msg.session_key
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
# Handle slash commands # Handle slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
@@ -330,24 +341,37 @@ class AgentLoop:
async with lock: async with lock:
temp_session = Session(key=session.key) temp_session = Session(key=session.key)
temp_session.messages = messages_to_archive temp_session.messages = messages_to_archive
await self._consolidate_memory(temp_session, archive_all=True) archived = await self._consolidate_memory(temp_session, archive_all=True)
except Exception as e: except Exception as e:
logger.error("/new archival failed for {}: {}", session.key, e) logger.error("/new archival failed for {}: {}", session.key, e)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
content="Could not start a new session because memory archival failed. Please try again." content="Could not start a new session because memory archival failed. Please try again.",
)
if messages_to_archive and not archived:
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="Could not start a new session because memory archival failed. Please try again.",
) )
session.clear() session.clear()
self.sessions.save(session) self.sessions.save(session)
self.sessions.invalidate(session.key) self.sessions.invalidate(session.key)
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(
content="New session started. Memory consolidation in progress.") channel=msg.channel,
chat_id=msg.chat_id,
content="New session started. Memory consolidation in progress.",
)
if cmd == "/help": if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands") channel=msg.channel,
chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands",
)
if len(session.messages) > self.memory_window and session.key not in self._consolidating: if len(session.messages) > self.memory_window and session.key not in self._consolidating:
self._consolidating.add(session.key) self._consolidating.add(session.key)
lock = self._get_consolidation_lock(session.key) lock = self._get_consolidation_lock(session.key)
@@ -375,42 +399,49 @@ class AgentLoop:
) )
async def _bus_progress(content: str) -> None: async def _bus_progress(content: str) -> None:
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, chat_id=msg.chat_id, content=content, OutboundMessage(
metadata=msg.metadata or {}, channel=msg.channel,
)) chat_id=msg.chat_id,
content=content,
metadata=msg.metadata or {},
)
)
final_content, tools_used = await self._run_agent_loop( final_content, tools_used = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress, initial_messages,
on_progress=on_progress or _bus_progress,
) )
if final_content is None: if final_content is None:
final_content = "I've completed processing but have no response to give." final_content = "I've completed processing but have no response to give."
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
session.add_message("user", msg.content) session.add_message("user", msg.content)
session.add_message("assistant", final_content, session.add_message(
tools_used=tools_used if tools_used else None) "assistant", final_content, tools_used=tools_used if tools_used else None
)
self.sessions.save(session) self.sessions.save(session)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
content=final_content, content=final_content,
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts) metadata=msg.metadata
or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
) )
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None: async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
""" """
Process a system message (e.g., subagent announce). Process a system message (e.g., subagent announce).
The chat_id field contains "original_channel:original_chat_id" to route The chat_id field contains "original_channel:original_chat_id" to route
the response back to the correct destination. the response back to the correct destination.
""" """
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
# Parse origin from chat_id (format: "channel:chat_id") # Parse origin from chat_id (format: "channel:chat_id")
if ":" in msg.chat_id: if ":" in msg.chat_id:
parts = msg.chat_id.split(":", 1) parts = msg.chat_id.split(":", 1)
@@ -420,7 +451,7 @@ class AgentLoop:
# Fallback # Fallback
origin_channel = "cli" origin_channel = "cli"
origin_chat_id = msg.chat_id origin_chat_id = msg.chat_id
session_key = f"{origin_channel}:{origin_chat_id}" session_key = f"{origin_channel}:{origin_chat_id}"
session = self.sessions.get_or_create(session_key) session = self.sessions.get_or_create(session_key)
self._set_tool_context(origin_channel, origin_chat_id, msg.metadata.get("message_id")) self._set_tool_context(origin_channel, origin_chat_id, msg.metadata.get("message_id"))
@@ -434,18 +465,16 @@ class AgentLoop:
if final_content is None: if final_content is None:
final_content = "Background task completed." final_content = "Background task completed."
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}") session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
session.add_message("assistant", final_content) session.add_message("assistant", final_content)
self.sessions.save(session) self.sessions.save(session)
return OutboundMessage( return OutboundMessage(
channel=origin_channel, channel=origin_channel, chat_id=origin_chat_id, content=final_content
chat_id=origin_chat_id,
content=final_content
) )
async def _consolidate_memory(self, session, archive_all: bool = False) -> None: async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Consolidate old messages into MEMORY.md + HISTORY.md. """Consolidate old messages into MEMORY.md + HISTORY.md.
Args: Args:
@@ -457,29 +486,49 @@ class AgentLoop:
if archive_all: if archive_all:
old_messages = session.messages old_messages = session.messages
keep_count = 0 keep_count = 0
logger.info("Memory consolidation (archive_all): {} total messages archived", len(session.messages)) logger.info(
"Memory consolidation (archive_all): {} total messages archived",
len(session.messages),
)
else: else:
keep_count = self.memory_window // 2 keep_count = self.memory_window // 2
if len(session.messages) <= keep_count: if len(session.messages) <= keep_count:
logger.debug("Session {}: No consolidation needed (messages={}, keep={})", session.key, len(session.messages), keep_count) logger.debug(
return "Session {}: No consolidation needed (messages={}, keep={})",
session.key,
len(session.messages),
keep_count,
)
return True
messages_to_process = len(session.messages) - session.last_consolidated messages_to_process = len(session.messages) - session.last_consolidated
if messages_to_process <= 0: if messages_to_process <= 0:
logger.debug("Session {}: No new messages to consolidate (last_consolidated={}, total={})", session.key, session.last_consolidated, len(session.messages)) logger.debug(
return "Session {}: No new messages to consolidate (last_consolidated={}, total={})",
session.key,
session.last_consolidated,
len(session.messages),
)
return True
old_messages = session.messages[session.last_consolidated:-keep_count] old_messages = session.messages[session.last_consolidated : -keep_count]
if not old_messages: if not old_messages:
return return True
logger.info("Memory consolidation started: {} total, {} new to consolidate, {} keep", len(session.messages), len(old_messages), keep_count) logger.info(
"Memory consolidation started: {} total, {} new to consolidate, {} keep",
len(session.messages),
len(old_messages),
keep_count,
)
lines = [] lines = []
for m in old_messages: for m in old_messages:
if not m.get("content"): if not m.get("content"):
continue continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}") lines.append(
f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}"
)
conversation = "\n".join(lines) conversation = "\n".join(lines)
current_memory = memory.read_long_term() current_memory = memory.read_long_term()
@@ -508,7 +557,10 @@ Respond with ONLY valid JSON, no markdown fences."""
try: try:
response = await self.provider.chat( response = await self.provider.chat(
messages=[ messages=[
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."}, {
"role": "system",
"content": "You are a memory consolidation agent. Respond only with valid JSON.",
},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
], ],
model=self.model, model=self.model,
@@ -516,13 +568,16 @@ Respond with ONLY valid JSON, no markdown fences."""
text = (response.content or "").strip() text = (response.content or "").strip()
if not text: if not text:
logger.warning("Memory consolidation: LLM returned empty response, skipping") logger.warning("Memory consolidation: LLM returned empty response, skipping")
return return False
if text.startswith("```"): if text.startswith("```"):
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip() text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
result = json_repair.loads(text) result = json_repair.loads(text)
if not isinstance(result, dict): if not isinstance(result, dict):
logger.warning("Memory consolidation: unexpected response type, skipping. Response: {}", text[:200]) logger.warning(
return "Memory consolidation: unexpected response type, skipping. Response: {}",
text[:200],
)
return False
if entry := result.get("history_entry"): if entry := result.get("history_entry"):
# Defensive: ensure entry is a string (LLM may return dict) # Defensive: ensure entry is a string (LLM may return dict)
@@ -540,9 +595,15 @@ Respond with ONLY valid JSON, no markdown fences."""
session.last_consolidated = 0 session.last_consolidated = 0
else: else:
session.last_consolidated = len(session.messages) - keep_count session.last_consolidated = len(session.messages) - keep_count
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) logger.info(
"Memory consolidation done: {} messages, last_consolidated={}",
len(session.messages),
session.last_consolidated,
)
return True
except Exception as e: except Exception as e:
logger.error("Memory consolidation failed: {}", e) logger.error("Memory consolidation failed: {}", e)
return False
async def process_direct( async def process_direct(
self, self,
@@ -554,24 +615,21 @@ Respond with ONLY valid JSON, no markdown fences."""
) -> str: ) -> str:
""" """
Process a message directly (for CLI or cron usage). Process a message directly (for CLI or cron usage).
Args: Args:
content: The message content. content: The message content.
session_key: Session identifier (overrides channel:chat_id for session lookup). session_key: Session identifier (overrides channel:chat_id for session lookup).
channel: Source channel (for tool context routing). channel: Source channel (for tool context routing).
chat_id: Source chat ID (for tool context routing). chat_id: Source chat ID (for tool context routing).
on_progress: Optional callback for intermediate output. on_progress: Optional callback for intermediate output.
Returns: Returns:
The agent's response. The agent's response.
""" """
await self._connect_mcp() await self._connect_mcp()
msg = InboundMessage( msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
channel=channel,
sender_id="user", response = await self._process_message(
chat_id=chat_id, msg, session_key=session_key, on_progress=on_progress
content=content
) )
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
return response.content if response else "" return response.content if response else ""

View File

@@ -652,13 +652,14 @@ class TestConsolidationDeduplicationGuard:
release = asyncio.Event() release = asyncio.Event()
archived_count = 0 archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> None: async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count nonlocal archived_count
if archive_all: if archive_all:
archived_count = len(sess.messages) archived_count = len(sess.messages)
return return True
started.set() started.set()
await release.wait() await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
@@ -683,7 +684,7 @@ class TestConsolidationDeduplicationGuard:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None: async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new keeps session data if archive step fails.""" """/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@@ -706,9 +707,10 @@ class TestConsolidationDeduplicationGuard:
loop.sessions.save(session) loop.sessions.save(session)
before_count = len(session.messages) before_count = len(session.messages)
async def _failing_consolidate(_session, archive_all: bool = False) -> None: async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
if archive_all: if archive_all:
raise RuntimeError("forced archive failure") return False
return True
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign] loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]