From fafd8d4eb86c856c72d3dcabab59a013ed5a741a Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Thu, 26 Feb 2026 00:23:58 +0800 Subject: [PATCH] fix(agent): only suppress final reply when message tool sends to same target MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A refactoring in commit 132807a introduced a regression where the final response was silently discarded whenever the message tool was used, regardless of the target. This restored the original logic from PR #832 that only suppresses the final reply when the message tool sends to the same (channel, chat_id) as the original message. Changes: - message.py: Replace _sent_in_turn: bool with _turn_sends: list[tuple] to track actual send targets, add get_turn_sends() method - loop.py: Check if (msg.channel, msg.chat_id) is in sent_targets before suppressing final reply. Also move the "Response to" log after the suppress check to avoid misleading logs. - Add unit tests for the suppress logic This ensures: - Email sent via message tool → Feishu still gets confirmation - Message tool sends to same Feishu chat → No duplicate (suppressed) --- nanobot/agent/loop.py | 19 ++- nanobot/agent/tools/message.py | 10 +- tests/test_message_tool_suppress.py | 200 ++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+), 8 deletions(-) create mode 100644 tests/test_message_tool_suppress.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 8be8e51..2a998d4 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -407,16 +407,25 @@ class AgentLoop: if final_content is None: final_content = "I've completed processing but have no response to give." - preview = final_content[:120] + "..." if len(final_content) > 120 else final_content - logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) - self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + suppress_final_reply = False if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: - return None + if isinstance(message_tool, MessageTool): + sent_targets = set(message_tool.get_turn_sends()) + suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets + if suppress_final_reply: + logger.info( + "Skipping final auto-reply because message tool already sent to {}:{} in this turn", + msg.channel, + msg.chat_id, + ) + return None + + preview = final_content[:120] + "..." if len(final_content) > 120 else final_content + logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=final_content, metadata=msg.metadata or {}, diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 40e76e3..be359f3 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -20,7 +20,7 @@ class MessageTool(Tool): self._default_channel = default_channel self._default_chat_id = default_chat_id self._default_message_id = default_message_id - self._sent_in_turn: bool = False + self._turn_sends: list[tuple[str, str]] = [] def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Set the current message context.""" @@ -34,7 +34,11 @@ class MessageTool(Tool): def start_turn(self) -> None: """Reset per-turn send tracking.""" - self._sent_in_turn = False + self._turn_sends.clear() + + def get_turn_sends(self) -> list[tuple[str, str]]: + """Get (channel, chat_id) targets sent in the current turn.""" + return list(self._turn_sends) @property def name(self) -> str: @@ -101,7 +105,7 @@ class MessageTool(Tool): try: await self._send_callback(msg) - self._sent_in_turn = True + self._turn_sends.append((channel, chat_id)) media_info = f" with {len(media)} attachments" if media else "" return f"Message sent to {channel}:{chat_id}{media_info}" except Exception as e: diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py new file mode 100644 index 0000000..77436a0 --- /dev/null +++ b/tests/test_message_tool_suppress.py @@ -0,0 +1,200 @@ +"""Test message tool suppress logic for final replies.""" + +import asyncio +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.agent.tools.message import MessageTool +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMResponse, ToolCallRequest + + +class TestMessageToolSuppressLogic: + """Test that final reply is only suppressed when message tool sends to same target.""" + + @pytest.mark.asyncio + async def test_final_reply_suppressed_when_message_tool_sends_to_same_target( + self, tmp_path: Path + ) -> None: + """If message tool sends to the same (channel, chat_id), final reply is suppressed.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + ) + + # First call returns tool call, second call returns final response + tool_call = ToolCallRequest( + id="call1", + name="message", + arguments={"content": "Hello from tool", "channel": "feishu", "chat_id": "chat123"} + ) + + call_count = 0 + + def mock_chat(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return LLMResponse(content="", tool_calls=[tool_call]) + else: + return LLMResponse(content="Done", tool_calls=[]) + + loop.provider.chat = AsyncMock(side_effect=mock_chat) + loop.tools.get_definitions = MagicMock(return_value=[]) + + # Track outbound messages + sent_messages: list[OutboundMessage] = [] + + async def _capture_outbound(msg: OutboundMessage) -> None: + sent_messages.append(msg) + + # Set up message tool with callback + message_tool = loop.tools.get("message") + if isinstance(message_tool, MessageTool): + message_tool.set_send_callback(_capture_outbound) + + msg = InboundMessage( + channel="feishu", sender_id="user1", chat_id="chat123", content="Send a message" + ) + result = await loop._process_message(msg) + + # Message tool should have sent to the same target + assert len(sent_messages) == 1 + assert sent_messages[0].channel == "feishu" + assert sent_messages[0].chat_id == "chat123" + + # Final reply should be None (suppressed) + assert result is None + + @pytest.mark.asyncio + async def test_final_reply_sent_when_message_tool_sends_to_different_target( + self, tmp_path: Path + ) -> None: + """If message tool sends to a different target, final reply is still sent.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + ) + + # First call returns tool call to email, second call returns final response + tool_call = ToolCallRequest( + id="call1", + name="message", + arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"} + ) + + call_count = 0 + + def mock_chat(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return LLMResponse(content="", tool_calls=[tool_call]) + else: + return LLMResponse(content="I've sent the email.", tool_calls=[]) + + loop.provider.chat = AsyncMock(side_effect=mock_chat) + loop.tools.get_definitions = MagicMock(return_value=[]) + + # Track outbound messages + sent_messages: list[OutboundMessage] = [] + + async def _capture_outbound(msg: OutboundMessage) -> None: + sent_messages.append(msg) + + # Set up message tool with callback + message_tool = loop.tools.get("message") + if isinstance(message_tool, MessageTool): + message_tool.set_send_callback(_capture_outbound) + + msg = InboundMessage( + channel="feishu", sender_id="user1", chat_id="chat123", content="Send an email" + ) + result = await loop._process_message(msg) + + # Message tool should have sent to email + assert len(sent_messages) == 1 + assert sent_messages[0].channel == "email" + assert sent_messages[0].chat_id == "user@example.com" + + # Final reply should be sent to Feishu (not suppressed) + assert result is not None + assert result.channel == "feishu" + assert result.chat_id == "chat123" + assert "email" in result.content.lower() or "sent" in result.content.lower() + + @pytest.mark.asyncio + async def test_final_reply_sent_when_no_message_tool_used(self, tmp_path: Path) -> None: + """If no message tool is used, final reply is always sent.""" + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + ) + + # Mock provider to return a simple response without tool calls + loop.provider.chat = AsyncMock(return_value=LLMResponse( + content="Hello! How can I help you?", + tool_calls=[] + )) + loop.tools.get_definitions = MagicMock(return_value=[]) + + msg = InboundMessage( + channel="feishu", sender_id="user1", chat_id="chat123", content="Hi" + ) + result = await loop._process_message(msg) + + # Final reply should be sent + assert result is not None + assert result.channel == "feishu" + assert result.chat_id == "chat123" + assert "Hello" in result.content + + +class TestMessageToolTurnTracking: + """Test MessageTool's turn tracking functionality.""" + + def test_turn_sends_tracking(self) -> None: + """MessageTool correctly tracks sends per turn.""" + tool = MessageTool() + + # Initially empty + assert tool.get_turn_sends() == [] + + # Simulate sends + tool._turn_sends.append(("feishu", "chat1")) + tool._turn_sends.append(("email", "user@example.com")) + + sends = tool.get_turn_sends() + assert len(sends) == 2 + assert ("feishu", "chat1") in sends + assert ("email", "user@example.com") in sends + + def test_start_turn_clears_tracking(self) -> None: + """start_turn() clears the turn sends list.""" + tool = MessageTool() + tool._turn_sends.append(("feishu", "chat1")) + assert len(tool.get_turn_sends()) == 1 + + tool.start_turn() + assert tool.get_turn_sends() == [] + + def test_get_turn_sends_returns_copy(self) -> None: + """get_turn_sends() returns a copy, not the original list.""" + tool = MessageTool() + tool._turn_sends.append(("feishu", "chat1")) + + sends = tool.get_turn_sends() + sends.append(("email", "user@example.com")) # Modify the copy + + # Original should be unchanged + assert len(tool.get_turn_sends()) == 1