Move consolidation policy into MemoryConsolidator, keep backward compatibility for legacy config, and compress history by token budget instead of message count.
490 lines
21 KiB
Python
490 lines
21 KiB
Python
"""Agent loop: the core processing engine."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import re
|
|
from contextlib import AsyncExitStack
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|
|
|
from loguru import logger
|
|
|
|
from nanobot.agent.context import ContextBuilder
|
|
from nanobot.agent.memory import MemoryConsolidator
|
|
from nanobot.agent.subagent import SubagentManager
|
|
from nanobot.agent.tools.cron import CronTool
|
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
|
from nanobot.agent.tools.message import MessageTool
|
|
from nanobot.agent.tools.registry import ToolRegistry
|
|
from nanobot.agent.tools.shell import ExecTool
|
|
from nanobot.agent.tools.spawn import SpawnTool
|
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.providers.base import LLMProvider
|
|
from nanobot.session.manager import Session, SessionManager
|
|
|
|
if TYPE_CHECKING:
|
|
from nanobot.config.schema import ChannelsConfig, ExecToolConfig
|
|
from nanobot.cron.service import CronService
|
|
|
|
|
|
class AgentLoop:
|
|
"""
|
|
The agent loop is the core processing engine.
|
|
|
|
It:
|
|
1. Receives messages from the bus
|
|
2. Builds context with history, memory, skills
|
|
3. Calls the LLM
|
|
4. Executes tool calls
|
|
5. Sends responses back
|
|
"""
|
|
|
|
_TOOL_RESULT_MAX_CHARS = 500
|
|
|
|
def __init__(
|
|
self,
|
|
bus: MessageBus,
|
|
provider: LLMProvider,
|
|
workspace: Path,
|
|
model: str | None = None,
|
|
max_iterations: int = 40,
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 4096,
|
|
reasoning_effort: str | None = None,
|
|
context_window_tokens: int = 65_536,
|
|
brave_api_key: str | None = None,
|
|
web_proxy: str | None = None,
|
|
exec_config: ExecToolConfig | None = None,
|
|
cron_service: CronService | None = None,
|
|
restrict_to_workspace: bool = False,
|
|
session_manager: SessionManager | None = None,
|
|
mcp_servers: dict | None = None,
|
|
channels_config: ChannelsConfig | None = None,
|
|
):
|
|
from nanobot.config.schema import ExecToolConfig
|
|
self.bus = bus
|
|
self.channels_config = channels_config
|
|
self.provider = provider
|
|
self.workspace = workspace
|
|
self.model = model or provider.get_default_model()
|
|
self.max_iterations = max_iterations
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.reasoning_effort = reasoning_effort
|
|
self.context_window_tokens = context_window_tokens
|
|
self.brave_api_key = brave_api_key
|
|
self.web_proxy = web_proxy
|
|
self.exec_config = exec_config or ExecToolConfig()
|
|
self.cron_service = cron_service
|
|
self.restrict_to_workspace = restrict_to_workspace
|
|
|
|
self.context = ContextBuilder(workspace)
|
|
self.sessions = session_manager or SessionManager(workspace)
|
|
self.tools = ToolRegistry()
|
|
self.subagents = SubagentManager(
|
|
provider=provider,
|
|
workspace=workspace,
|
|
bus=bus,
|
|
model=self.model,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
reasoning_effort=reasoning_effort,
|
|
brave_api_key=brave_api_key,
|
|
web_proxy=web_proxy,
|
|
exec_config=self.exec_config,
|
|
restrict_to_workspace=restrict_to_workspace,
|
|
)
|
|
|
|
self._running = False
|
|
self._mcp_servers = mcp_servers or {}
|
|
self._mcp_stack: AsyncExitStack | None = None
|
|
self._mcp_connected = False
|
|
self._mcp_connecting = False
|
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
|
self._processing_lock = asyncio.Lock()
|
|
self.memory_consolidator = MemoryConsolidator(
|
|
workspace=workspace,
|
|
provider=provider,
|
|
model=self.model,
|
|
sessions=self.sessions,
|
|
context_window_tokens=context_window_tokens,
|
|
build_messages=self.context.build_messages,
|
|
get_tool_definitions=self.tools.get_definitions,
|
|
)
|
|
self._register_default_tools()
|
|
|
|
def _register_default_tools(self) -> None:
|
|
"""Register the default set of tools."""
|
|
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
|
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
|
self.tools.register(ExecTool(
|
|
working_dir=str(self.workspace),
|
|
timeout=self.exec_config.timeout,
|
|
restrict_to_workspace=self.restrict_to_workspace,
|
|
path_append=self.exec_config.path_append,
|
|
))
|
|
self.tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
|
|
self.tools.register(WebFetchTool(proxy=self.web_proxy))
|
|
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
|
self.tools.register(SpawnTool(manager=self.subagents))
|
|
if self.cron_service:
|
|
self.tools.register(CronTool(self.cron_service))
|
|
|
|
async def _connect_mcp(self) -> None:
|
|
"""Connect to configured MCP servers (one-time, lazy)."""
|
|
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
|
|
return
|
|
self._mcp_connecting = True
|
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
|
try:
|
|
self._mcp_stack = AsyncExitStack()
|
|
await self._mcp_stack.__aenter__()
|
|
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
|
self._mcp_connected = True
|
|
except Exception as e:
|
|
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
|
if self._mcp_stack:
|
|
try:
|
|
await self._mcp_stack.aclose()
|
|
except Exception:
|
|
pass
|
|
self._mcp_stack = None
|
|
finally:
|
|
self._mcp_connecting = False
|
|
|
|
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
|
"""Update context for all tools that need routing info."""
|
|
for name in ("message", "spawn", "cron"):
|
|
if tool := self.tools.get(name):
|
|
if hasattr(tool, "set_context"):
|
|
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
|
|
|
@staticmethod
|
|
def _strip_think(text: str | None) -> str | None:
|
|
"""Remove <think>…</think> blocks that some models embed in content."""
|
|
if not text:
|
|
return None
|
|
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
|
|
|
|
@staticmethod
|
|
def _tool_hint(tool_calls: list) -> str:
|
|
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
|
def _fmt(tc):
|
|
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
|
|
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
|
if not isinstance(val, str):
|
|
return tc.name
|
|
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
|
return ", ".join(_fmt(tc) for tc in tool_calls)
|
|
|
|
async def _run_agent_loop(
|
|
self,
|
|
initial_messages: list[dict],
|
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
|
) -> tuple[str | None, list[str], list[dict]]:
|
|
"""Run the agent iteration loop."""
|
|
messages = initial_messages
|
|
iteration = 0
|
|
final_content = None
|
|
tools_used: list[str] = []
|
|
|
|
while iteration < self.max_iterations:
|
|
iteration += 1
|
|
|
|
tool_defs = self.tools.get_definitions()
|
|
|
|
response = await self.provider.chat_with_retry(
|
|
messages=messages,
|
|
tools=tool_defs,
|
|
model=self.model,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
reasoning_effort=self.reasoning_effort,
|
|
)
|
|
|
|
if response.has_tool_calls:
|
|
if on_progress:
|
|
thought = self._strip_think(response.content)
|
|
if thought:
|
|
await on_progress(thought)
|
|
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
|
|
|
tool_call_dicts = [
|
|
{
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.name,
|
|
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
|
}
|
|
}
|
|
for tc in response.tool_calls
|
|
]
|
|
messages = self.context.add_assistant_message(
|
|
messages, response.content, tool_call_dicts,
|
|
reasoning_content=response.reasoning_content,
|
|
thinking_blocks=response.thinking_blocks,
|
|
)
|
|
|
|
for tool_call in response.tool_calls:
|
|
tools_used.append(tool_call.name)
|
|
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
|
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
|
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
|
messages = self.context.add_tool_result(
|
|
messages, tool_call.id, tool_call.name, result
|
|
)
|
|
else:
|
|
clean = self._strip_think(response.content)
|
|
# Don't persist error responses to session history — they can
|
|
# poison the context and cause permanent 400 loops (#1303).
|
|
if response.finish_reason == "error":
|
|
logger.error("LLM returned error: {}", (clean or "")[:200])
|
|
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
|
break
|
|
messages = self.context.add_assistant_message(
|
|
messages, clean, reasoning_content=response.reasoning_content,
|
|
thinking_blocks=response.thinking_blocks,
|
|
)
|
|
final_content = clean
|
|
break
|
|
|
|
if final_content is None and iteration >= self.max_iterations:
|
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
|
final_content = (
|
|
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
|
"without completing the task. You can try breaking the task into smaller steps."
|
|
)
|
|
|
|
return final_content, tools_used, messages
|
|
|
|
async def run(self) -> None:
|
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
|
self._running = True
|
|
await self._connect_mcp()
|
|
logger.info("Agent loop started")
|
|
|
|
while self._running:
|
|
try:
|
|
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
if msg.content.strip().lower() == "/stop":
|
|
await self._handle_stop(msg)
|
|
else:
|
|
task = asyncio.create_task(self._dispatch(msg))
|
|
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
|
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
|
|
|
async def _handle_stop(self, msg: InboundMessage) -> None:
|
|
"""Cancel all active tasks and subagents for the session."""
|
|
tasks = self._active_tasks.pop(msg.session_key, [])
|
|
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
|
for t in tasks:
|
|
try:
|
|
await t
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
|
total = cancelled + sub_cancelled
|
|
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
|
))
|
|
|
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
|
"""Process a message under the global lock."""
|
|
async with self._processing_lock:
|
|
try:
|
|
response = await self._process_message(msg)
|
|
if response is not None:
|
|
await self.bus.publish_outbound(response)
|
|
elif msg.channel == "cli":
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
content="", metadata=msg.metadata or {},
|
|
))
|
|
except asyncio.CancelledError:
|
|
logger.info("Task cancelled for session {}", msg.session_key)
|
|
raise
|
|
except Exception:
|
|
logger.exception("Error processing message for session {}", msg.session_key)
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
content="Sorry, I encountered an error.",
|
|
))
|
|
|
|
async def close_mcp(self) -> None:
|
|
"""Close MCP connections."""
|
|
if self._mcp_stack:
|
|
try:
|
|
await self._mcp_stack.aclose()
|
|
except (RuntimeError, BaseExceptionGroup):
|
|
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
|
self._mcp_stack = None
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the agent loop."""
|
|
self._running = False
|
|
logger.info("Agent loop stopping")
|
|
|
|
async def _process_message(
|
|
self,
|
|
msg: InboundMessage,
|
|
session_key: str | None = None,
|
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> OutboundMessage | None:
|
|
"""Process a single inbound message and return the response."""
|
|
# System messages: parse origin from chat_id ("channel:chat_id")
|
|
if msg.channel == "system":
|
|
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
|
else ("cli", msg.chat_id))
|
|
logger.info("Processing system message from {}", msg.sender_id)
|
|
key = f"{channel}:{chat_id}"
|
|
session = self.sessions.get_or_create(key)
|
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
|
history = session.get_history(max_messages=0)
|
|
messages = self.context.build_messages(
|
|
history=history,
|
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
|
)
|
|
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
|
self._save_turn(session, all_msgs, 1 + len(history))
|
|
self.sessions.save(session)
|
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
|
content=final_content or "Background task completed.")
|
|
|
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
|
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
|
|
key = session_key or msg.session_key
|
|
session = self.sessions.get_or_create(key)
|
|
|
|
# Slash commands
|
|
cmd = msg.content.strip().lower()
|
|
if cmd == "/new":
|
|
try:
|
|
if not await self.memory_consolidator.archive_unconsolidated(session):
|
|
return OutboundMessage(
|
|
channel=msg.channel,
|
|
chat_id=msg.chat_id,
|
|
content="Memory archival failed, session not cleared. Please try again.",
|
|
)
|
|
except Exception:
|
|
logger.exception("/new archival failed for {}", session.key)
|
|
return OutboundMessage(
|
|
channel=msg.channel,
|
|
chat_id=msg.chat_id,
|
|
content="Memory archival failed, session not cleared. Please try again.",
|
|
)
|
|
|
|
session.clear()
|
|
self.sessions.save(session)
|
|
self.sessions.invalidate(session.key)
|
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
|
content="New session started.")
|
|
if cmd == "/help":
|
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
|
|
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
|
|
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
|
if message_tool := self.tools.get("message"):
|
|
if isinstance(message_tool, MessageTool):
|
|
message_tool.start_turn()
|
|
|
|
history = session.get_history(max_messages=0)
|
|
initial_messages = self.context.build_messages(
|
|
history=history,
|
|
current_message=msg.content,
|
|
media=msg.media if msg.media else None,
|
|
channel=msg.channel, chat_id=msg.chat_id,
|
|
)
|
|
|
|
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
|
meta = dict(msg.metadata or {})
|
|
meta["_progress"] = True
|
|
meta["_tool_hint"] = tool_hint
|
|
await self.bus.publish_outbound(OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
|
))
|
|
|
|
final_content, _, all_msgs = await self._run_agent_loop(
|
|
initial_messages, on_progress=on_progress or _bus_progress,
|
|
)
|
|
|
|
if final_content is None:
|
|
final_content = "I've completed processing but have no response to give."
|
|
|
|
self._save_turn(session, all_msgs, 1 + len(history))
|
|
self.sessions.save(session)
|
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
|
|
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
|
return None
|
|
|
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
|
return OutboundMessage(
|
|
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
|
metadata=msg.metadata or {},
|
|
)
|
|
|
|
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
|
"""Save new-turn messages into session, truncating large tool results."""
|
|
from datetime import datetime
|
|
for m in messages[skip:]:
|
|
entry = dict(m)
|
|
role, content = entry.get("role"), entry.get("content")
|
|
if role == "assistant" and not content and not entry.get("tool_calls"):
|
|
continue # skip empty assistant messages — they poison session context
|
|
if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
|
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
|
elif role == "user":
|
|
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
|
# Strip the runtime-context prefix, keep only the user text.
|
|
parts = content.split("\n\n", 1)
|
|
if len(parts) > 1 and parts[1].strip():
|
|
entry["content"] = parts[1]
|
|
else:
|
|
continue
|
|
if isinstance(content, list):
|
|
filtered = []
|
|
for c in content:
|
|
if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
|
continue # Strip runtime context from multimodal messages
|
|
if (c.get("type") == "image_url"
|
|
and c.get("image_url", {}).get("url", "").startswith("data:image/")):
|
|
filtered.append({"type": "text", "text": "[image]"})
|
|
else:
|
|
filtered.append(c)
|
|
if not filtered:
|
|
continue
|
|
entry["content"] = filtered
|
|
entry.setdefault("timestamp", datetime.now().isoformat())
|
|
session.messages.append(entry)
|
|
session.updated_at = datetime.now()
|
|
|
|
async def process_direct(
|
|
self,
|
|
content: str,
|
|
session_key: str = "cli:direct",
|
|
channel: str = "cli",
|
|
chat_id: str = "direct",
|
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> str:
|
|
"""Process a message directly (for CLI or cron usage)."""
|
|
await self._connect_mcp()
|
|
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
|
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
|
return response.content if response else ""
|