fix(loop): serialize /new consolidation and track task refs

This commit is contained in:
Alexander Minges
2026-02-20 12:38:43 +01:00
parent c8089021a5
commit 755e424127
2 changed files with 407 additions and 90 deletions

View File

@@ -21,7 +21,6 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.spawn import SpawnTool from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.memory import MemoryStore
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
@@ -57,6 +56,7 @@ class AgentLoop:
): ):
from nanobot.config.schema import ExecToolConfig from nanobot.config.schema import ExecToolConfig
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
self.bus = bus self.bus = bus
self.provider = provider self.provider = provider
self.workspace = workspace self.workspace = workspace
@@ -90,6 +90,8 @@ class AgentLoop:
self._mcp_stack: AsyncExitStack | None = None self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False self._mcp_connected = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Keep strong refs for in-flight tasks
self._consolidation_locks: dict[str, asyncio.Lock] = {}
self._register_default_tools() self._register_default_tools()
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
@@ -102,11 +104,13 @@ class AgentLoop:
self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
# Shell tool # Shell tool
self.tools.register(ExecTool( self.tools.register(
ExecTool(
working_dir=str(self.workspace), working_dir=str(self.workspace),
timeout=self.exec_config.timeout, timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace, restrict_to_workspace=self.restrict_to_workspace,
)) )
)
# Web tools # Web tools
self.tools.register(WebSearchTool(api_key=self.brave_api_key)) self.tools.register(WebSearchTool(api_key=self.brave_api_key))
@@ -130,6 +134,7 @@ class AgentLoop:
return return
self._mcp_connected = True self._mcp_connected = True
from nanobot.agent.tools.mcp import connect_mcp_servers from nanobot.agent.tools.mcp import connect_mcp_servers
self._mcp_stack = AsyncExitStack() self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__() await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack) await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
@@ -158,11 +163,13 @@ class AgentLoop:
@staticmethod @staticmethod
def _tool_hint(tool_calls: list) -> str: def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'.""" """Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc): def _fmt(tc):
val = next(iter(tc.arguments.values()), None) if tc.arguments else None val = next(iter(tc.arguments.values()), None) if tc.arguments else None
if not isinstance(val, str): if not isinstance(val, str):
return tc.name return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")' return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
return ", ".join(_fmt(tc) for tc in tool_calls) return ", ".join(_fmt(tc) for tc in tool_calls)
async def _run_agent_loop( async def _run_agent_loop(
@@ -210,13 +217,15 @@ class AgentLoop:
"type": "function", "type": "function",
"function": { "function": {
"name": tc.name, "name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False) "arguments": json.dumps(tc.arguments, ensure_ascii=False),
} },
} }
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
messages, response.content, tool_call_dicts, messages,
response.content,
tool_call_dicts,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
) )
@@ -234,9 +243,13 @@ class AgentLoop:
# Give them one retry; don't forward the text to avoid duplicates. # Give them one retry; don't forward the text to avoid duplicates.
if not tools_used and not text_only_retried and final_content: if not tools_used and not text_only_retried and final_content:
text_only_retried = True text_only_retried = True
logger.debug("Interim text response (no tools used yet), retrying: {}", final_content[:80]) logger.debug(
"Interim text response (no tools used yet), retrying: {}",
final_content[:80],
)
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
messages, response.content, messages,
response.content,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
) )
final_content = None final_content = None
@@ -253,21 +266,20 @@ class AgentLoop:
while self._running: while self._running:
try: try:
msg = await asyncio.wait_for( msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
self.bus.consume_inbound(),
timeout=1.0
)
try: try:
response = await self._process_message(msg) response = await self._process_message(msg)
if response: if response:
await self.bus.publish_outbound(response) await self.bus.publish_outbound(response)
except Exception as e: except Exception as e:
logger.error("Error processing message: {}", e) logger.error("Error processing message: {}", e)
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
content=f"Sorry, I encountered an error: {str(e)}" content=f"Sorry, I encountered an error: {str(e)}",
)) )
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
@@ -285,6 +297,14 @@ class AgentLoop:
self._running = False self._running = False
logger.info("Agent loop stopping") logger.info("Agent loop stopping")
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
"""Return a per-session lock for memory consolidation writers."""
lock = self._consolidation_locks.get(session_key)
if lock is None:
lock = asyncio.Lock()
self._consolidation_locks[session_key] = lock
return lock
async def _process_message( async def _process_message(
self, self,
msg: InboundMessage, msg: InboundMessage,
@@ -315,34 +335,53 @@ class AgentLoop:
# Handle slash commands # Handle slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
# Capture messages before clearing (avoid race condition with background task)
messages_to_archive = session.messages.copy() messages_to_archive = session.messages.copy()
session.clear() lock = self._get_consolidation_lock(session.key)
self.sessions.save(session)
self.sessions.invalidate(session.key)
async def _consolidate_and_cleanup(): try:
async with lock:
temp_session = Session(key=session.key) temp_session = Session(key=session.key)
temp_session.messages = messages_to_archive temp_session.messages = messages_to_archive
await self._consolidate_memory(temp_session, archive_all=True) await self._consolidate_memory(temp_session, archive_all=True)
except Exception as e:
logger.error("/new archival failed for {}: {}", session.key, e)
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="Could not start a new session because memory archival failed. Please try again.",
)
asyncio.create_task(_consolidate_and_cleanup()) session.clear()
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, self.sessions.save(session)
content="New session started. Memory consolidation in progress.") self.sessions.invalidate(session.key)
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="New session started. Memory consolidation in progress.",
)
if cmd == "/help": if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands") channel=msg.channel,
chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands",
)
if len(session.messages) > self.memory_window and session.key not in self._consolidating: if len(session.messages) > self.memory_window and session.key not in self._consolidating:
self._consolidating.add(session.key) self._consolidating.add(session.key)
lock = self._get_consolidation_lock(session.key)
async def _consolidate_and_unlock(): async def _consolidate_and_unlock():
try: try:
async with lock:
await self._consolidate_memory(session) await self._consolidate_memory(session)
finally: finally:
self._consolidating.discard(session.key) self._consolidating.discard(session.key)
task = asyncio.current_task()
if task is not None:
self._consolidation_tasks.discard(task)
asyncio.create_task(_consolidate_and_unlock()) task = asyncio.create_task(_consolidate_and_unlock())
self._consolidation_tasks.add(task)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
initial_messages = self.context.build_messages( initial_messages = self.context.build_messages(
@@ -354,13 +393,18 @@ class AgentLoop:
) )
async def _bus_progress(content: str) -> None: async def _bus_progress(content: str) -> None:
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(
channel=msg.channel, chat_id=msg.chat_id, content=content, OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=content,
metadata=msg.metadata or {}, metadata=msg.metadata or {},
)) )
)
final_content, tools_used = await self._run_agent_loop( final_content, tools_used = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress, initial_messages,
on_progress=on_progress or _bus_progress,
) )
if final_content is None: if final_content is None:
@@ -370,15 +414,17 @@ class AgentLoop:
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
session.add_message("user", msg.content) session.add_message("user", msg.content)
session.add_message("assistant", final_content, session.add_message(
tools_used=tools_used if tools_used else None) "assistant", final_content, tools_used=tools_used if tools_used else None
)
self.sessions.save(session) self.sessions.save(session)
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
content=final_content, content=final_content,
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts) metadata=msg.metadata
or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
) )
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None: async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
@@ -419,9 +465,7 @@ class AgentLoop:
self.sessions.save(session) self.sessions.save(session)
return OutboundMessage( return OutboundMessage(
channel=origin_channel, channel=origin_channel, chat_id=origin_chat_id, content=final_content
chat_id=origin_chat_id,
content=final_content
) )
async def _consolidate_memory(self, session, archive_all: bool = False) -> None: async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
@@ -431,34 +475,54 @@ class AgentLoop:
archive_all: If True, clear all messages and reset session (for /new command). archive_all: If True, clear all messages and reset session (for /new command).
If False, only write to files without modifying session. If False, only write to files without modifying session.
""" """
memory = MemoryStore(self.workspace) memory = self.context.memory
if archive_all: if archive_all:
old_messages = session.messages old_messages = session.messages
keep_count = 0 keep_count = 0
logger.info("Memory consolidation (archive_all): {} total messages archived", len(session.messages)) logger.info(
"Memory consolidation (archive_all): {} total messages archived",
len(session.messages),
)
else: else:
keep_count = self.memory_window // 2 keep_count = self.memory_window // 2
if len(session.messages) <= keep_count: if len(session.messages) <= keep_count:
logger.debug("Session {}: No consolidation needed (messages={}, keep={})", session.key, len(session.messages), keep_count) logger.debug(
"Session {}: No consolidation needed (messages={}, keep={})",
session.key,
len(session.messages),
keep_count,
)
return return
messages_to_process = len(session.messages) - session.last_consolidated messages_to_process = len(session.messages) - session.last_consolidated
if messages_to_process <= 0: if messages_to_process <= 0:
logger.debug("Session {}: No new messages to consolidate (last_consolidated={}, total={})", session.key, session.last_consolidated, len(session.messages)) logger.debug(
"Session {}: No new messages to consolidate (last_consolidated={}, total={})",
session.key,
session.last_consolidated,
len(session.messages),
)
return return
old_messages = session.messages[session.last_consolidated : -keep_count] old_messages = session.messages[session.last_consolidated : -keep_count]
if not old_messages: if not old_messages:
return return
logger.info("Memory consolidation started: {} total, {} new to consolidate, {} keep", len(session.messages), len(old_messages), keep_count) logger.info(
"Memory consolidation started: {} total, {} new to consolidate, {} keep",
len(session.messages),
len(old_messages),
keep_count,
)
lines = [] lines = []
for m in old_messages: for m in old_messages:
if not m.get("content"): if not m.get("content"):
continue continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}") lines.append(
f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}"
)
conversation = "\n".join(lines) conversation = "\n".join(lines)
current_memory = memory.read_long_term() current_memory = memory.read_long_term()
@@ -487,7 +551,10 @@ Respond with ONLY valid JSON, no markdown fences."""
try: try:
response = await self.provider.chat( response = await self.provider.chat(
messages=[ messages=[
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."}, {
"role": "system",
"content": "You are a memory consolidation agent. Respond only with valid JSON.",
},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},
], ],
model=self.model, model=self.model,
@@ -500,7 +567,10 @@ Respond with ONLY valid JSON, no markdown fences."""
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip() text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
result = json_repair.loads(text) result = json_repair.loads(text)
if not isinstance(result, dict): if not isinstance(result, dict):
logger.warning("Memory consolidation: unexpected response type, skipping. Response: {}", text[:200]) logger.warning(
"Memory consolidation: unexpected response type, skipping. Response: {}",
text[:200],
)
return return
if entry := result.get("history_entry"): if entry := result.get("history_entry"):
@@ -519,7 +589,11 @@ Respond with ONLY valid JSON, no markdown fences."""
session.last_consolidated = 0 session.last_consolidated = 0
else: else:
session.last_consolidated = len(session.messages) - keep_count session.last_consolidated = len(session.messages) - keep_count
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) logger.info(
"Memory consolidation done: {} messages, last_consolidated={}",
len(session.messages),
session.last_consolidated,
)
except Exception as e: except Exception as e:
logger.error("Memory consolidation failed: {}", e) logger.error("Memory consolidation failed: {}", e)
@@ -545,12 +619,9 @@ Respond with ONLY valid JSON, no markdown fences."""
The agent's response. The agent's response.
""" """
await self._connect_mcp() await self._connect_mcp()
msg = InboundMessage( msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
channel=channel,
sender_id="user",
chat_id=chat_id,
content=content
)
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress) response = await self._process_message(
msg, session_key=session_key, on_progress=on_progress
)
return response.content if response else "" return response.content if response else ""

View File

@@ -1,5 +1,8 @@
"""Test session management with cache-friendly message handling.""" """Test session management with cache-friendly message handling."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from pathlib import Path from pathlib import Path
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
@@ -475,3 +478,246 @@ class TestEmptyAndBoundarySessions:
expected_count = 60 - KEEP_COUNT - 10 expected_count = 60 - KEEP_COUNT - 10
assert len(old_messages) == expected_count assert len(old_messages) == expected_count
assert_messages_content(old_messages, 10, 34) assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard:
"""Test that consolidation tasks are deduplicated and serialized."""
@pytest.mark.asyncio
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
"""Concurrent messages above memory_window spawn only one consolidation task."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> None:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return
started.set()
await release.wait()
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
@pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new keeps session data if archive step fails."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(_session, archive_all: bool = False) -> None:
if archive_all:
raise RuntimeError("forced archive failure")
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)