feat: /stop cancels spawned subagents via session tracking
- SubagentManager tracks _session_tasks: session_key -> {task_id, ...}
- cancel_by_session() cancels all subagents for a session
- SpawnTool passes session_key through to SubagentManager
- /stop response reports subagent cancellation count
- Cleanup callback removes from both _running_tasks and _session_tasks
Builds on #1179
This commit is contained in:
@@ -218,3 +218,101 @@ class TestTaskCancellation:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubagentCancellation:
|
||||
"""Tests for /stop cancelling subagents spawned under a session."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session(self):
|
||||
"""cancel_by_session cancels all tasks for that session."""
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow_subagent():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow_subagent())
|
||||
await asyncio.sleep(0)
|
||||
tid = "sub-1"
|
||||
session_key = "test:c1"
|
||||
mgr._running_tasks[tid] = task
|
||||
mgr._session_tasks[session_key] = {tid}
|
||||
|
||||
count = await mgr.cancel_by_session(session_key)
|
||||
assert count == 1
|
||||
assert cancelled.is_set()
|
||||
assert task.cancelled()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session_no_tasks(self):
|
||||
"""cancel_by_session returns 0 when no subagents for session."""
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
|
||||
count = await mgr.cancel_by_session("nonexistent:session")
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_subagents_via_loop(self):
|
||||
"""/stop on AgentLoop also cancels subagents for that session."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager"):
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
|
||||
# Replace subagents with a real SubagentManager
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
loop.subagents = SubagentManager(
|
||||
provider=provider, workspace=MagicMock(), bus=bus
|
||||
)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
session_key = "test:c1"
|
||||
|
||||
async def slow_sub():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow_sub())
|
||||
await asyncio.sleep(0)
|
||||
loop.subagents._running_tasks["sub-1"] = task
|
||||
loop.subagents._session_tasks[session_key] = {"sub-1"}
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="test", sender_id="u1", chat_id="c1", content="/stop"
|
||||
)
|
||||
await loop._handle_immediate_command("/stop", msg)
|
||||
|
||||
assert cancelled.is_set()
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "stopped" in out.content.lower() or "background" in out.content.lower()
|
||||
|
||||
Reference in New Issue
Block a user