Merge remote-tracking branch 'origin/main' into pr-1257
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
||||||
|
|
||||||
📏 Real-time line count: **3,966 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: **3,932 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ class AgentLoop:
|
|||||||
5. Sends responses back
|
5. Sends responses back
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_TOOL_RESULT_MAX_CHARS = 500
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
@@ -145,17 +147,10 @@ class AgentLoop:
|
|||||||
|
|
||||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||||
"""Update context for all tools that need routing info."""
|
"""Update context for all tools that need routing info."""
|
||||||
if message_tool := self.tools.get("message"):
|
for name in ("message", "spawn", "cron"):
|
||||||
if isinstance(message_tool, MessageTool):
|
if tool := self.tools.get(name):
|
||||||
message_tool.set_context(channel, chat_id, message_id)
|
if hasattr(tool, "set_context"):
|
||||||
|
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||||
if spawn_tool := self.tools.get("spawn"):
|
|
||||||
if isinstance(spawn_tool, SpawnTool):
|
|
||||||
spawn_tool.set_context(channel, chat_id)
|
|
||||||
|
|
||||||
if cron_tool := self.tools.get("cron"):
|
|
||||||
if isinstance(cron_tool, CronTool):
|
|
||||||
cron_tool.set_context(channel, chat_id)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _strip_think(text: str | None) -> str | None:
|
def _strip_think(text: str | None) -> str | None:
|
||||||
@@ -315,18 +310,6 @@ class AgentLoop:
|
|||||||
self._running = False
|
self._running = False
|
||||||
logger.info("Agent loop stopping")
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
|
|
||||||
lock = self._consolidation_locks.get(session_key)
|
|
||||||
if lock is None:
|
|
||||||
lock = asyncio.Lock()
|
|
||||||
self._consolidation_locks[session_key] = lock
|
|
||||||
return lock
|
|
||||||
|
|
||||||
def _prune_consolidation_lock(self, session_key: str, lock: asyncio.Lock) -> None:
|
|
||||||
"""Drop lock entry if no longer in use."""
|
|
||||||
if not lock.locked():
|
|
||||||
self._consolidation_locks.pop(session_key, None)
|
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
self,
|
self,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
@@ -362,7 +345,7 @@ class AgentLoop:
|
|||||||
# Slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
lock = self._get_consolidation_lock(session.key)
|
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||||
self._consolidating.add(session.key)
|
self._consolidating.add(session.key)
|
||||||
try:
|
try:
|
||||||
async with lock:
|
async with lock:
|
||||||
@@ -383,7 +366,8 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
self._prune_consolidation_lock(session.key, lock)
|
if not lock.locked():
|
||||||
|
self._consolidation_locks.pop(session.key, None)
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
@@ -397,7 +381,7 @@ class AgentLoop:
|
|||||||
unconsolidated = len(session.messages) - session.last_consolidated
|
unconsolidated = len(session.messages) - session.last_consolidated
|
||||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||||
self._consolidating.add(session.key)
|
self._consolidating.add(session.key)
|
||||||
lock = self._get_consolidation_lock(session.key)
|
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||||
|
|
||||||
async def _consolidate_and_unlock():
|
async def _consolidate_and_unlock():
|
||||||
try:
|
try:
|
||||||
@@ -405,7 +389,8 @@ class AgentLoop:
|
|||||||
await self._consolidate_memory(session)
|
await self._consolidate_memory(session)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
self._prune_consolidation_lock(session.key, lock)
|
if not lock.locked():
|
||||||
|
self._consolidation_locks.pop(session.key, None)
|
||||||
_task = asyncio.current_task()
|
_task = asyncio.current_task()
|
||||||
if _task is not None:
|
if _task is not None:
|
||||||
self._consolidation_tasks.discard(_task)
|
self._consolidation_tasks.discard(_task)
|
||||||
@@ -441,23 +426,19 @@ class AgentLoop:
|
|||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
|
||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
if message_tool := self.tools.get("message"):
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||||
metadata=msg.metadata or {},
|
metadata=msg.metadata or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
_TOOL_RESULT_MAX_CHARS = 500
|
|
||||||
|
|
||||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||||
"""Save new-turn messages into session, truncating large tool results."""
|
"""Save new-turn messages into session, truncating large tool results."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|||||||
@@ -101,7 +101,8 @@ class MessageTool(Tool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_callback(msg)
|
await self._send_callback(msg)
|
||||||
self._sent_in_turn = True
|
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||||
|
self._sent_in_turn = True
|
||||||
media_info = f" with {len(media)} attachments" if media else ""
|
media_info = f" with {len(media)} attachments" if media else ""
|
||||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class WebSearchTool(Tool):
|
|||||||
r = await client.get(
|
r = await client.get(
|
||||||
"https://api.search.brave.com/res/v1/web/search",
|
"https://api.search.brave.com/res/v1/web/search",
|
||||||
params={"q": query, "count": n},
|
params={"q": query, "count": n},
|
||||||
headers={"Accept": "application/json", "X-Subscription-Token": api_key},
|
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
||||||
timeout=10.0
|
timeout=10.0
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|||||||
@@ -127,6 +127,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
self._app: Application | None = None
|
self._app: Application | None = None
|
||||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||||
|
self._media_group_buffers: dict[str, dict] = {}
|
||||||
|
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Telegram bot with long polling."""
|
"""Start the Telegram bot with long polling."""
|
||||||
@@ -192,6 +194,11 @@ class TelegramChannel(BaseChannel):
|
|||||||
for chat_id in list(self._typing_tasks):
|
for chat_id in list(self._typing_tasks):
|
||||||
self._stop_typing(chat_id)
|
self._stop_typing(chat_id)
|
||||||
|
|
||||||
|
for task in self._media_group_tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
self._media_group_tasks.clear()
|
||||||
|
self._media_group_buffers.clear()
|
||||||
|
|
||||||
if self._app:
|
if self._app:
|
||||||
logger.info("Stopping Telegram bot...")
|
logger.info("Stopping Telegram bot...")
|
||||||
await self._app.updater.stop()
|
await self._app.updater.stop()
|
||||||
@@ -400,6 +407,28 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
str_chat_id = str(chat_id)
|
str_chat_id = str(chat_id)
|
||||||
|
|
||||||
|
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||||
|
if media_group_id := getattr(message, "media_group_id", None):
|
||||||
|
key = f"{str_chat_id}:{media_group_id}"
|
||||||
|
if key not in self._media_group_buffers:
|
||||||
|
self._media_group_buffers[key] = {
|
||||||
|
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||||
|
"contents": [], "media": [],
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message.message_id, "user_id": user.id,
|
||||||
|
"username": user.username, "first_name": user.first_name,
|
||||||
|
"is_group": message.chat.type != "private",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self._start_typing(str_chat_id)
|
||||||
|
buf = self._media_group_buffers[key]
|
||||||
|
if content and content != "[empty message]":
|
||||||
|
buf["contents"].append(content)
|
||||||
|
buf["media"].extend(media_paths)
|
||||||
|
if key not in self._media_group_tasks:
|
||||||
|
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
||||||
|
return
|
||||||
|
|
||||||
# Start typing indicator before processing
|
# Start typing indicator before processing
|
||||||
self._start_typing(str_chat_id)
|
self._start_typing(str_chat_id)
|
||||||
|
|
||||||
@@ -418,6 +447,21 @@ class TelegramChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _flush_media_group(self, key: str) -> None:
|
||||||
|
"""Wait briefly, then forward buffered media-group as one turn."""
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.6)
|
||||||
|
if not (buf := self._media_group_buffers.pop(key, None)):
|
||||||
|
return
|
||||||
|
content = "\n".join(buf["contents"]) or "[empty message]"
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||||
|
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||||
|
metadata=buf["metadata"],
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._media_group_tasks.pop(key, None)
|
||||||
|
|
||||||
def _start_typing(self, chat_id: str) -> None:
|
def _start_typing(self, chat_id: str) -> None:
|
||||||
"""Start sending 'typing...' indicator for a chat."""
|
"""Start sending 'typing...' indicator for a chat."""
|
||||||
# Cancel any existing typing task for this chat
|
# Cancel any existing typing task for this chat
|
||||||
|
|||||||
@@ -812,7 +812,7 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
# Ensure lock exists before /new.
|
# Ensure lock exists before /new.
|
||||||
_ = loop._get_consolidation_lock(session.key)
|
loop._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||||
assert session.key in loop._consolidation_locks
|
assert session.key in loop._consolidation_locks
|
||||||
|
|
||||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||||
|
|||||||
103
tests/test_message_tool_suppress.py
Normal file
103
tests/test_message_tool_suppress.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Test message tool suppress logic for final replies."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolSuppressLogic:
|
||||||
|
"""Final reply suppressed only when message tool sends to the same target."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert result is None # suppressed
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert sent[0].channel == "email"
|
||||||
|
assert result is not None # not suppressed
|
||||||
|
assert result.channel == "feishu"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||||
|
result = await loop._process_message(msg)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageToolTurnTracking:
|
||||||
|
|
||||||
|
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||||
|
tool = MessageTool()
|
||||||
|
tool.set_context("feishu", "chat1")
|
||||||
|
assert not tool._sent_in_turn
|
||||||
|
tool._sent_in_turn = True
|
||||||
|
assert tool._sent_in_turn
|
||||||
|
|
||||||
|
def test_start_turn_resets(self) -> None:
|
||||||
|
tool = MessageTool()
|
||||||
|
tool._sent_in_turn = True
|
||||||
|
tool.start_turn()
|
||||||
|
assert not tool._sent_in_turn
|
||||||
Reference in New Issue
Block a user