Files
nanobot/nanobot/agent/loop.py

483 lines
21 KiB
Python

"""Agent loop: the core processing engine."""
from __future__ import annotations
import asyncio
import json
import re
from contextlib import AsyncExitStack
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager
if TYPE_CHECKING:
from nanobot.config.schema import ChannelsConfig, ExecToolConfig
from nanobot.cron.service import CronService
class AgentLoop:
"""
The agent loop is the core processing engine.
It:
1. Receives messages from the bus
2. Builds context with history, memory, skills
3. Calls the LLM
4. Executes tool calls
5. Sends responses back
"""
_TOOL_RESULT_MAX_CHARS = 500
def __init__(
self,
bus: MessageBus,
provider: LLMProvider,
workspace: Path,
model: str | None = None,
max_iterations: int = 40,
temperature: float = 0.1,
max_tokens: int = 4096,
memory_window: int = 100,
brave_api_key: str | None = None,
exec_config: ExecToolConfig | None = None,
cron_service: CronService | None = None,
restrict_to_workspace: bool = False,
session_manager: SessionManager | None = None,
mcp_servers: dict | None = None,
channels_config: ChannelsConfig | None = None,
):
from nanobot.config.schema import ExecToolConfig
self.bus = bus
self.channels_config = channels_config
self.provider = provider
self.workspace = workspace
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
self.temperature = temperature
self.max_tokens = max_tokens
self.memory_window = memory_window
self.brave_api_key = brave_api_key
self.exec_config = exec_config or ExecToolConfig()
self.cron_service = cron_service
self.restrict_to_workspace = restrict_to_workspace
self.context = ContextBuilder(workspace)
self.sessions = session_manager or SessionManager(workspace)
self.tools = ToolRegistry()
self.subagents = SubagentManager(
provider=provider,
workspace=workspace,
bus=bus,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
brave_api_key=brave_api_key,
exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace,
)
self._running = False
self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False
self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
self._consolidation_locks: dict[str, asyncio.Lock] = {}
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._processing_lock = asyncio.Lock()
self._register_default_tools()
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = self.workspace if self.restrict_to_workspace else None
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
))
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
self.tools.register(WebFetchTool())
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(SpawnTool(manager=self.subagents))
if self.cron_service:
self.tools.register(CronTool(self.cron_service))
async def _connect_mcp(self) -> None:
"""Connect to configured MCP servers (one-time, lazy)."""
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
return
self._mcp_connecting = True
from nanobot.agent.tools.mcp import connect_mcp_servers
try:
self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True
except Exception as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack:
try:
await self._mcp_stack.aclose()
except Exception:
pass
self._mcp_stack = None
finally:
self._mcp_connecting = False
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Update context for all tools that need routing info."""
for name in ("message", "spawn", "cron"):
if tool := self.tools.get(name):
if hasattr(tool, "set_context"):
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
@staticmethod
def _strip_think(text: str | None) -> str | None:
"""Remove <think>…</think> blocks that some models embed in content."""
if not text:
return None
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
@staticmethod
def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc):
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
if not isinstance(val, str):
return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
return ", ".join(_fmt(tc) for tc in tool_calls)
async def _run_agent_loop(
self,
initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None,
) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
messages = initial_messages
iteration = 0
final_content = None
tools_used: list[str] = []
while iteration < self.max_iterations:
iteration += 1
response = await self.provider.chat(
messages=messages,
tools=self.tools.get_definitions(),
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
if response.has_tool_calls:
if on_progress:
clean = self._strip_think(response.content)
if clean:
await on_progress(clean)
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
}
}
for tc in response.tool_calls
]
messages = self.context.add_assistant_message(
messages, response.content, tool_call_dicts,
reasoning_content=response.reasoning_content,
)
for tool_call in response.tool_calls:
tools_used.append(tool_call.name)
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
result = await self.tools.execute(tool_call.name, tool_call.arguments)
messages = self.context.add_tool_result(
messages, tool_call.id, tool_call.name, result
)
else:
clean = self._strip_think(response.content)
messages = self.context.add_assistant_message(
messages, clean, reasoning_content=response.reasoning_content,
)
final_content = clean
break
if final_content is None and iteration >= self.max_iterations:
logger.warning("Max iterations ({}) reached", self.max_iterations)
final_content = (
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
"without completing the task. You can try breaking the task into smaller steps."
)
return final_content, tools_used, messages
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
self._running = True
await self._connect_mcp()
logger.info("Agent loop started")
while self._running:
try:
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
except asyncio.TimeoutError:
continue
if msg.content.strip().lower() == "/stop":
await self._handle_stop(msg)
else:
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task)
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
async def _handle_stop(self, msg: InboundMessage) -> None:
"""Cancel all active tasks and subagents for the session."""
tasks = self._active_tasks.pop(msg.session_key, [])
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
for t in tasks:
try:
await t
except (asyncio.CancelledError, Exception):
pass
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
total = cancelled + sub_cancelled
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content,
))
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message under the global lock."""
async with self._processing_lock:
try:
response = await self._process_message(msg)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
))
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
async def close_mcp(self) -> None:
"""Close MCP connections."""
if self._mcp_stack:
try:
await self._mcp_stack.aclose()
except (RuntimeError, BaseExceptionGroup):
pass # MCP SDK cancel scope cleanup is noisy but harmless
self._mcp_stack = None
def stop(self) -> None:
"""Stop the agent loop."""
self._running = False
logger.info("Agent loop stopping")
async def _process_message(
self,
msg: InboundMessage,
session_key: str | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system":
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
else ("cli", msg.chat_id))
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=self.memory_window)
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
)
final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.")
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
key = session_key or msg.session_key
session = self.sessions.get_or_create(key)
# Slash commands
cmd = msg.content.strip().lower()
if cmd == "/new":
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
self._consolidating.add(session.key)
try:
async with lock:
snapshot = session.messages[session.last_consolidated:]
if snapshot:
temp = Session(key=session.key)
temp.messages = list(snapshot)
if not await self._consolidate_memory(temp, archive_all=True):
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
except Exception:
logger.exception("/new archival failed for {}", session.key)
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
finally:
self._consolidating.discard(session.key)
if not lock.locked():
self._consolidation_locks.pop(session.key, None)
session.clear()
self.sessions.save(session)
self.sessions.invalidate(session.key)
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started.")
if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
unconsolidated = len(session.messages) - session.last_consolidated
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
self._consolidating.add(session.key)
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
async def _consolidate_and_unlock():
try:
async with lock:
await self._consolidate_memory(session)
finally:
self._consolidating.discard(session.key)
if not lock.locked():
self._consolidation_locks.pop(session.key, None)
_task = asyncio.current_task()
if _task is not None:
self._consolidation_tasks.discard(_task)
_task = asyncio.create_task(_consolidate_and_unlock())
self._consolidation_tasks.add(_task)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
history = session.get_history(max_messages=self.memory_window)
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id,
)
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {})
meta["_progress"] = True
meta["_tool_hint"] = tool_hint
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
final_content, _, all_msgs = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress,
)
if final_content is None:
final_content = "I've completed processing but have no response to give."
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
metadata=msg.metadata or {},
)
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
for m in messages[skip:]:
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
if entry.get("role") == "tool" and isinstance(entry.get("content"), str):
content = entry["content"]
if len(content) > self._TOOL_RESULT_MAX_CHARS:
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
if entry.get("role") == "user" and isinstance(entry.get("content"), list):
entry["content"] = [
{"type": "text", "text": "[image]"} if (
c.get("type") == "image_url"
and c.get("image_url", {}).get("url", "").startswith("data:image/")
) else c
for c in entry["content"]
]
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
session.updated_at = datetime.now()
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
return await MemoryStore(self.workspace).consolidate(
session, self.provider, self.model,
archive_all=archive_all, memory_window=self.memory_window,
)
async def process_direct(
self,
content: str,
session_key: str = "cli:direct",
channel: str = "cli",
chat_id: str = "direct",
on_progress: Callable[[str], Awaitable[None]] | None = None,
) -> str:
"""Process a message directly (for CLI or cron usage)."""
await self._connect_mcp()
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
return response.content if response else ""