perf: background post-response memory consolidation for faster replies

This commit is contained in:
Xubin Ren
2026-03-16 09:01:11 +00:00
parent 6d63e22e86
commit 46b19b15e1
5 changed files with 23 additions and 522 deletions

View File

@@ -100,7 +100,7 @@ class AgentLoop:
self._mcp_connected = False self._mcp_connected = False
self._mcp_connecting = False self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._pending_archives: list[asyncio.Task] = [] self._background_tasks: list[asyncio.Task] = []
self._processing_lock = asyncio.Lock() self._processing_lock = asyncio.Lock()
self.memory_consolidator = MemoryConsolidator( self.memory_consolidator = MemoryConsolidator(
workspace=workspace, workspace=workspace,
@@ -257,8 +257,6 @@ class AgentLoop:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" """Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
self._running = True self._running = True
await self._connect_mcp() await self._connect_mcp()
# Start background consolidation task
await self.memory_consolidator.start_background_task()
logger.info("Agent loop started") logger.info("Agent loop started")
while self._running: while self._running:
@@ -334,9 +332,9 @@ class AgentLoop:
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections.""" """Drain pending background archives, then close MCP connections."""
if self._pending_archives: if self._background_tasks:
await asyncio.gather(*self._pending_archives, return_exceptions=True) await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._pending_archives.clear() self._background_tasks.clear()
if self._mcp_stack: if self._mcp_stack:
try: try:
await self._mcp_stack.aclose() await self._mcp_stack.aclose()
@@ -344,11 +342,16 @@ class AgentLoop:
pass # MCP SDK cancel scope cleanup is noisy but harmless pass # MCP SDK cancel scope cleanup is noisy but harmless
self._mcp_stack = None self._mcp_stack = None
async def stop(self) -> None: def _schedule_background(self, coro) -> None:
"""Stop the agent loop and background tasks.""" """Schedule a coroutine as a tracked background task (drained on shutdown)."""
task = asyncio.create_task(coro)
self._background_tasks.append(task)
task.add_done_callback(self._background_tasks.remove)
def stop(self) -> None:
"""Stop the agent loop."""
self._running = False self._running = False
await self.memory_consolidator.stop_background_task() logger.info("Agent loop stopping")
logger.info("Agent loop stopped")
async def _process_message( async def _process_message(
self, self,
@@ -364,8 +367,7 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}" key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
self.memory_consolidator.record_activity(key) await self.memory_consolidator.maybe_consolidate_by_tokens(session)
await self.memory_consolidator.maybe_consolidate_by_tokens_async(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0) history = session.get_history(max_messages=0)
messages = self.context.build_messages( messages = self.context.build_messages(
@@ -375,6 +377,7 @@ class AgentLoop:
final_content, _, all_msgs = await self._run_agent_loop(messages) final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id, return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.") content=final_content or "Background task completed.")
@@ -383,7 +386,6 @@ class AgentLoop:
key = session_key or msg.session_key key = session_key or msg.session_key
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
self.memory_consolidator.record_activity(key)
# Slash commands # Slash commands
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
@@ -394,11 +396,7 @@ class AgentLoop:
self.sessions.invalidate(session.key) self.sessions.invalidate(session.key)
if snapshot: if snapshot:
task = asyncio.create_task( self._schedule_background(self.memory_consolidator.archive_messages(snapshot))
self.memory_consolidator.archive_messages(snapshot)
)
self._pending_archives.append(task)
task.add_done_callback(self._pending_archives.remove)
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started.") content="New session started.")
@@ -413,8 +411,7 @@ class AgentLoop:
return OutboundMessage( return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines), channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
) )
# Record activity and schedule background consolidation for non-slash commands await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self.memory_consolidator.record_activity(key)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"): if message_tool := self.tools.get("message"):
@@ -446,6 +443,7 @@ class AgentLoop:
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None return None

View File

@@ -220,14 +220,9 @@ class MemoryStore:
class MemoryConsolidator: class MemoryConsolidator:
"""Owns consolidation policy, locking, and session offset updates. """Owns consolidation policy, locking, and session offset updates."""
Consolidation runs asynchronously in the background when sessions are idle,
so it doesn't block user interactions.
"""
_MAX_CONSOLIDATION_ROUNDS = 5 _MAX_CONSOLIDATION_ROUNDS = 5
_IDLE_CHECK_INTERVAL = 30 # seconds between idle checks
def __init__( def __init__(
self, self,
@@ -247,57 +242,11 @@ class MemoryConsolidator:
self._build_messages = build_messages self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
self._background_task: asyncio.Task[None] | None = None
self._stop_event = asyncio.Event()
self._session_last_activity: dict[str, float] = {} # session_key -> last activity timestamp
def get_lock(self, session_key: str) -> asyncio.Lock: def get_lock(self, session_key: str) -> asyncio.Lock:
"""Return the shared consolidation lock for one session.""" """Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock()) return self._locks.setdefault(session_key, asyncio.Lock())
def record_activity(self, session_key: str) -> None:
"""Record that a session is active (for idle detection)."""
self._session_last_activity[session_key] = asyncio.get_event_loop().time()
async def start_background_task(self) -> None:
"""Start the background task that checks for idle sessions and consolidates."""
if self._background_task is not None and not self._background_task.done():
return # Already running
self._stop_event.clear()
self._background_task = asyncio.create_task(self._idle_consolidation_loop())
async def stop_background_task(self) -> None:
"""Stop the background task."""
self._stop_event.set()
if self._background_task is not None and not self._background_task.done():
self._background_task.cancel()
try:
await self._background_task
except asyncio.CancelledError:
pass
self._background_task = None
async def _idle_consolidation_loop(self) -> None:
"""Background loop that checks for idle sessions and triggers consolidation."""
while not self._stop_event.is_set():
try:
await asyncio.sleep(self._IDLE_CHECK_INTERVAL)
if self._stop_event.is_set():
break
# Check all sessions for idleness
current_time = asyncio.get_event_loop().time()
for session in list(self.sessions.all()):
last_active = self._session_last_activity.get(session.key, 0)
if current_time - last_active > self._IDLE_CHECK_INTERVAL * 2:
# Session is idle, trigger consolidation
await self.maybe_consolidate_by_tokens_async(session)
except asyncio.CancelledError:
break
except Exception:
logger.exception("Error in background consolidation loop")
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
"""Archive a selected message chunk into persistent memory.""" """Archive a selected message chunk into persistent memory."""
return await self.store.consolidate(messages, self.provider, self.model) return await self.store.consolidate(messages, self.provider, self.model)
@@ -350,26 +299,8 @@ class MemoryConsolidator:
return True return True
return True return True
def maybe_consolidate_by_tokens(self, session: Session) -> None: async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Schedule token-based consolidation to run asynchronously in background. """Loop: archive old messages until prompt fits within half the context window."""
This method is synchronous and just schedules the consolidation task.
The actual consolidation runs in the background when the session is idle.
"""
if not session.messages or self.context_window_tokens <= 0:
return
# Schedule for background execution
asyncio.create_task(self._schedule_consolidation(session))
async def _schedule_consolidation(self, session: Session) -> None:
"""Internal method to run consolidation asynchronously."""
await self.maybe_consolidate_by_tokens_async(session)
async def maybe_consolidate_by_tokens_async(self, session: Session) -> None:
"""Async version: Loop and archive old messages until prompt fits within half the context window.
This is called from the background task when a session is idle.
"""
if not session.messages or self.context_window_tokens <= 0: if not session.messages or self.context_window_tokens <= 0:
return return
@@ -424,11 +355,3 @@ class MemoryConsolidator:
estimated, source = self.estimate_session_prompt_tokens(session) estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0: if estimated <= 0:
return return
logger.debug(
"Token consolidation complete for {}: {}/{} via {}",
session.key,
estimated,
self.context_window_tokens,
source,
)

View File

@@ -1,9 +0,0 @@
"""Pytest configuration for nanobot tests."""
import pytest
@pytest.fixture(autouse=True)
def enable_asyncio_auto_mode():
"""Auto-configure asyncio mode for all async tests."""
pass

View File

@@ -1,411 +0,0 @@
"""Test async memory consolidation background task.
Tests for the new async background consolidation feature where token-based
consolidation runs when sessions are idle instead of blocking user interactions.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryConsolidator
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
class TestMemoryConsolidatorBackgroundTask:
"""Tests for the background consolidation task."""
@pytest.mark.asyncio
async def test_start_and_stop_background_task(self, tmp_path) -> None:
"""Test that background task can be started and stopped cleanly."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
sessions = MagicMock()
sessions.all = MagicMock(return_value=[])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Start background task
await consolidator.start_background_task()
assert consolidator._background_task is not None
assert not consolidator._stop_event.is_set()
# Stop background task
await consolidator.stop_background_task()
assert consolidator._background_task is None or consolidator._background_task.done()
@pytest.mark.asyncio
async def test_background_loop_checks_idle_sessions(self, tmp_path) -> None:
"""Test that background loop checks for idle sessions."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
session1 = MagicMock()
session1.key = "cli:session1"
session1.messages = [{"role": "user", "content": "msg"}]
session2 = MagicMock()
session2.key = "cli:session2"
session2.messages = []
sessions = MagicMock()
sessions.all = MagicMock(return_value=[session1, session2])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Mark session1 as recently active (should not consolidate)
consolidator._session_last_activity["cli:session1"] = asyncio.get_event_loop().time()
# Leave session2 without activity record (should be considered idle)
# Mock maybe_consolidate_by_tokens_async to track calls
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
# Run the background loop with a very short interval for testing
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
# Start task and let it run briefly
await consolidator.start_background_task()
await asyncio.sleep(0.5)
await consolidator.stop_background_task()
# session2 should have been checked for consolidation (it's idle)
# session1 should not have been consolidated (recently active)
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 0
@pytest.mark.asyncio
async def test_record_activity_updates_timestamp(self, tmp_path) -> None:
"""Test that record_activity updates the activity timestamp."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
sessions = MagicMock()
sessions.all = MagicMock(return_value=[])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Initially no activity recorded
assert "cli:test" not in consolidator._session_last_activity
# Record activity
consolidator.record_activity("cli:test")
assert "cli:test" in consolidator._session_last_activity
# Wait a bit and check timestamp changed
await asyncio.sleep(0.1)
consolidator.record_activity("cli:test")
# The timestamp should have updated (though we can't easily verify the exact value)
assert consolidator._session_last_activity["cli:test"] > 0
@pytest.mark.asyncio
async def test_maybe_consolidate_by_tokens_schedules_async_task(self, tmp_path) -> None:
"""Test that maybe_consolidate_by_tokens schedules an async task."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
session = MagicMock()
session.messages = [{"role": "user", "content": "msg"}]
session.key = "cli:test"
session.context_window_tokens = 200
sessions = MagicMock()
sessions.all = MagicMock(return_value=[session])
sessions.save = MagicMock()
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Mock the async version to track calls
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
# Call the synchronous method - should schedule a task
consolidator.maybe_consolidate_by_tokens(session)
# The async version should have been scheduled via create_task
await asyncio.sleep(0.1) # Let the task start
class TestAgentLoopIntegration:
"""Integration tests for AgentLoop with background consolidation."""
@pytest.mark.asyncio
async def test_loop_starts_background_task(self, tmp_path) -> None:
"""Test that run() starts the background consolidation task."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=200,
)
loop.tools.get_definitions = MagicMock(return_value=[])
# Start the loop in background
import asyncio
run_task = asyncio.create_task(loop.run())
# Give it time to start the background task
await asyncio.sleep(0.3)
# Background task should be started
assert loop.memory_consolidator._background_task is not None
# Stop the loop
await loop.stop()
await run_task
@pytest.mark.asyncio
async def test_loop_stops_background_task(self, tmp_path) -> None:
"""Test that stop() stops the background consolidation task."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=200,
)
loop.tools.get_definitions = MagicMock(return_value=[])
# Start the loop in background
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.3)
# Stop via async stop method
await loop.stop()
# Background task should be stopped
assert loop.memory_consolidator._background_task is None or \
loop.memory_consolidator._background_task.done()
class TestIdleDetection:
"""Tests for idle session detection logic."""
@pytest.mark.asyncio
async def test_recently_active_session_not_considered_idle(self, tmp_path) -> None:
"""Test that recently active sessions are not consolidated."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
session = MagicMock()
session.key = "cli:active"
session.messages = [{"role": "user", "content": "msg"}]
sessions = MagicMock()
sessions.all = MagicMock(return_value=[session])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Mark as recently active (within idle threshold)
current_time = asyncio.get_event_loop().time()
consolidator._session_last_activity["cli:active"] = current_time
# Mock maybe_consolidate_by_tokens_async to track calls
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
await consolidator.start_background_task()
# Sleep less than 2 * interval to ensure session remains active
await asyncio.sleep(0.15)
await consolidator.stop_background_task()
# Should not have been called for recently active session
assert consolidator.maybe_consolidate_by_tokens_async.await_count == 0
@pytest.mark.asyncio
async def test_idle_session_triggers_consolidation(self, tmp_path) -> None:
"""Test that idle sessions trigger consolidation."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
session = MagicMock()
session.key = "cli:idle"
session.messages = [{"role": "user", "content": "msg"}]
sessions = MagicMock()
sessions.all = MagicMock(return_value=[session])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Mark as inactive (older than idle threshold)
current_time = asyncio.get_event_loop().time()
consolidator._session_last_activity["cli:idle"] = current_time - 10 # 10 seconds ago
# Mock maybe_consolidate_by_tokens_async to track calls
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
await consolidator.start_background_task()
await asyncio.sleep(0.5)
await consolidator.stop_background_task()
# Should have been called for idle session
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 1
class TestScheduleConsolidation:
"""Tests for the schedule consolidation mechanism."""
@pytest.mark.asyncio
async def test_schedule_consolidation_runs_async_version(self, tmp_path) -> None:
"""Test that scheduling runs the async version."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
session = MagicMock()
session.messages = [{"role": "user", "content": "msg"}]
session.key = "cli:scheduled"
sessions = MagicMock()
sessions.all = MagicMock(return_value=[session])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Mock the async version to track calls
consolidator.maybe_consolidate_by_tokens_async = AsyncMock() # type: ignore[method-assign]
# Schedule consolidation
await consolidator._schedule_consolidation(session)
await asyncio.sleep(0.1)
assert consolidator.maybe_consolidate_by_tokens_async.await_count >= 1
class TestBackgroundTaskCancellation:
"""Tests for background task cancellation and error handling."""
@pytest.mark.asyncio
async def test_background_task_handles_exceptions_gracefully(self, tmp_path) -> None:
"""Test that exceptions in the loop don't crash it."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
sessions = MagicMock()
sessions.all = MagicMock(return_value=[])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Mock maybe_consolidate_by_tokens_async to raise an exception
consolidator.maybe_consolidate_by_tokens_async = AsyncMock( # type: ignore[method-assign]
side_effect=Exception("Test exception")
)
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 0.1):
await consolidator.start_background_task()
await asyncio.sleep(0.5)
# Task should still be running despite exceptions
assert consolidator._background_task is not None
await consolidator.stop_background_task()
@pytest.mark.asyncio
async def test_stop_cancels_running_task(self, tmp_path) -> None:
"""Test that stop properly cancels a running task."""
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
sessions = MagicMock()
sessions.all = MagicMock(return_value=[])
consolidator = MemoryConsolidator(
workspace=tmp_path,
provider=provider,
model="test-model",
sessions=sessions,
context_window_tokens=200,
build_messages=lambda **kw: [],
get_tool_definitions=lambda: [],
)
# Start a task that will sleep for a while
with patch.object(consolidator, '_IDLE_CHECK_INTERVAL', 10): # Long interval
await consolidator.start_background_task()
# Task should be running
assert consolidator._background_task is not None
# Stop should cancel it
await consolidator.stop_background_task()
# Verify task was cancelled or completed
assert consolidator._background_task is None or \
consolidator._background_task.done()

View File

@@ -591,8 +591,8 @@ class TestNewCommandArchival:
assert loop.sessions.get_or_create("cli:test").messages == [] assert loop.sessions.get_or_create("cli:test").messages == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_close_mcp_drains_pending_archives(self, tmp_path: Path) -> None: async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
"""close_mcp waits for background archive tasks to complete.""" """close_mcp waits for background tasks to complete."""
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path) loop = self._make_loop(tmp_path)