feat(agent): replace global lock with per-session locks for concurrent dispatch

Replace the single _processing_lock (asyncio.Lock) with per-session locks
so that different sessions can process LLM requests concurrently, while
messages within the same session remain serialised.

An optional global concurrency cap is available via the
NANOBOT_MAX_CONCURRENT_REQUESTS env var (default 3, <=0 for unlimited).

Also re-binds tool context before each tool execution round to prevent
concurrent sessions from clobbering each other's routing info.

Tested in production and manually reviewed.

(cherry picked from commit c397bb4229e8c3b7f99acea7ffe4bea15e73e957)
This commit is contained in:
gem12
2026-03-21 22:55:10 +08:00
committed by Xubin Ren
parent 20494a2c52
commit 97fe9ab7d4

View File

@@ -5,8 +5,9 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import re import re
import os
import time import time
from contextlib import AsyncExitStack from contextlib import AsyncExitStack, nullcontext
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable from typing import TYPE_CHECKING, Any, Awaitable, Callable
@@ -103,7 +104,12 @@ class AgentLoop:
self._mcp_connecting = False self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._background_tasks: list[asyncio.Task] = [] self._background_tasks: list[asyncio.Task] = []
self._processing_lock = asyncio.Lock() self._session_locks: dict[str, asyncio.Lock] = {}
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
self._concurrency_gate: asyncio.Semaphore | None = (
asyncio.Semaphore(_max) if _max > 0 else None
)
self.memory_consolidator = MemoryConsolidator( self.memory_consolidator = MemoryConsolidator(
workspace=workspace, workspace=workspace,
provider=provider, provider=provider,
@@ -193,6 +199,10 @@ class AgentLoop:
on_progress: Callable[..., Awaitable[None]] | None = None, on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None,
*,
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> tuple[str | None, list[str], list[dict]]: ) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop. """Run the agent iteration loop.
@@ -270,11 +280,27 @@ class AgentLoop:
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
) )
for tool_call in response.tool_calls: for tc in response.tool_calls:
tools_used.append(tool_call.name) tools_used.append(tc.name)
args_str = json.dumps(tool_call.arguments, ensure_ascii=False) args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tool_call.name, args_str[:200]) logger.info("Tool call: {}({})", tc.name, args_str[:200])
result = await self.tools.execute(tool_call.name, tool_call.arguments)
# Re-bind tool context right before execution so that
# concurrent sessions don't clobber each other's routing.
self._set_tool_context(channel, chat_id, message_id)
# Execute all tool calls concurrently — the LLM batches
# independent calls in a single response on purpose.
# return_exceptions=True ensures all results are collected
# even if one tool is cancelled or raises BaseException.
results = await asyncio.gather(*(
self.tools.execute(tc.name, tc.arguments)
for tc in response.tool_calls
), return_exceptions=True)
for tool_call, result in zip(response.tool_calls, results):
if isinstance(result, BaseException):
result = f"Error: {type(result).__name__}: {result}"
messages = self.context.add_tool_result( messages = self.context.add_tool_result(
messages, tool_call.id, tool_call.name, result messages, tool_call.id, tool_call.name, result
) )
@@ -337,8 +363,10 @@ class AgentLoop:
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
async def _dispatch(self, msg: InboundMessage) -> None: async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message under the global lock.""" """Process a message: per-session serial, cross-session concurrent."""
async with self._processing_lock: lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()
async with lock, gate:
try: try:
on_stream = on_stream_end = None on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"): if msg.metadata.get("_wants_stream"):
@@ -422,7 +450,10 @@ class AgentLoop:
current_message=msg.content, channel=channel, chat_id=chat_id, current_message=msg.content, channel=channel, chat_id=chat_id,
current_role=current_role, current_role=current_role,
) )
final_content, _, all_msgs = await self._run_agent_loop(messages) final_content, _, all_msgs = await self._run_agent_loop(
messages, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
@@ -469,6 +500,8 @@ class AgentLoop:
on_progress=on_progress or _bus_progress, on_progress=on_progress or _bus_progress,
on_stream=on_stream, on_stream=on_stream,
on_stream_end=on_stream_end, on_stream_end=on_stream_end,
channel=msg.channel, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
) )
if final_content is None: if final_content is None: