fix: parallel subagent cancellation + register task before lock
- cancel_by_session: use asyncio.gather for parallel cancellation instead of sequential await per task - _dispatch: register in _active_tasks before acquiring lock so /stop can find queued tasks (synced from #1179)
This commit is contained in:
@@ -304,30 +304,35 @@ class AgentLoop:
|
|||||||
))
|
))
|
||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Dispatch a message for processing under the global lock."""
|
"""Dispatch a message for processing under the global lock.
|
||||||
async with self._processing_lock:
|
|
||||||
self._active_tasks[msg.session_key] = asyncio.current_task() # type: ignore[arg-type]
|
The task is registered in _active_tasks *before* acquiring the lock
|
||||||
try:
|
so that /stop can find (and cancel) tasks that are still queued.
|
||||||
response = await self._process_message(msg)
|
"""
|
||||||
if response is not None:
|
self._active_tasks[msg.session_key] = asyncio.current_task() # type: ignore[arg-type]
|
||||||
await self.bus.publish_outbound(response)
|
try:
|
||||||
elif msg.channel == "cli":
|
async with self._processing_lock:
|
||||||
|
try:
|
||||||
|
response = await self._process_message(msg)
|
||||||
|
if response is not None:
|
||||||
|
await self.bus.publish_outbound(response)
|
||||||
|
elif msg.channel == "cli":
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="", metadata=msg.metadata or {},
|
||||||
|
))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Task cancelled for session {}", msg.session_key)
|
||||||
|
# Response already sent by _handle_immediate_command
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error processing message: {}", e)
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
content="", metadata=msg.metadata or {},
|
chat_id=msg.chat_id,
|
||||||
|
content=f"Sorry, I encountered an error: {str(e)}"
|
||||||
))
|
))
|
||||||
except asyncio.CancelledError:
|
finally:
|
||||||
logger.info("Task cancelled for session {}", msg.session_key)
|
self._active_tasks.pop(msg.session_key, None)
|
||||||
# Response already sent by _handle_immediate_command
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error processing message: {}", e)
|
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
|
||||||
channel=msg.channel,
|
|
||||||
chat_id=msg.chat_id,
|
|
||||||
content=f"Sorry, I encountered an error: {str(e)}"
|
|
||||||
))
|
|
||||||
finally:
|
|
||||||
self._active_tasks.pop(msg.session_key, None)
|
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Close MCP connections."""
|
"""Close MCP connections."""
|
||||||
|
|||||||
@@ -268,17 +268,15 @@ When you have completed the task, provide a clear summary of your findings or ac
|
|||||||
async def cancel_by_session(self, session_key: str) -> int:
|
async def cancel_by_session(self, session_key: str) -> int:
|
||||||
"""Cancel all subagents spawned under the given session. Returns count cancelled."""
|
"""Cancel all subagents spawned under the given session. Returns count cancelled."""
|
||||||
task_ids = list(self._session_tasks.get(session_key, []))
|
task_ids = list(self._session_tasks.get(session_key, []))
|
||||||
cancelled = 0
|
to_cancel: list[asyncio.Task] = []
|
||||||
for tid in task_ids:
|
for tid in task_ids:
|
||||||
t = self._running_tasks.get(tid)
|
t = self._running_tasks.get(tid)
|
||||||
if t and not t.done():
|
if t and not t.done():
|
||||||
t.cancel()
|
t.cancel()
|
||||||
try:
|
to_cancel.append(t)
|
||||||
await t
|
if to_cancel:
|
||||||
except (asyncio.CancelledError, Exception):
|
await asyncio.gather(*to_cancel, return_exceptions=True)
|
||||||
pass
|
return len(to_cancel)
|
||||||
cancelled += 1
|
|
||||||
return cancelled
|
|
||||||
|
|
||||||
def get_running_count(self) -> int:
|
def get_running_count(self) -> int:
|
||||||
"""Return the number of currently running subagents."""
|
"""Return the number of currently running subagents."""
|
||||||
|
|||||||
Reference in New Issue
Block a user