fix(memory): validate save_memory payload and raw-archive on repeated failure

- Require both history_entry and memory_update, reject null/empty values
- Fallback to tool_choice=auto when provider rejects forced function call
- After 3 consecutive consolidation failures, raw-archive messages to
  HISTORY.md without LLM summarization to prevent context window overflow
This commit is contained in:
Xubin Ren
2026-03-13 03:53:50 +00:00
parent 60c29702cc
commit 6d3a0ab6c9
2 changed files with 75 additions and 5 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import weakref import weakref
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
@@ -74,10 +75,13 @@ def _is_tool_choice_unsupported(content: str | None) -> bool:
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
_MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.memory_dir = ensure_dir(workspace / "memory") self.memory_dir = ensure_dir(workspace / "memory")
self.memory_file = self.memory_dir / "MEMORY.md" self.memory_file = self.memory_dir / "MEMORY.md"
self.history_file = self.memory_dir / "HISTORY.md" self.history_file = self.memory_dir / "HISTORY.md"
self._consecutive_failures = 0
def read_long_term(self) -> str: def read_long_term(self) -> str:
if self.memory_file.exists(): if self.memory_file.exists():
@@ -159,39 +163,60 @@ class MemoryStore:
len(response.content or ""), len(response.content or ""),
(response.content or "")[:200], (response.content or "")[:200],
) )
return False return self._fail_or_raw_archive(messages)
args = _normalize_save_memory_args(response.tool_calls[0].arguments) args = _normalize_save_memory_args(response.tool_calls[0].arguments)
if args is None: if args is None:
logger.warning("Memory consolidation: unexpected save_memory arguments") logger.warning("Memory consolidation: unexpected save_memory arguments")
return False return self._fail_or_raw_archive(messages)
if "history_entry" not in args or "memory_update" not in args: if "history_entry" not in args or "memory_update" not in args:
logger.warning("Memory consolidation: save_memory payload missing required fields") logger.warning("Memory consolidation: save_memory payload missing required fields")
return False return self._fail_or_raw_archive(messages)
entry = args["history_entry"] entry = args["history_entry"]
update = args["memory_update"] update = args["memory_update"]
if entry is None or update is None: if entry is None or update is None:
logger.warning("Memory consolidation: save_memory payload contains null required fields") logger.warning("Memory consolidation: save_memory payload contains null required fields")
return False return self._fail_or_raw_archive(messages)
entry = _ensure_text(entry).strip() entry = _ensure_text(entry).strip()
if not entry: if not entry:
logger.warning("Memory consolidation: history_entry is empty after normalization") logger.warning("Memory consolidation: history_entry is empty after normalization")
return False return self._fail_or_raw_archive(messages)
self.append_history(entry) self.append_history(entry)
update = _ensure_text(update) update = _ensure_text(update)
if update != current_memory: if update != current_memory:
self.write_long_term(update) self.write_long_term(update)
self._consecutive_failures = 0
logger.info("Memory consolidation done for {} messages", len(messages)) logger.info("Memory consolidation done for {} messages", len(messages))
return True return True
except Exception: except Exception:
logger.exception("Memory consolidation failed") logger.exception("Memory consolidation failed")
return self._fail_or_raw_archive(messages)
def _fail_or_raw_archive(self, messages: list[dict]) -> bool:
"""Increment failure count; after threshold, raw-archive messages and return True."""
self._consecutive_failures += 1
if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE:
return False return False
self._raw_archive(messages)
self._consecutive_failures = 0
return True
def _raw_archive(self, messages: list[dict]) -> None:
"""Fallback: dump raw messages to HISTORY.md without LLM summarization."""
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
self.append_history(
f"[{ts}] [RAW] {len(messages)} messages\n"
f"{self._format_messages(messages)}"
)
logger.warning(
"Memory consolidation degraded: raw-archived {} messages", len(messages)
)
class MemoryConsolidator: class MemoryConsolidator:

View File

@@ -431,3 +431,48 @@ class TestMemoryConsolidationTypeHandling:
assert result is False assert result is False
assert not store.history_file.exists() assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
"""After 3 consecutive failures, raw-archive messages and return True."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
messages = _make_messages(message_count=10)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is True
assert store.history_file.exists()
content = store.history_file.read_text()
assert "[RAW]" in content
assert "10 messages" in content
assert "msg0" in content
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
"""A successful consolidation resets the failure counter."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
ok_resp = _make_tool_response(
history_entry="[2026-01-01] OK.",
memory_update="# Memory\nOK.",
)
messages = _make_messages(message_count=10)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 2
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
assert await store.consolidate(messages, provider, "m") is True
assert store._consecutive_failures == 0
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 1