diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 3e513cb..6155f99 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -441,16 +441,14 @@ class AgentLoop: if final_content is None: final_content = "I've completed processing but have no response to give." - preview = final_content[:120] + "..." if len(final_content) > 120 else final_content - logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) - self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: - return None + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + return None + preview = final_content[:120] + "..." if len(final_content) > 120 else final_content + logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=final_content, metadata=msg.metadata or {}, diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 40e76e3..35e519a 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -101,7 +101,8 @@ class MessageTool(Tool): try: await self._send_callback(msg) - self._sent_in_turn = True + if channel == self._default_channel and chat_id == self._default_chat_id: + self._sent_in_turn = True media_info = f" with {len(media)} attachments" if media else "" return f"Message sent to {channel}:{chat_id}{media_info}" except Exception as e: diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py new file mode 100644 index 0000000..26b8a16 --- /dev/null +++ b/tests/test_message_tool_suppress.py @@ -0,0 +1,103 @@ +"""Test message tool suppress logic for final replies.""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +from nanobot.agent.tools.message import MessageTool +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMResponse, ToolCallRequest + + +def _make_loop(tmp_path: Path) -> AgentLoop: + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) + + +class TestMessageToolSuppressLogic: + """Final reply suppressed only when message tool sends to the same target.""" + + @pytest.mark.asyncio + async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call1", name="message", + arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"}, + ) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send") + result = await loop._process_message(msg) + + assert len(sent) == 1 + assert result is None # suppressed + + @pytest.mark.asyncio + async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call1", name="message", + arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"}, + ) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="I've sent the email.", tool_calls=[]), + ]) + loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email") + result = await loop._process_message(msg) + + assert len(sent) == 1 + assert sent[0].channel == "email" + assert result is not None # not suppressed + assert result.channel == "feishu" + + @pytest.mark.asyncio + async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: + loop = _make_loop(tmp_path) + loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) + loop.tools.get_definitions = MagicMock(return_value=[]) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") + result = await loop._process_message(msg) + + assert result is not None + assert "Hello" in result.content + + +class TestMessageToolTurnTracking: + + def test_sent_in_turn_tracks_same_target(self) -> None: + tool = MessageTool() + tool.set_context("feishu", "chat1") + assert not tool._sent_in_turn + tool._sent_in_turn = True + assert tool._sent_in_turn + + def test_start_turn_resets(self) -> None: + tool = MessageTool() + tool._sent_in_turn = True + tool.start_turn() + assert not tool._sent_in_turn