fix(loop): lock /new snapshot and prune stale consolidation locks
This commit is contained in:
@@ -304,6 +304,14 @@ class AgentLoop:
|
|||||||
self._consolidation_locks[session_key] = lock
|
self._consolidation_locks[session_key] = lock
|
||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
|
||||||
|
"""Drop unused per-session lock entries to avoid unbounded growth."""
|
||||||
|
waiters = getattr(lock, "_waiters", None)
|
||||||
|
has_waiters = bool(waiters)
|
||||||
|
if lock.locked() or has_waiters:
|
||||||
|
return
|
||||||
|
self._consolidation_locks.pop(session_key, None)
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
@@ -334,11 +342,11 @@ class AgentLoop:
|
|||||||
# Handle slash commands
|
# Handle slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
messages_to_archive = session.messages.copy()
|
|
||||||
lock = self._get_consolidation_lock(session.key)
|
lock = self._get_consolidation_lock(session.key)
|
||||||
|
messages_to_archive: list[dict[str, Any]] = []
|
||||||
try:
|
try:
|
||||||
async with lock:
|
async with lock:
|
||||||
|
messages_to_archive = session.messages[session.last_consolidated :].copy()
|
||||||
temp_session = Session(key=session.key)
|
temp_session = Session(key=session.key)
|
||||||
temp_session.messages = messages_to_archive
|
temp_session.messages = messages_to_archive
|
||||||
archived = await self._consolidate_memory(temp_session, archive_all=True)
|
archived = await self._consolidate_memory(temp_session, archive_all=True)
|
||||||
@@ -360,6 +368,7 @@ class AgentLoop:
|
|||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self.sessions.invalidate(session.key)
|
self.sessions.invalidate(session.key)
|
||||||
|
self._prune_consolidation_lock(session.key, lock)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
@@ -382,6 +391,7 @@ class AgentLoop:
|
|||||||
await self._consolidate_memory(session)
|
await self._consolidate_memory(session)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
|
self._prune_consolidation_lock(session.key, lock)
|
||||||
_task = asyncio.current_task()
|
_task = asyncio.current_task()
|
||||||
if _task is not None:
|
if _task is not None:
|
||||||
self._consolidation_tasks.discard(_task)
|
self._consolidation_tasks.discard(_task)
|
||||||
|
|||||||
@@ -723,3 +723,106 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
assert len(session_after.messages) == before_count, (
|
assert len(session_after.messages) == before_count, (
|
||||||
"Session must remain intact when /new archival fails"
|
"Session must remain intact when /new archival fails"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new should archive only messages not yet consolidated by prior task."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(15):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
archived_count = -1
|
||||||
|
|
||||||
|
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
nonlocal archived_count
|
||||||
|
if archive_all:
|
||||||
|
archived_count = len(sess.messages)
|
||||||
|
return True
|
||||||
|
|
||||||
|
started.set()
|
||||||
|
await release.wait()
|
||||||
|
sess.last_consolidated = len(sess.messages) - 3
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||||
|
await loop._process_message(msg)
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
assert not pending_new.done()
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
response = await pending_new
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert archived_count == 3, (
|
||||||
|
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""/new should remove lock entry for fully invalidated session key."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(3):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
# Ensure lock exists before /new.
|
||||||
|
_ = loop._get_consolidation_lock(session.key)
|
||||||
|
assert session.key in loop._consolidation_locks
|
||||||
|
|
||||||
|
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
response = await loop._process_message(new_msg)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert session.key not in loop._consolidation_locks
|
||||||
|
|||||||
Reference in New Issue
Block a user