fix(memory): validate save_memory payload before persisting
This commit is contained in:
@@ -139,15 +139,30 @@ class MemoryStore:
|
|||||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if entry := args.get("history_entry"):
|
if "history_entry" not in args or "memory_update" not in args:
|
||||||
if not isinstance(entry, str):
|
logger.warning("Memory consolidation: save_memory payload missing required fields")
|
||||||
entry = json.dumps(entry, ensure_ascii=False)
|
return False
|
||||||
self.append_history(entry)
|
|
||||||
if update := args.get("memory_update"):
|
entry = args["history_entry"]
|
||||||
if not isinstance(update, str):
|
update = args["memory_update"]
|
||||||
update = json.dumps(update, ensure_ascii=False)
|
|
||||||
if update != current_memory:
|
if entry is None or update is None:
|
||||||
self.write_long_term(update)
|
logger.warning("Memory consolidation: save_memory payload contains null required fields")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not isinstance(entry, str):
|
||||||
|
entry = json.dumps(entry, ensure_ascii=False)
|
||||||
|
if not isinstance(update, str):
|
||||||
|
update = json.dumps(update, ensure_ascii=False)
|
||||||
|
|
||||||
|
entry = entry.strip()
|
||||||
|
if not entry:
|
||||||
|
logger.warning("Memory consolidation: history_entry is empty after normalization")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.append_history(entry)
|
||||||
|
if update != current_memory:
|
||||||
|
self.write_long_term(update)
|
||||||
|
|
||||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||||
|
|||||||
@@ -97,7 +97,6 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
|
|
||||||
# Simulate arguments being a JSON string (not yet parsed)
|
|
||||||
response = LLMResponse(
|
response = LLMResponse(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -152,7 +151,6 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
store = MemoryStore(tmp_path)
|
store = MemoryStore(tmp_path)
|
||||||
provider = AsyncMock()
|
provider = AsyncMock()
|
||||||
|
|
||||||
# Simulate arguments being a list containing a dict
|
|
||||||
response = LLMResponse(
|
response = LLMResponse(
|
||||||
content=None,
|
content=None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -220,3 +218,95 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Do not persist partial results when required fields are missing."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={"memory_update": "# Memory\nOnly memory update"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Do not append history if memory_update is missing."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments={"history_entry": "[2026-01-01] Partial output."},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Null required fields should be rejected before persistence."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry=None,
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
|
||||||
|
"""Empty history entries should be rejected to avoid blank archival records."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry=" ",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert not store.history_file.exists()
|
||||||
|
assert not store.memory_file.exists()
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user