fix(memory): validate save_memory payload before persisting
This commit is contained in:
@@ -139,15 +139,30 @@ class MemoryStore:
|
||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||
return False
|
||||
|
||||
if entry := args.get("history_entry"):
|
||||
if not isinstance(entry, str):
|
||||
entry = json.dumps(entry, ensure_ascii=False)
|
||||
self.append_history(entry)
|
||||
if update := args.get("memory_update"):
|
||||
if not isinstance(update, str):
|
||||
update = json.dumps(update, ensure_ascii=False)
|
||||
if update != current_memory:
|
||||
self.write_long_term(update)
|
||||
if "history_entry" not in args or "memory_update" not in args:
|
||||
logger.warning("Memory consolidation: save_memory payload missing required fields")
|
||||
return False
|
||||
|
||||
entry = args["history_entry"]
|
||||
update = args["memory_update"]
|
||||
|
||||
if entry is None or update is None:
|
||||
logger.warning("Memory consolidation: save_memory payload contains null required fields")
|
||||
return False
|
||||
|
||||
if not isinstance(entry, str):
|
||||
entry = json.dumps(entry, ensure_ascii=False)
|
||||
if not isinstance(update, str):
|
||||
update = json.dumps(update, ensure_ascii=False)
|
||||
|
||||
entry = entry.strip()
|
||||
if not entry:
|
||||
logger.warning("Memory consolidation: history_entry is empty after normalization")
|
||||
return False
|
||||
|
||||
self.append_history(entry)
|
||||
if update != current_memory:
|
||||
self.write_long_term(update)
|
||||
|
||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||
|
||||
@@ -97,7 +97,6 @@ class TestMemoryConsolidationTypeHandling:
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a JSON string (not yet parsed)
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
@@ -152,7 +151,6 @@ class TestMemoryConsolidationTypeHandling:
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a list containing a dict
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
@@ -220,3 +218,95 @@ class TestMemoryConsolidationTypeHandling:
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Do not persist partial results when required fields are missing."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={"memory_update": "# Memory\nOnly memory update"},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Do not append history if memory_update is missing."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={"history_entry": "[2026-01-01] Partial output."},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Null required fields should be rejected before persistence."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry=None,
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||
"""Empty history entries should be rejected to avoid blank archival records."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry=" ",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
assert not store.memory_file.exists()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
Reference in New Issue
Block a user