diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 35a0dfa..c5c8c6d 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -71,6 +71,7 @@ class AgentLoop: "registry.npmjs.org", ) _CLAWHUB_NPM_CACHE_DIR = Path(tempfile.gettempdir()) / "nanobot-npm-cache" + _PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS = 1.5 def __init__( self, @@ -137,7 +138,8 @@ class AgentLoop: self._mcp_connected = False self._mcp_connecting = False self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks - self._background_tasks: list[asyncio.Task] = [] + self._background_tasks: set[asyncio.Task] = set() + self._token_consolidation_tasks: dict[str, asyncio.Task[None]] = {} self._processing_lock = asyncio.Lock() self.memory_consolidator = MemoryConsolidator( workspace=workspace, @@ -933,15 +935,55 @@ class AgentLoop: async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" if self._background_tasks: - await asyncio.gather(*self._background_tasks, return_exceptions=True) + await asyncio.gather(*list(self._background_tasks), return_exceptions=True) self._background_tasks.clear() + self._token_consolidation_tasks.clear() await self._reset_mcp_connections() - def _schedule_background(self, coro) -> None: + def _track_background_task(self, task: asyncio.Task) -> asyncio.Task: + """Track a background task until completion.""" + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task + + def _schedule_background(self, coro) -> asyncio.Task: """Schedule a coroutine as a tracked background task (drained on shutdown).""" task = asyncio.create_task(coro) - self._background_tasks.append(task) - task.add_done_callback(self._background_tasks.remove) + return self._track_background_task(task) + + def _ensure_background_token_consolidation(self, session: Session) -> asyncio.Task[None]: + """Ensure at most one token-consolidation task runs per session.""" + existing = self._token_consolidation_tasks.get(session.key) + if existing and not existing.done(): + return existing + + task = asyncio.create_task(self.memory_consolidator.maybe_consolidate_by_tokens(session)) + self._token_consolidation_tasks[session.key] = task + self._track_background_task(task) + + def _cleanup(done: asyncio.Task[None]) -> None: + if self._token_consolidation_tasks.get(session.key) is done: + self._token_consolidation_tasks.pop(session.key, None) + + task.add_done_callback(_cleanup) + return task + + async def _run_preflight_token_consolidation(self, session: Session) -> None: + """Give token consolidation a short head start, then continue in background if needed.""" + task = self._ensure_background_token_consolidation(session) + try: + await asyncio.wait_for( + asyncio.shield(task), + timeout=self._PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS, + ) + except asyncio.TimeoutError: + logger.warning( + "Token consolidation still running for {} after {:.1f}s; continuing in background", + session.key, + self._PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS, + ) + except Exception: + logger.exception("Preflight token consolidation failed for {}", session.key) def stop(self) -> None: """Stop the agent loop.""" @@ -967,7 +1009,7 @@ class AgentLoop: persona = self._get_session_persona(session) language = self._get_session_language(session) await self._connect_mcp() - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self._run_preflight_token_consolidation(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) history = session.get_history(max_messages=0) # Subagent results should be assistant role, other system messages use user role @@ -984,7 +1026,7 @@ class AgentLoop: final_content, _, all_msgs = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) + self._ensure_background_token_consolidation(session) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -1022,7 +1064,7 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content="\n".join(help_lines(language)), ) await self._connect_mcp() - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self._run_preflight_token_consolidation(session) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): @@ -1057,7 +1099,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) + self._ensure_background_token_consolidation(session) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index f8244e5..56ad6cc 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -31,6 +31,9 @@ class Session: updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) last_consolidated: int = 0 # Number of messages already consolidated to files + _persisted_message_count: int = field(default=0, init=False, repr=False) + _persisted_metadata_state: str = field(default="", init=False, repr=False) + _requires_full_save: bool = field(default=False, init=False, repr=False) def add_message(self, role: str, content: str, **kwargs: Any) -> None: """Add a message to the session.""" @@ -97,6 +100,7 @@ class Session: self.messages = [] self.last_consolidated = 0 self.updated_at = datetime.now() + self._requires_full_save = True class SessionManager: @@ -178,33 +182,87 @@ class SessionManager: else: messages.append(data) - return Session( + session = Session( key=key, messages=messages, created_at=created_at or datetime.now(), + updated_at=datetime.fromtimestamp(path.stat().st_mtime), metadata=metadata, last_consolidated=last_consolidated ) + self._mark_persisted(session) + return session except Exception as e: logger.warning("Failed to load session {}: {}", key, e) return None + @staticmethod + def _metadata_state(session: Session) -> str: + """Serialize metadata fields that require a checkpoint line.""" + return json.dumps( + { + "key": session.key, + "created_at": session.created_at.isoformat(), + "metadata": session.metadata, + "last_consolidated": session.last_consolidated, + }, + ensure_ascii=False, + sort_keys=True, + ) + + @staticmethod + def _metadata_line(session: Session) -> dict[str, Any]: + """Build a metadata checkpoint record.""" + return { + "_type": "metadata", + "key": session.key, + "created_at": session.created_at.isoformat(), + "updated_at": session.updated_at.isoformat(), + "metadata": session.metadata, + "last_consolidated": session.last_consolidated + } + + @staticmethod + def _write_jsonl_line(handle: Any, payload: dict[str, Any]) -> None: + handle.write(json.dumps(payload, ensure_ascii=False) + "\n") + + def _mark_persisted(self, session: Session) -> None: + session._persisted_message_count = len(session.messages) + session._persisted_metadata_state = self._metadata_state(session) + session._requires_full_save = False + + def _rewrite_session_file(self, path: Path, session: Session) -> None: + with open(path, "w", encoding="utf-8") as f: + self._write_jsonl_line(f, self._metadata_line(session)) + for msg in session.messages: + self._write_jsonl_line(f, msg) + self._mark_persisted(session) + def save(self, session: Session) -> None: """Save a session to disk.""" path = self._get_session_path(session.key) + metadata_state = self._metadata_state(session) + needs_full_rewrite = ( + session._requires_full_save + or not path.exists() + or session._persisted_message_count > len(session.messages) + ) - with open(path, "w", encoding="utf-8") as f: - metadata_line = { - "_type": "metadata", - "key": session.key, - "created_at": session.created_at.isoformat(), - "updated_at": session.updated_at.isoformat(), - "metadata": session.metadata, - "last_consolidated": session.last_consolidated - } - f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n") - for msg in session.messages: - f.write(json.dumps(msg, ensure_ascii=False) + "\n") + if needs_full_rewrite: + session.updated_at = datetime.now() + self._rewrite_session_file(path, session) + else: + new_messages = session.messages[session._persisted_message_count:] + metadata_changed = metadata_state != session._persisted_metadata_state + + if new_messages or metadata_changed: + session.updated_at = datetime.now() + with open(path, "a", encoding="utf-8") as f: + for msg in new_messages: + self._write_jsonl_line(f, msg) + if metadata_changed: + self._write_jsonl_line(f, self._metadata_line(session)) + self._mark_persisted(session) self._cache[session.key] = session @@ -223,19 +281,24 @@ class SessionManager: for path in self.sessions_dir.glob("*.jsonl"): try: - # Read just the metadata line + created_at = None + key = path.stem.replace("_", ":", 1) with open(path, encoding="utf-8") as f: first_line = f.readline().strip() if first_line: data = json.loads(first_line) if data.get("_type") == "metadata": - key = data.get("key") or path.stem.replace("_", ":", 1) - sessions.append({ - "key": key, - "created_at": data.get("created_at"), - "updated_at": data.get("updated_at"), - "path": str(path) - }) + key = data.get("key") or key + created_at = data.get("created_at") + + # Incremental saves append messages without rewriting the first metadata line, + # so use file mtime as the session's latest activity timestamp. + sessions.append({ + "key": key, + "created_at": created_at, + "updated_at": datetime.fromtimestamp(path.stat().st_mtime).isoformat(), + "path": str(path) + }) except Exception: continue diff --git a/tests/test_loop_consolidation_tokens.py b/tests/test_loop_consolidation_tokens.py index b0f3dda..3e5411f 100644 --- a/tests/test_loop_consolidation_tokens.py +++ b/tests/test_loop_consolidation_tokens.py @@ -1,9 +1,10 @@ +import asyncio from unittest.mock import AsyncMock, MagicMock import pytest -from nanobot.agent.loop import AgentLoop import nanobot.agent.memory as memory_module +from nanobot.agent.loop import AgentLoop from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse @@ -188,3 +189,36 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> assert "consolidate" in order assert "llm" in order assert order.index("consolidate") < order.index("llm") + + +@pytest.mark.asyncio +async def test_slow_preflight_consolidation_continues_in_background(tmp_path, monkeypatch) -> None: + order: list[str] = [] + + loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) + monkeypatch.setattr(loop, "_PREFLIGHT_CONSOLIDATION_BUDGET_SECONDS", 0.01) + + release = asyncio.Event() + + async def slow_consolidation(_session): + order.append("consolidate-start") + await release.wait() + order.append("consolidate-end") + + async def track_llm(*args, **kwargs): + order.append("llm") + return LLMResponse(content="ok", tool_calls=[]) + + loop.memory_consolidator.maybe_consolidate_by_tokens = slow_consolidation # type: ignore[method-assign] + loop.provider.chat_with_retry = track_llm + + await loop.process_direct("hello", session_key="cli:test") + + assert "consolidate-start" in order + assert "llm" in order + assert "consolidate-end" not in order + + release.set() + await loop.close_mcp() + + assert "consolidate-end" in order diff --git a/tests/test_session_manager_persistence.py b/tests/test_session_manager_persistence.py new file mode 100644 index 0000000..f4c4c9d --- /dev/null +++ b/tests/test_session_manager_persistence.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import json +import os +import time +from datetime import datetime +from pathlib import Path + +from nanobot.session.manager import SessionManager + + +def _read_jsonl(path: Path) -> list[dict]: + return [ + json.loads(line) + for line in path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + + +def test_save_appends_only_new_messages(tmp_path: Path) -> None: + manager = SessionManager(tmp_path) + session = manager.get_or_create("qq:test") + session.add_message("user", "hello") + session.add_message("assistant", "hi") + manager.save(session) + + path = manager._get_session_path(session.key) + original_text = path.read_text(encoding="utf-8") + + session.add_message("user", "next") + manager.save(session) + + lines = _read_jsonl(path) + assert path.read_text(encoding="utf-8").startswith(original_text) + assert sum(1 for line in lines if line.get("_type") == "metadata") == 1 + assert [line["content"] for line in lines if line.get("role")] == ["hello", "hi", "next"] + + +def test_save_appends_metadata_checkpoint_without_rewriting_history(tmp_path: Path) -> None: + manager = SessionManager(tmp_path) + session = manager.get_or_create("qq:test") + session.add_message("user", "hello") + session.add_message("assistant", "hi") + manager.save(session) + + path = manager._get_session_path(session.key) + original_text = path.read_text(encoding="utf-8") + + session.last_consolidated = 2 + manager.save(session) + + lines = _read_jsonl(path) + assert path.read_text(encoding="utf-8").startswith(original_text) + assert sum(1 for line in lines if line.get("_type") == "metadata") == 2 + assert lines[-1]["_type"] == "metadata" + assert lines[-1]["last_consolidated"] == 2 + + manager.invalidate(session.key) + reloaded = manager.get_or_create("qq:test") + assert reloaded.last_consolidated == 2 + assert [message["content"] for message in reloaded.messages] == ["hello", "hi"] + + +def test_clear_rewrites_session_file(tmp_path: Path) -> None: + manager = SessionManager(tmp_path) + session = manager.get_or_create("qq:test") + session.add_message("user", "hello") + session.add_message("assistant", "hi") + manager.save(session) + + path = manager._get_session_path(session.key) + session.clear() + manager.save(session) + + lines = _read_jsonl(path) + assert len(lines) == 1 + assert lines[0]["_type"] == "metadata" + assert lines[0]["last_consolidated"] == 0 + + manager.invalidate(session.key) + reloaded = manager.get_or_create("qq:test") + assert reloaded.messages == [] + assert reloaded.last_consolidated == 0 + + +def test_list_sessions_uses_file_mtime_for_append_only_updates(tmp_path: Path) -> None: + manager = SessionManager(tmp_path) + session = manager.get_or_create("qq:test") + session.add_message("user", "hello") + manager.save(session) + + path = manager._get_session_path(session.key) + stale_time = time.time() - 3600 + os.utime(path, (stale_time, stale_time)) + + before = datetime.fromisoformat(manager.list_sessions()[0]["updated_at"]) + assert before.timestamp() < time.time() - 3000 + + session.add_message("assistant", "hi") + manager.save(session) + + after = datetime.fromisoformat(manager.list_sessions()[0]["updated_at"]) + assert after > before +