refactor: simplify message tool suppress — bool check instead of target tracking

This commit is contained in:
Re-bin
2026-02-27 02:27:18 +00:00
parent ac1c40db91
commit 29e6709e26
3 changed files with 58 additions and 169 deletions

View File

@@ -444,18 +444,7 @@ class AgentLoop:
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
suppress_final_reply = False if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
sent_targets = set(message_tool.get_turn_sends())
suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets
if suppress_final_reply:
logger.info(
"Skipping final auto-reply because message tool already sent to {}:{} in this turn",
msg.channel,
msg.chat_id,
)
return None return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content preview = final_content[:120] + "..." if len(final_content) > 120 else final_content

View File

@@ -20,7 +20,7 @@ class MessageTool(Tool):
self._default_channel = default_channel self._default_channel = default_channel
self._default_chat_id = default_chat_id self._default_chat_id = default_chat_id
self._default_message_id = default_message_id self._default_message_id = default_message_id
self._turn_sends: list[tuple[str, str]] = [] self._sent_in_turn: bool = False
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Set the current message context.""" """Set the current message context."""
@@ -34,11 +34,7 @@ class MessageTool(Tool):
def start_turn(self) -> None: def start_turn(self) -> None:
"""Reset per-turn send tracking.""" """Reset per-turn send tracking."""
self._turn_sends.clear() self._sent_in_turn = False
def get_turn_sends(self) -> list[tuple[str, str]]:
"""Get (channel, chat_id) targets sent in the current turn."""
return list(self._turn_sends)
@property @property
def name(self) -> str: def name(self) -> str:
@@ -105,7 +101,8 @@ class MessageTool(Tool):
try: try:
await self._send_callback(msg) await self._send_callback(msg)
self._turn_sends.append((channel, chat_id)) if channel == self._default_channel and chat_id == self._default_chat_id:
self._sent_in_turn = True
media_info = f" with {len(media)} attachments" if media else "" media_info = f" with {len(media)} attachments" if media else ""
return f"Message sent to {channel}:{chat_id}{media_info}" return f"Message sent to {channel}:{chat_id}{media_info}"
except Exception as e: except Exception as e:

View File

@@ -1,6 +1,5 @@
"""Test message tool suppress logic for final replies.""" """Test message tool suppress logic for final replies."""
import asyncio
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@@ -13,188 +12,92 @@ from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
class TestMessageToolSuppressLogic: class TestMessageToolSuppressLogic:
"""Test that final reply is only suppressed when message tool sends to same target.""" """Final reply suppressed only when message tool sends to the same target."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_final_reply_suppressed_when_message_tool_sends_to_same_target( async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
self, tmp_path: Path loop = _make_loop(tmp_path)
) -> None:
"""If message tool sends to the same (channel, chat_id), final reply is suppressed."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
# First call returns tool call, second call returns final response
tool_call = ToolCallRequest( tool_call = ToolCallRequest(
id="call1", id="call1", name="message",
name="message", arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
arguments={"content": "Hello from tool", "channel": "feishu", "chat_id": "chat123"}
) )
calls = iter([
call_count = 0 LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
def mock_chat(*args, **kwargs): ])
nonlocal call_count loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
call_count += 1
if call_count == 1:
return LLMResponse(content="", tool_calls=[tool_call])
else:
return LLMResponse(content="Done", tool_calls=[])
loop.provider.chat = AsyncMock(side_effect=mock_chat)
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
# Track outbound messages sent: list[OutboundMessage] = []
sent_messages: list[OutboundMessage] = [] mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
async def _capture_outbound(msg: OutboundMessage) -> None: msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
sent_messages.append(msg)
# Set up message tool with callback
message_tool = loop.tools.get("message")
if isinstance(message_tool, MessageTool):
message_tool.set_send_callback(_capture_outbound)
msg = InboundMessage(
channel="feishu", sender_id="user1", chat_id="chat123", content="Send a message"
)
result = await loop._process_message(msg) result = await loop._process_message(msg)
# Message tool should have sent to the same target assert len(sent) == 1
assert len(sent_messages) == 1 assert result is None # suppressed
assert sent_messages[0].channel == "feishu"
assert sent_messages[0].chat_id == "chat123"
# Final reply should be None (suppressed)
assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_final_reply_sent_when_message_tool_sends_to_different_target( async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
self, tmp_path: Path loop = _make_loop(tmp_path)
) -> None:
"""If message tool sends to a different target, final reply is still sent."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
# First call returns tool call to email, second call returns final response
tool_call = ToolCallRequest( tool_call = ToolCallRequest(
id="call1", id="call1", name="message",
name="message", arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"}
) )
calls = iter([
call_count = 0 LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
def mock_chat(*args, **kwargs): ])
nonlocal call_count loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
call_count += 1
if call_count == 1:
return LLMResponse(content="", tool_calls=[tool_call])
else:
return LLMResponse(content="I've sent the email.", tool_calls=[])
loop.provider.chat = AsyncMock(side_effect=mock_chat)
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
# Track outbound messages sent: list[OutboundMessage] = []
sent_messages: list[OutboundMessage] = [] mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
async def _capture_outbound(msg: OutboundMessage) -> None: msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
sent_messages.append(msg)
# Set up message tool with callback
message_tool = loop.tools.get("message")
if isinstance(message_tool, MessageTool):
message_tool.set_send_callback(_capture_outbound)
msg = InboundMessage(
channel="feishu", sender_id="user1", chat_id="chat123", content="Send an email"
)
result = await loop._process_message(msg) result = await loop._process_message(msg)
# Message tool should have sent to email assert len(sent) == 1
assert len(sent_messages) == 1 assert sent[0].channel == "email"
assert sent_messages[0].channel == "email" assert result is not None # not suppressed
assert sent_messages[0].chat_id == "user@example.com"
# Final reply should be sent to Feishu (not suppressed)
assert result is not None
assert result.channel == "feishu" assert result.channel == "feishu"
assert result.chat_id == "chat123"
assert "email" in result.content.lower() or "sent" in result.content.lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_final_reply_sent_when_no_message_tool_used(self, tmp_path: Path) -> None: async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
"""If no message tool is used, final reply is always sent.""" loop = _make_loop(tmp_path)
bus = MessageBus() loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
# Mock provider to return a simple response without tool calls
loop.provider.chat = AsyncMock(return_value=LLMResponse(
content="Hello! How can I help you?",
tool_calls=[]
))
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage( msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
channel="feishu", sender_id="user1", chat_id="chat123", content="Hi"
)
result = await loop._process_message(msg) result = await loop._process_message(msg)
# Final reply should be sent
assert result is not None assert result is not None
assert result.channel == "feishu"
assert result.chat_id == "chat123"
assert "Hello" in result.content assert "Hello" in result.content
class TestMessageToolTurnTracking: class TestMessageToolTurnTracking:
"""Test MessageTool's turn tracking functionality."""
def test_turn_sends_tracking(self) -> None: def test_sent_in_turn_tracks_same_target(self) -> None:
"""MessageTool correctly tracks sends per turn."""
tool = MessageTool() tool = MessageTool()
tool.set_context("feishu", "chat1")
assert not tool._sent_in_turn
tool._sent_in_turn = True
assert tool._sent_in_turn
# Initially empty def test_start_turn_resets(self) -> None:
assert tool.get_turn_sends() == []
# Simulate sends
tool._turn_sends.append(("feishu", "chat1"))
tool._turn_sends.append(("email", "user@example.com"))
sends = tool.get_turn_sends()
assert len(sends) == 2
assert ("feishu", "chat1") in sends
assert ("email", "user@example.com") in sends
def test_start_turn_clears_tracking(self) -> None:
"""start_turn() clears the turn sends list."""
tool = MessageTool() tool = MessageTool()
tool._turn_sends.append(("feishu", "chat1")) tool._sent_in_turn = True
assert len(tool.get_turn_sends()) == 1
tool.start_turn() tool.start_turn()
assert tool.get_turn_sends() == [] assert not tool._sent_in_turn
def test_get_turn_sends_returns_copy(self) -> None:
"""get_turn_sends() returns a copy, not the original list."""
tool = MessageTool()
tool._turn_sends.append(("feishu", "chat1"))
sends = tool.get_turn_sends()
sends.append(("email", "user@example.com")) # Modify the copy
# Original should be unchanged
assert len(tool.get_turn_sends()) == 1