fix: parallel subagent cancellation + register task before lock

- cancel_by_session: use asyncio.gather for parallel cancellation
  instead of sequential await per task
- _dispatch: register in _active_tasks before acquiring lock so /stop
  can find queued tasks (synced from #1179)
This commit is contained in:
coldxiangyu
2026-02-25 18:21:46 +08:00
parent 2466b8b843
commit 4768b9a09d
2 changed files with 32 additions and 29 deletions

View File

@@ -304,30 +304,35 @@ class AgentLoop:
)) ))
async def _dispatch(self, msg: InboundMessage) -> None: async def _dispatch(self, msg: InboundMessage) -> None:
"""Dispatch a message for processing under the global lock.""" """Dispatch a message for processing under the global lock.
async with self._processing_lock:
self._active_tasks[msg.session_key] = asyncio.current_task() # type: ignore[arg-type] The task is registered in _active_tasks *before* acquiring the lock
try: so that /stop can find (and cancel) tasks that are still queued.
response = await self._process_message(msg) """
if response is not None: self._active_tasks[msg.session_key] = asyncio.current_task() # type: ignore[arg-type]
await self.bus.publish_outbound(response) try:
elif msg.channel == "cli": async with self._processing_lock:
try:
response = await self._process_message(msg)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
))
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
# Response already sent by _handle_immediate_command
except Exception as e:
logger.error("Error processing message: {}", e)
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, channel=msg.channel,
content="", metadata=msg.metadata or {}, chat_id=msg.chat_id,
content=f"Sorry, I encountered an error: {str(e)}"
)) ))
except asyncio.CancelledError: finally:
logger.info("Task cancelled for session {}", msg.session_key) self._active_tasks.pop(msg.session_key, None)
# Response already sent by _handle_immediate_command
except Exception as e:
logger.error("Error processing message: {}", e)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=f"Sorry, I encountered an error: {str(e)}"
))
finally:
self._active_tasks.pop(msg.session_key, None)
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
"""Close MCP connections.""" """Close MCP connections."""

View File

@@ -268,17 +268,15 @@ When you have completed the task, provide a clear summary of your findings or ac
async def cancel_by_session(self, session_key: str) -> int: async def cancel_by_session(self, session_key: str) -> int:
"""Cancel all subagents spawned under the given session. Returns count cancelled.""" """Cancel all subagents spawned under the given session. Returns count cancelled."""
task_ids = list(self._session_tasks.get(session_key, [])) task_ids = list(self._session_tasks.get(session_key, []))
cancelled = 0 to_cancel: list[asyncio.Task] = []
for tid in task_ids: for tid in task_ids:
t = self._running_tasks.get(tid) t = self._running_tasks.get(tid)
if t and not t.done(): if t and not t.done():
t.cancel() t.cancel()
try: to_cancel.append(t)
await t if to_cancel:
except (asyncio.CancelledError, Exception): await asyncio.gather(*to_cancel, return_exceptions=True)
pass return len(to_cancel)
cancelled += 1
return cancelled
def get_running_count(self) -> int: def get_running_count(self) -> int:
"""Return the number of currently running subagents.""" """Return the number of currently running subagents."""