refactor: simplify message tool turn tracking to a single boolean flag
This commit is contained in:
@@ -383,18 +383,8 @@ class AgentLoop:
|
|||||||
tools_used=tools_used if tools_used else None)
|
tools_used=tools_used if tools_used else None)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
suppress_final_reply = False
|
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
if isinstance(message_tool, MessageTool):
|
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||||
sent_targets = set(message_tool.get_turn_sends())
|
|
||||||
suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets
|
|
||||||
|
|
||||||
if suppress_final_reply:
|
|
||||||
logger.info(
|
|
||||||
"Skipping final auto-reply because message tool already sent to {}:{} in this turn",
|
|
||||||
msg.channel,
|
|
||||||
msg.chat_id,
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
|
|||||||
@@ -20,24 +20,21 @@ class MessageTool(Tool):
|
|||||||
self._default_channel = default_channel
|
self._default_channel = default_channel
|
||||||
self._default_chat_id = default_chat_id
|
self._default_chat_id = default_chat_id
|
||||||
self._default_message_id = default_message_id
|
self._default_message_id = default_message_id
|
||||||
self._turn_sends: list[tuple[str, str]] = []
|
self._sent_in_turn: bool = False
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||||
"""Set the current message context."""
|
"""Set the current message context."""
|
||||||
self._default_channel = channel
|
self._default_channel = channel
|
||||||
self._default_chat_id = chat_id
|
self._default_chat_id = chat_id
|
||||||
self._default_message_id = message_id
|
self._default_message_id = message_id
|
||||||
|
|
||||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||||
"""Set the callback for sending messages."""
|
"""Set the callback for sending messages."""
|
||||||
self._send_callback = callback
|
self._send_callback = callback
|
||||||
|
|
||||||
def start_turn(self) -> None:
|
def start_turn(self) -> None:
|
||||||
"""Reset per-turn send tracking."""
|
"""Reset per-turn send tracking."""
|
||||||
self._turn_sends.clear()
|
self._sent_in_turn = False
|
||||||
|
|
||||||
def get_turn_sends(self) -> list[tuple[str, str]]:
|
|
||||||
"""Get (channel, chat_id) targets sent in the current turn."""
|
|
||||||
return list(self._turn_sends)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -104,7 +101,7 @@ class MessageTool(Tool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_callback(msg)
|
await self._send_callback(msg)
|
||||||
self._turn_sends.append((channel, chat_id))
|
self._sent_in_turn = True
|
||||||
media_info = f" with {len(media)} attachments" if media else ""
|
media_info = f" with {len(media)} attachments" if media else ""
|
||||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user