refactor(memory): switch consolidation to token-based context windows
Move consolidation policy into MemoryConsolidator, keep backward compatibility for legacy config, and compress history by token budget instead of message count.
This commit is contained in:
@@ -267,6 +267,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
|
||||
mock_agent_runtime["config"].agents.defaults.memory_window = 100
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
@@ -327,6 +337,29 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
assert seen["workspace"] == override
|
||||
assert config.workspace_path == override
|
||||
|
||||
|
||||
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.memory_window = 100
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
|
||||
88
tests/test_config_migration.py
Normal file
88
tests/test_config_migration.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 1234,
|
||||
"memoryWindow": 42,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.agents.defaults.max_tokens == 1234
|
||||
assert config.agents.defaults.context_window_tokens == 65_536
|
||||
assert config.agents.defaults.should_warn_deprecated_memory_window is True
|
||||
|
||||
|
||||
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 2222,
|
||||
"memoryWindow": 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
|
||||
assert defaults["maxTokens"] == 2222
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
|
||||
|
||||
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"maxTokens": 3333,
|
||||
"memoryWindow": 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
assert defaults["maxTokens"] == 3333
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
|
||||
|
||||
class TestConsolidationDeduplicationGuard:
|
||||
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||
class TestNewCommandArchival:
|
||||
"""Test /new archival behavior with the simplified consolidation flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||
@staticmethod
|
||||
def _make_loop(tmp_path: Path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=1,
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls
|
||||
consolidation_calls += 1
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await loop._process_message(msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 1, (
|
||||
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
active = 0
|
||||
max_active = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls, active, max_active
|
||||
consolidation_calls += 1
|
||||
active += 1
|
||||
max_active = max(max_active, active)
|
||||
await asyncio.sleep(0.05)
|
||||
active -= 1
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
await loop._process_message(new_msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 2, (
|
||||
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||
)
|
||||
assert max_active == 1, (
|
||||
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||
started.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
await started.wait()
|
||||
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
assert len(loop._consolidation_tasks) == 0, (
|
||||
"Task reference must be removed after completion"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new waits for in-flight consolidation and archives before clear."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = 0
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
started.set()
|
||||
await release.wait()
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||
return loop
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||
"""/new must keep session data if archive step reports failure."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
if archive_all:
|
||||
return False
|
||||
return True
|
||||
async def _failing_consolidate(_messages) -> bool:
|
||||
return False
|
||||
|
||||
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "failed" in response.content.lower()
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == before_count, (
|
||||
"Session must remain intact when /new archival fails"
|
||||
)
|
||||
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new should archive only messages not yet consolidated by prior task."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
session.last_consolidated = len(session.messages) - 3
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _fake_consolidate(messages) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
|
||||
started.set()
|
||||
await release.wait()
|
||||
sess.last_consolidated = len(sess.messages) - 3
|
||||
archived_count = len(messages)
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done()
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count == 3, (
|
||||
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||
)
|
||||
assert archived_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||
"""/new clears session and returns confirmation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _ok_consolidate(_messages) -> bool:
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
190
tests/test_loop_consolidation_tokens.py
Normal file
190
tests/test_loop_consolidation_tokens.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
import nanobot.agent.memory as memory_module
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=context_window_tokens,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
|
||||
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
|
||||
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
|
||||
assert session.last_consolidated == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
|
||||
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (300, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
|
||||
"""Once triggered, consolidation should continue until it drops below half threshold."""
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
|
||||
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
|
||||
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
|
||||
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return (500, "test")
|
||||
if call_count[0] == 2:
|
||||
return (150, "test")
|
||||
return (80, "test")
|
||||
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.memory_consolidator.consolidate_messages.await_count == 2
|
||||
assert session.last_consolidated == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
|
||||
"""Verify preflight consolidation runs before the LLM call in process_direct."""
|
||||
order: list[str] = []
|
||||
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
|
||||
async def track_consolidate(messages):
|
||||
order.append("consolidate")
|
||||
return True
|
||||
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||
|
||||
async def track_llm(*args, **kwargs):
|
||||
order.append("llm")
|
||||
return LLMResponse(content="ok", tool_calls=[])
|
||||
loop.provider.chat_with_retry = track_llm
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = [
|
||||
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
|
||||
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
|
||||
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
|
||||
]
|
||||
loop.sessions.save(session)
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
|
||||
|
||||
call_count = [0]
|
||||
def mock_estimate(_session):
|
||||
call_count[0] += 1
|
||||
return (1000 if call_count[0] <= 1 else 80, "test")
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
|
||||
await loop.process_direct("hello", session_key="cli:test")
|
||||
|
||||
assert "consolidate" in order
|
||||
assert "llm" in order
|
||||
assert order.index("consolidate") < order.index("llm")
|
||||
@@ -7,7 +7,7 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -15,15 +15,12 @@ from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_session(message_count: int = 30, memory_window: int = 50):
|
||||
"""Create a mock session with messages."""
|
||||
session = MagicMock()
|
||||
session.messages = [
|
||||
def _make_messages(message_count: int = 30):
|
||||
"""Create a list of mock messages."""
|
||||
return [
|
||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||
for i in range(message_count)
|
||||
]
|
||||
session.last_consolidated = 0
|
||||
return session
|
||||
|
||||
|
||||
def _make_tool_response(history_entry, memory_update):
|
||||
@@ -74,9 +71,9 @@ class TestMemoryConsolidationTypeHandling:
|
||||
)
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
@@ -95,9 +92,9 @@ class TestMemoryConsolidationTypeHandling:
|
||||
)
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
@@ -131,9 +128,9 @@ class TestMemoryConsolidationTypeHandling:
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
@@ -147,22 +144,22 @@ class TestMemoryConsolidationTypeHandling:
|
||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||
)
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when messages < keep_count."""
|
||||
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when the selected chunk is empty."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=10)
|
||||
messages: list[dict] = []
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
@@ -189,9 +186,9 @@ class TestMemoryConsolidationTypeHandling:
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
@@ -215,9 +212,9 @@ class TestMemoryConsolidationTypeHandling:
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -239,9 +236,9 @@ class TestMemoryConsolidationTypeHandling:
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
provider.chat_with_retry = provider.chat
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -255,7 +252,7 @@ class TestMemoryConsolidationTypeHandling:
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
),
|
||||
])
|
||||
session = _make_session(message_count=60)
|
||||
messages = _make_messages(message_count=60)
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
@@ -263,7 +260,7 @@ class TestMemoryConsolidationTypeHandling:
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
result = await store.consolidate(messages, provider, "test-model")
|
||||
|
||||
assert result is True
|
||||
assert provider.calls == 2
|
||||
|
||||
@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
class TestMessageToolSuppressLogic:
|
||||
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user