diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index c6e565b..6155f99 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -444,18 +444,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - suppress_final_reply = False - if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool): - sent_targets = set(message_tool.get_turn_sends()) - suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets - - if suppress_final_reply: - logger.info( - "Skipping final auto-reply because message tool already sent to {}:{} in this turn", - msg.channel, - msg.chat_id, - ) + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None preview = final_content[:120] + "..." if len(final_content) > 120 else final_content diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index be359f3..35e519a 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -20,7 +20,7 @@ class MessageTool(Tool): self._default_channel = default_channel self._default_chat_id = default_chat_id self._default_message_id = default_message_id - self._turn_sends: list[tuple[str, str]] = [] + self._sent_in_turn: bool = False def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Set the current message context.""" @@ -34,11 +34,7 @@ class MessageTool(Tool): def start_turn(self) -> None: """Reset per-turn send tracking.""" - self._turn_sends.clear() - - def get_turn_sends(self) -> list[tuple[str, str]]: - """Get (channel, chat_id) targets sent in the current turn.""" - return list(self._turn_sends) + self._sent_in_turn = False @property def name(self) -> str: @@ -105,7 +101,8 @@ class MessageTool(Tool): try: await self._send_callback(msg) - self._turn_sends.append((channel, chat_id)) + if channel == self._default_channel and chat_id == self._default_chat_id: + self._sent_in_turn = True media_info = f" with {len(media)} attachments" if media else "" return f"Message sent to {channel}:{chat_id}{media_info}" except Exception as e: diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py index 77436a0..26b8a16 100644 --- a/tests/test_message_tool_suppress.py +++ b/tests/test_message_tool_suppress.py @@ -1,6 +1,5 @@ """Test message tool suppress logic for final replies.""" -import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -13,188 +12,92 @@ from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse, ToolCallRequest +def _make_loop(tmp_path: Path) -> AgentLoop: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) + + class TestMessageToolSuppressLogic: - """Test that final reply is only suppressed when message tool sends to same target.""" + """Final reply suppressed only when message tool sends to the same target.""" @pytest.mark.asyncio - async def test_final_reply_suppressed_when_message_tool_sends_to_same_target( - self, tmp_path: Path - ) -> None: - """If message tool sends to the same (channel, chat_id), final reply is suppressed.""" - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - # First call returns tool call, second call returns final response + async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) tool_call = ToolCallRequest( - id="call1", - name="message", - arguments={"content": "Hello from tool", "channel": "feishu", "chat_id": "chat123"} + id="call1", name="message", + arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"}, ) - - call_count = 0 - - def mock_chat(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 1: - return LLMResponse(content="", tool_calls=[tool_call]) - else: - return LLMResponse(content="Done", tool_calls=[]) - - loop.provider.chat = AsyncMock(side_effect=mock_chat) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) - # Track outbound messages - sent_messages: list[OutboundMessage] = [] + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) - async def _capture_outbound(msg: OutboundMessage) -> None: - sent_messages.append(msg) - - # Set up message tool with callback - message_tool = loop.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_send_callback(_capture_outbound) - - msg = InboundMessage( - channel="feishu", sender_id="user1", chat_id="chat123", content="Send a message" - ) + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send") result = await loop._process_message(msg) - # Message tool should have sent to the same target - assert len(sent_messages) == 1 - assert sent_messages[0].channel == "feishu" - assert sent_messages[0].chat_id == "chat123" - - # Final reply should be None (suppressed) - assert result is None + assert len(sent) == 1 + assert result is None # suppressed @pytest.mark.asyncio - async def test_final_reply_sent_when_message_tool_sends_to_different_target( - self, tmp_path: Path - ) -> None: - """If message tool sends to a different target, final reply is still sent.""" - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - # First call returns tool call to email, second call returns final response + async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) tool_call = ToolCallRequest( - id="call1", - name="message", - arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"} + id="call1", name="message", + arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"}, ) - - call_count = 0 - - def mock_chat(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 1: - return LLMResponse(content="", tool_calls=[tool_call]) - else: - return LLMResponse(content="I've sent the email.", tool_calls=[]) - - loop.provider.chat = AsyncMock(side_effect=mock_chat) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="I've sent the email.", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) - # Track outbound messages - sent_messages: list[OutboundMessage] = [] + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) - async def _capture_outbound(msg: OutboundMessage) -> None: - sent_messages.append(msg) - - # Set up message tool with callback - message_tool = loop.tools.get("message") - if isinstance(message_tool, MessageTool): - message_tool.set_send_callback(_capture_outbound) - - msg = InboundMessage( - channel="feishu", sender_id="user1", chat_id="chat123", content="Send an email" - ) + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email") result = await loop._process_message(msg) - # Message tool should have sent to email - assert len(sent_messages) == 1 - assert sent_messages[0].channel == "email" - assert sent_messages[0].chat_id == "user@example.com" - - # Final reply should be sent to Feishu (not suppressed) - assert result is not None + assert len(sent) == 1 + assert sent[0].channel == "email" + assert result is not None # not suppressed assert result.channel == "feishu" - assert result.chat_id == "chat123" - assert "email" in result.content.lower() or "sent" in result.content.lower() @pytest.mark.asyncio - async def test_final_reply_sent_when_no_message_tool_used(self, tmp_path: Path) -> None: - """If no message tool is used, final reply is always sent.""" - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - # Mock provider to return a simple response without tool calls - loop.provider.chat = AsyncMock(return_value=LLMResponse( - content="Hello! How can I help you?", - tool_calls=[] - )) + async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) loop.tools.get_definitions = MagicMock(return_value=[]) - msg = InboundMessage( - channel="feishu", sender_id="user1", chat_id="chat123", content="Hi" - ) + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") result = await loop._process_message(msg) - # Final reply should be sent assert result is not None - assert result.channel == "feishu" - assert result.chat_id == "chat123" assert "Hello" in result.content class TestMessageToolTurnTracking: - """Test MessageTool's turn tracking functionality.""" - def test_turn_sends_tracking(self) -> None: - """MessageTool correctly tracks sends per turn.""" + def test_sent_in_turn_tracks_same_target(self) -> None: tool = MessageTool() + tool.set_context("feishu", "chat1") + assert not tool._sent_in_turn + tool._sent_in_turn = True + assert tool._sent_in_turn - # Initially empty - assert tool.get_turn_sends() == [] - - # Simulate sends - tool._turn_sends.append(("feishu", "chat1")) - tool._turn_sends.append(("email", "user@example.com")) - - sends = tool.get_turn_sends() - assert len(sends) == 2 - assert ("feishu", "chat1") in sends - assert ("email", "user@example.com") in sends - - def test_start_turn_clears_tracking(self) -> None: - """start_turn() clears the turn sends list.""" + def test_start_turn_resets(self) -> None: tool = MessageTool() - tool._turn_sends.append(("feishu", "chat1")) - assert len(tool.get_turn_sends()) == 1 - + tool._sent_in_turn = True tool.start_turn() - assert tool.get_turn_sends() == [] - - def test_get_turn_sends_returns_copy(self) -> None: - """get_turn_sends() returns a copy, not the original list.""" - tool = MessageTool() - tool._turn_sends.append(("feishu", "chat1")) - - sends = tool.get_turn_sends() - sends.append(("email", "user@example.com")) # Modify the copy - - # Original should be unchanged - assert len(tool.get_turn_sends()) == 1 + assert not tool._sent_in_turn