perf: background post-response memory consolidation for faster replies
This commit is contained in:
@@ -100,7 +100,7 @@ class AgentLoop:
|
|||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._pending_archives: list[asyncio.Task] = []
|
self._background_tasks: list[asyncio.Task] = []
|
||||||
self._processing_lock = asyncio.Lock()
|
self._processing_lock = asyncio.Lock()
|
||||||
self.memory_consolidator = MemoryConsolidator(
|
self.memory_consolidator = MemoryConsolidator(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
@@ -257,8 +257,6 @@ class AgentLoop:
|
|||||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||||
self._running = True
|
self._running = True
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
# Start background consolidation task
|
|
||||||
await self.memory_consolidator.start_background_task()
|
|
||||||
logger.info("Agent loop started")
|
logger.info("Agent loop started")
|
||||||
|
|
||||||
while self._running:
|
while self._running:
|
||||||
@@ -334,9 +332,9 @@ class AgentLoop:
|
|||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Drain pending background archives, then close MCP connections."""
|
"""Drain pending background archives, then close MCP connections."""
|
||||||
if self._pending_archives:
|
if self._background_tasks:
|
||||||
await asyncio.gather(*self._pending_archives, return_exceptions=True)
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
self._pending_archives.clear()
|
self._background_tasks.clear()
|
||||||
if self._mcp_stack:
|
if self._mcp_stack:
|
||||||
try:
|
try:
|
||||||
await self._mcp_stack.aclose()
|
await self._mcp_stack.aclose()
|
||||||
@@ -344,11 +342,16 @@ class AgentLoop:
|
|||||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||||
self._mcp_stack = None
|
self._mcp_stack = None
|
||||||
|
|
||||||
async def stop(self) -> None:
|
def _schedule_background(self, coro) -> None:
|
||||||
"""Stop the agent loop and background tasks."""
|
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||||
|
task = asyncio.create_task(coro)
|
||||||
|
self._background_tasks.append(task)
|
||||||
|
task.add_done_callback(self._background_tasks.remove)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the agent loop."""
|
||||||
self._running = False
|
self._running = False
|
||||||
await self.memory_consolidator.stop_background_task()
|
logger.info("Agent loop stopping")
|
||||||
logger.info("Agent loop stopped")
|
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
@@ -364,8 +367,7 @@ class AgentLoop:
|
|||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
self.memory_consolidator.record_activity(key)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens_async(session)
|
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
@@ -375,6 +377,7 @@ class AgentLoop:
|
|||||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
content=final_content or "Background task completed.")
|
content=final_content or "Background task completed.")
|
||||||
|
|
||||||
@@ -383,7 +386,6 @@ class AgentLoop:
|
|||||||
|
|
||||||
key = session_key or msg.session_key
|
key = session_key or msg.session_key
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
self.memory_consolidator.record_activity(key)
|
|
||||||
|
|
||||||
# Slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
@@ -394,11 +396,7 @@ class AgentLoop:
|
|||||||
self.sessions.invalidate(session.key)
|
self.sessions.invalidate(session.key)
|
||||||
|
|
||||||
if snapshot:
|
if snapshot:
|
||||||
task = asyncio.create_task(
|
self._schedule_background(self.memory_consolidator.archive_messages(snapshot))
|
||||||
self.memory_consolidator.archive_messages(snapshot)
|
|
||||||
)
|
|
||||||
self._pending_archives.append(task)
|
|
||||||
task.add_done_callback(self._pending_archives.remove)
|
|
||||||
|
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="New session started.")
|
content="New session started.")
|
||||||
@@ -413,8 +411,7 @@ class AgentLoop:
|
|||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
||||||
)
|
)
|
||||||
# Record activity and schedule background consolidation for non-slash commands
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
self.memory_consolidator.record_activity(key)
|
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
@@ -446,6 +443,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
|
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -220,14 +220,9 @@ class MemoryStore:
|
|||||||
|
|
||||||
|
|
||||||
class MemoryConsolidator:
|
class MemoryConsolidator:
|
||||||
"""Owns consolidation policy, locking, and session offset updates.
|
"""Owns consolidation policy, locking, and session offset updates."""
|
||||||
|
|
||||||
Consolidation runs asynchronously in the background when sessions are idle,
|
|
||||||
so it doesn't block user interactions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||||
_IDLE_CHECK_INTERVAL = 30 # seconds between idle checks
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -247,57 +242,11 @@ class MemoryConsolidator:
|
|||||||
self._build_messages = build_messages
|
self._build_messages = build_messages
|
||||||
self._get_tool_definitions = get_tool_definitions
|
self._get_tool_definitions = get_tool_definitions
|
||||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||||
self._background_task: asyncio.Task[None] | None = None
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._session_last_activity: dict[str, float] = {} # session_key -> last activity timestamp
|
|
||||||
|
|
||||||
def get_lock(self, session_key: str) -> asyncio.Lock:
|
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||||
"""Return the shared consolidation lock for one session."""
|
"""Return the shared consolidation lock for one session."""
|
||||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||||
|
|
||||||
def record_activity(self, session_key: str) -> None:
|
|
||||||
"""Record that a session is active (for idle detection)."""
|
|
||||||
self._session_last_activity[session_key] = asyncio.get_event_loop().time()
|
|
||||||
|
|
||||||
async def start_background_task(self) -> None:
|
|
||||||
"""Start the background task that checks for idle sessions and consolidates."""
|
|
||||||
if self._background_task is not None and not self._background_task.done():
|
|
||||||
return # Already running
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._background_task = asyncio.create_task(self._idle_consolidation_loop())
|
|
||||||
|
|
||||||
async def stop_background_task(self) -> None:
|
|
||||||
"""Stop the background task."""
|
|
||||||
self._stop_event.set()
|
|
||||||
if self._background_task is not None and not self._background_task.done():
|
|
||||||
self._background_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._background_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._background_task = None
|
|
||||||
|
|
||||||
async def _idle_consolidation_loop(self) -> None:
|
|
||||||
"""Background loop that checks for idle sessions and triggers consolidation."""
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(self._IDLE_CHECK_INTERVAL)
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check all sessions for idleness
|
|
||||||
current_time = asyncio.get_event_loop().time()
|
|
||||||
for session in list(self.sessions.all()):
|
|
||||||
last_active = self._session_last_activity.get(session.key, 0)
|
|
||||||
if current_time - last_active > self._IDLE_CHECK_INTERVAL * 2:
|
|
||||||
# Session is idle, trigger consolidation
|
|
||||||
await self.maybe_consolidate_by_tokens_async(session)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error in background consolidation loop")
|
|
||||||
|
|
||||||
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||||
"""Archive a selected message chunk into persistent memory."""
|
"""Archive a selected message chunk into persistent memory."""
|
||||||
return await self.store.consolidate(messages, self.provider, self.model)
|
return await self.store.consolidate(messages, self.provider, self.model)
|
||||||
@@ -350,26 +299,8 @@ class MemoryConsolidator:
|
|||||||
return True
|
return True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||||
"""Schedule token-based consolidation to run asynchronously in background.
|
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||||
|
|
||||||
This method is synchronous and just schedules the consolidation task.
|
|
||||||
The actual consolidation runs in the background when the session is idle.
|
|
||||||
"""
|
|
||||||
if not session.messages or self.context_window_tokens <= 0:
|
|
||||||
return
|
|
||||||
# Schedule for background execution
|
|
||||||
asyncio.create_task(self._schedule_consolidation(session))
|
|
||||||
|
|
||||||
async def _schedule_consolidation(self, session: Session) -> None:
|
|
||||||
"""Internal method to run consolidation asynchronously."""
|
|
||||||
await self.maybe_consolidate_by_tokens_async(session)
|
|
||||||
|
|
||||||
async def maybe_consolidate_by_tokens_async(self, session: Session) -> None:
|
|
||||||
"""Async version: Loop and archive old messages until prompt fits within half the context window.
|
|
||||||
|
|
||||||
This is called from the background task when a session is idle.
|
|
||||||
"""
|
|
||||||
if not session.messages or self.context_window_tokens <= 0:
|
if not session.messages or self.context_window_tokens <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -424,11 +355,3 @@ class MemoryConsolidator:
|
|||||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
if estimated <= 0:
|
if estimated <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Token consolidation complete for {}: {}/{} via {}",
|
|
||||||
session.key,
|
|
||||||
estimated,
|
|
||||||
self.context_window_tokens,
|
|
||||||
source,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
"""Pytest configuration for nanobot tests."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def enable_asyncio_auto_mode():
|
|
||||||
"""Auto-configure asyncio mode for all async tests."""
|
|
||||||
pass
|
|
||||||
@@ -1,411 +0,0 @@
|
|||||||
"""Test async memory consolidation background task.
|
|
||||||
|
|
||||||
Tests for the new async background consolidation feature where token-based
|
|
||||||
consolidation runs when sessions are idle instead of blocking user interactions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.agent.memory import MemoryConsolidator
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.providers.base import LLMResponse
|
|
||||||
|
|
||||||
|
|
||||||
class TestMemoryConsolidatorBackgroundTask:
|
|
||||||
"""Tests for the background consolidation task."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_start_and_stop_background_task(self, tmp_path) -> None:
|
|
||||||
"""Test that background task can be started and stopped cleanly."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start background task
|
|
||||||
await consolidator.start_background_task()
|
|
||||||
assert consolidator._background_task is not None
|
|
||||||
assert not consolidator._stop_event.is_set()
|
|
||||||
|
|
||||||
# Stop background task
|
|
||||||
await consolidator.stop_background_task()
|
|
||||||
assert consolidator._background_task is None or consolidator._background_task.done()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_background_loop_checks_idle_sessions(self, tmp_path) -> None:
|
|
||||||
"""Test that background loop checks for idle sessions."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
session1 = MagicMock()
|
|
||||||
session1.key = "cli:session1"
|
|
||||||
session1.messages = [{"role": "user", "content": "msg"}]
|
|
||||||
session2 = MagicMock()
|
|
||||||
session2.key = "cli:session2"
|
|
||||||
session2.messages = []
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[session1, session2])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark session1 as recently active (should not consolidate)
|
|
||||||
consolidator._session_last_activity["cli:session1"] = asyncio.get_event_loop().time()
|
|
||||||
# Leave session2 without activity record (should be considered idle)
|
|
||||||
|
|
||||||
# Mock maybe_consolidate_by_tokens_async to track calls
|
|
||||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
|
||||||
|
|
||||||
# Run the background loop with a very short interval for testing
|
|
||||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
|
||||||
# Start task and let it run briefly
|
|
||||||
await consolidator.start_background_task()
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
await consolidator.stop_background_task()
|
|
||||||
|
|
||||||
# session2 should have been checked for consolidation (it's idle)
|
|
||||||
# session1 should not have been consolidated (recently active)
|
|
||||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_record_activity_updates_timestamp(self, tmp_path) -> None:
|
|
||||||
"""Test that record_activity updates the activity timestamp."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initially no activity recorded
|
|
||||||
assert "cli:test" not in consolidator._session_last_activity
|
|
||||||
|
|
||||||
# Record activity
|
|
||||||
consolidator.record_activity("cli:test")
|
|
||||||
assert "cli:test" in consolidator._session_last_activity
|
|
||||||
|
|
||||||
# Wait a bit and check timestamp changed
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
consolidator.record_activity("cli:test")
|
|
||||||
# The timestamp should have updated (though we can't easily verify the exact value)
|
|
||||||
assert consolidator._session_last_activity["cli:test"] > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_maybe_consolidate_by_tokens_schedules_async_task(self, tmp_path) -> None:
|
|
||||||
"""Test that maybe_consolidate_by_tokens schedules an async task."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
|
|
||||||
session = MagicMock()
|
|
||||||
session.messages = [{"role": "user", "content": "msg"}]
|
|
||||||
session.key = "cli:test"
|
|
||||||
session.context_window_tokens = 200
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[session])
|
|
||||||
sessions.save = MagicMock()
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the async version to track calls
|
|
||||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
|
||||||
|
|
||||||
# Call the synchronous method - should schedule a task
|
|
||||||
consolidator.maybe_consolidate_by_tokens(session)
|
|
||||||
|
|
||||||
# The async version should have been scheduled via create_task
|
|
||||||
await asyncio.sleep(0.1) # Let the task start
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentLoopIntegration:
|
|
||||||
"""Integration tests for AgentLoop with background consolidation."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_loop_starts_background_task(self, tmp_path) -> None:
|
|
||||||
"""Test that run() starts the background consolidation task."""
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus,
|
|
||||||
provider=provider,
|
|
||||||
workspace=tmp_path,
|
|
||||||
model="test-model",
|
|
||||||
context_window_tokens=200,
|
|
||||||
)
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
# Start the loop in background
|
|
||||||
import asyncio
|
|
||||||
run_task = asyncio.create_task(loop.run())
|
|
||||||
|
|
||||||
# Give it time to start the background task
|
|
||||||
await asyncio.sleep(0.3)
|
|
||||||
|
|
||||||
# Background task should be started
|
|
||||||
assert loop.memory_consolidator._background_task is not None
|
|
||||||
|
|
||||||
# Stop the loop
|
|
||||||
await loop.stop()
|
|
||||||
await run_task
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_loop_stops_background_task(self, tmp_path) -> None:
|
|
||||||
"""Test that stop() stops the background consolidation task."""
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
|
|
||||||
loop = AgentLoop(
|
|
||||||
bus=bus,
|
|
||||||
provider=provider,
|
|
||||||
workspace=tmp_path,
|
|
||||||
model="test-model",
|
|
||||||
context_window_tokens=200,
|
|
||||||
)
|
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
# Start the loop in background
|
|
||||||
run_task = asyncio.create_task(loop.run())
|
|
||||||
await asyncio.sleep(0.3)
|
|
||||||
|
|
||||||
# Stop via async stop method
|
|
||||||
await loop.stop()
|
|
||||||
|
|
||||||
# Background task should be stopped
|
|
||||||
assert loop.memory_consolidator._background_task is None or \
|
|
||||||
loop.memory_consolidator._background_task.done()
|
|
||||||
|
|
||||||
|
|
||||||
class TestIdleDetection:
|
|
||||||
"""Tests for idle session detection logic."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recently_active_session_not_considered_idle(self, tmp_path) -> None:
|
|
||||||
"""Test that recently active sessions are not consolidated."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
session = MagicMock()
|
|
||||||
session.key = "cli:active"
|
|
||||||
session.messages = [{"role": "user", "content": "msg"}]
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[session])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark as recently active (within idle threshold)
|
|
||||||
current_time = asyncio.get_event_loop().time()
|
|
||||||
consolidator._session_last_activity["cli:active"] = current_time
|
|
||||||
|
|
||||||
# Mock maybe_consolidate_by_tokens_async to track calls
|
|
||||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
|
||||||
await consolidator.start_background_task()
|
|
||||||
# Sleep less than 2 * interval to ensure session remains active
|
|
||||||
await asyncio.sleep(0.15)
|
|
||||||
await consolidator.stop_background_task()
|
|
||||||
|
|
||||||
# Should not have been called for recently active session
|
|
||||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_idle_session_triggers_consolidation(self, tmp_path) -> None:
|
|
||||||
"""Test that idle sessions trigger consolidation."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
session = MagicMock()
|
|
||||||
session.key = "cli:idle"
|
|
||||||
session.messages = [{"role": "user", "content": "msg"}]
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[session])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark as inactive (older than idle threshold)
|
|
||||||
current_time = asyncio.get_event_loop().time()
|
|
||||||
consolidator._session_last_activity["cli:idle"] = current_time - 10 # 10 seconds ago
|
|
||||||
|
|
||||||
# Mock maybe_consolidate_by_tokens_async to track calls
|
|
||||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
|
||||||
|
|
||||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
|
||||||
await consolidator.start_background_task()
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
await consolidator.stop_background_task()
|
|
||||||
|
|
||||||
# Should have been called for idle session
|
|
||||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestScheduleConsolidation:
|
|
||||||
"""Tests for the schedule consolidation mechanism."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_schedule_consolidation_runs_async_version(self, tmp_path) -> None:
|
|
||||||
"""Test that scheduling runs the async version."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
session = MagicMock()
|
|
||||||
session.messages = [{"role": "user", "content": "msg"}]
|
|
||||||
session.key = "cli:scheduled"
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[session])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the async version to track calls
|
|
||||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
|
||||||
|
|
||||||
# Schedule consolidation
|
|
||||||
await consolidator._schedule_consolidation(session)
|
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestBackgroundTaskCancellation:
|
|
||||||
"""Tests for background task cancellation and error handling."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_background_task_handles_exceptions_gracefully(self, tmp_path) -> None:
|
|
||||||
"""Test that exceptions in the loop don't crash it."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock maybe_consolidate_by_tokens_async to raise an exception
|
|
||||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock( # type: ignore[method-assign]
|
|
||||||
side_effect=Exception("Test exception")
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
|
||||||
await consolidator.start_background_task()
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
# Task should still be running despite exceptions
|
|
||||||
assert consolidator._background_task is not None
|
|
||||||
await consolidator.stop_background_task()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_stop_cancels_running_task(self, tmp_path) -> None:
|
|
||||||
"""Test that stop properly cancels a running task."""
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
|
||||||
|
|
||||||
sessions = MagicMock()
|
|
||||||
sessions.all = MagicMock(return_value=[])
|
|
||||||
|
|
||||||
consolidator = MemoryConsolidator(
|
|
||||||
workspace=tmp_path,
|
|
||||||
provider=provider,
|
|
||||||
model="test-model",
|
|
||||||
sessions=sessions,
|
|
||||||
context_window_tokens=200,
|
|
||||||
build_messages=lambda **kw: [],
|
|
||||||
get_tool_definitions=lambda: [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start a task that will sleep for a while
|
|
||||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 10): # Long interval
|
|
||||||
await consolidator.start_background_task()
|
|
||||||
# Task should be running
|
|
||||||
assert consolidator._background_task is not None
|
|
||||||
|
|
||||||
# Stop should cancel it
|
|
||||||
await consolidator.stop_background_task()
|
|
||||||
|
|
||||||
# Verify task was cancelled or completed
|
|
||||||
assert consolidator._background_task is None or \
|
|
||||||
consolidator._background_task.done()
|
|
||||||
@@ -591,8 +591,8 @@ class TestNewCommandArchival:
|
|||||||
assert loop.sessions.get_or_create("cli:test").messages == []
|
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_close_mcp_drains_pending_archives(self, tmp_path: Path) -> None:
|
async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
|
||||||
"""close_mcp waits for background archive tasks to complete."""
|
"""close_mcp waits for background tasks to complete."""
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
loop = self._make_loop(tmp_path)
|
loop = self._make_loop(tmp_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user