feat: extensible command system + task-based dispatch with /stop
- Add commands.py with CommandDef registry, parse_command(), get_help_text() - Refactor run() to dispatch messages as asyncio tasks (non-blocking) - /stop is an 'immediate' command: handled inline, cancels active task - Global processing lock serializes message handling (safe for shared state) - _pending_tasks set prevents GC of dispatched tasks before lock acquisition - _dispatch() registers/clears active tasks, catches CancelledError gracefully - /help now auto-generated from COMMANDS registry Closes #849
This commit is contained in:
59
nanobot/agent/commands.py
Normal file
59
nanobot/agent/commands.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Command definitions and dispatch for the agent loop.
|
||||||
|
|
||||||
|
Commands are slash-prefixed messages (e.g. /stop, /new, /help) that are
|
||||||
|
handled specially — either immediately in the run() loop or inside
|
||||||
|
_process_message before the LLM is called.
|
||||||
|
|
||||||
|
To add a new command:
|
||||||
|
1. Add a CommandDef to COMMANDS
|
||||||
|
2. If immediate=True, add a handler in AgentLoop._handle_immediate_command
|
||||||
|
3. If immediate=False, add handling in AgentLoop._process_message
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CommandDef:
|
||||||
|
"""Definition of a slash command."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
immediate: bool = False # True = handled in run() loop, bypasses message processing
|
||||||
|
|
||||||
|
|
||||||
|
# Registry of all known commands.
|
||||||
|
# "immediate" commands are handled while the agent may be busy (e.g. /stop).
|
||||||
|
# Non-immediate commands go through normal _process_message flow.
|
||||||
|
COMMANDS: dict[str, CommandDef] = {
|
||||||
|
"/stop": CommandDef("/stop", "Stop the current task", immediate=True),
|
||||||
|
"/new": CommandDef("/new", "Start a new conversation"),
|
||||||
|
"/help": CommandDef("/help", "Show available commands"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_command(text: str) -> str | None:
|
||||||
|
"""Extract a slash command from message text.
|
||||||
|
|
||||||
|
Returns the command string (e.g. "/stop") or None if not a command.
|
||||||
|
"""
|
||||||
|
stripped = text.strip()
|
||||||
|
if not stripped.startswith("/"):
|
||||||
|
return None
|
||||||
|
return stripped.split()[0].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def is_immediate_command(cmd: str) -> bool:
|
||||||
|
"""Check if a command should be handled immediately, bypassing processing."""
|
||||||
|
defn = COMMANDS.get(cmd)
|
||||||
|
return defn.immediate if defn else False
|
||||||
|
|
||||||
|
|
||||||
|
def get_help_text() -> str:
|
||||||
|
"""Generate help text from registered commands."""
|
||||||
|
lines = ["🐈 nanobot commands:"]
|
||||||
|
for defn in COMMANDS.values():
|
||||||
|
lines.append(f"{defn.name} — {defn.description}")
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.commands import get_help_text, is_immediate_command, parse_command
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
@@ -99,6 +100,9 @@ class AgentLoop:
|
|||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||||
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._active_tasks: dict[str, asyncio.Task] = {} # session_key -> running task
|
||||||
|
self._pending_tasks: set[asyncio.Task] = set() # Strong refs until dispatch starts
|
||||||
|
self._processing_lock = asyncio.Lock() # Serialize message processing
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
@@ -238,7 +242,12 @@ class AgentLoop:
|
|||||||
return final_content, tools_used, messages
|
return final_content, tools_used, messages
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, processing messages from the bus."""
|
"""Run the agent loop, processing messages from the bus.
|
||||||
|
|
||||||
|
Regular messages are dispatched as asyncio tasks so the loop stays
|
||||||
|
responsive to immediate commands like /stop. A global processing
|
||||||
|
lock serializes message handling to avoid shared-state races.
|
||||||
|
"""
|
||||||
self._running = True
|
self._running = True
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
logger.info("Agent loop started")
|
logger.info("Agent loop started")
|
||||||
@@ -249,14 +258,58 @@ class AgentLoop:
|
|||||||
self.bus.consume_inbound(),
|
self.bus.consume_inbound(),
|
||||||
timeout=1.0
|
timeout=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Immediate commands (/stop) are handled inline
|
||||||
|
cmd = parse_command(msg.content)
|
||||||
|
if cmd and is_immediate_command(cmd):
|
||||||
|
await self._handle_immediate_command(cmd, msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Regular messages (including non-immediate commands) are
|
||||||
|
# dispatched as tasks so the loop keeps consuming.
|
||||||
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
|
self._pending_tasks.add(task)
|
||||||
|
task.add_done_callback(self._pending_tasks.discard)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def _handle_immediate_command(self, cmd: str, msg: InboundMessage) -> None:
|
||||||
|
"""Handle a command that must be processed while the agent may be busy."""
|
||||||
|
if cmd == "/stop":
|
||||||
|
task = self._active_tasks.get(msg.session_key)
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="⏹ Task stopped.",
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="No active task to stop.",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
|
"""Dispatch a message for processing under the global lock."""
|
||||||
|
async with self._processing_lock:
|
||||||
|
self._active_tasks[msg.session_key] = asyncio.current_task() # type: ignore[arg-type]
|
||||||
try:
|
try:
|
||||||
response = await self._process_message(msg)
|
response = await self._process_message(msg)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
await self.bus.publish_outbound(response)
|
await self.bus.publish_outbound(response)
|
||||||
elif msg.channel == "cli":
|
elif msg.channel == "cli":
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content="", metadata=msg.metadata or {},
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="", metadata=msg.metadata or {},
|
||||||
))
|
))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Task cancelled for session {}", msg.session_key)
|
||||||
|
# Response already sent by _handle_immediate_command
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error processing message: {}", e)
|
logger.error("Error processing message: {}", e)
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
@@ -264,8 +317,8 @@ class AgentLoop:
|
|||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
content=f"Sorry, I encountered an error: {str(e)}"
|
content=f"Sorry, I encountered an error: {str(e)}"
|
||||||
))
|
))
|
||||||
except asyncio.TimeoutError:
|
finally:
|
||||||
continue
|
self._active_tasks.pop(msg.session_key, None)
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Close MCP connections."""
|
"""Close MCP connections."""
|
||||||
@@ -358,7 +411,7 @@ class AgentLoop:
|
|||||||
content="New session started.")
|
content="New session started.")
|
||||||
if cmd == "/help":
|
if cmd == "/help":
|
||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands")
|
content=get_help_text())
|
||||||
|
|
||||||
unconsolidated = len(session.messages) - session.last_consolidated
|
unconsolidated = len(session.messages) - session.last_consolidated
|
||||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||||
|
|||||||
220
tests/test_task_cancel.py
Normal file
220
tests/test_task_cancel.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Tests for the command system and task cancellation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.commands import (
|
||||||
|
COMMANDS,
|
||||||
|
get_help_text,
|
||||||
|
is_immediate_command,
|
||||||
|
parse_command,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# commands.py unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestParseCommand:
|
||||||
|
def test_slash_command(self):
|
||||||
|
assert parse_command("/stop") == "/stop"
|
||||||
|
|
||||||
|
def test_slash_command_with_args(self):
|
||||||
|
assert parse_command("/new some args") == "/new"
|
||||||
|
|
||||||
|
def test_not_a_command(self):
|
||||||
|
assert parse_command("hello world") is None
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert parse_command("") is None
|
||||||
|
|
||||||
|
def test_leading_whitespace(self):
|
||||||
|
assert parse_command(" /help") == "/help"
|
||||||
|
|
||||||
|
def test_uppercase_normalized(self):
|
||||||
|
assert parse_command("/STOP") == "/stop"
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsImmediateCommand:
|
||||||
|
def test_stop_is_immediate(self):
|
||||||
|
assert is_immediate_command("/stop") is True
|
||||||
|
|
||||||
|
def test_new_is_not_immediate(self):
|
||||||
|
assert is_immediate_command("/new") is False
|
||||||
|
|
||||||
|
def test_help_is_not_immediate(self):
|
||||||
|
assert is_immediate_command("/help") is False
|
||||||
|
|
||||||
|
def test_unknown_command(self):
|
||||||
|
assert is_immediate_command("/unknown") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetHelpText:
|
||||||
|
def test_contains_all_commands(self):
|
||||||
|
text = get_help_text()
|
||||||
|
for cmd in COMMANDS:
|
||||||
|
assert cmd in text
|
||||||
|
|
||||||
|
def test_contains_descriptions(self):
|
||||||
|
text = get_help_text()
|
||||||
|
for defn in COMMANDS.values():
|
||||||
|
assert defn.description in text
|
||||||
|
|
||||||
|
def test_starts_with_header(self):
|
||||||
|
assert get_help_text().startswith("🐈")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Task cancellation integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestTaskCancellation:
|
||||||
|
"""Tests for /stop cancelling an active task in AgentLoop."""
|
||||||
|
|
||||||
|
def _make_loop(self):
|
||||||
|
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
workspace = MagicMock()
|
||||||
|
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=workspace,
|
||||||
|
)
|
||||||
|
return loop, bus
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_no_active_task(self):
|
||||||
|
"""'/stop' when nothing is running returns 'No active task'."""
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = self._make_loop()
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="test", sender_id="u1", chat_id="c1", content="/stop"
|
||||||
|
)
|
||||||
|
await loop._handle_immediate_command("/stop", msg)
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "No active task" in out.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_active_task(self):
|
||||||
|
"""'/stop' cancels a running task."""
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = self._make_loop()
|
||||||
|
session_key = "test:c1"
|
||||||
|
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow_task():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow_task())
|
||||||
|
await asyncio.sleep(0) # Let task enter its await
|
||||||
|
loop._active_tasks[session_key] = task
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="test", sender_id="u1", chat_id="c1", content="/stop"
|
||||||
|
)
|
||||||
|
await loop._handle_immediate_command("/stop", msg)
|
||||||
|
|
||||||
|
assert cancelled.is_set()
|
||||||
|
assert task.cancelled()
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "stopped" in out.content.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_registers_and_clears_task(self):
|
||||||
|
"""_dispatch registers the task in _active_tasks and clears it after."""
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
loop, bus = self._make_loop()
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="test", sender_id="u1", chat_id="c1", content="hello"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock _process_message to return a simple response
|
||||||
|
loop._process_message = AsyncMock(
|
||||||
|
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
|
||||||
|
)
|
||||||
|
|
||||||
|
task = asyncio.create_task(loop._dispatch(msg))
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Task should be cleaned up
|
||||||
|
assert msg.session_key not in loop._active_tasks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_handles_cancelled_error(self):
|
||||||
|
"""_dispatch catches CancelledError gracefully."""
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop, bus = self._make_loop()
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="test", sender_id="u1", chat_id="c1", content="hello"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_process(m, **kwargs):
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
loop._process_message = mock_process
|
||||||
|
|
||||||
|
task = asyncio.create_task(loop._dispatch(msg))
|
||||||
|
await asyncio.sleep(0.05) # Let task start
|
||||||
|
|
||||||
|
assert msg.session_key in loop._active_tasks
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Task should be cleaned up even after cancel
|
||||||
|
assert msg.session_key not in loop._active_tasks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_processing_lock_serializes(self):
|
||||||
|
"""Only one message processes at a time due to _processing_lock."""
|
||||||
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
loop, bus = self._make_loop()
|
||||||
|
order = []
|
||||||
|
|
||||||
|
async def mock_process(m, **kwargs):
|
||||||
|
order.append(f"start-{m.content}")
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
order.append(f"end-{m.content}")
|
||||||
|
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
|
||||||
|
|
||||||
|
loop._process_message = mock_process
|
||||||
|
|
||||||
|
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
|
||||||
|
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
|
||||||
|
|
||||||
|
t1 = asyncio.create_task(loop._dispatch(msg1))
|
||||||
|
t2 = asyncio.create_task(loop._dispatch(msg2))
|
||||||
|
await asyncio.gather(t1, t2)
|
||||||
|
|
||||||
|
# Should be serialized: start-a, end-a, start-b, end-b
|
||||||
|
assert order == ["start-a", "end-a", "start-b", "end-b"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
Reference in New Issue
Block a user