import asyncio from unittest.mock import AsyncMock, MagicMock import pytest import nanobot.agent.memory as memory_module from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop: provider = MagicMock() provider.get_default_model.return_value = "test-model" provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) loop = AgentLoop( bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model", context_window_tokens=context_window_tokens, ) loop.tools.get_definitions = MagicMock(return_value=[]) return loop @pytest.mark.asyncio async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None: loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200) loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] await loop.process_direct("hello", session_key="cli:test") loop.memory_consolidator.consolidate_messages.assert_not_awaited() @pytest.mark.asyncio async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None: loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, ] loop.sessions.save(session) monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500) await loop.process_direct("hello", session_key="cli:test") assert loop.memory_consolidator.consolidate_messages.await_count >= 1 @pytest.mark.asyncio async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None: loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, ] loop.sessions.save(session) token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120} monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]]) await loop.memory_consolidator.maybe_consolidate_by_tokens(session) archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0] assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"] assert session.last_consolidated == 4 @pytest.mark.asyncio async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None: """Verify maybe_consolidate_by_tokens keeps looping until under threshold.""" loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"}, {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"}, ] loop.sessions.save(session) call_count = [0] def mock_estimate(_session): call_count[0] += 1 if call_count[0] == 1: return (500, "test") if call_count[0] == 2: return (300, "test") return (80, "test") loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) await loop.memory_consolidator.maybe_consolidate_by_tokens(session) assert loop.memory_consolidator.consolidate_messages.await_count == 2 assert session.last_consolidated == 6 @pytest.mark.asyncio async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None: """Once triggered, consolidation should continue until it drops below half threshold.""" loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"}, {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"}, ] loop.sessions.save(session) call_count = [0] def mock_estimate(_session): call_count[0] += 1 if call_count[0] == 1: return (500, "test") if call_count[0] == 2: return (150, "test") return (80, "test") loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) await loop.memory_consolidator.maybe_consolidate_by_tokens(session) assert loop.memory_consolidator.consolidate_messages.await_count == 2 assert session.last_consolidated == 6 @pytest.mark.asyncio async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None: """Verify preflight consolidation runs before the LLM call in process_direct.""" order: list[str] = [] loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) async def track_consolidate(messages): order.append("consolidate") return True loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign] async def track_llm(*args, **kwargs): order.append("llm") return LLMResponse(content="ok", tool_calls=[]) loop.provider.chat_with_retry = track_llm session = loop.sessions.get_or_create("cli:test") session.messages = [ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, ] loop.sessions.save(session) monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500) call_count = [0] def mock_estimate(_session): call_count[0] += 1 return (1000 if call_count[0] <= 1 else 80, "test") loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] await loop.process_direct("hello", session_key="cli:test") assert "consolidate" in order assert "llm" in order assert order.index("consolidate") < order.index("llm") @pytest.mark.asyncio async def test_slow_preflight_consolidation_continues_in_background(tmp_path, monkeypatch) -> None: order: list[str] = [] loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) monkeypatch.setattr(loop, "_PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS", 0.01) release = asyncio.Event() async def slow_consolidation(_session): order.append("consolidate-start") await release.wait() order.append("consolidate-end") async def track_llm(*args, **kwargs): order.append("llm") return LLMResponse(content="ok", tool_calls=[]) loop.memory_consolidator.maybe_consolidate_by_tokens = slow_consolidation # type: ignore[method-assign] loop.provider.chat_with_retry = track_llm await loop.process_direct("hello", session_key="cli:test") assert "consolidate-start" in order assert "llm" in order assert "consolidate-end" not in order release.set() await loop.close_mcp() assert "consolidate-end" in order