From fafd8d4eb86c856c72d3dcabab59a013ed5a741a Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Thu, 26 Feb 2026 00:23:58 +0800 Subject: [PATCH 1/3] fix(agent): only suppress final reply when message tool sends to same target MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A refactoring in commit 132807a introduced a regression where the final response was silently discarded whenever the message tool was used, regardless of the target. This restored the original logic from PR #832 that only suppresses the final reply when the message tool sends to the same (channel, chat_id) as the original message. Changes: - message.py: Replace _sent_in_turn: bool with _turn_sends: list[tuple] to track actual send targets, add get_turn_sends() method - loop.py: Check if (msg.channel, msg.chat_id) is in sent_targets before suppressing final reply. Also move the "Response to" log after the suppress check to avoid misleading logs. - Add unit tests for the suppress logic This ensures: - Email sent via message tool → Feishu still gets confirmation - Message tool sends to same Feishu chat → No duplicate (suppressed) --- nanobot/agent/loop.py | 19 ++- nanobot/agent/tools/message.py | 10 +- tests/test_message_tool_suppress.py | 200 ++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+), 8 deletions(-) create mode 100644 tests/test_message_tool_suppress.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 8be8e51..2a998d4 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -407,16 +407,25 @@ class AgentLoop: if final_content is None: final_content = "I've completed processing but have no response to give." - preview = final_content[:120] + "..." if len(final_content) > 120 else final_content - logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) - self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + suppress_final_reply = False if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: - return None + if isinstance(message_tool, MessageTool): + sent_targets = set(message_tool.get_turn_sends()) + suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets + if suppress_final_reply: + logger.info( + "Skipping final auto-reply because message tool already sent to {}:{} in this turn", + msg.channel, + msg.chat_id, + ) + return None + + preview = final_content[:120] + "..." if len(final_content) > 120 else final_content + logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=final_content, metadata=msg.metadata or {}, diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 40e76e3..be359f3 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -20,7 +20,7 @@ class MessageTool(Tool): self._default_channel = default_channel self._default_chat_id = default_chat_id self._default_message_id = default_message_id - self._sent_in_turn: bool = False + self._turn_sends: list[tuple[str, str]] = [] def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Set the current message context.""" @@ -34,7 +34,11 @@ class MessageTool(Tool): def start_turn(self) -> None: """Reset per-turn send tracking.""" - self._sent_in_turn = False + self._turn_sends.clear() + + def get_turn_sends(self) -> list[tuple[str, str]]: + """Get (channel, chat_id) targets sent in the current turn.""" + return list(self._turn_sends) @property def name(self) -> str: @@ -101,7 +105,7 @@ class MessageTool(Tool): try: await self._send_callback(msg) - self._sent_in_turn = True + self._turn_sends.append((channel, chat_id)) media_info = f" with {len(media)} attachments" if media else "" return f"Message sent to {channel}:{chat_id}{media_info}" except Exception as e: diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py new file mode 100644 index 0000000..77436a0 --- /dev/null +++ b/tests/test_message_tool_suppress.py @@ -0,0 +1,200 @@ +"""Test message tool suppress logic for final replies.""" + +import asyncio +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.agent.tools.message import MessageTool +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMResponse, ToolCallRequest + + +class TestMessageToolSuppressLogic: + """Test that final reply is only suppressed when message tool sends to same target.""" + + @pytest.mark.asyncio + async def test_final_reply_suppressed_when_message_tool_sends_to_same_target( + self, tmp_path: Path + ) -> None: + """If message tool sends to the same (channel, chat_id), final reply is suppressed.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + ) + + # First call returns tool call, second call returns final response + tool_call = ToolCallRequest( + id="call1", + name="message", + arguments={"content": "Hello from tool", "channel": "feishu", "chat_id": "chat123"} + ) + + call_count = 0 + + def mock_chat(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return LLMResponse(content="", tool_calls=[tool_call]) + else: + return LLMResponse(content="Done", tool_calls=[]) + + loop.provider.chat = AsyncMock(side_effect=mock_chat) + loop.tools.get_definitions = MagicMock(return_value=[]) + + # Track outbound messages + sent_messages: list[OutboundMessage] = [] + + async def _capture_outbound(msg: OutboundMessage) -> None: + sent_messages.append(msg) + + # Set up message tool with callback + message_tool = loop.tools.get("message") + if isinstance(message_tool, MessageTool): + message_tool.set_send_callback(_capture_outbound) + + msg = InboundMessage( + channel="feishu", sender_id="user1", chat_id="chat123", content="Send a message" + ) + result = await loop._process_message(msg) + + # Message tool should have sent to the same target + assert len(sent_messages) == 1 + assert sent_messages[0].channel == "feishu" + assert sent_messages[0].chat_id == "chat123" + + # Final reply should be None (suppressed) + assert result is None + + @pytest.mark.asyncio + async def test_final_reply_sent_when_message_tool_sends_to_different_target( + self, tmp_path: Path + ) -> None: + """If message tool sends to a different target, final reply is still sent.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + ) + + # First call returns tool call to email, second call returns final response + tool_call = ToolCallRequest( + id="call1", + name="message", + arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"} + ) + + call_count = 0 + + def mock_chat(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return LLMResponse(content="", tool_calls=[tool_call]) + else: + return LLMResponse(content="I've sent the email.", tool_calls=[]) + + loop.provider.chat = AsyncMock(side_effect=mock_chat) + loop.tools.get_definitions = MagicMock(return_value=[]) + + # Track outbound messages + sent_messages: list[OutboundMessage] = [] + + async def _capture_outbound(msg: OutboundMessage) -> None: + sent_messages.append(msg) + + # Set up message tool with callback + message_tool = loop.tools.get("message") + if isinstance(message_tool, MessageTool): + message_tool.set_send_callback(_capture_outbound) + + msg = InboundMessage( + channel="feishu", sender_id="user1", chat_id="chat123", content="Send an email" + ) + result = await loop._process_message(msg) + + # Message tool should have sent to email + assert len(sent_messages) == 1 + assert sent_messages[0].channel == "email" + assert sent_messages[0].chat_id == "user@example.com" + + # Final reply should be sent to Feishu (not suppressed) + assert result is not None + assert result.channel == "feishu" + assert result.chat_id == "chat123" + assert "email" in result.content.lower() or "sent" in result.content.lower() + + @pytest.mark.asyncio + async def test_final_reply_sent_when_no_message_tool_used(self, tmp_path: Path) -> None: + """If no message tool is used, final reply is always sent.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + ) + + # Mock provider to return a simple response without tool calls + loop.provider.chat = AsyncMock(return_value=LLMResponse( + content="Hello! How can I help you?", + tool_calls=[] + )) + loop.tools.get_definitions = MagicMock(return_value=[]) + + msg = InboundMessage( + channel="feishu", sender_id="user1", chat_id="chat123", content="Hi" + ) + result = await loop._process_message(msg) + + # Final reply should be sent + assert result is not None + assert result.channel == "feishu" + assert result.chat_id == "chat123" + assert "Hello" in result.content + + +class TestMessageToolTurnTracking: + """Test MessageTool's turn tracking functionality.""" + + def test_turn_sends_tracking(self) -> None: + """MessageTool correctly tracks sends per turn.""" + tool = MessageTool() + + # Initially empty + assert tool.get_turn_sends() == [] + + # Simulate sends + tool._turn_sends.append(("feishu", "chat1")) + tool._turn_sends.append(("email", "user@example.com")) + + sends = tool.get_turn_sends() + assert len(sends) == 2 + assert ("feishu", "chat1") in sends + assert ("email", "user@example.com") in sends + + def test_start_turn_clears_tracking(self) -> None: + """start_turn() clears the turn sends list.""" + tool = MessageTool() + tool._turn_sends.append(("feishu", "chat1")) + assert len(tool.get_turn_sends()) == 1 + + tool.start_turn() + assert tool.get_turn_sends() == [] + + def test_get_turn_sends_returns_copy(self) -> None: + """get_turn_sends() returns a copy, not the original list.""" + tool = MessageTool() + tool._turn_sends.append(("feishu", "chat1")) + + sends = tool.get_turn_sends() + sends.append(("email", "user@example.com")) # Modify the copy + + # Original should be unchanged + assert len(tool.get_turn_sends()) == 1 From 29e6709e261632b0494760f053711bf677ab6b22 Mon Sep 17 00:00:00 2001 From: Re-bin Date: Fri, 27 Feb 2026 02:27:18 +0000 Subject: [PATCH 2/3] =?UTF-8?q?refactor:=20simplify=20message=20tool=20sup?= =?UTF-8?q?press=20=E2=80=94=20bool=20check=20instead=20of=20target=20trac?= =?UTF-8?q?king?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nanobot/agent/loop.py | 13 +- nanobot/agent/tools/message.py | 11 +- tests/test_message_tool_suppress.py | 203 ++++++++-------------------- 3 files changed, 58 insertions(+), 169 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index c6e565b..6155f99 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -444,18 +444,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - suppress_final_reply = False - if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool): - sent_targets = set(message_tool.get_turn_sends()) - suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets - - if suppress_final_reply: - logger.info( - "Skipping final auto-reply because message tool already sent to {}:{} in this turn", - msg.channel, - msg.chat_id, - ) + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None preview = final_content[:120] + "..." if len(final_content) > 120 else final_content diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index be359f3..35e519a 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -20,7 +20,7 @@ class MessageTool(Tool): self._default_channel = default_channel self._default_chat_id = default_chat_id self._default_message_id = default_message_id - self._turn_sends: list[tuple[str, str]] = [] + self._sent_in_turn: bool = False def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Set the current message context.""" @@ -34,11 +34,7 @@ class MessageTool(Tool): def start_turn(self) -> None: """Reset per-turn send tracking.""" - self._turn_sends.clear() - - def get_turn_sends(self) -> list[tuple[str, str]]: - """Get (channel, chat_id) targets sent in the current turn.""" - return list(self._turn_sends) + self._sent_in_turn = False @property def name(self) -> str: @@ -105,7 +101,8 @@ class MessageTool(Tool): try: await self._send_callback(msg) - self._turn_sends.append((channel, chat_id)) + if channel == self._default_channel and chat_id == self._default_chat_id: + self._sent_in_turn = True media_info = f" with {len(media)} attachments" if media else "" return f"Message sent to {channel}:{chat_id}{media_info}" except Exception as e: diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py index 77436a0..26b8a16 100644 --- a/tests/test_message_tool_suppress.py +++ b/tests/test_message_tool_suppress.py @@ -1,6 +1,5 @@ """Test message tool suppress logic for final replies.""" -import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -13,188 +12,92 @@ from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse, ToolCallRequest +def _make_loop(tmp_path: Path) -> AgentLoop: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) + + class TestMessageToolSuppressLogic: - """Test that final reply is only suppressed when message tool sends to same target.""" + """Final reply suppressed only when message tool sends to the same target.""" @pytest.mark.asyncio - async def test_final_reply_suppressed_when_message_tool_sends_to_same_target( - self, tmp_path: Path - ) -> None: - """If message tool sends to the same (channel, chat_id), final reply is suppressed.""" - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - # First call returns tool call, second call returns final response + async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) tool_call = ToolCallRequest( - id="call1", - name="message", - arguments={"content": "Hello from tool", "channel": "feishu", "chat_id": "chat123"} + id="call1", name="message", + arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"}, ) - - call_count = 0 - - def mock_chat(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 1: - return LLMResponse(content="", tool_calls=[tool_call]) - else: - return LLMResponse(content="Done", tool_calls=[]) - - loop.provider.chat = AsyncMock(side_effect=mock_chat) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) - # Track outbound messages - sent_messages: list[OutboundMessage] = [] + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) - async def _capture_outbound(msg: OutboundMessage) -> None: - sent_messages.append(msg) - - # Set up message tool with callback - message_tool = loop.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_send_callback(_capture_outbound) - - msg = InboundMessage( - channel="feishu", sender_id="user1", chat_id="chat123", content="Send a message" - ) + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send") result = await loop._process_message(msg) - # Message tool should have sent to the same target - assert len(sent_messages) == 1 - assert sent_messages[0].channel == "feishu" - assert sent_messages[0].chat_id == "chat123" - - # Final reply should be None (suppressed) - assert result is None + assert len(sent) == 1 + assert result is None # suppressed @pytest.mark.asyncio - async def test_final_reply_sent_when_message_tool_sends_to_different_target( - self, tmp_path: Path - ) -> None: - """If message tool sends to a different target, final reply is still sent.""" - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - # First call returns tool call to email, second call returns final response + async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) tool_call = ToolCallRequest( - id="call1", - name="message", - arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"} + id="call1", name="message", + arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"}, ) - - call_count = 0 - - def mock_chat(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 1: - return LLMResponse(content="", tool_calls=[tool_call]) - else: - return LLMResponse(content="I've sent the email.", tool_calls=[]) - - loop.provider.chat = AsyncMock(side_effect=mock_chat) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="I've sent the email.", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) - # Track outbound messages - sent_messages: list[OutboundMessage] = [] + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) - async def _capture_outbound(msg: OutboundMessage) -> None: - sent_messages.append(msg) - - # Set up message tool with callback - message_tool = loop.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_send_callback(_capture_outbound) - - msg = InboundMessage( - channel="feishu", sender_id="user1", chat_id="chat123", content="Send an email" - ) + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email") result = await loop._process_message(msg) - # Message tool should have sent to email - assert len(sent_messages) == 1 - assert sent_messages[0].channel == "email" - assert sent_messages[0].chat_id == "user@example.com" - - # Final reply should be sent to Feishu (not suppressed) - assert result is not None + assert len(sent) == 1 + assert sent[0].channel == "email" + assert result is not None # not suppressed assert result.channel == "feishu" - assert result.chat_id == "chat123" - assert "email" in result.content.lower() or "sent" in result.content.lower() @pytest.mark.asyncio - async def test_final_reply_sent_when_no_message_tool_used(self, tmp_path: Path) -> None: - """If no message tool is used, final reply is always sent.""" - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - # Mock provider to return a simple response without tool calls - loop.provider.chat = AsyncMock(return_value=LLMResponse( - content="Hello! How can I help you?", - tool_calls=[] - )) + async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) loop.tools.get_definitions = MagicMock(return_value=[]) - msg = InboundMessage( - channel="feishu", sender_id="user1", chat_id="chat123", content="Hi" - ) + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") result = await loop._process_message(msg) - # Final reply should be sent assert result is not None - assert result.channel == "feishu" - assert result.chat_id == "chat123" assert "Hello" in result.content class TestMessageToolTurnTracking: - """Test MessageTool's turn tracking functionality.""" - def test_turn_sends_tracking(self) -> None: - """MessageTool correctly tracks sends per turn.""" + def test_sent_in_turn_tracks_same_target(self) -> None: tool = MessageTool() + tool.set_context("feishu", "chat1") + assert not tool._sent_in_turn + tool._sent_in_turn = True + assert tool._sent_in_turn - # Initially empty - assert tool.get_turn_sends() == [] - - # Simulate sends - tool._turn_sends.append(("feishu", "chat1")) - tool._turn_sends.append(("email", "user@example.com")) - - sends = tool.get_turn_sends() - assert len(sends) == 2 - assert ("feishu", "chat1") in sends - assert ("email", "user@example.com") in sends - - def test_start_turn_clears_tracking(self) -> None: - """start_turn() clears the turn sends list.""" + def test_start_turn_resets(self) -> None: tool = MessageTool() - tool._turn_sends.append(("feishu", "chat1")) - assert len(tool.get_turn_sends()) == 1 - + tool._sent_in_turn = True tool.start_turn() - assert tool.get_turn_sends() == [] - - def test_get_turn_sends_returns_copy(self) -> None: - """get_turn_sends() returns a copy, not the original list.""" - tool = MessageTool() - tool._turn_sends.append(("feishu", "chat1")) - - sends = tool.get_turn_sends() - sends.append(("email", "user@example.com")) # Modify the copy - - # Original should be unchanged - assert len(tool.get_turn_sends()) == 1 + assert not tool._sent_in_turn From ec8dee802c3727e6293e1d0bba9c6d0bb171b718 Mon Sep 17 00:00:00 2001 From: Re-bin Date: Fri, 27 Feb 2026 02:39:38 +0000 Subject: [PATCH 3/3] refactor: simplify message tool suppress and inline consolidation locks --- README.md | 2 +- nanobot/agent/loop.py | 41 ++++++++++---------------------- tests/test_consolidate_offset.py | 2 +- 3 files changed, 14 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index be360dc..71922fb 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ ⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines. -📏 Real-time line count: **3,966 lines** (run `bash core_agent_lines.sh` to verify anytime) +📏 Real-time line count: **3,932 lines** (run `bash core_agent_lines.sh` to verify anytime) ## 📢 News diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 6155f99..e3a9d67 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -43,6 +43,8 @@ class AgentLoop: 5. Sends responses back """ + _TOOL_RESULT_MAX_CHARS = 500 + def __init__( self, bus: MessageBus, @@ -145,17 +147,10 @@ class AgentLoop: def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Update context for all tools that need routing info.""" - if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool): - message_tool.set_context(channel, chat_id, message_id) - - if spawn_tool := self.tools.get("spawn"): - if isinstance(spawn_tool, SpawnTool): - spawn_tool.set_context(channel, chat_id) - - if cron_tool := self.tools.get("cron"): - if isinstance(cron_tool, CronTool): - cron_tool.set_context(channel, chat_id) + for name in ("message", "spawn", "cron"): + if tool := self.tools.get(name): + if hasattr(tool, "set_context"): + tool.set_context(channel, chat_id, *([message_id] if name == "message" else [])) @staticmethod def _strip_think(text: str | None) -> str | None: @@ -315,18 +310,6 @@ class AgentLoop: self._running = False logger.info("Agent loop stopping") - def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock: - lock = self._consolidation_locks.get(session_key) - if lock is None: - lock = asyncio.Lock() - self._consolidation_locks[session_key] = lock - return lock - - def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None: - """Drop lock entry if no longer in use.""" - if not lock.locked(): - self._consolidation_locks.pop(session_key, None) - async def _process_message( self, msg: InboundMessage, @@ -362,7 +345,7 @@ class AgentLoop: # Slash commands cmd = msg.content.strip().lower() if cmd == "/new": - lock = self._get_consolidation_lock(session.key) + lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) self._consolidating.add(session.key) try: async with lock: @@ -383,7 +366,8 @@ class AgentLoop: ) finally: self._consolidating.discard(session.key) - self._prune_consolidation_lock(session.key, lock) + if not lock.locked(): + self._consolidation_locks.pop(session.key, None) session.clear() self.sessions.save(session) @@ -397,7 +381,7 @@ class AgentLoop: unconsolidated = len(session.messages) - session.last_consolidated if (unconsolidated >= self.memory_window and session.key not in self._consolidating): self._consolidating.add(session.key) - lock = self._get_consolidation_lock(session.key) + lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) async def _consolidate_and_unlock(): try: @@ -405,7 +389,8 @@ class AgentLoop: await self._consolidate_memory(session) finally: self._consolidating.discard(session.key) - self._prune_consolidation_lock(session.key, lock) + if not lock.locked(): + self._consolidation_locks.pop(session.key, None) _task = asyncio.current_task() if _task is not None: self._consolidation_tasks.discard(_task) @@ -454,8 +439,6 @@ class AgentLoop: metadata=msg.metadata or {}, ) - _TOOL_RESULT_MAX_CHARS = 500 - def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: """Save new-turn messages into session, truncating large tool results.""" from datetime import datetime diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index 323519e..6755124 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -812,7 +812,7 @@ class TestConsolidationDeduplicationGuard: loop.sessions.save(session) # Ensure lock exists before /new. - _ = loop._get_consolidation_lock(session.key) + loop._consolidation_locks.setdefault(session.key, asyncio.Lock()) assert session.key in loop._consolidation_locks async def _ok_consolidate(sess, archive_all: bool = False) -> bool: