Merge PR #1206: fix message tool suppress for cross-channel sends
This commit is contained in:
@@ -441,16 +441,14 @@ class AgentLoop:
|
|||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
|
||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
if message_tool := self.tools.get("message"):
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||||
metadata=msg.metadata or {},
|
metadata=msg.metadata or {},
|
||||||
|
|||||||
@@ -101,7 +101,8 @@ class MessageTool(Tool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_callback(msg)
|
await self._send_callback(msg)
|
||||||
self._sent_in_turn = True
|
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||||
|
self._sent_in_turn = True
|
||||||
media_info = f" with {len(media)} attachments" if media else ""
|
media_info = f" with {len(media)} attachments" if media else ""
|
||||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
103
tests/test_message_tool_suppress.py
Normal file
103
tests/test_message_tool_suppress.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Test message tool suppress logic for final replies."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolSuppressLogic:
|
||||||
|
"""Final reply suppressed only when message tool sends to the same target."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert result is None # suppressed
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert sent[0].channel == "email"
|
||||||
|
assert result is not None # not suppressed
|
||||||
|
assert result.channel == "feishu"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolTurnTracking:
|
||||||
|
|
||||||
|
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||||
|
tool = MessageTool()
|
||||||
|
tool.set_context("feishu", "chat1")
|
||||||
|
assert not tool._sent_in_turn
|
||||||
|
tool._sent_in_turn = True
|
||||||
|
assert tool._sent_in_turn
|
||||||
|
|
||||||
|
def test_start_turn_resets(self) -> None:
|
||||||
|
tool = MessageTool()
|
||||||
|
tool._sent_in_turn = True
|
||||||
|
tool.start_turn()
|
||||||
|
assert not tool._sent_in_turn
|
||||||
Reference in New Issue
Block a user