feat(agent): replace global lock with per-session locks for concurrent dispatch
Replace the single _processing_lock (asyncio.Lock) with per-session locks so that different sessions can process LLM requests concurrently, while messages within the same session remain serialised. An optional global concurrency cap is available via the NANOBOT_MAX_CONCURRENT_REQUESTS env var (default 3, <=0 for unlimited). Also re-binds tool context before each tool execution round to prevent concurrent sessions from clobbering each other's routing info. Tested in production and manually reviewed. (cherry picked from commit c397bb4229e8c3b7f99acea7ffe4bea15e73e957)
This commit is contained in:
@@ -5,8 +5,9 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
from contextlib import AsyncExitStack
|
||||
from contextlib import AsyncExitStack, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
@@ -103,7 +104,12 @@ class AgentLoop:
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._background_tasks: list[asyncio.Task] = []
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
||||
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
||||
self._concurrency_gate: asyncio.Semaphore | None = (
|
||||
asyncio.Semaphore(_max) if _max > 0 else None
|
||||
)
|
||||
self.memory_consolidator = MemoryConsolidator(
|
||||
workspace=workspace,
|
||||
provider=provider,
|
||||
@@ -193,6 +199,10 @@ class AgentLoop:
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
@@ -270,11 +280,27 @@ class AgentLoop:
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
for tc in response.tool_calls:
|
||||
tools_used.append(tc.name)
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
|
||||
# Re-bind tool context right before execution so that
|
||||
# concurrent sessions don't clobber each other's routing.
|
||||
self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
# Execute all tool calls concurrently — the LLM batches
|
||||
# independent calls in a single response on purpose.
|
||||
# return_exceptions=True ensures all results are collected
|
||||
# even if one tool is cancelled or raises BaseException.
|
||||
results = await asyncio.gather(*(
|
||||
self.tools.execute(tc.name, tc.arguments)
|
||||
for tc in response.tool_calls
|
||||
), return_exceptions=True)
|
||||
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
if isinstance(result, BaseException):
|
||||
result = f"Error: {type(result).__name__}: {result}"
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
@@ -337,8 +363,10 @@ class AgentLoop:
|
||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message under the global lock."""
|
||||
async with self._processing_lock:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
@@ -422,7 +450,10 @@ class AgentLoop:
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
messages, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||
@@ -469,6 +500,8 @@ class AgentLoop:
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
|
||||
Reference in New Issue
Block a user