fix(loop): serialize /new consolidation and track task refs
This commit is contained in:
@@ -21,7 +21,6 @@ from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
|||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
from nanobot.agent.tools.spawn import SpawnTool
|
from nanobot.agent.tools.spawn import SpawnTool
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.memory import MemoryStore
|
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
@@ -57,6 +56,7 @@ class AgentLoop:
|
|||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
@@ -90,6 +90,8 @@ class AgentLoop:
|
|||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stack: AsyncExitStack | None = None
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||||
|
self._consolidation_tasks: set[asyncio.Task] = set() # Keep strong refs for in-flight tasks
|
||||||
|
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@@ -102,11 +104,13 @@ class AgentLoop:
|
|||||||
self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
|
|
||||||
# Shell tool
|
# Shell tool
|
||||||
self.tools.register(ExecTool(
|
self.tools.register(
|
||||||
|
ExecTool(
|
||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
timeout=self.exec_config.timeout,
|
timeout=self.exec_config.timeout,
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Web tools
|
# Web tools
|
||||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
@@ -130,6 +134,7 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._mcp_connected = True
|
self._mcp_connected = True
|
||||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||||
|
|
||||||
self._mcp_stack = AsyncExitStack()
|
self._mcp_stack = AsyncExitStack()
|
||||||
await self._mcp_stack.__aenter__()
|
await self._mcp_stack.__aenter__()
|
||||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||||
@@ -158,11 +163,13 @@ class AgentLoop:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _tool_hint(tool_calls: list) -> str:
|
def _tool_hint(tool_calls: list) -> str:
|
||||||
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
||||||
|
|
||||||
def _fmt(tc):
|
def _fmt(tc):
|
||||||
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
|
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
|
||||||
if not isinstance(val, str):
|
if not isinstance(val, str):
|
||||||
return tc.name
|
return tc.name
|
||||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||||
|
|
||||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||||
|
|
||||||
async def _run_agent_loop(
|
async def _run_agent_loop(
|
||||||
@@ -210,13 +217,15 @@ class AgentLoop:
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc.name,
|
"name": tc.name,
|
||||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
]
|
]
|
||||||
messages = self.context.add_assistant_message(
|
messages = self.context.add_assistant_message(
|
||||||
messages, response.content, tool_call_dicts,
|
messages,
|
||||||
|
response.content,
|
||||||
|
tool_call_dicts,
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -234,9 +243,13 @@ class AgentLoop:
|
|||||||
# Give them one retry; don't forward the text to avoid duplicates.
|
# Give them one retry; don't forward the text to avoid duplicates.
|
||||||
if not tools_used and not text_only_retried and final_content:
|
if not tools_used and not text_only_retried and final_content:
|
||||||
text_only_retried = True
|
text_only_retried = True
|
||||||
logger.debug("Interim text response (no tools used yet), retrying: {}", final_content[:80])
|
logger.debug(
|
||||||
|
"Interim text response (no tools used yet), retrying: {}",
|
||||||
|
final_content[:80],
|
||||||
|
)
|
||||||
messages = self.context.add_assistant_message(
|
messages = self.context.add_assistant_message(
|
||||||
messages, response.content,
|
messages,
|
||||||
|
response.content,
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
)
|
)
|
||||||
final_content = None
|
final_content = None
|
||||||
@@ -253,21 +266,20 @@ class AgentLoop:
|
|||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(
|
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||||
self.bus.consume_inbound(),
|
|
||||||
timeout=1.0
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
response = await self._process_message(msg)
|
response = await self._process_message(msg)
|
||||||
if response:
|
if response:
|
||||||
await self.bus.publish_outbound(response)
|
await self.bus.publish_outbound(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error processing message: {}", e)
|
logger.error("Error processing message: {}", e)
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
|
OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
content=f"Sorry, I encountered an error: {str(e)}"
|
content=f"Sorry, I encountered an error: {str(e)}",
|
||||||
))
|
)
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -285,6 +297,14 @@ class AgentLoop:
|
|||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Agent loop stopping")
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
|
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
|
||||||
|
"""Return a per-session lock for memory consolidation writers."""
|
||||||
|
lock = self._consolidation_locks.get(session_key)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
self._consolidation_locks[session_key] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
@@ -315,34 +335,53 @@ class AgentLoop:
|
|||||||
# Handle slash commands
|
# Handle slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
# Capture messages before clearing (avoid race condition with background task)
|
|
||||||
messages_to_archive = session.messages.copy()
|
messages_to_archive = session.messages.copy()
|
||||||
session.clear()
|
lock = self._get_consolidation_lock(session.key)
|
||||||
self.sessions.save(session)
|
|
||||||
self.sessions.invalidate(session.key)
|
|
||||||
|
|
||||||
async def _consolidate_and_cleanup():
|
try:
|
||||||
|
async with lock:
|
||||||
temp_session = Session(key=session.key)
|
temp_session = Session(key=session.key)
|
||||||
temp_session.messages = messages_to_archive
|
temp_session.messages = messages_to_archive
|
||||||
await self._consolidate_memory(temp_session, archive_all=True)
|
await self._consolidate_memory(temp_session, archive_all=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("/new archival failed for {}: {}", session.key, e)
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content="Could not start a new session because memory archival failed. Please try again.",
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.create_task(_consolidate_and_cleanup())
|
session.clear()
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
self.sessions.save(session)
|
||||||
content="New session started. Memory consolidation in progress.")
|
self.sessions.invalidate(session.key)
|
||||||
|
return OutboundMessage(
|
||||||
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content="New session started. Memory consolidation in progress.",
|
||||||
|
)
|
||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands",
|
||||||
|
)
|
||||||
|
|
||||||
if len(session.messages) > self.memory_window and session.key not in self._consolidating:
|
if len(session.messages) > self.memory_window and session.key not in self._consolidating:
|
||||||
self._consolidating.add(session.key)
|
self._consolidating.add(session.key)
|
||||||
|
lock = self._get_consolidation_lock(session.key)
|
||||||
|
|
||||||
async def _consolidate_and_unlock():
|
async def _consolidate_and_unlock():
|
||||||
try:
|
try:
|
||||||
|
async with lock:
|
||||||
await self._consolidate_memory(session)
|
await self._consolidate_memory(session)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
|
task = asyncio.current_task()
|
||||||
|
if task is not None:
|
||||||
|
self._consolidation_tasks.discard(task)
|
||||||
|
|
||||||
asyncio.create_task(_consolidate_and_unlock())
|
task = asyncio.create_task(_consolidate_and_unlock())
|
||||||
|
self._consolidation_tasks.add(task)
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
@@ -354,13 +393,18 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _bus_progress(content: str) -> None:
|
async def _bus_progress(content: str) -> None:
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
OutboundMessage(
|
||||||
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=content,
|
||||||
metadata=msg.metadata or {},
|
metadata=msg.metadata or {},
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
final_content, tools_used = await self._run_agent_loop(
|
final_content, tools_used = await self._run_agent_loop(
|
||||||
initial_messages, on_progress=on_progress or _bus_progress,
|
initial_messages,
|
||||||
|
on_progress=on_progress or _bus_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
@@ -370,15 +414,17 @@ class AgentLoop:
|
|||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
|
|
||||||
session.add_message("user", msg.content)
|
session.add_message("user", msg.content)
|
||||||
session.add_message("assistant", final_content,
|
session.add_message(
|
||||||
tools_used=tools_used if tools_used else None)
|
"assistant", final_content, tools_used=tools_used if tools_used else None
|
||||||
|
)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
content=final_content,
|
content=final_content,
|
||||||
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
|
metadata=msg.metadata
|
||||||
|
or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
||||||
@@ -419,9 +465,7 @@ class AgentLoop:
|
|||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=origin_channel,
|
channel=origin_channel, chat_id=origin_chat_id, content=final_content
|
||||||
chat_id=origin_chat_id,
|
|
||||||
content=final_content
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
||||||
@@ -431,34 +475,54 @@ class AgentLoop:
|
|||||||
archive_all: If True, clear all messages and reset session (for /new command).
|
archive_all: If True, clear all messages and reset session (for /new command).
|
||||||
If False, only write to files without modifying session.
|
If False, only write to files without modifying session.
|
||||||
"""
|
"""
|
||||||
memory = MemoryStore(self.workspace)
|
memory = self.context.memory
|
||||||
|
|
||||||
if archive_all:
|
if archive_all:
|
||||||
old_messages = session.messages
|
old_messages = session.messages
|
||||||
keep_count = 0
|
keep_count = 0
|
||||||
logger.info("Memory consolidation (archive_all): {} total messages archived", len(session.messages))
|
logger.info(
|
||||||
|
"Memory consolidation (archive_all): {} total messages archived",
|
||||||
|
len(session.messages),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
keep_count = self.memory_window // 2
|
keep_count = self.memory_window // 2
|
||||||
if len(session.messages) <= keep_count:
|
if len(session.messages) <= keep_count:
|
||||||
logger.debug("Session {}: No consolidation needed (messages={}, keep={})", session.key, len(session.messages), keep_count)
|
logger.debug(
|
||||||
|
"Session {}: No consolidation needed (messages={}, keep={})",
|
||||||
|
session.key,
|
||||||
|
len(session.messages),
|
||||||
|
keep_count,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
messages_to_process = len(session.messages) - session.last_consolidated
|
messages_to_process = len(session.messages) - session.last_consolidated
|
||||||
if messages_to_process <= 0:
|
if messages_to_process <= 0:
|
||||||
logger.debug("Session {}: No new messages to consolidate (last_consolidated={}, total={})", session.key, session.last_consolidated, len(session.messages))
|
logger.debug(
|
||||||
|
"Session {}: No new messages to consolidate (last_consolidated={}, total={})",
|
||||||
|
session.key,
|
||||||
|
session.last_consolidated,
|
||||||
|
len(session.messages),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
old_messages = session.messages[session.last_consolidated : -keep_count]
|
old_messages = session.messages[session.last_consolidated : -keep_count]
|
||||||
if not old_messages:
|
if not old_messages:
|
||||||
return
|
return
|
||||||
logger.info("Memory consolidation started: {} total, {} new to consolidate, {} keep", len(session.messages), len(old_messages), keep_count)
|
logger.info(
|
||||||
|
"Memory consolidation started: {} total, {} new to consolidate, {} keep",
|
||||||
|
len(session.messages),
|
||||||
|
len(old_messages),
|
||||||
|
keep_count,
|
||||||
|
)
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
for m in old_messages:
|
for m in old_messages:
|
||||||
if not m.get("content"):
|
if not m.get("content"):
|
||||||
continue
|
continue
|
||||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
lines.append(
|
||||||
|
f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}"
|
||||||
|
)
|
||||||
conversation = "\n".join(lines)
|
conversation = "\n".join(lines)
|
||||||
current_memory = memory.read_long_term()
|
current_memory = memory.read_long_term()
|
||||||
|
|
||||||
@@ -487,7 +551,10 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
try:
|
try:
|
||||||
response = await self.provider.chat(
|
response = await self.provider.chat(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a memory consolidation agent. Respond only with valid JSON."},
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a memory consolidation agent. Respond only with valid JSON.",
|
||||||
|
},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -500,7 +567,10 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||||
result = json_repair.loads(text)
|
result = json_repair.loads(text)
|
||||||
if not isinstance(result, dict):
|
if not isinstance(result, dict):
|
||||||
logger.warning("Memory consolidation: unexpected response type, skipping. Response: {}", text[:200])
|
logger.warning(
|
||||||
|
"Memory consolidation: unexpected response type, skipping. Response: {}",
|
||||||
|
text[:200],
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if entry := result.get("history_entry"):
|
if entry := result.get("history_entry"):
|
||||||
@@ -519,7 +589,11 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
session.last_consolidated = 0
|
session.last_consolidated = 0
|
||||||
else:
|
else:
|
||||||
session.last_consolidated = len(session.messages) - keep_count
|
session.last_consolidated = len(session.messages) - keep_count
|
||||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
logger.info(
|
||||||
|
"Memory consolidation done: {} messages, last_consolidated={}",
|
||||||
|
len(session.messages),
|
||||||
|
session.last_consolidated,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Memory consolidation failed: {}", e)
|
logger.error("Memory consolidation failed: {}", e)
|
||||||
|
|
||||||
@@ -545,12 +619,9 @@ Respond with ONLY valid JSON, no markdown fences."""
|
|||||||
The agent's response.
|
The agent's response.
|
||||||
"""
|
"""
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||||
channel=channel,
|
|
||||||
sender_id="user",
|
|
||||||
chat_id=chat_id,
|
|
||||||
content=content
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
response = await self._process_message(
|
||||||
|
msg, session_key=session_key, on_progress=on_progress
|
||||||
|
)
|
||||||
return response.content if response else ""
|
return response.content if response else ""
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""Test session management with cache-friendly message handling."""
|
"""Test session management with cache-friendly message handling."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
@@ -475,3 +478,246 @@ class TestEmptyAndBoundarySessions:
|
|||||||
expected_count = 60 - KEEP_COUNT - 10
|
expected_count = 60 - KEEP_COUNT - 10
|
||||||
assert len(old_messages) == expected_count
|
assert len(old_messages) == expected_count
|
||||||
assert_messages_content(old_messages, 10, 34)
|
assert_messages_content(old_messages, 10, 34)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsolidationDeduplicationGuard:
|
||||||
|
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||||
|
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
consolidation_calls = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
nonlocal consolidation_calls
|
||||||
|
consolidation_calls += 1
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert consolidation_calls == 1, (
|
||||||
|
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
consolidation_calls = 0
|
||||||
|
active = 0
|
||||||
|
max_active = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
nonlocal consolidation_calls, active, max_active
|
||||||
|
consolidation_calls += 1
|
||||||
|
active += 1
|
||||||
|
max_active = max(max_active, active)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
active -= 1
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
await loop._process_message(new_msg)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
assert consolidation_calls == 2, (
|
||||||
|
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||||
|
)
|
||||||
|
assert max_active == 1, (
|
||||||
|
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||||
|
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
started.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
|
||||||
|
await started.wait()
|
||||||
|
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||||
|
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
assert len(loop._consolidation_tasks) == 0, (
|
||||||
|
"Task reference must be removed after completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new waits for in-flight consolidation and archives before clear."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
archived_count = 0
|
||||||
|
|
||||||
|
async def _fake_consolidate(sess, archive_all: bool = False) -> None:
|
||||||
|
nonlocal archived_count
|
||||||
|
if archive_all:
|
||||||
|
archived_count = len(sess.messages)
|
||||||
|
return
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||||
|
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
response = await pending_new
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||||
|
|
||||||
|
session_after = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||||
|
"""/new keeps session data if archive step fails."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(5):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
before_count = len(session.messages)
|
||||||
|
|
||||||
|
async def _failing_consolidate(_session, archive_all: bool = False) -> None:
|
||||||
|
if archive_all:
|
||||||
|
raise RuntimeError("forced archive failure")
|
||||||
|
|
||||||
|
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "failed" in response.content.lower()
|
||||||
|
session_after = loop.sessions.get_or_create("cli:test")
|
||||||
|
assert len(session_after.messages) == before_count, (
|
||||||
|
"Session must remain intact when /new archival fails"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user