feat(agent): replace global lock with per-session locks for concurrent dispatch
Replace the single _processing_lock (asyncio.Lock) with per-session locks so that different sessions can process LLM requests concurrently, while messages within the same session remain serialised. An optional global concurrency cap is available via the NANOBOT_MAX_CONCURRENT_REQUESTS env var (default 3, <=0 for unlimited). Also re-binds tool context before each tool execution round to prevent concurrent sessions from clobbering each other's routing info. Tested in production and manually reviewed. (cherry picked from commit c397bb4229e8c3b7f99acea7ffe4bea15e73e957)
This commit is contained in:
@@ -5,8 +5,9 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack, nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
|
|
||||||
@@ -103,7 +104,12 @@ class AgentLoop:
|
|||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._background_tasks: list[asyncio.Task] = []
|
self._background_tasks: list[asyncio.Task] = []
|
||||||
self._processing_lock = asyncio.Lock()
|
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
||||||
|
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
||||||
|
self._concurrency_gate: asyncio.Semaphore | None = (
|
||||||
|
asyncio.Semaphore(_max) if _max > 0 else None
|
||||||
|
)
|
||||||
self.memory_consolidator = MemoryConsolidator(
|
self.memory_consolidator = MemoryConsolidator(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@@ -193,6 +199,10 @@ class AgentLoop:
|
|||||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
|
*,
|
||||||
|
channel: str = "cli",
|
||||||
|
chat_id: str = "direct",
|
||||||
|
message_id: str | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
) -> tuple[str | None, list[str], list[dict]]:
|
||||||
"""Run the agent iteration loop.
|
"""Run the agent iteration loop.
|
||||||
|
|
||||||
@@ -270,11 +280,27 @@ class AgentLoop:
|
|||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_call in response.tool_calls:
|
for tc in response.tool_calls:
|
||||||
tools_used.append(tool_call.name)
|
tools_used.append(tc.name)
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
|
||||||
|
# Re-bind tool context right before execution so that
|
||||||
|
# concurrent sessions don't clobber each other's routing.
|
||||||
|
self._set_tool_context(channel, chat_id, message_id)
|
||||||
|
|
||||||
|
# Execute all tool calls concurrently — the LLM batches
|
||||||
|
# independent calls in a single response on purpose.
|
||||||
|
# return_exceptions=True ensures all results are collected
|
||||||
|
# even if one tool is cancelled or raises BaseException.
|
||||||
|
results = await asyncio.gather(*(
|
||||||
|
self.tools.execute(tc.name, tc.arguments)
|
||||||
|
for tc in response.tool_calls
|
||||||
|
), return_exceptions=True)
|
||||||
|
|
||||||
|
for tool_call, result in zip(response.tool_calls, results):
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
result = f"Error: {type(result).__name__}: {result}"
|
||||||
messages = self.context.add_tool_result(
|
messages = self.context.add_tool_result(
|
||||||
messages, tool_call.id, tool_call.name, result
|
messages, tool_call.id, tool_call.name, result
|
||||||
)
|
)
|
||||||
@@ -337,8 +363,10 @@ class AgentLoop:
|
|||||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Process a message under the global lock."""
|
"""Process a message: per-session serial, cross-session concurrent."""
|
||||||
async with self._processing_lock:
|
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||||
|
gate = self._concurrency_gate or nullcontext()
|
||||||
|
async with lock, gate:
|
||||||
try:
|
try:
|
||||||
on_stream = on_stream_end = None
|
on_stream = on_stream_end = None
|
||||||
if msg.metadata.get("_wants_stream"):
|
if msg.metadata.get("_wants_stream"):
|
||||||
@@ -422,7 +450,10 @@ class AgentLoop:
|
|||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
|
messages, channel=channel, chat_id=chat_id,
|
||||||
|
message_id=msg.metadata.get("message_id"),
|
||||||
|
)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
@@ -469,6 +500,8 @@ class AgentLoop:
|
|||||||
on_progress=on_progress or _bus_progress,
|
on_progress=on_progress or _bus_progress,
|
||||||
on_stream=on_stream,
|
on_stream=on_stream,
|
||||||
on_stream_end=on_stream_end,
|
on_stream_end=on_stream_end,
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None:
|
if final_content is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user