perf: background post-response memory consolidation for faster replies
This commit is contained in:
@@ -100,7 +100,7 @@ class AgentLoop:
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._pending_archives: list[asyncio.Task] = []
|
||||
self._background_tasks: list[asyncio.Task] = []
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self.memory_consolidator = MemoryConsolidator(
|
||||
workspace=workspace,
|
||||
@@ -257,8 +257,6 @@ class AgentLoop:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
self._running = True
|
||||
await self._connect_mcp()
|
||||
# Start background consolidation task
|
||||
await self.memory_consolidator.start_background_task()
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
@@ -334,9 +332,9 @@ class AgentLoop:
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
if self._pending_archives:
|
||||
await asyncio.gather(*self._pending_archives, return_exceptions=True)
|
||||
self._pending_archives.clear()
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
@@ -344,11 +342,16 @@ class AgentLoop:
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the agent loop and background tasks."""
|
||||
def _schedule_background(self, coro) -> None:
|
||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.append(task)
|
||||
task.add_done_callback(self._background_tasks.remove)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
await self.memory_consolidator.stop_background_task()
|
||||
logger.info("Agent loop stopped")
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(
|
||||
self,
|
||||
@@ -364,8 +367,7 @@ class AgentLoop:
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
self.memory_consolidator.record_activity(key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens_async(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
messages = self.context.build_messages(
|
||||
@@ -375,6 +377,7 @@ class AgentLoop:
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
@@ -383,7 +386,6 @@ class AgentLoop:
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
self.memory_consolidator.record_activity(key)
|
||||
|
||||
# Slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
@@ -394,11 +396,7 @@ class AgentLoop:
|
||||
self.sessions.invalidate(session.key)
|
||||
|
||||
if snapshot:
|
||||
task = asyncio.create_task(
|
||||
self.memory_consolidator.archive_messages(snapshot)
|
||||
)
|
||||
self._pending_archives.append(task)
|
||||
task.add_done_callback(self._pending_archives.remove)
|
||||
self._schedule_background(self.memory_consolidator.archive_messages(snapshot))
|
||||
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="New session started.")
|
||||
@@ -413,8 +411,7 @@ class AgentLoop:
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
||||
)
|
||||
# Record activity and schedule background consolidation for non-slash commands
|
||||
self.memory_consolidator.record_activity(key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
@@ -446,6 +443,7 @@ class AgentLoop:
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
|
||||
@@ -220,14 +220,9 @@ class MemoryStore:
|
||||
|
||||
|
||||
class MemoryConsolidator:
|
||||
"""Owns consolidation policy, locking, and session offset updates.
|
||||
|
||||
Consolidation runs asynchronously in the background when sessions are idle,
|
||||
so it doesn't block user interactions.
|
||||
"""
|
||||
"""Owns consolidation policy, locking, and session offset updates."""
|
||||
|
||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||
_IDLE_CHECK_INTERVAL = 30 # seconds between idle checks
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -247,57 +242,11 @@ class MemoryConsolidator:
|
||||
self._build_messages = build_messages
|
||||
self._get_tool_definitions = get_tool_definitions
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
self._background_task: asyncio.Task[None] | None = None
|
||||
self._stop_event = asyncio.Event()
|
||||
self._session_last_activity: dict[str, float] = {} # session_key -> last activity timestamp
|
||||
|
||||
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||
"""Return the shared consolidation lock for one session."""
|
||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
def record_activity(self, session_key: str) -> None:
|
||||
"""Record that a session is active (for idle detection)."""
|
||||
self._session_last_activity[session_key] = asyncio.get_event_loop().time()
|
||||
|
||||
async def start_background_task(self) -> None:
|
||||
"""Start the background task that checks for idle sessions and consolidates."""
|
||||
if self._background_task is not None and not self._background_task.done():
|
||||
return # Already running
|
||||
self._stop_event.clear()
|
||||
self._background_task = asyncio.create_task(self._idle_consolidation_loop())
|
||||
|
||||
async def stop_background_task(self) -> None:
|
||||
"""Stop the background task."""
|
||||
self._stop_event.set()
|
||||
if self._background_task is not None and not self._background_task.done():
|
||||
self._background_task.cancel()
|
||||
try:
|
||||
await self._background_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._background_task = None
|
||||
|
||||
async def _idle_consolidation_loop(self) -> None:
|
||||
"""Background loop that checks for idle sessions and triggers consolidation."""
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
await asyncio.sleep(self._IDLE_CHECK_INTERVAL)
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
|
||||
# Check all sessions for idleness
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
for session in list(self.sessions.all()):
|
||||
last_active = self._session_last_activity.get(session.key, 0)
|
||||
if current_time - last_active > self._IDLE_CHECK_INTERVAL * 2:
|
||||
# Session is idle, trigger consolidation
|
||||
await self.maybe_consolidate_by_tokens_async(session)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Error in background consolidation loop")
|
||||
|
||||
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||
"""Archive a selected message chunk into persistent memory."""
|
||||
return await self.store.consolidate(messages, self.provider, self.model)
|
||||
@@ -350,26 +299,8 @@ class MemoryConsolidator:
|
||||
return True
|
||||
return True
|
||||
|
||||
def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
"""Schedule token-based consolidation to run asynchronously in background.
|
||||
|
||||
This method is synchronous and just schedules the consolidation task.
|
||||
The actual consolidation runs in the background when the session is idle.
|
||||
"""
|
||||
if not session.messages or self.context_window_tokens <= 0:
|
||||
return
|
||||
# Schedule for background execution
|
||||
asyncio.create_task(self._schedule_consolidation(session))
|
||||
|
||||
async def _schedule_consolidation(self, session: Session) -> None:
|
||||
"""Internal method to run consolidation asynchronously."""
|
||||
await self.maybe_consolidate_by_tokens_async(session)
|
||||
|
||||
async def maybe_consolidate_by_tokens_async(self, session: Session) -> None:
|
||||
"""Async version: Loop and archive old messages until prompt fits within half the context window.
|
||||
|
||||
This is called from the background task when a session is idle.
|
||||
"""
|
||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||
if not session.messages or self.context_window_tokens <= 0:
|
||||
return
|
||||
|
||||
@@ -424,11 +355,3 @@ class MemoryConsolidator:
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
if estimated <= 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
"Token consolidation complete for {}: {}/{} via {}",
|
||||
session.key,
|
||||
estimated,
|
||||
self.context_window_tokens,
|
||||
source,
|
||||
)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
"""Pytest configuration for nanobot tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_asyncio_auto_mode():
|
||||
"""Auto-configure asyncio mode for all async tests."""
|
||||
pass
|
||||
@@ -1,411 +0,0 @@
|
||||
"""Test async memory consolidation background task.
|
||||
|
||||
Tests for the new async background consolidation feature where token-based
|
||||
consolidation runs when sessions are idle instead of blocking user interactions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
class TestMemoryConsolidatorBackgroundTask:
|
||||
"""Tests for the background consolidation task."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_and_stop_background_task(self, tmp_path) -> None:
|
||||
"""Test that background task can be started and stopped cleanly."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Start background task
|
||||
await consolidator.start_background_task()
|
||||
assert consolidator._background_task is not None
|
||||
assert not consolidator._stop_event.is_set()
|
||||
|
||||
# Stop background task
|
||||
await consolidator.stop_background_task()
|
||||
assert consolidator._background_task is None or consolidator._background_task.done()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_loop_checks_idle_sessions(self, tmp_path) -> None:
|
||||
"""Test that background loop checks for idle sessions."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
session1 = MagicMock()
|
||||
session1.key = "cli:session1"
|
||||
session1.messages = [{"role": "user", "content": "msg"}]
|
||||
session2 = MagicMock()
|
||||
session2.key = "cli:session2"
|
||||
session2.messages = []
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[session1, session2])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Mark session1 as recently active (should not consolidate)
|
||||
consolidator._session_last_activity["cli:session1"] = asyncio.get_event_loop().time()
|
||||
# Leave session2 without activity record (should be considered idle)
|
||||
|
||||
# Mock maybe_consolidate_by_tokens_async to track calls
|
||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
# Run the background loop with a very short interval for testing
|
||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
||||
# Start task and let it run briefly
|
||||
await consolidator.start_background_task()
|
||||
await asyncio.sleep(0.5)
|
||||
await consolidator.stop_background_task()
|
||||
|
||||
# session2 should have been checked for consolidation (it's idle)
|
||||
# session1 should not have been consolidated (recently active)
|
||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_activity_updates_timestamp(self, tmp_path) -> None:
|
||||
"""Test that record_activity updates the activity timestamp."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Initially no activity recorded
|
||||
assert "cli:test" not in consolidator._session_last_activity
|
||||
|
||||
# Record activity
|
||||
consolidator.record_activity("cli:test")
|
||||
assert "cli:test" in consolidator._session_last_activity
|
||||
|
||||
# Wait a bit and check timestamp changed
|
||||
await asyncio.sleep(0.1)
|
||||
consolidator.record_activity("cli:test")
|
||||
# The timestamp should have updated (though we can't easily verify the exact value)
|
||||
assert consolidator._session_last_activity["cli:test"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_consolidate_by_tokens_schedules_async_task(self, tmp_path) -> None:
|
||||
"""Test that maybe_consolidate_by_tokens schedules an async task."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
session = MagicMock()
|
||||
session.messages = [{"role": "user", "content": "msg"}]
|
||||
session.key = "cli:test"
|
||||
session.context_window_tokens = 200
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[session])
|
||||
sessions.save = MagicMock()
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Mock the async version to track calls
|
||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
# Call the synchronous method - should schedule a task
|
||||
consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
# The async version should have been scheduled via create_task
|
||||
await asyncio.sleep(0.1) # Let the task start
|
||||
|
||||
|
||||
class TestAgentLoopIntegration:
|
||||
"""Integration tests for AgentLoop with background consolidation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_starts_background_task(self, tmp_path) -> None:
|
||||
"""Test that run() starts the background consolidation task."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=200,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
# Start the loop in background
|
||||
import asyncio
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
|
||||
# Give it time to start the background task
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Background task should be started
|
||||
assert loop.memory_consolidator._background_task is not None
|
||||
|
||||
# Stop the loop
|
||||
await loop.stop()
|
||||
await run_task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stops_background_task(self, tmp_path) -> None:
|
||||
"""Test that stop() stops the background consolidation task."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=200,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
# Start the loop in background
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Stop via async stop method
|
||||
await loop.stop()
|
||||
|
||||
# Background task should be stopped
|
||||
assert loop.memory_consolidator._background_task is None or \
|
||||
loop.memory_consolidator._background_task.done()
|
||||
|
||||
|
||||
class TestIdleDetection:
|
||||
"""Tests for idle session detection logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recently_active_session_not_considered_idle(self, tmp_path) -> None:
|
||||
"""Test that recently active sessions are not consolidated."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
session = MagicMock()
|
||||
session.key = "cli:active"
|
||||
session.messages = [{"role": "user", "content": "msg"}]
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[session])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Mark as recently active (within idle threshold)
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
consolidator._session_last_activity["cli:active"] = current_time
|
||||
|
||||
# Mock maybe_consolidate_by_tokens_async to track calls
|
||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
||||
await consolidator.start_background_task()
|
||||
# Sleep less than 2 * interval to ensure session remains active
|
||||
await asyncio.sleep(0.15)
|
||||
await consolidator.stop_background_task()
|
||||
|
||||
# Should not have been called for recently active session
|
||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idle_session_triggers_consolidation(self, tmp_path) -> None:
|
||||
"""Test that idle sessions trigger consolidation."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
session = MagicMock()
|
||||
session.key = "cli:idle"
|
||||
session.messages = [{"role": "user", "content": "msg"}]
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[session])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Mark as inactive (older than idle threshold)
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
consolidator._session_last_activity["cli:idle"] = current_time - 10 # 10 seconds ago
|
||||
|
||||
# Mock maybe_consolidate_by_tokens_async to track calls
|
||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
||||
await consolidator.start_background_task()
|
||||
await asyncio.sleep(0.5)
|
||||
await consolidator.stop_background_task()
|
||||
|
||||
# Should have been called for idle session
|
||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 1
|
||||
|
||||
|
||||
class TestScheduleConsolidation:
|
||||
"""Tests for the schedule consolidation mechanism."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schedule_consolidation_runs_async_version(self, tmp_path) -> None:
|
||||
"""Test that scheduling runs the async version."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
session = MagicMock()
|
||||
session.messages = [{"role": "user", "content": "msg"}]
|
||||
session.key = "cli:scheduled"
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[session])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Mock the async version to track calls
|
||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
# Schedule consolidation
|
||||
await consolidator._schedule_consolidation(session)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 1
|
||||
|
||||
|
||||
class TestBackgroundTaskCancellation:
|
||||
"""Tests for background task cancellation and error handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_task_handles_exceptions_gracefully(self, tmp_path) -> None:
|
||||
"""Test that exceptions in the loop don't crash it."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Mock maybe_consolidate_by_tokens_async to raise an exception
|
||||
consolidator.maybe_consolidate_by_tokens_async = AsyncMock( # type: ignore[method-assign]
|
||||
side_effect=Exception("Test exception")
|
||||
)
|
||||
|
||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
|
||||
await consolidator.start_background_task()
|
||||
await asyncio.sleep(0.5)
|
||||
# Task should still be running despite exceptions
|
||||
assert consolidator._background_task is not None
|
||||
await consolidator.stop_background_task()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_running_task(self, tmp_path) -> None:
|
||||
"""Test that stop properly cancels a running task."""
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
sessions = MagicMock()
|
||||
sessions.all = MagicMock(return_value=[])
|
||||
|
||||
consolidator = MemoryConsolidator(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=200,
|
||||
build_messages=lambda **kw: [],
|
||||
get_tool_definitions=lambda: [],
|
||||
)
|
||||
|
||||
# Start a task that will sleep for a while
|
||||
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 10): # Long interval
|
||||
await consolidator.start_background_task()
|
||||
# Task should be running
|
||||
assert consolidator._background_task is not None
|
||||
|
||||
# Stop should cancel it
|
||||
await consolidator.stop_background_task()
|
||||
|
||||
# Verify task was cancelled or completed
|
||||
assert consolidator._background_task is None or \
|
||||
consolidator._background_task.done()
|
||||
@@ -591,8 +591,8 @@ class TestNewCommandArchival:
|
||||
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_mcp_drains_pending_archives(self, tmp_path: Path) -> None:
|
||||
"""close_mcp waits for background archive tasks to complete."""
|
||||
async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
|
||||
"""close_mcp waits for background tasks to complete."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
|
||||
Reference in New Issue
Block a user