Merge remote-tracking branch 'origin/main'
Some checks failed
Test Suite / test (3.12) (push) Has been cancelled
Test Suite / test (3.11) (push) Has been cancelled
Test Suite / test (3.13) (push) Has been cancelled

# Conflicts:
#	.gitignore
#	nanobot/agent/loop.py
#	nanobot/agent/memory.py
This commit is contained in:
Hua
2026-03-16 18:52:43 +08:00
6 changed files with 342 additions and 47 deletions

3
.gitignore vendored
View File

@@ -22,4 +22,5 @@ poetry.lock
.pytest_cache/ .pytest_cache/
botpy.log botpy.log
nano.*.save nano.*.save
.DS_Store
uv.lock

View File

@@ -117,6 +117,7 @@ class AgentLoop:
self._mcp_connected = False self._mcp_connected = False
self._mcp_connecting = False self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._background_tasks: list[asyncio.Task] = []
self._processing_lock = asyncio.Lock() self._processing_lock = asyncio.Lock()
self.memory_consolidator = MemoryConsolidator( self.memory_consolidator = MemoryConsolidator(
workspace=workspace, workspace=workspace,
@@ -536,7 +537,10 @@ class AgentLoop:
) )
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
"""Close MCP connections.""" """Drain pending background archives, then close MCP connections."""
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
if self._mcp_stack: if self._mcp_stack:
try: try:
await self._mcp_stack.aclose() await self._mcp_stack.aclose()
@@ -544,6 +548,12 @@ class AgentLoop:
pass # MCP SDK cancel scope cleanup is noisy but harmless pass # MCP SDK cancel scope cleanup is noisy but harmless
self._mcp_stack = None self._mcp_stack = None
def _schedule_background(self, coro) -> None:
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
task = asyncio.create_task(coro)
self._background_tasks.append(task)
task.add_done_callback(self._background_tasks.remove)
def stop(self) -> None: def stop(self) -> None:
"""Stop the agent loop.""" """Stop the agent loop."""
self._running = False self._running = False
@@ -579,7 +589,7 @@ class AgentLoop:
final_content, _, all_msgs = await self._run_agent_loop(messages) final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id, return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.") content=final_content or "Background task completed.")
@@ -594,24 +604,14 @@ class AgentLoop:
# Slash commands # Slash commands
cmd = self._command_name(msg.content) cmd = self._command_name(msg.content)
if cmd == "/new": if cmd == "/new":
try: snapshot = session.messages[session.last_consolidated:]
if not await self.memory_consolidator.archive_unconsolidated(session):
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=text(language, "memory_archival_failed_session"),
)
except Exception:
logger.exception("/new archival failed for {}", session.key)
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=text(language, "memory_archival_failed_session"),
)
session.clear() session.clear()
self.sessions.save(session) self.sessions.save(session)
self.sessions.invalidate(session.key) self.sessions.invalidate(session.key)
if snapshot:
self._schedule_background(self.memory_consolidator.archive_messages(session, snapshot))
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content=text(language, "new_session_started")) content=text(language, "new_session_started"))
if cmd in {"/lang", "/language"}: if cmd in {"/lang", "/language"}:
@@ -657,7 +657,7 @@ class AgentLoop:
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None return None

View File

@@ -3,8 +3,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextvars
import json import json
import weakref import weakref
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
@@ -77,10 +79,13 @@ def _is_tool_choice_unsupported(content: str | None) -> bool:
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
_MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.memory_dir = ensure_dir(workspace / "memory") self.memory_dir = ensure_dir(workspace / "memory")
self.memory_file = self.memory_dir / "MEMORY.md" self.memory_file = self.memory_dir / "MEMORY.md"
self.history_file = self.memory_dir / "HISTORY.md" self.history_file = self.memory_dir / "HISTORY.md"
self._consecutive_failures = 0
def read_long_term(self) -> str: def read_long_term(self) -> str:
if self.memory_file.exists(): if self.memory_file.exists():
@@ -162,25 +167,60 @@ class MemoryStore:
len(response.content or ""), len(response.content or ""),
(response.content or "")[:200], (response.content or "")[:200],
) )
return False return self._fail_or_raw_archive(messages)
args = _normalize_save_memory_args(response.tool_calls[0].arguments) args = _normalize_save_memory_args(response.tool_calls[0].arguments)
if args is None: if args is None:
logger.warning("Memory consolidation: unexpected save_memory arguments") logger.warning("Memory consolidation: unexpected save_memory arguments")
return False return self._fail_or_raw_archive(messages)
if entry := args.get("history_entry"): if "history_entry" not in args or "memory_update" not in args:
self.append_history(_ensure_text(entry)) logger.warning("Memory consolidation: save_memory payload missing required fields")
if update := args.get("memory_update"): return self._fail_or_raw_archive(messages)
update = _ensure_text(update)
if update != current_memory:
self.write_long_term(update)
entry = args["history_entry"]
update = args["memory_update"]
if entry is None or update is None:
logger.warning("Memory consolidation: save_memory payload contains null required fields")
return self._fail_or_raw_archive(messages)
entry = _ensure_text(entry).strip()
if not entry:
logger.warning("Memory consolidation: history_entry is empty after normalization")
return self._fail_or_raw_archive(messages)
self.append_history(entry)
update = _ensure_text(update)
if update != current_memory:
self.write_long_term(update)
self._consecutive_failures = 0
logger.info("Memory consolidation done for {} messages", len(messages)) logger.info("Memory consolidation done for {} messages", len(messages))
return True return True
except Exception: except Exception:
logger.exception("Memory consolidation failed") logger.exception("Memory consolidation failed")
return self._fail_or_raw_archive(messages)
def _fail_or_raw_archive(self, messages: list[dict]) -> bool:
"""Increment failure count; after threshold, raw-archive messages and return True."""
self._consecutive_failures += 1
if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE:
return False return False
self._raw_archive(messages)
self._consecutive_failures = 0
return True
def _raw_archive(self, messages: list[dict]) -> None:
"""Fallback: dump raw messages to HISTORY.md without LLM summarization."""
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
self.append_history(
f"[{ts}] [RAW] {len(messages)} messages\n"
f"{self._format_messages(messages)}"
)
logger.warning(
"Memory consolidation degraded: raw-archived {} messages", len(messages)
)
class MemoryConsolidator: class MemoryConsolidator:
@@ -206,6 +246,11 @@ class MemoryConsolidator:
self._build_messages = build_messages self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
self._stores: dict[Path, MemoryStore] = {}
self._active_session: contextvars.ContextVar[Session | None] = contextvars.ContextVar(
"memory_consolidation_session",
default=None,
)
def _get_persona(self, session: Session) -> str: def _get_persona(self, session: Session) -> str:
"""Resolve the active persona for a session.""" """Resolve the active persona for a session."""
@@ -219,15 +264,23 @@ class MemoryConsolidator:
def _get_store(self, session: Session) -> MemoryStore: def _get_store(self, session: Session) -> MemoryStore:
"""Return the memory store associated with the active persona.""" """Return the memory store associated with the active persona."""
return MemoryStore(persona_workspace(self.workspace, self._get_persona(session))) store_root = persona_workspace(self.workspace, self._get_persona(session))
return self._stores.setdefault(store_root, MemoryStore(store_root))
def _get_default_store(self) -> MemoryStore:
"""Return the default persona store for session-less consolidation contexts."""
store_root = persona_workspace(self.workspace, DEFAULT_PERSONA)
return self._stores.setdefault(store_root, MemoryStore(store_root))
def get_lock(self, session_key: str) -> asyncio.Lock: def get_lock(self, session_key: str) -> asyncio.Lock:
"""Return the shared consolidation lock for one session.""" """Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock()) return self._locks.setdefault(session_key, asyncio.Lock())
async def consolidate_messages(self, session: Session, messages: list[dict[str, object]]) -> bool: async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
"""Archive a selected message chunk into persistent memory.""" """Archive a selected message chunk into persistent memory."""
return await self._get_store(session).consolidate(messages, self.provider, self.model) session = self._active_session.get()
store = self._get_store(session) if session is not None else self._get_default_store()
return await store.consolidate(messages, self.provider, self.model)
def pick_consolidation_boundary( def pick_consolidation_boundary(
self, self,
@@ -270,14 +323,37 @@ class MemoryConsolidator:
self._get_tool_definitions(), self._get_tool_definitions(),
) )
async def _archive_messages_locked(
self,
session: Session,
messages: list[dict[str, object]],
) -> bool:
"""Archive messages with guaranteed persistence (retries until raw-dump fallback)."""
if not messages:
return True
token = self._active_session.set(session)
try:
for _ in range(self._get_store(session)._MAX_FAILURES_BEFORE_RAW_ARCHIVE):
if await self.consolidate_messages(messages):
return True
finally:
self._active_session.reset(token)
return True
async def archive_messages(self, session: Session, messages: list[dict[str, object]]) -> bool:
"""Archive messages in the background with session-scoped memory persistence."""
lock = self.get_lock(session.key)
async with lock:
return await self._archive_messages_locked(session, messages)
async def archive_unconsolidated(self, session: Session) -> bool: async def archive_unconsolidated(self, session: Session) -> bool:
"""Archive the full unconsolidated tail for /new-style session rollover.""" """Archive the full unconsolidated tail for persona switch and similar rollover flows."""
lock = self.get_lock(session.key) lock = self.get_lock(session.key)
async with lock: async with lock:
snapshot = session.messages[session.last_consolidated:] snapshot = session.messages[session.last_consolidated:]
if not snapshot: if not snapshot:
return True return True
return await self.consolidate_messages(session, snapshot) return await self._archive_messages_locked(session, snapshot)
async def maybe_consolidate_by_tokens(self, session: Session) -> None: async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Loop: archive old messages until prompt fits within half the context window.""" """Loop: archive old messages until prompt fits within half the context window."""
@@ -327,8 +403,12 @@ class MemoryConsolidator:
source, source,
len(chunk), len(chunk),
) )
if not await self.consolidate_messages(session, chunk): token = self._active_session.set(session)
return try:
if not await self.consolidate_messages(chunk):
return
finally:
self._active_session.reset(token)
session.last_consolidated = end_idx session.last_consolidated = end_idx
self.sessions.save(session) self.sessions.save(session)

View File

@@ -43,23 +43,52 @@ class Session:
self.messages.append(msg) self.messages.append(msg)
self.updated_at = datetime.now() self.updated_at = datetime.now()
@staticmethod
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
"""Find first index where every tool result has a matching assistant tool_call."""
declared: set[str] = set()
start = 0
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
start = i + 1
declared.clear()
for prev in messages[start:i + 1]:
if prev.get("role") == "assistant":
for tc in prev.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
return start
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a user turn.""" """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:] unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:] sliced = unconsolidated[-max_messages:]
# Drop leading non-user messages to avoid orphaned tool_result blocks # Drop leading non-user messages to avoid starting mid-turn when possible.
for i, m in enumerate(sliced): for i, message in enumerate(sliced):
if m.get("role") == "user": if message.get("role") == "user":
sliced = sliced[i:] sliced = sliced[i:]
break break
# Some providers reject orphan tool results if the matching assistant
# tool_calls message fell outside the fixed-size history window.
start = self._find_legal_start(sliced)
if start:
sliced = sliced[start:]
out: list[dict[str, Any]] = [] out: list[dict[str, Any]] = []
for m in sliced: for message in sliced:
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")} entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")}
for k in ("tool_calls", "tool_call_id", "name"): for key in ("tool_calls", "tool_call_id", "name"):
if k in m: if key in message:
entry[k] = m[k] entry[key] = message[key]
out.append(entry) out.append(entry)
return out return out

View File

@@ -505,7 +505,8 @@ class TestNewCommandArchival:
return loop return loop
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None: async def test_new_clears_session_immediately_even_if_archive_fails(self, tmp_path: Path) -> None:
"""/new clears session immediately; archive_messages retries until raw dump."""
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path) loop = self._make_loop(tmp_path)
@@ -514,9 +515,12 @@ class TestNewCommandArchival:
session.add_message("user", f"msg{i}") session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}") session.add_message("assistant", f"resp{i}")
loop.sessions.save(session) loop.sessions.save(session)
before_count = len(session.messages)
call_count = 0
async def _failing_consolidate(_messages) -> bool: async def _failing_consolidate(_messages) -> bool:
nonlocal call_count
call_count += 1
return False return False
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
@@ -525,8 +529,13 @@ class TestNewCommandArchival:
response = await loop._process_message(new_msg) response = await loop._process_message(new_msg)
assert response is not None assert response is not None
assert "failed" in response.content.lower() assert "new session started" in response.content.lower()
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == 0
await loop.close_mcp()
assert call_count == 3 # retried up to raw-archive threshold
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None: async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
@@ -554,6 +563,8 @@ class TestNewCommandArchival:
assert response is not None assert response is not None
assert "new session started" in response.content.lower() assert "new session started" in response.content.lower()
await loop.close_mcp()
assert archived_count == 3 assert archived_count == 3
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -578,3 +589,31 @@ class TestNewCommandArchival:
assert response is not None assert response is not None
assert "new session started" in response.content.lower() assert "new session started" in response.content.lower()
assert loop.sessions.get_or_create("cli:test").messages == [] assert loop.sessions.get_or_create("cli:test").messages == []
@pytest.mark.asyncio
async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
"""close_mcp waits for background tasks to complete."""
from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
archived = asyncio.Event()
async def _slow_consolidate(_messages) -> bool:
await asyncio.sleep(0.1)
archived.set()
return True
loop.memory_consolidator.consolidate_messages = _slow_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
assert not archived.is_set()
await loop.close_mcp()
assert archived.is_set()

View File

@@ -0,0 +1,146 @@
from nanobot.session.manager import Session
def _assert_no_orphans(history: list[dict]) -> None:
"""Assert every tool result in history has a matching assistant tool_call."""
declared = {
tc["id"]
for m in history if m.get("role") == "assistant"
for tc in (m.get("tool_calls") or [])
}
orphans = [
m.get("tool_call_id") for m in history
if m.get("role") == "tool" and m.get("tool_call_id") not in declared
]
assert orphans == [], f"orphan tool_call_ids: {orphans}"
def _tool_turn(prefix: str, idx: int) -> list[dict]:
"""Helper: one assistant with 2 tool_calls + 2 tool results."""
return [
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": f"{prefix}_{idx}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": f"{prefix}_{idx}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_a", "name": "x", "content": "ok"},
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_b", "name": "y", "content": "ok"},
]
# --- Original regression test (from PR 2075) ---
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
session = Session(key="telegram:test")
session.messages.append({"role": "user", "content": "old turn"})
for i in range(20):
session.messages.extend(_tool_turn("old", i))
session.messages.append({"role": "user", "content": "problem turn"})
for i in range(25):
session.messages.extend(_tool_turn("cur", i))
session.messages.append({"role": "user", "content": "new telegram question"})
history = session.get_history(max_messages=100)
_assert_no_orphans(history)
# --- Positive test: legitimate pairs survive trimming ---
def test_legitimate_tool_pairs_preserved_after_trim():
"""Complete tool-call groups within the window must not be dropped."""
session = Session(key="test:positive")
session.messages.append({"role": "user", "content": "hello"})
for i in range(5):
session.messages.extend(_tool_turn("ok", i))
session.messages.append({"role": "assistant", "content": "done"})
history = session.get_history(max_messages=500)
_assert_no_orphans(history)
tool_ids = [m["tool_call_id"] for m in history if m.get("role") == "tool"]
assert len(tool_ids) == 10
assert history[0]["role"] == "user"
# --- last_consolidated > 0 ---
def test_orphan_trim_with_last_consolidated():
"""Orphan trimming works correctly when session is partially consolidated."""
session = Session(key="test:consolidated")
for i in range(10):
session.messages.append({"role": "user", "content": f"old {i}"})
session.messages.extend(_tool_turn("cons", i))
session.last_consolidated = 30
session.messages.append({"role": "user", "content": "recent"})
for i in range(15):
session.messages.extend(_tool_turn("new", i))
session.messages.append({"role": "user", "content": "latest"})
history = session.get_history(max_messages=20)
_assert_no_orphans(history)
assert all(m.get("role") != "tool" or m["tool_call_id"].startswith("new_") for m in history)
# --- Edge: no tool messages at all ---
def test_no_tool_messages_unchanged():
session = Session(key="test:plain")
for i in range(5):
session.messages.append({"role": "user", "content": f"q{i}"})
session.messages.append({"role": "assistant", "content": f"a{i}"})
history = session.get_history(max_messages=6)
assert len(history) == 6
_assert_no_orphans(history)
# --- Edge: all leading messages are orphan tool results ---
def test_all_orphan_prefix_stripped():
"""If the window starts with orphan tool results and nothing else, they're all dropped."""
session = Session(key="test:all-orphan")
session.messages.append({"role": "tool", "tool_call_id": "gone_1", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": "gone_2", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "fresh start"})
session.messages.append({"role": "assistant", "content": "hi"})
history = session.get_history(max_messages=500)
_assert_no_orphans(history)
assert history[0]["role"] == "user"
assert len(history) == 2
# --- Edge: empty session ---
def test_empty_session_history():
session = Session(key="test:empty")
history = session.get_history(max_messages=500)
assert history == []
# --- Window cuts mid-group: assistant present but some tool results orphaned ---
def test_window_cuts_mid_tool_group():
"""If the window starts between an assistant's tool results, the partial group is trimmed."""
session = Session(key="test:mid-cut")
session.messages.append({"role": "user", "content": "setup"})
session.messages.append({
"role": "assistant", "content": None,
"tool_calls": [
{"id": "split_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": "split_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
})
session.messages.append({"role": "tool", "tool_call_id": "split_a", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": "split_b", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "next"})
session.messages.extend(_tool_turn("intact", 0))
session.messages.append({"role": "assistant", "content": "final"})
# Window of 6 should cut off the "setup" user msg and the assistant with split_a/split_b,
# leaving orphan tool results for split_a at the front.
history = session.get_history(max_messages=6)
_assert_no_orphans(history)