diff --git a/.gitignore b/.gitignore index 0d392d3..8d88b4d 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ poetry.lock .pytest_cache/ botpy.log nano.*.save - +.DS_Store +uv.lock diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 9e58f27..42bb5ed 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -117,6 +117,7 @@ class AgentLoop: self._mcp_connected = False self._mcp_connecting = False self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks + self._background_tasks: list[asyncio.Task] = [] self._processing_lock = asyncio.Lock() self.memory_consolidator = MemoryConsolidator( workspace=workspace, @@ -536,7 +537,10 @@ class AgentLoop: ) async def close_mcp(self) -> None: - """Close MCP connections.""" + """Drain pending background archives, then close MCP connections.""" + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() if self._mcp_stack: try: await self._mcp_stack.aclose() @@ -544,6 +548,12 @@ class AgentLoop: pass # MCP SDK cancel scope cleanup is noisy but harmless self._mcp_stack = None + def _schedule_background(self, coro) -> None: + """Schedule a coroutine as a tracked background task (drained on shutdown).""" + task = asyncio.create_task(coro) + self._background_tasks.append(task) + task.add_done_callback(self._background_tasks.remove) + def stop(self) -> None: """Stop the agent loop.""" self._running = False @@ -579,7 +589,7 @@ class AgentLoop: final_content, _, all_msgs = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -594,24 +604,14 @@ class AgentLoop: # Slash commands cmd = self._command_name(msg.content) if cmd == "/new": - try: - if not await self.memory_consolidator.archive_unconsolidated(session): - return OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content=text(language, "memory_archival_failed_session"), - ) - except Exception: - logger.exception("/new archival failed for {}", session.key) - return OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content=text(language, "memory_archival_failed_session"), - ) - + snapshot = session.messages[session.last_consolidated:] session.clear() self.sessions.save(session) self.sessions.invalidate(session.key) + + if snapshot: + self._schedule_background(self.memory_consolidator.archive_messages(session, snapshot)) + return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content=text(language, "new_session_started")) if cmd in {"/lang", "/language"}: @@ -657,7 +657,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 2f306d9..ae003a0 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio +import contextvars import json import weakref +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Callable @@ -77,10 +79,13 @@ def _is_tool_choice_unsupported(content: str | None) -> bool: class MemoryStore: """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" + _MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3 + def __init__(self, workspace: Path): self.memory_dir = ensure_dir(workspace / "memory") self.memory_file = self.memory_dir / "MEMORY.md" self.history_file = self.memory_dir / "HISTORY.md" + self._consecutive_failures = 0 def read_long_term(self) -> str: if self.memory_file.exists(): @@ -162,25 +167,60 @@ class MemoryStore: len(response.content or ""), (response.content or "")[:200], ) - return False + return self._fail_or_raw_archive(messages) args = _normalize_save_memory_args(response.tool_calls[0].arguments) if args is None: logger.warning("Memory consolidation: unexpected save_memory arguments") - return False + return self._fail_or_raw_archive(messages) - if entry := args.get("history_entry"): - self.append_history(_ensure_text(entry)) - if update := args.get("memory_update"): - update = _ensure_text(update) - if update != current_memory: - self.write_long_term(update) + if "history_entry" not in args or "memory_update" not in args: + logger.warning("Memory consolidation: save_memory payload missing required fields") + return self._fail_or_raw_archive(messages) + entry = args["history_entry"] + update = args["memory_update"] + + if entry is None or update is None: + logger.warning("Memory consolidation: save_memory payload contains null required fields") + return self._fail_or_raw_archive(messages) + + entry = _ensure_text(entry).strip() + if not entry: + logger.warning("Memory consolidation: history_entry is empty after normalization") + return self._fail_or_raw_archive(messages) + + self.append_history(entry) + update = _ensure_text(update) + if update != current_memory: + self.write_long_term(update) + + self._consecutive_failures = 0 logger.info("Memory consolidation done for {} messages", len(messages)) return True except Exception: logger.exception("Memory consolidation failed") + return self._fail_or_raw_archive(messages) + + def _fail_or_raw_archive(self, messages: list[dict]) -> bool: + """Increment failure count; after threshold, raw-archive messages and return True.""" + self._consecutive_failures += 1 + if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE: return False + self._raw_archive(messages) + self._consecutive_failures = 0 + return True + + def _raw_archive(self, messages: list[dict]) -> None: + """Fallback: dump raw messages to HISTORY.md without LLM summarization.""" + ts = datetime.now().strftime("%Y-%m-%d %H:%M") + self.append_history( + f"[{ts}] [RAW] {len(messages)} messages\n" + f"{self._format_messages(messages)}" + ) + logger.warning( + "Memory consolidation degraded: raw-archived {} messages", len(messages) + ) class MemoryConsolidator: @@ -206,6 +246,11 @@ class MemoryConsolidator: self._build_messages = build_messages self._get_tool_definitions = get_tool_definitions self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() + self._stores: dict[Path, MemoryStore] = {} + self._active_session: contextvars.ContextVar[Session | None] = contextvars.ContextVar( + "memory_consolidation_session", + default=None, + ) def _get_persona(self, session: Session) -> str: """Resolve the active persona for a session.""" @@ -219,15 +264,23 @@ class MemoryConsolidator: def _get_store(self, session: Session) -> MemoryStore: """Return the memory store associated with the active persona.""" - return MemoryStore(persona_workspace(self.workspace, self._get_persona(session))) + store_root = persona_workspace(self.workspace, self._get_persona(session)) + return self._stores.setdefault(store_root, MemoryStore(store_root)) + + def _get_default_store(self) -> MemoryStore: + """Return the default persona store for session-less consolidation contexts.""" + store_root = persona_workspace(self.workspace, DEFAULT_PERSONA) + return self._stores.setdefault(store_root, MemoryStore(store_root)) def get_lock(self, session_key: str) -> asyncio.Lock: """Return the shared consolidation lock for one session.""" return self._locks.setdefault(session_key, asyncio.Lock()) - async def consolidate_messages(self, session: Session, messages: list[dict[str, object]]) -> bool: + async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: """Archive a selected message chunk into persistent memory.""" - return await self._get_store(session).consolidate(messages, self.provider, self.model) + session = self._active_session.get() + store = self._get_store(session) if session is not None else self._get_default_store() + return await store.consolidate(messages, self.provider, self.model) def pick_consolidation_boundary( self, @@ -270,14 +323,37 @@ class MemoryConsolidator: self._get_tool_definitions(), ) + async def _archive_messages_locked( + self, + session: Session, + messages: list[dict[str, object]], + ) -> bool: + """Archive messages with guaranteed persistence (retries until raw-dump fallback).""" + if not messages: + return True + token = self._active_session.set(session) + try: + for _ in range(self._get_store(session)._MAX_FAILURES_BEFORE_RAW_ARCHIVE): + if await self.consolidate_messages(messages): + return True + finally: + self._active_session.reset(token) + return True + + async def archive_messages(self, session: Session, messages: list[dict[str, object]]) -> bool: + """Archive messages in the background with session-scoped memory persistence.""" + lock = self.get_lock(session.key) + async with lock: + return await self._archive_messages_locked(session, messages) + async def archive_unconsolidated(self, session: Session) -> bool: - """Archive the full unconsolidated tail for /new-style session rollover.""" + """Archive the full unconsolidated tail for persona switch and similar rollover flows.""" lock = self.get_lock(session.key) async with lock: snapshot = session.messages[session.last_consolidated:] if not snapshot: return True - return await self.consolidate_messages(session, snapshot) + return await self._archive_messages_locked(session, snapshot) async def maybe_consolidate_by_tokens(self, session: Session) -> None: """Loop: archive old messages until prompt fits within half the context window.""" @@ -327,8 +403,12 @@ class MemoryConsolidator: source, len(chunk), ) - if not await self.consolidate_messages(session, chunk): - return + token = self._active_session.set(session) + try: + if not await self.consolidate_messages(chunk): + return + finally: + self._active_session.reset(token) session.last_consolidated = end_idx self.sessions.save(session) diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index f0a6484..f8244e5 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -43,23 +43,52 @@ class Session: self.messages.append(msg) self.updated_at = datetime.now() + @staticmethod + def _find_legal_start(messages: list[dict[str, Any]]) -> int: + """Find first index where every tool result has a matching assistant tool_call.""" + declared: set[str] = set() + start = 0 + for i, msg in enumerate(messages): + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + elif role == "tool": + tid = msg.get("tool_call_id") + if tid and str(tid) not in declared: + start = i + 1 + declared.clear() + for prev in messages[start:i + 1]: + if prev.get("role") == "assistant": + for tc in prev.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + return start + def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: - """Return unconsolidated messages for LLM input, aligned to a user turn.""" + """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" unconsolidated = self.messages[self.last_consolidated:] sliced = unconsolidated[-max_messages:] - # Drop leading non-user messages to avoid orphaned tool_result blocks - for i, m in enumerate(sliced): - if m.get("role") == "user": + # Drop leading non-user messages to avoid starting mid-turn when possible. + for i, message in enumerate(sliced): + if message.get("role") == "user": sliced = sliced[i:] break + # Some providers reject orphan tool results if the matching assistant + # tool_calls message fell outside the fixed-size history window. + start = self._find_legal_start(sliced) + if start: + sliced = sliced[start:] + out: list[dict[str, Any]] = [] - for m in sliced: - entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")} - for k in ("tool_calls", "tool_call_id", "name"): - if k in m: - entry[k] = m[k] + for message in sliced: + entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")} + for key in ("tool_calls", "tool_call_id", "name"): + if key in message: + entry[key] = message[key] out.append(entry) return out diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index 7d12338..21e1e78 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -505,7 +505,8 @@ class TestNewCommandArchival: return loop @pytest.mark.asyncio - async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None: + async def test_new_clears_session_immediately_even_if_archive_fails(self, tmp_path: Path) -> None: + """/new clears session immediately; archive_messages retries until raw dump.""" from nanobot.bus.events import InboundMessage loop = self._make_loop(tmp_path) @@ -514,9 +515,12 @@ class TestNewCommandArchival: session.add_message("user", f"msg{i}") session.add_message("assistant", f"resp{i}") loop.sessions.save(session) - before_count = len(session.messages) + + call_count = 0 async def _failing_consolidate(_messages) -> bool: + nonlocal call_count + call_count += 1 return False loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign] @@ -525,8 +529,13 @@ class TestNewCommandArchival: response = await loop._process_message(new_msg) assert response is not None - assert "failed" in response.content.lower() - assert len(loop.sessions.get_or_create("cli:test").messages) == before_count + assert "new session started" in response.content.lower() + + session_after = loop.sessions.get_or_create("cli:test") + assert len(session_after.messages) == 0 + + await loop.close_mcp() + assert call_count == 3 # retried up to raw-archive threshold @pytest.mark.asyncio async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None: @@ -554,6 +563,8 @@ class TestNewCommandArchival: assert response is not None assert "new session started" in response.content.lower() + + await loop.close_mcp() assert archived_count == 3 @pytest.mark.asyncio @@ -578,3 +589,31 @@ class TestNewCommandArchival: assert response is not None assert "new session started" in response.content.lower() assert loop.sessions.get_or_create("cli:test").messages == [] + + @pytest.mark.asyncio + async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None: + """close_mcp waits for background tasks to complete.""" + from nanobot.bus.events import InboundMessage + + loop = self._make_loop(tmp_path) + session = loop.sessions.get_or_create("cli:test") + for i in range(3): + session.add_message("user", f"msg{i}") + session.add_message("assistant", f"resp{i}") + loop.sessions.save(session) + + archived = asyncio.Event() + + async def _slow_consolidate(_messages) -> bool: + await asyncio.sleep(0.1) + archived.set() + return True + + loop.memory_consolidator.consolidate_messages = _slow_consolidate # type: ignore[method-assign] + + new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") + await loop._process_message(new_msg) + + assert not archived.is_set() + await loop.close_mcp() + assert archived.is_set() diff --git a/tests/test_session_manager_history.py b/tests/test_session_manager_history.py new file mode 100644 index 0000000..4f56344 --- /dev/null +++ b/tests/test_session_manager_history.py @@ -0,0 +1,146 @@ +from nanobot.session.manager import Session + + +def _assert_no_orphans(history: list[dict]) -> None: + """Assert every tool result in history has a matching assistant tool_call.""" + declared = { + tc["id"] + for m in history if m.get("role") == "assistant" + for tc in (m.get("tool_calls") or []) + } + orphans = [ + m.get("tool_call_id") for m in history + if m.get("role") == "tool" and m.get("tool_call_id") not in declared + ] + assert orphans == [], f"orphan tool_call_ids: {orphans}" + + +def _tool_turn(prefix: str, idx: int) -> list[dict]: + """Helper: one assistant with 2 tool_calls + 2 tool results.""" + return [ + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": f"{prefix}_{idx}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, + {"id": f"{prefix}_{idx}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": f"{prefix}_{idx}_a", "name": "x", "content": "ok"}, + {"role": "tool", "tool_call_id": f"{prefix}_{idx}_b", "name": "y", "content": "ok"}, + ] + + +# --- Original regression test (from PR 2075) --- + +def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls(): + session = Session(key="telegram:test") + session.messages.append({"role": "user", "content": "old turn"}) + for i in range(20): + session.messages.extend(_tool_turn("old", i)) + session.messages.append({"role": "user", "content": "problem turn"}) + for i in range(25): + session.messages.extend(_tool_turn("cur", i)) + session.messages.append({"role": "user", "content": "new telegram question"}) + + history = session.get_history(max_messages=100) + _assert_no_orphans(history) + + +# --- Positive test: legitimate pairs survive trimming --- + +def test_legitimate_tool_pairs_preserved_after_trim(): + """Complete tool-call groups within the window must not be dropped.""" + session = Session(key="test:positive") + session.messages.append({"role": "user", "content": "hello"}) + for i in range(5): + session.messages.extend(_tool_turn("ok", i)) + session.messages.append({"role": "assistant", "content": "done"}) + + history = session.get_history(max_messages=500) + _assert_no_orphans(history) + tool_ids = [m["tool_call_id"] for m in history if m.get("role") == "tool"] + assert len(tool_ids) == 10 + assert history[0]["role"] == "user" + + +# --- last_consolidated > 0 --- + +def test_orphan_trim_with_last_consolidated(): + """Orphan trimming works correctly when session is partially consolidated.""" + session = Session(key="test:consolidated") + for i in range(10): + session.messages.append({"role": "user", "content": f"old {i}"}) + session.messages.extend(_tool_turn("cons", i)) + session.last_consolidated = 30 + + session.messages.append({"role": "user", "content": "recent"}) + for i in range(15): + session.messages.extend(_tool_turn("new", i)) + session.messages.append({"role": "user", "content": "latest"}) + + history = session.get_history(max_messages=20) + _assert_no_orphans(history) + assert all(m.get("role") != "tool" or m["tool_call_id"].startswith("new_") for m in history) + + +# --- Edge: no tool messages at all --- + +def test_no_tool_messages_unchanged(): + session = Session(key="test:plain") + for i in range(5): + session.messages.append({"role": "user", "content": f"q{i}"}) + session.messages.append({"role": "assistant", "content": f"a{i}"}) + + history = session.get_history(max_messages=6) + assert len(history) == 6 + _assert_no_orphans(history) + + +# --- Edge: all leading messages are orphan tool results --- + +def test_all_orphan_prefix_stripped(): + """If the window starts with orphan tool results and nothing else, they're all dropped.""" + session = Session(key="test:all-orphan") + session.messages.append({"role": "tool", "tool_call_id": "gone_1", "name": "x", "content": "ok"}) + session.messages.append({"role": "tool", "tool_call_id": "gone_2", "name": "y", "content": "ok"}) + session.messages.append({"role": "user", "content": "fresh start"}) + session.messages.append({"role": "assistant", "content": "hi"}) + + history = session.get_history(max_messages=500) + _assert_no_orphans(history) + assert history[0]["role"] == "user" + assert len(history) == 2 + + +# --- Edge: empty session --- + +def test_empty_session_history(): + session = Session(key="test:empty") + history = session.get_history(max_messages=500) + assert history == [] + + +# --- Window cuts mid-group: assistant present but some tool results orphaned --- + +def test_window_cuts_mid_tool_group(): + """If the window starts between an assistant's tool results, the partial group is trimmed.""" + session = Session(key="test:mid-cut") + session.messages.append({"role": "user", "content": "setup"}) + session.messages.append({ + "role": "assistant", "content": None, + "tool_calls": [ + {"id": "split_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, + {"id": "split_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, + ], + }) + session.messages.append({"role": "tool", "tool_call_id": "split_a", "name": "x", "content": "ok"}) + session.messages.append({"role": "tool", "tool_call_id": "split_b", "name": "y", "content": "ok"}) + session.messages.append({"role": "user", "content": "next"}) + session.messages.extend(_tool_turn("intact", 0)) + session.messages.append({"role": "assistant", "content": "final"}) + + # Window of 6 should cut off the "setup" user msg and the assistant with split_a/split_b, + # leaving orphan tool results for split_a at the front. + history = session.get_history(max_messages=6) + _assert_no_orphans(history)