Merge PR #1326: use WeakValueDictionary for consolidation locks

refactor: use WeakValueDictionary for consolidation locks
This commit is contained in:
Xubin Ren
2026-02-28 16:31:16 +08:00
committed by GitHub
2 changed files with 5 additions and 15 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import json
import re
import weakref
from contextlib import AsyncExitStack
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -100,7 +101,7 @@ class AgentLoop:
self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: dict[str, asyncio.Lock] = {}
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._processing_lock = asyncio.Lock()
self._register_default_tools()
@@ -373,8 +374,6 @@ class AgentLoop:
)
finally:
self._consolidating.discard(session.key)
if not lock.locked():
self._consolidation_locks.pop(session.key, None)
session.clear()
self.sessions.save(session)
@@ -396,8 +395,6 @@ class AgentLoop:
await self._consolidate_memory(session)
finally:
self._consolidating.discard(session.key)
if not lock.locked():
self._consolidation_locks.pop(session.key, None)
_task = asyncio.current_task()
if _task is not None:
self._consolidation_tasks.discard(_task)

View File

@@ -786,10 +786,8 @@ class TestConsolidationDeduplicationGuard:
)
@pytest.mark.asyncio
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
self, tmp_path: Path
) -> None:
"""/new should remove lock entry for fully invalidated session key."""
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
"""/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
@@ -801,7 +799,6 @@ class TestConsolidationDeduplicationGuard:
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
@@ -811,10 +808,6 @@ class TestConsolidationDeduplicationGuard:
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
# Ensure lock exists before /new.
loop._consolidation_locks.setdefault(session.key, asyncio.Lock())
assert session.key in loop._consolidation_locks
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
return True
@@ -825,4 +818,4 @@ class TestConsolidationDeduplicationGuard:
assert response is not None
assert "new session started" in response.content.lower()
assert session.key not in loop._consolidation_locks
assert loop.sessions.get_or_create("cli:test").messages == []