fix(session): avoid blocking large chat cleanup
This commit is contained in:
@@ -71,6 +71,7 @@ class AgentLoop:
|
|||||||
"registry.npmjs.org",
|
"registry.npmjs.org",
|
||||||
)
|
)
|
||||||
_CLAWHUB_NPM_CACHE_DIR = Path(tempfile.gettempdir()) / "nanobot-npm-cache"
|
_CLAWHUB_NPM_CACHE_DIR = Path(tempfile.gettempdir()) / "nanobot-npm-cache"
|
||||||
|
_PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS = 1.5
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -137,7 +138,8 @@ class AgentLoop:
|
|||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._background_tasks: list[asyncio.Task] = []
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
self._token_consolidation_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
self._processing_lock = asyncio.Lock()
|
self._processing_lock = asyncio.Lock()
|
||||||
self.memory_consolidator = MemoryConsolidator(
|
self.memory_consolidator = MemoryConsolidator(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
@@ -933,15 +935,55 @@ class AgentLoop:
|
|||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Drain pending background archives, then close MCP connections."""
|
"""Drain pending background archives, then close MCP connections."""
|
||||||
if self._background_tasks:
|
if self._background_tasks:
|
||||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
await asyncio.gather(*list(self._background_tasks), return_exceptions=True)
|
||||||
self._background_tasks.clear()
|
self._background_tasks.clear()
|
||||||
|
self._token_consolidation_tasks.clear()
|
||||||
await self._reset_mcp_connections()
|
await self._reset_mcp_connections()
|
||||||
|
|
||||||
def _schedule_background(self, coro) -> None:
|
def _track_background_task(self, task: asyncio.Task) -> asyncio.Task:
|
||||||
|
"""Track a background task until completion."""
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def _schedule_background(self, coro) -> asyncio.Task:
|
||||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||||
task = asyncio.create_task(coro)
|
task = asyncio.create_task(coro)
|
||||||
self._background_tasks.append(task)
|
return self._track_background_task(task)
|
||||||
task.add_done_callback(self._background_tasks.remove)
|
|
||||||
|
def _ensure_background_token_consolidation(self, session: Session) -> asyncio.Task[None]:
|
||||||
|
"""Ensure at most one token-consolidation task runs per session."""
|
||||||
|
existing = self._token_consolidation_tasks.get(session.key)
|
||||||
|
if existing and not existing.done():
|
||||||
|
return existing
|
||||||
|
|
||||||
|
task = asyncio.create_task(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
|
self._token_consolidation_tasks[session.key] = task
|
||||||
|
self._track_background_task(task)
|
||||||
|
|
||||||
|
def _cleanup(done: asyncio.Task[None]) -> None:
|
||||||
|
if self._token_consolidation_tasks.get(session.key) is done:
|
||||||
|
self._token_consolidation_tasks.pop(session.key, None)
|
||||||
|
|
||||||
|
task.add_done_callback(_cleanup)
|
||||||
|
return task
|
||||||
|
|
||||||
|
async def _run_preflight_token_consolidation(self, session: Session) -> None:
|
||||||
|
"""Give token consolidation a short head start, then continue in background if needed."""
|
||||||
|
task = self._ensure_background_token_consolidation(session)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.shield(task),
|
||||||
|
timeout=self._PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Token consolidation still running for {} after {:.1f}s; continuing in background",
|
||||||
|
session.key,
|
||||||
|
self._PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Preflight token consolidation failed for {}", session.key)
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the agent loop."""
|
"""Stop the agent loop."""
|
||||||
@@ -967,7 +1009,7 @@ class AgentLoop:
|
|||||||
persona = self._get_session_persona(session)
|
persona = self._get_session_persona(session)
|
||||||
language = self._get_session_language(session)
|
language = self._get_session_language(session)
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self._run_preflight_token_consolidation(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
# Subagent results should be assistant role, other system messages use user role
|
# Subagent results should be assistant role, other system messages use user role
|
||||||
@@ -984,7 +1026,7 @@ class AgentLoop:
|
|||||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._ensure_background_token_consolidation(session)
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
content=final_content or "Background task completed.")
|
content=final_content or "Background task completed.")
|
||||||
|
|
||||||
@@ -1022,7 +1064,7 @@ class AgentLoop:
|
|||||||
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(help_lines(language)),
|
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(help_lines(language)),
|
||||||
)
|
)
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self._run_preflight_token_consolidation(session)
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
@@ -1057,7 +1099,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._ensure_background_token_consolidation(session)
|
||||||
|
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -31,6 +31,9 @@ class Session:
|
|||||||
updated_at: datetime = field(default_factory=datetime.now)
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||||
|
_persisted_message_count: int = field(default=0, init=False, repr=False)
|
||||||
|
_persisted_metadata_state: str = field(default="", init=False, repr=False)
|
||||||
|
_requires_full_save: bool = field(default=False, init=False, repr=False)
|
||||||
|
|
||||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||||
"""Add a message to the session."""
|
"""Add a message to the session."""
|
||||||
@@ -97,6 +100,7 @@ class Session:
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
self.last_consolidated = 0
|
self.last_consolidated = 0
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
self._requires_full_save = True
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
@@ -178,23 +182,38 @@ class SessionManager:
|
|||||||
else:
|
else:
|
||||||
messages.append(data)
|
messages.append(data)
|
||||||
|
|
||||||
return Session(
|
session = Session(
|
||||||
key=key,
|
key=key,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
created_at=created_at or datetime.now(),
|
created_at=created_at or datetime.now(),
|
||||||
|
updated_at=datetime.fromtimestamp(path.stat().st_mtime),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
last_consolidated=last_consolidated
|
last_consolidated=last_consolidated
|
||||||
)
|
)
|
||||||
|
self._mark_persisted(session)
|
||||||
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load session {}: {}", key, e)
|
logger.warning("Failed to load session {}: {}", key, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def save(self, session: Session) -> None:
|
@staticmethod
|
||||||
"""Save a session to disk."""
|
def _metadata_state(session: Session) -> str:
|
||||||
path = self._get_session_path(session.key)
|
"""Serialize metadata fields that require a checkpoint line."""
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"key": session.key,
|
||||||
|
"created_at": session.created_at.isoformat(),
|
||||||
|
"metadata": session.metadata,
|
||||||
|
"last_consolidated": session.last_consolidated,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
sort_keys=True,
|
||||||
|
)
|
||||||
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
@staticmethod
|
||||||
metadata_line = {
|
def _metadata_line(session: Session) -> dict[str, Any]:
|
||||||
|
"""Build a metadata checkpoint record."""
|
||||||
|
return {
|
||||||
"_type": "metadata",
|
"_type": "metadata",
|
||||||
"key": session.key,
|
"key": session.key,
|
||||||
"created_at": session.created_at.isoformat(),
|
"created_at": session.created_at.isoformat(),
|
||||||
@@ -202,9 +221,48 @@ class SessionManager:
|
|||||||
"metadata": session.metadata,
|
"metadata": session.metadata,
|
||||||
"last_consolidated": session.last_consolidated
|
"last_consolidated": session.last_consolidated
|
||||||
}
|
}
|
||||||
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_jsonl_line(handle: Any, payload: dict[str, Any]) -> None:
|
||||||
|
handle.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
def _mark_persisted(self, session: Session) -> None:
|
||||||
|
session._persisted_message_count = len(session.messages)
|
||||||
|
session._persisted_metadata_state = self._metadata_state(session)
|
||||||
|
session._requires_full_save = False
|
||||||
|
|
||||||
|
def _rewrite_session_file(self, path: Path, session: Session) -> None:
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
self._write_jsonl_line(f, self._metadata_line(session))
|
||||||
for msg in session.messages:
|
for msg in session.messages:
|
||||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
self._write_jsonl_line(f, msg)
|
||||||
|
self._mark_persisted(session)
|
||||||
|
|
||||||
|
def save(self, session: Session) -> None:
|
||||||
|
"""Save a session to disk."""
|
||||||
|
path = self._get_session_path(session.key)
|
||||||
|
metadata_state = self._metadata_state(session)
|
||||||
|
needs_full_rewrite = (
|
||||||
|
session._requires_full_save
|
||||||
|
or not path.exists()
|
||||||
|
or session._persisted_message_count > len(session.messages)
|
||||||
|
)
|
||||||
|
|
||||||
|
if needs_full_rewrite:
|
||||||
|
session.updated_at = datetime.now()
|
||||||
|
self._rewrite_session_file(path, session)
|
||||||
|
else:
|
||||||
|
new_messages = session.messages[session._persisted_message_count:]
|
||||||
|
metadata_changed = metadata_state != session._persisted_metadata_state
|
||||||
|
|
||||||
|
if new_messages or metadata_changed:
|
||||||
|
session.updated_at = datetime.now()
|
||||||
|
with open(path, "a", encoding="utf-8") as f:
|
||||||
|
for msg in new_messages:
|
||||||
|
self._write_jsonl_line(f, msg)
|
||||||
|
if metadata_changed:
|
||||||
|
self._write_jsonl_line(f, self._metadata_line(session))
|
||||||
|
self._mark_persisted(session)
|
||||||
|
|
||||||
self._cache[session.key] = session
|
self._cache[session.key] = session
|
||||||
|
|
||||||
@@ -223,17 +281,22 @@ class SessionManager:
|
|||||||
|
|
||||||
for path in self.sessions_dir.glob("*.jsonl"):
|
for path in self.sessions_dir.glob("*.jsonl"):
|
||||||
try:
|
try:
|
||||||
# Read just the metadata line
|
created_at = None
|
||||||
|
key = path.stem.replace("_", ":", 1)
|
||||||
with open(path, encoding="utf-8") as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
first_line = f.readline().strip()
|
first_line = f.readline().strip()
|
||||||
if first_line:
|
if first_line:
|
||||||
data = json.loads(first_line)
|
data = json.loads(first_line)
|
||||||
if data.get("_type") == "metadata":
|
if data.get("_type") == "metadata":
|
||||||
key = data.get("key") or path.stem.replace("_", ":", 1)
|
key = data.get("key") or key
|
||||||
|
created_at = data.get("created_at")
|
||||||
|
|
||||||
|
# Incremental saves append messages without rewriting the first metadata line,
|
||||||
|
# so use file mtime as the session's latest activity timestamp.
|
||||||
sessions.append({
|
sessions.append({
|
||||||
"key": key,
|
"key": key,
|
||||||
"created_at": data.get("created_at"),
|
"created_at": created_at,
|
||||||
"updated_at": data.get("updated_at"),
|
"updated_at": datetime.fromtimestamp(path.stat().st_mtime).isoformat(),
|
||||||
"path": str(path)
|
"path": str(path)
|
||||||
})
|
})
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
import nanobot.agent.memory as memory_module
|
import nanobot.agent.memory as memory_module
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
@@ -188,3 +189,36 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
|
|||||||
assert "consolidate" in order
|
assert "consolidate" in order
|
||||||
assert "llm" in order
|
assert "llm" in order
|
||||||
assert order.index("consolidate") < order.index("llm")
|
assert order.index("consolidate") < order.index("llm")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_slow_preflight_consolidation_continues_in_background(tmp_path, monkeypatch) -> None:
|
||||||
|
order: list[str] = []
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
monkeypatch.setattr(loop, "_PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS", 0.01)
|
||||||
|
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow_consolidation(_session):
|
||||||
|
order.append("consolidate-start")
|
||||||
|
await release.wait()
|
||||||
|
order.append("consolidate-end")
|
||||||
|
|
||||||
|
async def track_llm(*args, **kwargs):
|
||||||
|
order.append("llm")
|
||||||
|
return LLMResponse(content="ok", tool_calls=[])
|
||||||
|
|
||||||
|
loop.memory_consolidator.maybe_consolidate_by_tokens = slow_consolidation # type: ignore[method-assign]
|
||||||
|
loop.provider.chat_with_retry = track_llm
|
||||||
|
|
||||||
|
await loop.process_direct("hello", session_key="cli:test")
|
||||||
|
|
||||||
|
assert "consolidate-start" in order
|
||||||
|
assert "llm" in order
|
||||||
|
assert "consolidate-end" not in order
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
await loop.close_mcp()
|
||||||
|
|
||||||
|
assert "consolidate-end" in order
|
||||||
|
|||||||
104
tests/test_session_manager_persistence.py
Normal file
104
tests/test_session_manager_persistence.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
def _read_jsonl(path: Path) -> list[dict]:
|
||||||
|
return [
|
||||||
|
json.loads(line)
|
||||||
|
for line in path.read_text(encoding="utf-8").splitlines()
|
||||||
|
if line.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_appends_only_new_messages(tmp_path: Path) -> None:
|
||||||
|
manager = SessionManager(tmp_path)
|
||||||
|
session = manager.get_or_create("qq:test")
|
||||||
|
session.add_message("user", "hello")
|
||||||
|
session.add_message("assistant", "hi")
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
path = manager._get_session_path(session.key)
|
||||||
|
original_text = path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
session.add_message("user", "next")
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
lines = _read_jsonl(path)
|
||||||
|
assert path.read_text(encoding="utf-8").startswith(original_text)
|
||||||
|
assert sum(1 for line in lines if line.get("_type") == "metadata") == 1
|
||||||
|
assert [line["content"] for line in lines if line.get("role")] == ["hello", "hi", "next"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_appends_metadata_checkpoint_without_rewriting_history(tmp_path: Path) -> None:
|
||||||
|
manager = SessionManager(tmp_path)
|
||||||
|
session = manager.get_or_create("qq:test")
|
||||||
|
session.add_message("user", "hello")
|
||||||
|
session.add_message("assistant", "hi")
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
path = manager._get_session_path(session.key)
|
||||||
|
original_text = path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
session.last_consolidated = 2
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
lines = _read_jsonl(path)
|
||||||
|
assert path.read_text(encoding="utf-8").startswith(original_text)
|
||||||
|
assert sum(1 for line in lines if line.get("_type") == "metadata") == 2
|
||||||
|
assert lines[-1]["_type"] == "metadata"
|
||||||
|
assert lines[-1]["last_consolidated"] == 2
|
||||||
|
|
||||||
|
manager.invalidate(session.key)
|
||||||
|
reloaded = manager.get_or_create("qq:test")
|
||||||
|
assert reloaded.last_consolidated == 2
|
||||||
|
assert [message["content"] for message in reloaded.messages] == ["hello", "hi"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_rewrites_session_file(tmp_path: Path) -> None:
|
||||||
|
manager = SessionManager(tmp_path)
|
||||||
|
session = manager.get_or_create("qq:test")
|
||||||
|
session.add_message("user", "hello")
|
||||||
|
session.add_message("assistant", "hi")
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
path = manager._get_session_path(session.key)
|
||||||
|
session.clear()
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
lines = _read_jsonl(path)
|
||||||
|
assert len(lines) == 1
|
||||||
|
assert lines[0]["_type"] == "metadata"
|
||||||
|
assert lines[0]["last_consolidated"] == 0
|
||||||
|
|
||||||
|
manager.invalidate(session.key)
|
||||||
|
reloaded = manager.get_or_create("qq:test")
|
||||||
|
assert reloaded.messages == []
|
||||||
|
assert reloaded.last_consolidated == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_sessions_uses_file_mtime_for_append_only_updates(tmp_path: Path) -> None:
|
||||||
|
manager = SessionManager(tmp_path)
|
||||||
|
session = manager.get_or_create("qq:test")
|
||||||
|
session.add_message("user", "hello")
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
path = manager._get_session_path(session.key)
|
||||||
|
stale_time = time.time() - 3600
|
||||||
|
os.utime(path, (stale_time, stale_time))
|
||||||
|
|
||||||
|
before = datetime.fromisoformat(manager.list_sessions()[0]["updated_at"])
|
||||||
|
assert before.timestamp() < time.time() - 3000
|
||||||
|
|
||||||
|
session.add_message("assistant", "hi")
|
||||||
|
manager.save(session)
|
||||||
|
|
||||||
|
after = datetime.fromisoformat(manager.list_sessions()[0]["updated_at"])
|
||||||
|
assert after > before
|
||||||
|
|
||||||
Reference in New Issue
Block a user