fix(agent): only suppress final reply when message tool sends to same target
A refactoring in commit 132807a introduced a regression where the final
response was silently discarded whenever the message tool was used,
regardless of the target. This restored the original logic from PR #832
that only suppresses the final reply when the message tool sends to the
same (channel, chat_id) as the original message.
Changes:
- message.py: Replace _sent_in_turn: bool with _turn_sends: list[tuple]
to track actual send targets, add get_turn_sends() method
- loop.py: Check if (msg.channel, msg.chat_id) is in sent_targets before
suppressing final reply. Also move the "Response to" log after the
suppress check to avoid misleading logs.
- Add unit tests for the suppress logic
This ensures:
- Email sent via message tool → Feishu still gets confirmation
- Message tool sends to same Feishu chat → No duplicate (suppressed)
This commit is contained in:
@@ -407,16 +407,25 @@ class AgentLoop:
|
|||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
|
||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
suppress_final_reply = False
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
if isinstance(message_tool, MessageTool):
|
||||||
return None
|
sent_targets = set(message_tool.get_turn_sends())
|
||||||
|
suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets
|
||||||
|
|
||||||
|
if suppress_final_reply:
|
||||||
|
logger.info(
|
||||||
|
"Skipping final auto-reply because message tool already sent to {}:{} in this turn",
|
||||||
|
msg.channel,
|
||||||
|
msg.chat_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||||
metadata=msg.metadata or {},
|
metadata=msg.metadata or {},
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class MessageTool(Tool):
|
|||||||
self._default_channel = default_channel
|
self._default_channel = default_channel
|
||||||
self._default_chat_id = default_chat_id
|
self._default_chat_id = default_chat_id
|
||||||
self._default_message_id = default_message_id
|
self._default_message_id = default_message_id
|
||||||
self._sent_in_turn: bool = False
|
self._turn_sends: list[tuple[str, str]] = []
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||||
"""Set the current message context."""
|
"""Set the current message context."""
|
||||||
@@ -34,7 +34,11 @@ class MessageTool(Tool):
|
|||||||
|
|
||||||
def start_turn(self) -> None:
|
def start_turn(self) -> None:
|
||||||
"""Reset per-turn send tracking."""
|
"""Reset per-turn send tracking."""
|
||||||
self._sent_in_turn = False
|
self._turn_sends.clear()
|
||||||
|
|
||||||
|
def get_turn_sends(self) -> list[tuple[str, str]]:
|
||||||
|
"""Get (channel, chat_id) targets sent in the current turn."""
|
||||||
|
return list(self._turn_sends)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -101,7 +105,7 @@ class MessageTool(Tool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_callback(msg)
|
await self._send_callback(msg)
|
||||||
self._sent_in_turn = True
|
self._turn_sends.append((channel, chat_id))
|
||||||
media_info = f" with {len(media)} attachments" if media else ""
|
media_info = f" with {len(media)} attachments" if media else ""
|
||||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
200
tests/test_message_tool_suppress.py
Normal file
200
tests/test_message_tool_suppress.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""Test message tool suppress logic for final replies."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolSuppressLogic:
|
||||||
|
"""Test that final reply is only suppressed when message tool sends to same target."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_final_reply_suppressed_when_message_tool_sends_to_same_target(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""If message tool sends to the same (channel, chat_id), final reply is suppressed."""
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call returns tool call, second call returns final response
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1",
|
||||||
|
name="message",
|
||||||
|
arguments={"content": "Hello from tool", "channel": "feishu", "chat_id": "chat123"}
|
||||||
|
)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_chat(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return LLMResponse(content="", tool_calls=[tool_call])
|
||||||
|
else:
|
||||||
|
return LLMResponse(content="Done", tool_calls=[])
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=mock_chat)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
# Track outbound messages
|
||||||
|
sent_messages: list[OutboundMessage] = []
|
||||||
|
|
||||||
|
async def _capture_outbound(msg: OutboundMessage) -> None:
|
||||||
|
sent_messages.append(msg)
|
||||||
|
|
||||||
|
# Set up message tool with callback
|
||||||
|
message_tool = loop.tools.get("message")
|
||||||
|
if isinstance(message_tool, MessageTool):
|
||||||
|
message_tool.set_send_callback(_capture_outbound)
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="feishu", sender_id="user1", chat_id="chat123", content="Send a message"
|
||||||
|
)
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
# Message tool should have sent to the same target
|
||||||
|
assert len(sent_messages) == 1
|
||||||
|
assert sent_messages[0].channel == "feishu"
|
||||||
|
assert sent_messages[0].chat_id == "chat123"
|
||||||
|
|
||||||
|
# Final reply should be None (suppressed)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_final_reply_sent_when_message_tool_sends_to_different_target(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
"""If message tool sends to a different target, final reply is still sent."""
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call returns tool call to email, second call returns final response
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1",
|
||||||
|
name="message",
|
||||||
|
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"}
|
||||||
|
)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_chat(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return LLMResponse(content="", tool_calls=[tool_call])
|
||||||
|
else:
|
||||||
|
return LLMResponse(content="I've sent the email.", tool_calls=[])
|
||||||
|
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=mock_chat)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
# Track outbound messages
|
||||||
|
sent_messages: list[OutboundMessage] = []
|
||||||
|
|
||||||
|
async def _capture_outbound(msg: OutboundMessage) -> None:
|
||||||
|
sent_messages.append(msg)
|
||||||
|
|
||||||
|
# Set up message tool with callback
|
||||||
|
message_tool = loop.tools.get("message")
|
||||||
|
if isinstance(message_tool, MessageTool):
|
||||||
|
message_tool.set_send_callback(_capture_outbound)
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="feishu", sender_id="user1", chat_id="chat123", content="Send an email"
|
||||||
|
)
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
# Message tool should have sent to email
|
||||||
|
assert len(sent_messages) == 1
|
||||||
|
assert sent_messages[0].channel == "email"
|
||||||
|
assert sent_messages[0].chat_id == "user@example.com"
|
||||||
|
|
||||||
|
# Final reply should be sent to Feishu (not suppressed)
|
||||||
|
assert result is not None
|
||||||
|
assert result.channel == "feishu"
|
||||||
|
assert result.chat_id == "chat123"
|
||||||
|
assert "email" in result.content.lower() or "sent" in result.content.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_final_reply_sent_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||||
|
"""If no message tool is used, final reply is always sent."""
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock provider to return a simple response without tool calls
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Hello! How can I help you?",
|
||||||
|
tool_calls=[]
|
||||||
|
))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="feishu", sender_id="user1", chat_id="chat123", content="Hi"
|
||||||
|
)
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
# Final reply should be sent
|
||||||
|
assert result is not None
|
||||||
|
assert result.channel == "feishu"
|
||||||
|
assert result.chat_id == "chat123"
|
||||||
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolTurnTracking:
|
||||||
|
"""Test MessageTool's turn tracking functionality."""
|
||||||
|
|
||||||
|
def test_turn_sends_tracking(self) -> None:
|
||||||
|
"""MessageTool correctly tracks sends per turn."""
|
||||||
|
tool = MessageTool()
|
||||||
|
|
||||||
|
# Initially empty
|
||||||
|
assert tool.get_turn_sends() == []
|
||||||
|
|
||||||
|
# Simulate sends
|
||||||
|
tool._turn_sends.append(("feishu", "chat1"))
|
||||||
|
tool._turn_sends.append(("email", "user@example.com"))
|
||||||
|
|
||||||
|
sends = tool.get_turn_sends()
|
||||||
|
assert len(sends) == 2
|
||||||
|
assert ("feishu", "chat1") in sends
|
||||||
|
assert ("email", "user@example.com") in sends
|
||||||
|
|
||||||
|
def test_start_turn_clears_tracking(self) -> None:
|
||||||
|
"""start_turn() clears the turn sends list."""
|
||||||
|
tool = MessageTool()
|
||||||
|
tool._turn_sends.append(("feishu", "chat1"))
|
||||||
|
assert len(tool.get_turn_sends()) == 1
|
||||||
|
|
||||||
|
tool.start_turn()
|
||||||
|
assert tool.get_turn_sends() == []
|
||||||
|
|
||||||
|
def test_get_turn_sends_returns_copy(self) -> None:
|
||||||
|
"""get_turn_sends() returns a copy, not the original list."""
|
||||||
|
tool = MessageTool()
|
||||||
|
tool._turn_sends.append(("feishu", "chat1"))
|
||||||
|
|
||||||
|
sends = tool.get_turn_sends()
|
||||||
|
sends.append(("email", "user@example.com")) # Modify the copy
|
||||||
|
|
||||||
|
# Original should be unchanged
|
||||||
|
assert len(tool.get_turn_sends()) == 1
|
||||||
Reference in New Issue
Block a user