fix(loop): serialize /new consolidation and track task refs

This commit is contained in:
Alexander Minges
2026-02-20 12:38:43 +01:00
parent c8089021a5
commit 755e424127
2 changed files with 407 additions and 90 deletions

View File

@@ -21,7 +21,6 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.memory import MemoryStore
from nanobot.agent.subagent import SubagentManager
from nanobot.session.manager import Session, SessionManager
@@ -57,6 +56,7 @@ class AgentLoop:
):
from nanobot.config.schema import ExecToolConfig
from nanobot.cron.service import CronService
self.bus = bus
self.provider = provider
self.workspace = workspace
@@ -90,6 +90,8 @@ class AgentLoop:
self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Keep strong refs for in-flight tasks
self._consolidation_locks: dict[str, asyncio.Lock] = {}
self._register_default_tools()
def _register_default_tools(self) -> None:
@@ -102,11 +104,13 @@ class AgentLoop:
self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
# Shell tool
self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
))
self.tools.register(
ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
)
)
# Web tools
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
@@ -130,6 +134,7 @@ class AgentLoop:
return
self._mcp_connected = True
from nanobot.agent.tools.mcp import connect_mcp_servers
self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
@@ -158,11 +163,13 @@ class AgentLoop:
@staticmethod
def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc):
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
if not isinstance(val, str):
return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
return ", ".join(_fmt(tc) for tc in tool_calls)
async def _run_agent_loop(
@@ -210,13 +217,15 @@ class AgentLoop:
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
}
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
},
}
for tc in response.tool_calls
]
messages = self.context.add_assistant_message(
messages, response.content, tool_call_dicts,
messages,
response.content,
tool_call_dicts,
reasoning_content=response.reasoning_content,
)
@@ -234,9 +243,13 @@ class AgentLoop:
# Give them one retry; don't forward the text to avoid duplicates.
if not tools_used and not text_only_retried and final_content:
text_only_retried = True
logger.debug("Interim text response (no tools used yet), retrying: {}", final_content[:80])
logger.debug(
"Interim text response (no tools used yet), retrying: {}",
final_content[:80],
)
messages = self.context.add_assistant_message(
messages, response.content,
messages,
response.content,
reasoning_content=response.reasoning_content,
)
final_content = None
@@ -253,21 +266,20 @@ class AgentLoop:
while self._running:
try:
msg = await asyncio.wait_for(
self.bus.consume_inbound(),
timeout=1.0
)
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
try:
response = await self._process_message(msg)
if response:
await self.bus.publish_outbound(response)
except Exception as e:
logger.error("Error processing message: {}", e)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=f"Sorry, I encountered an error: {str(e)}"
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=f"Sorry, I encountered an error: {str(e)}",
)
)
except asyncio.TimeoutError:
continue
@@ -285,6 +297,14 @@ class AgentLoop:
self._running = False
logger.info("Agent loop stopping")
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
"""Return a per-session lock for memory consolidation writers."""
lock = self._consolidation_locks.get(session_key)
if lock is None:
lock = asyncio.Lock()
self._consolidation_locks[session_key] = lock
return lock
async def _process_message(
self,
msg: InboundMessage,
@@ -315,34 +335,53 @@ class AgentLoop:
# Handle slash commands
cmd = msg.content.strip().lower()
if cmd == "/new":
# Capture messages before clearing (avoid race condition with background task)
messages_to_archive = session.messages.copy()
lock = self._get_consolidation_lock(session.key)
try:
async with lock:
temp_session = Session(key=session.key)
temp_session.messages = messages_to_archive
await self._consolidate_memory(temp_session, archive_all=True)
except Exception as e:
logger.error("/new archival failed for {}: {}", session.key, e)
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="Could not start a new session because memory archival failed. Please try again.",
)
session.clear()
self.sessions.save(session)
self.sessions.invalidate(session.key)
async def _consolidate_and_cleanup():
temp_session = Session(key=session.key)
temp_session.messages = messages_to_archive
await self._consolidate_memory(temp_session, archive_all=True)
asyncio.create_task(_consolidate_and_cleanup())
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started. Memory consolidation in progress.")
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="New session started. Memory consolidation in progress.",
)
if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands",
)
if len(session.messages) > self.memory_window and session.key not in self._consolidating:
self._consolidating.add(session.key)
lock = self._get_consolidation_lock(session.key)
async def _consolidate_and_unlock():
try:
await self._consolidate_memory(session)
async with lock:
await self._consolidate_memory(session)
finally:
self._consolidating.discard(session.key)
task = asyncio.current_task()
if task is not None:
self._consolidation_tasks.discard(task)
asyncio.create_task(_consolidate_and_unlock())
task = asyncio.create_task(_consolidate_and_unlock())
self._consolidation_tasks.add(task)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
initial_messages = self.context.build_messages(
@@ -354,13 +393,18 @@ class AgentLoop:
)
async def _bus_progress(content: str) -> None:
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content,
metadata=msg.metadata or {},
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=content,
metadata=msg.metadata or {},
)
)
final_content, tools_used = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress,
initial_messages,
on_progress=on_progress or _bus_progress,
)
if final_content is None:
@@ -370,15 +414,17 @@ class AgentLoop:
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
session.add_message("user", msg.content)
session.add_message("assistant", final_content,
tools_used=tools_used if tools_used else None)
session.add_message(
"assistant", final_content, tools_used=tools_used if tools_used else None
)
self.sessions.save(session)
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=final_content,
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
metadata=msg.metadata
or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
)
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
@@ -419,9 +465,7 @@ class AgentLoop:
self.sessions.save(session)
return OutboundMessage(
channel=origin_channel,
chat_id=origin_chat_id,
content=final_content
channel=origin_channel, chat_id=origin_chat_id, content=final_content
)
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
@@ -431,34 +475,54 @@ class AgentLoop:
archive_all: If True, clear all messages and reset session (for /new command).
If False, only write to files without modifying session.
"""
memory = MemoryStore(self.workspace)
memory = self.context.memory
if archive_all:
old_messages = session.messages
keep_count = 0
logger.info("Memory consolidation (archive_all): {} total messages archived", len(session.messages))
logger.info(
"Memory consolidation (archive_all): {} total messages archived",
len(session.messages),
)
else:
keep_count = self.memory_window // 2
if len(session.messages) <= keep_count:
logger.debug("Session {}: No consolidation needed (messages={}, keep={})", session.key, len(session.messages), keep_count)
logger.debug(
"Session {}: No consolidation needed (messages={}, keep={})",
session.key,
len(session.messages),
keep_count,
)
return
messages_to_process = len(session.messages) - session.last_consolidated
if messages_to_process <= 0:
logger.debug("Session {}: No new messages to consolidate (last_consolidated={}, total={})", session.key, session.last_consolidated, len(session.messages))
logger.debug(
"Session {}: No new messages to consolidate (last_consolidated={}, total={})",
session.key,
session.last_consolidated,
len(session.messages),
)
return
old_messages = session.messages[session.last_consolidated:-keep_count]
old_messages = session.messages[session.last_consolidated : -keep_count]
if not old_messages:
return
logger.info("Memory consolidation started: {} total, {} new to consolidate, {} keep", len(session.messages), len(old_messages), keep_count)
logger.info(
"Memory consolidation started: {} total, {} new to consolidate, {} keep",
len(session.messages),
len(old_messages),
keep_count,
)
lines = []
for m in old_messages:
if not m.get("content"):
continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
lines.append(
f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}"
)
conversation = "\n".join(lines)
current_memory = memory.read_long_term()
@@ -487,7 +551,10 @@ Respond with ONLY valid JSON, no markdown fences."""
try:
response = await self.provider.chat(
messages=[
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."},
{
"role": "system",
"content": "You are a memory consolidation agent. Respond only with valid JSON.",
},
{"role": "user", "content": prompt},
],
model=self.model,
@@ -500,7 +567,10 @@ Respond with ONLY valid JSON, no markdown fences."""
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
result = json_repair.loads(text)
if not isinstance(result, dict):
logger.warning("Memory consolidation: unexpected response type, skipping. Response: {}", text[:200])
logger.warning(
"Memory consolidation: unexpected response type, skipping. Response: {}",
text[:200],
)
return
if entry := result.get("history_entry"):
@@ -519,7 +589,11 @@ Respond with ONLY valid JSON, no markdown fences."""
session.last_consolidated = 0
else:
session.last_consolidated = len(session.messages) - keep_count
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
logger.info(
"Memory consolidation done: {} messages, last_consolidated={}",
len(session.messages),
session.last_consolidated,
)
except Exception as e:
logger.error("Memory consolidation failed: {}", e)
@@ -545,12 +619,9 @@ Respond with ONLY valid JSON, no markdown fences."""
The agent's response.
"""
await self._connect_mcp()
msg = InboundMessage(
channel=channel,
sender_id="user",
chat_id=chat_id,
content=content
)
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
response = await self._process_message(
msg, session_key=session_key, on_progress=on_progress
)
return response.content if response else ""

View File

@@ -1,5 +1,8 @@
"""Test session management with cache-friendly message handling."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from pathlib import Path
from nanobot.session.manager import Session, SessionManager
@@ -475,3 +478,246 @@ class TestEmptyAndBoundarySessions:
expected_count = 60 - KEEP_COUNT - 10
assert len(old_messages) == expected_count
assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard:
"""Test that consolidation tasks are deduplicated and serialized."""
@pytest.mark.asyncio
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
"""Concurrent messages above memory_window spawn only one consolidation task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> None:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return
started.set()
await release.wait()
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
@pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new keeps session data if archive step fails."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(_session, archive_all: bool = False) -> None:
if archive_all:
raise RuntimeError("forced archive failure")
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)