feat: /stop cancels spawned subagents via session tracking
- SubagentManager tracks _session_tasks: session_key -> {task_id, ...}
- cancel_by_session() cancels all subagents for a session
- SpawnTool passes session_key through to SubagentManager
- /stop response reports subagent cancellation count
- Cleanup callback removes from both _running_tasks and _session_tasks
Builds on #1179
This commit is contained in:
@@ -278,15 +278,24 @@ class AgentLoop:
|
|||||||
"""Handle a command that must be processed while the agent may be busy."""
|
"""Handle a command that must be processed while the agent may be busy."""
|
||||||
if cmd == "/stop":
|
if cmd == "/stop":
|
||||||
task = self._active_tasks.get(msg.session_key)
|
task = self._active_tasks.get(msg.session_key)
|
||||||
|
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||||
if task and not task.done():
|
if task and not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
try:
|
try:
|
||||||
await task
|
await task
|
||||||
except (asyncio.CancelledError, Exception):
|
except (asyncio.CancelledError, Exception):
|
||||||
pass
|
pass
|
||||||
|
parts = ["⏹ Task stopped."]
|
||||||
|
if sub_cancelled:
|
||||||
|
parts.append(f"Also stopped {sub_cancelled} background task(s).")
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="⏹ Task stopped.",
|
content=" ".join(parts),
|
||||||
|
))
|
||||||
|
elif sub_cancelled:
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content=f"⏹ Stopped {sub_cancelled} background task(s).",
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class SubagentManager:
|
|||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
|
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
self,
|
self,
|
||||||
@@ -56,6 +57,7 @@ class SubagentManager:
|
|||||||
label: str | None = None,
|
label: str | None = None,
|
||||||
origin_channel: str = "cli",
|
origin_channel: str = "cli",
|
||||||
origin_chat_id: str = "direct",
|
origin_chat_id: str = "direct",
|
||||||
|
session_key: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Spawn a subagent to execute a task in the background.
|
Spawn a subagent to execute a task in the background.
|
||||||
@@ -82,9 +84,20 @@ class SubagentManager:
|
|||||||
self._run_subagent(task_id, task, display_label, origin)
|
self._run_subagent(task_id, task, display_label, origin)
|
||||||
)
|
)
|
||||||
self._running_tasks[task_id] = bg_task
|
self._running_tasks[task_id] = bg_task
|
||||||
|
|
||||||
# Cleanup when done
|
if session_key:
|
||||||
bg_task.add_done_callback(lambda _: self._running_tasks.pop(task_id, None))
|
self._session_tasks.setdefault(session_key, set()).add(task_id)
|
||||||
|
|
||||||
|
def _cleanup(_: asyncio.Task) -> None:
|
||||||
|
self._running_tasks.pop(task_id, None)
|
||||||
|
if session_key:
|
||||||
|
ids = self._session_tasks.get(session_key)
|
||||||
|
if ids:
|
||||||
|
ids.discard(task_id)
|
||||||
|
if not ids:
|
||||||
|
self._session_tasks.pop(session_key, None)
|
||||||
|
|
||||||
|
bg_task.add_done_callback(_cleanup)
|
||||||
|
|
||||||
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
||||||
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
||||||
@@ -252,6 +265,21 @@ Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed
|
|||||||
|
|
||||||
When you have completed the task, provide a clear summary of your findings or actions."""
|
When you have completed the task, provide a clear summary of your findings or actions."""
|
||||||
|
|
||||||
|
async def cancel_by_session(self, session_key: str) -> int:
|
||||||
|
"""Cancel all subagents spawned under the given session. Returns count cancelled."""
|
||||||
|
task_ids = list(self._session_tasks.get(session_key, []))
|
||||||
|
cancelled = 0
|
||||||
|
for tid in task_ids:
|
||||||
|
t = self._running_tasks.get(tid)
|
||||||
|
if t and not t.done():
|
||||||
|
t.cancel()
|
||||||
|
try:
|
||||||
|
await t
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
cancelled += 1
|
||||||
|
return cancelled
|
||||||
|
|
||||||
def get_running_count(self) -> int:
|
def get_running_count(self) -> int:
|
||||||
"""Return the number of currently running subagents."""
|
"""Return the number of currently running subagents."""
|
||||||
return len(self._running_tasks)
|
return len(self._running_tasks)
|
||||||
|
|||||||
@@ -15,11 +15,13 @@ class SpawnTool(Tool):
|
|||||||
self._manager = manager
|
self._manager = manager
|
||||||
self._origin_channel = "cli"
|
self._origin_channel = "cli"
|
||||||
self._origin_chat_id = "direct"
|
self._origin_chat_id = "direct"
|
||||||
|
self._session_key = "cli:direct"
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str) -> None:
|
def set_context(self, channel: str, chat_id: str) -> None:
|
||||||
"""Set the origin context for subagent announcements."""
|
"""Set the origin context for subagent announcements."""
|
||||||
self._origin_channel = channel
|
self._origin_channel = channel
|
||||||
self._origin_chat_id = chat_id
|
self._origin_chat_id = chat_id
|
||||||
|
self._session_key = f"{channel}:{chat_id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -57,4 +59,5 @@ class SpawnTool(Tool):
|
|||||||
label=label,
|
label=label,
|
||||||
origin_channel=self._origin_channel,
|
origin_channel=self._origin_channel,
|
||||||
origin_chat_id=self._origin_chat_id,
|
origin_chat_id=self._origin_chat_id,
|
||||||
|
session_key=self._session_key,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -218,3 +218,101 @@ class TestTaskCancellation:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubagentCancellation:
|
||||||
|
"""Tests for /stop cancelling subagents spawned under a session."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_by_session(self):
|
||||||
|
"""cancel_by_session cancels all tasks for that session."""
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
|
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow_subagent():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow_subagent())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
tid = "sub-1"
|
||||||
|
session_key = "test:c1"
|
||||||
|
mgr._running_tasks[tid] = task
|
||||||
|
mgr._session_tasks[session_key] = {tid}
|
||||||
|
|
||||||
|
count = await mgr.cancel_by_session(session_key)
|
||||||
|
assert count == 1
|
||||||
|
assert cancelled.is_set()
|
||||||
|
assert task.cancelled()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_by_session_no_tasks(self):
|
||||||
|
"""cancel_by_session returns 0 when no subagents for session."""
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||||
|
|
||||||
|
count = await mgr.cancel_by_session("nonexistent:session")
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_subagents_via_loop(self):
|
||||||
|
"""/stop on AgentLoop also cancels subagents for that session."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
workspace = MagicMock()
|
||||||
|
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager"):
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||||
|
|
||||||
|
# Replace subagents with a real SubagentManager
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
loop.subagents = SubagentManager(
|
||||||
|
provider=provider, workspace=MagicMock(), bus=bus
|
||||||
|
)
|
||||||
|
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
session_key = "test:c1"
|
||||||
|
|
||||||
|
async def slow_sub():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
task = asyncio.create_task(slow_sub())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
loop.subagents._running_tasks["sub-1"] = task
|
||||||
|
loop.subagents._session_tasks[session_key] = {"sub-1"}
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="test", sender_id="u1", chat_id="c1", content="/stop"
|
||||||
|
)
|
||||||
|
await loop._handle_immediate_command("/stop", msg)
|
||||||
|
|
||||||
|
assert cancelled.is_set()
|
||||||
|
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||||
|
assert "stopped" in out.content.lower() or "background" in out.content.lower()
|
||||||
|
|||||||
Reference in New Issue
Block a user