fix(memory): validate save_memory payload before persisting

This commit is contained in:
shenchengtsi
2026-03-10 11:32:11 +08:00
parent 99b896f5d4
commit b24d6ffc94
2 changed files with 116 additions and 11 deletions

View File

@@ -139,15 +139,30 @@ class MemoryStore:
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False return False
if entry := args.get("history_entry"): if "history_entry" not in args or "memory_update" not in args:
if not isinstance(entry, str): logger.warning("Memory consolidation: save_memory payload missing required fields")
entry = json.dumps(entry, ensure_ascii=False) return False
self.append_history(entry)
if update := args.get("memory_update"): entry = args["history_entry"]
if not isinstance(update, str): update = args["memory_update"]
update = json.dumps(update, ensure_ascii=False)
if update != current_memory: if entry is None or update is None:
self.write_long_term(update) logger.warning("Memory consolidation: save_memory payload contains null required fields")
return False
if not isinstance(entry, str):
entry = json.dumps(entry, ensure_ascii=False)
if not isinstance(update, str):
update = json.dumps(update, ensure_ascii=False)
entry = entry.strip()
if not entry:
logger.warning("Memory consolidation: history_entry is empty after normalization")
return False
self.append_history(entry)
if update != current_memory:
self.write_long_term(update)
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)

View File

@@ -97,7 +97,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse( response = LLMResponse(
content=None, content=None,
tool_calls=[ tool_calls=[
@@ -152,7 +151,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse( response = LLMResponse(
content=None, content=None,
tool_calls=[ tool_calls=[
@@ -220,3 +218,95 @@ class TestMemoryConsolidationTypeHandling:
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False assert result is False
@pytest.mark.asyncio
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not persist partial results when required fields are missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"memory_update": "# Memory\nOnly memory update"},
)
],
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
assert session.last_consolidated == 0
@pytest.mark.asyncio
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not append history if memory_update is missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"history_entry": "[2026-01-01] Partial output."},
)
],
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
assert session.last_consolidated == 0
@pytest.mark.asyncio
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Null required fields should be rejected before persistence."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry=None,
memory_update="# Memory\nUser likes testing.",
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
assert session.last_consolidated == 0
@pytest.mark.asyncio
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Empty history entries should be rejected to avoid blank archival records."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat = AsyncMock(
return_value=_make_tool_response(
history_entry=" ",
memory_update="# Memory\nUser likes testing.",
)
)
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
assert session.last_consolidated == 0