From b24d6ffc941f7ff755898fa94485bab51e4415d4 Mon Sep 17 00:00:00 2001 From: shenchengtsi Date: Tue, 10 Mar 2026 11:32:11 +0800 Subject: [PATCH] fix(memory): validate save_memory payload before persisting --- nanobot/agent/memory.py | 33 ++++++--- tests/test_memory_consolidation_types.py | 94 +++++++++++++++++++++++- 2 files changed, 116 insertions(+), 11 deletions(-) diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 21fe77d..add014b 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -139,15 +139,30 @@ class MemoryStore: logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) return False - if entry := args.get("history_entry"): - if not isinstance(entry, str): - entry = json.dumps(entry, ensure_ascii=False) - self.append_history(entry) - if update := args.get("memory_update"): - if not isinstance(update, str): - update = json.dumps(update, ensure_ascii=False) - if update != current_memory: - self.write_long_term(update) + if "history_entry" not in args or "memory_update" not in args: + logger.warning("Memory consolidation: save_memory payload missing required fields") + return False + + entry = args["history_entry"] + update = args["memory_update"] + + if entry is None or update is None: + logger.warning("Memory consolidation: save_memory payload contains null required fields") + return False + + if not isinstance(entry, str): + entry = json.dumps(entry, ensure_ascii=False) + if not isinstance(update, str): + update = json.dumps(update, ensure_ascii=False) + + entry = entry.strip() + if not entry: + logger.warning("Memory consolidation: history_entry is empty after normalization") + return False + + self.append_history(entry) + if update != current_memory: + self.write_long_term(update) session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py index ff15584..4ba1ecd 100644 --- a/tests/test_memory_consolidation_types.py +++ b/tests/test_memory_consolidation_types.py @@ -97,7 +97,6 @@ class TestMemoryConsolidationTypeHandling: store = MemoryStore(tmp_path) provider = AsyncMock() - # Simulate arguments being a JSON string (not yet parsed) response = LLMResponse( content=None, tool_calls=[ @@ -152,7 +151,6 @@ class TestMemoryConsolidationTypeHandling: store = MemoryStore(tmp_path) provider = AsyncMock() - # Simulate arguments being a list containing a dict response = LLMResponse( content=None, tool_calls=[ @@ -220,3 +218,95 @@ class TestMemoryConsolidationTypeHandling: result = await store.consolidate(session, provider, "test-model", memory_window=50) assert result is False + + @pytest.mark.asyncio + async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None: + """Do not persist partial results when required fields are missing.""" + store = MemoryStore(tmp_path) + provider = AsyncMock() + provider.chat = AsyncMock( + return_value=LLMResponse( + content=None, + tool_calls=[ + ToolCallRequest( + id="call_1", + name="save_memory", + arguments={"memory_update": "# Memory\nOnly memory update"}, + ) + ], + ) + ) + session = _make_session(message_count=60) + + result = await store.consolidate(session, provider, "test-model", memory_window=50) + + assert result is False + assert not store.history_file.exists() + assert not store.memory_file.exists() + assert session.last_consolidated == 0 + + @pytest.mark.asyncio + async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None: + """Do not append history if memory_update is missing.""" + store = MemoryStore(tmp_path) + provider = AsyncMock() + provider.chat = AsyncMock( + return_value=LLMResponse( + content=None, + tool_calls=[ + ToolCallRequest( + id="call_1", + name="save_memory", + arguments={"history_entry": "[2026-01-01] Partial output."}, + ) + ], + ) + ) + session = _make_session(message_count=60) + + result = await store.consolidate(session, provider, "test-model", memory_window=50) + + assert result is False + assert not store.history_file.exists() + assert not store.memory_file.exists() + assert session.last_consolidated == 0 + + @pytest.mark.asyncio + async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None: + """Null required fields should be rejected before persistence.""" + store = MemoryStore(tmp_path) + provider = AsyncMock() + provider.chat = AsyncMock( + return_value=_make_tool_response( + history_entry=None, + memory_update="# Memory\nUser likes testing.", + ) + ) + session = _make_session(message_count=60) + + result = await store.consolidate(session, provider, "test-model", memory_window=50) + + assert result is False + assert not store.history_file.exists() + assert not store.memory_file.exists() + assert session.last_consolidated == 0 + + @pytest.mark.asyncio + async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None: + """Empty history entries should be rejected to avoid blank archival records.""" + store = MemoryStore(tmp_path) + provider = AsyncMock() + provider.chat = AsyncMock( + return_value=_make_tool_response( + history_entry=" ", + memory_update="# Memory\nUser likes testing.", + ) + ) + session = _make_session(message_count=60) + + result = await store.consolidate(session, provider, "test-model", memory_window=50) + + assert result is False + assert not store.history_file.exists() + assert not store.memory_file.exists() + assert session.last_consolidated == 0