refactor: simplify message tool suppress — bool check instead of target tracking

This commit is contained in:
Re-bin
2026-02-27 02:27:18 +00:00
parent ac1c40db91
commit 29e6709e26
3 changed files with 58 additions and 169 deletions

View File

@@ -444,18 +444,7 @@ class AgentLoop:
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
suppress_final_reply = False
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
sent_targets = set(message_tool.get_turn_sends())
suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets
if suppress_final_reply:
logger.info(
"Skipping final auto-reply because message tool already sent to {}:{} in this turn",
msg.channel,
msg.chat_id,
)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content

View File

@@ -20,7 +20,7 @@ class MessageTool(Tool):
self._default_channel = default_channel
self._default_chat_id = default_chat_id
self._default_message_id = default_message_id
self._turn_sends: list[tuple[str, str]] = []
self._sent_in_turn: bool = False
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Set the current message context."""
@@ -34,11 +34,7 @@ class MessageTool(Tool):
def start_turn(self) -> None:
"""Reset per-turn send tracking."""
self._turn_sends.clear()
def get_turn_sends(self) -> list[tuple[str, str]]:
"""Get (channel, chat_id) targets sent in the current turn."""
return list(self._turn_sends)
self._sent_in_turn = False
@property
def name(self) -> str:
@@ -105,7 +101,8 @@ class MessageTool(Tool):
try:
await self._send_callback(msg)
self._turn_sends.append((channel, chat_id))
if channel == self._default_channel and chat_id == self._default_chat_id:
self._sent_in_turn = True
media_info = f" with {len(media)} attachments" if media else ""
return f"Message sent to {channel}:{chat_id}{media_info}"
except Exception as e:

View File

@@ -1,6 +1,5 @@
"""Test message tool suppress logic for final replies."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
@@ -13,188 +12,92 @@ from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
class TestMessageToolSuppressLogic:
"""Test that final reply is only suppressed when message tool sends to same target."""
"""Final reply suppressed only when message tool sends to the same target."""
@pytest.mark.asyncio
async def test_final_reply_suppressed_when_message_tool_sends_to_same_target(
self, tmp_path: Path
) -> None:
"""If message tool sends to the same (channel, chat_id), final reply is suppressed."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
# First call returns tool call, second call returns final response
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1",
name="message",
arguments={"content": "Hello from tool", "channel": "feishu", "chat_id": "chat123"}
id="call1", name="message",
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
)
call_count = 0
def mock_chat(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return LLMResponse(content="", tool_calls=[tool_call])
else:
return LLMResponse(content="Done", tool_calls=[])
loop.provider.chat = AsyncMock(side_effect=mock_chat)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
# Track outbound messages
sent_messages: list[OutboundMessage] = []
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
async def _capture_outbound(msg: OutboundMessage) -> None:
sent_messages.append(msg)
# Set up message tool with callback
message_tool = loop.tools.get("message")
if isinstance(message_tool, MessageTool):
message_tool.set_send_callback(_capture_outbound)
msg = InboundMessage(
channel="feishu", sender_id="user1", chat_id="chat123", content="Send a message"
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
result = await loop._process_message(msg)
# Message tool should have sent to the same target
assert len(sent_messages) == 1
assert sent_messages[0].channel == "feishu"
assert sent_messages[0].chat_id == "chat123"
# Final reply should be None (suppressed)
assert result is None
assert len(sent) == 1
assert result is None # suppressed
@pytest.mark.asyncio
async def test_final_reply_sent_when_message_tool_sends_to_different_target(
self, tmp_path: Path
) -> None:
"""If message tool sends to a different target, final reply is still sent."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
# First call returns tool call to email, second call returns final response
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1",
name="message",
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"}
id="call1", name="message",
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
)
call_count = 0
def mock_chat(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return LLMResponse(content="", tool_calls=[tool_call])
else:
return LLMResponse(content="I've sent the email.", tool_calls=[])
loop.provider.chat = AsyncMock(side_effect=mock_chat)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
# Track outbound messages
sent_messages: list[OutboundMessage] = []
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
async def _capture_outbound(msg: OutboundMessage) -> None:
sent_messages.append(msg)
# Set up message tool with callback
message_tool = loop.tools.get("message")
if isinstance(message_tool, MessageTool):
message_tool.set_send_callback(_capture_outbound)
msg = InboundMessage(
channel="feishu", sender_id="user1", chat_id="chat123", content="Send an email"
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
result = await loop._process_message(msg)
# Message tool should have sent to email
assert len(sent_messages) == 1
assert sent_messages[0].channel == "email"
assert sent_messages[0].chat_id == "user@example.com"
# Final reply should be sent to Feishu (not suppressed)
assert result is not None
assert len(sent) == 1
assert sent[0].channel == "email"
assert result is not None # not suppressed
assert result.channel == "feishu"
assert result.chat_id == "chat123"
assert "email" in result.content.lower() or "sent" in result.content.lower()
@pytest.mark.asyncio
async def test_final_reply_sent_when_no_message_tool_used(self, tmp_path: Path) -> None:
"""If no message tool is used, final reply is always sent."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
# Mock provider to return a simple response without tool calls
loop.provider.chat = AsyncMock(return_value=LLMResponse(
content="Hello! How can I help you?",
tool_calls=[]
))
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(
channel="feishu", sender_id="user1", chat_id="chat123", content="Hi"
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
result = await loop._process_message(msg)
# Final reply should be sent
assert result is not None
assert result.channel == "feishu"
assert result.chat_id == "chat123"
assert "Hello" in result.content
class TestMessageToolTurnTracking:
"""Test MessageTool's turn tracking functionality."""
def test_turn_sends_tracking(self) -> None:
"""MessageTool correctly tracks sends per turn."""
def test_sent_in_turn_tracks_same_target(self) -> None:
tool = MessageTool()
tool.set_context("feishu", "chat1")
assert not tool._sent_in_turn
tool._sent_in_turn = True
assert tool._sent_in_turn
# Initially empty
assert tool.get_turn_sends() == []
# Simulate sends
tool._turn_sends.append(("feishu", "chat1"))
tool._turn_sends.append(("email", "user@example.com"))
sends = tool.get_turn_sends()
assert len(sends) == 2
assert ("feishu", "chat1") in sends
assert ("email", "user@example.com") in sends
def test_start_turn_clears_tracking(self) -> None:
"""start_turn() clears the turn sends list."""
def test_start_turn_resets(self) -> None:
tool = MessageTool()
tool._turn_sends.append(("feishu", "chat1"))
assert len(tool.get_turn_sends()) == 1
tool._sent_in_turn = True
tool.start_turn()
assert tool.get_turn_sends() == []
def test_get_turn_sends_returns_copy(self) -> None:
"""get_turn_sends() returns a copy, not the original list."""
tool = MessageTool()
tool._turn_sends.append(("feishu", "chat1"))
sends = tool.get_turn_sends()
sends.append(("email", "user@example.com")) # Modify the copy
# Original should be unchanged
assert len(tool.get_turn_sends()) == 1
assert not tool._sent_in_turn