merge: PR #1810 validate save_memory payload + raw-archive fallback

This commit is contained in:
Xubin Ren
2026-03-13 03:54:53 +00:00
2 changed files with 180 additions and 10 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import weakref import weakref
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
@@ -74,10 +75,13 @@ def _is_tool_choice_unsupported(content: str | None) -> bool:
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
_MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3
def __init__(self, workspace: Path): def __init__(self, workspace: Path):
self.memory_dir = ensure_dir(workspace / "memory") self.memory_dir = ensure_dir(workspace / "memory")
self.memory_file = self.memory_dir / "MEMORY.md" self.memory_file = self.memory_dir / "MEMORY.md"
self.history_file = self.memory_dir / "HISTORY.md" self.history_file = self.memory_dir / "HISTORY.md"
self._consecutive_failures = 0
def read_long_term(self) -> str: def read_long_term(self) -> str:
if self.memory_file.exists(): if self.memory_file.exists():
@@ -159,25 +163,60 @@ class MemoryStore:
len(response.content or ""), len(response.content or ""),
(response.content or "")[:200], (response.content or "")[:200],
) )
return False return self._fail_or_raw_archive(messages)
args = _normalize_save_memory_args(response.tool_calls[0].arguments) args = _normalize_save_memory_args(response.tool_calls[0].arguments)
if args is None: if args is None:
logger.warning("Memory consolidation: unexpected save_memory arguments") logger.warning("Memory consolidation: unexpected save_memory arguments")
return False return self._fail_or_raw_archive(messages)
if entry := args.get("history_entry"): if "history_entry" not in args or "memory_update" not in args:
self.append_history(_ensure_text(entry)) logger.warning("Memory consolidation: save_memory payload missing required fields")
if update := args.get("memory_update"): return self._fail_or_raw_archive(messages)
entry = args["history_entry"]
update = args["memory_update"]
if entry is None or update is None:
logger.warning("Memory consolidation: save_memory payload contains null required fields")
return self._fail_or_raw_archive(messages)
entry = _ensure_text(entry).strip()
if not entry:
logger.warning("Memory consolidation: history_entry is empty after normalization")
return self._fail_or_raw_archive(messages)
self.append_history(entry)
update = _ensure_text(update) update = _ensure_text(update)
if update != current_memory: if update != current_memory:
self.write_long_term(update) self.write_long_term(update)
self._consecutive_failures = 0
logger.info("Memory consolidation done for {} messages", len(messages)) logger.info("Memory consolidation done for {} messages", len(messages))
return True return True
except Exception: except Exception:
logger.exception("Memory consolidation failed") logger.exception("Memory consolidation failed")
return self._fail_or_raw_archive(messages)
def _fail_or_raw_archive(self, messages: list[dict]) -> bool:
"""Increment failure count; after threshold, raw-archive messages and return True."""
self._consecutive_failures += 1
if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE:
return False return False
self._raw_archive(messages)
self._consecutive_failures = 0
return True
def _raw_archive(self, messages: list[dict]) -> None:
"""Fallback: dump raw messages to HISTORY.md without LLM summarization."""
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
self.append_history(
f"[{ts}] [RAW] {len(messages)} messages\n"
f"{self._format_messages(messages)}"
)
logger.warning(
"Memory consolidation degraded: raw-archived {} messages", len(messages)
)
class MemoryConsolidator: class MemoryConsolidator:

View File

@@ -112,7 +112,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse( response = LLMResponse(
content=None, content=None,
tool_calls=[ tool_calls=[
@@ -170,7 +169,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse( response = LLMResponse(
content=None, content=None,
tool_calls=[ tool_calls=[
@@ -242,6 +240,94 @@ class TestMemoryConsolidationTypeHandling:
assert result is False assert result is False
@pytest.mark.asyncio
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not persist partial results when required fields are missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"memory_update": "# Memory\nOnly memory update"},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not append history if memory_update is missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"history_entry": "[2026-01-01] Partial output."},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Null required fields should be rejected before persistence."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=None,
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Empty history entries should be rejected to avoid blank archival records."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=" ",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None: async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
@@ -345,3 +431,48 @@ class TestMemoryConsolidationTypeHandling:
assert result is False assert result is False
assert not store.history_file.exists() assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
"""After 3 consecutive failures, raw-archive messages and return True."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
messages = _make_messages(message_count=10)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is True
assert store.history_file.exists()
content = store.history_file.read_text()
assert "[RAW]" in content
assert "10 messages" in content
assert "msg0" in content
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
"""A successful consolidation resets the failure counter."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
ok_resp = _make_tool_response(
history_entry="[2026-01-01] OK.",
memory_update="# Memory\nOK.",
)
messages = _make_messages(message_count=10)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 2
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
assert await store.consolidate(messages, provider, "m") is True
assert store._consecutive_failures == 0
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 1