From 92f3d5a8b317321902cbe8ce4200e9d1e8431bfb Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 16 Mar 2026 09:21:21 +0000 Subject: [PATCH] fix: keep truncated session history tool-call consistent --- nanobot/session/manager.py | 56 ++++----- tests/test_session_manager_history.py | 172 ++++++++++++++++++++------ 2 files changed, 157 insertions(+), 71 deletions(-) diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index acb6d7f..f8244e5 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -44,37 +44,27 @@ class Session: self.updated_at = datetime.now() @staticmethod - def _tool_call_ids(messages: list[dict[str, Any]]) -> set[str]: - ids: set[str] = set() - for message in messages: - if message.get("role") != "assistant": - continue - for tool_call in message.get("tool_calls") or []: - if not isinstance(tool_call, dict): - continue - tool_id = tool_call.get("id") - if tool_id: - ids.add(str(tool_id)) - return ids - - @classmethod - def _trim_orphan_tool_messages(cls, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Drop the oldest prefix that contains tool results without matching tool_calls.""" - trimmed = list(messages) - while trimmed: - tool_call_ids = cls._tool_call_ids(trimmed) - cut_to = None - for index, message in enumerate(trimmed): - if message.get("role") != "tool": - continue - tool_call_id = message.get("tool_call_id") - if tool_call_id and str(tool_call_id) not in tool_call_ids: - cut_to = index + 1 - break - if cut_to is None: - break - trimmed = trimmed[cut_to:] - return trimmed + def _find_legal_start(messages: list[dict[str, Any]]) -> int: + """Find first index where every tool result has a matching assistant tool_call.""" + declared: set[str] = set() + start = 0 + for i, msg in enumerate(messages): + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + elif role == "tool": + tid = msg.get("tool_call_id") + if tid and str(tid) not in declared: + start = i + 1 + declared.clear() + for prev in messages[start:i + 1]: + if prev.get("role") == "assistant": + for tc in prev.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + return start def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" @@ -89,7 +79,9 @@ class Session: # Some providers reject orphan tool results if the matching assistant # tool_calls message fell outside the fixed-size history window. - sliced = self._trim_orphan_tool_messages(sliced) + start = self._find_legal_start(sliced) + if start: + sliced = sliced[start:] out: list[dict[str, Any]] = [] for message in sliced: diff --git a/tests/test_session_manager_history.py b/tests/test_session_manager_history.py index a0effac..4f56344 100644 --- a/tests/test_session_manager_history.py +++ b/tests/test_session_manager_history.py @@ -1,52 +1,146 @@ from nanobot.session.manager import Session +def _assert_no_orphans(history: list[dict]) -> None: + """Assert every tool result in history has a matching assistant tool_call.""" + declared = { + tc["id"] + for m in history if m.get("role") == "assistant" + for tc in (m.get("tool_calls") or []) + } + orphans = [ + m.get("tool_call_id") for m in history + if m.get("role") == "tool" and m.get("tool_call_id") not in declared + ] + assert orphans == [], f"orphan tool_call_ids: {orphans}" + + +def _tool_turn(prefix: str, idx: int) -> list[dict]: + """Helper: one assistant with 2 tool_calls + 2 tool results.""" + return [ + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": f"{prefix}_{idx}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, + {"id": f"{prefix}_{idx}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, + ], + }, + {"role": "tool", "tool_call_id": f"{prefix}_{idx}_a", "name": "x", "content": "ok"}, + {"role": "tool", "tool_call_id": f"{prefix}_{idx}_b", "name": "y", "content": "ok"}, + ] + + +# --- Original regression test (from PR 2075) --- + def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls(): session = Session(key="telegram:test") session.messages.append({"role": "user", "content": "old turn"}) - for i in range(20): - session.messages.append( - { - "role": "assistant", - "content": None, - "tool_calls": [ - {"id": f"old_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, - {"id": f"old_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, - ], - } - ) - session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_a", "name": "x", "content": "ok"}) - session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_b", "name": "y", "content": "ok"}) - + session.messages.extend(_tool_turn("old", i)) session.messages.append({"role": "user", "content": "problem turn"}) for i in range(25): - session.messages.append( - { - "role": "assistant", - "content": None, - "tool_calls": [ - {"id": f"cur_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, - {"id": f"cur_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, - ], - } - ) - session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_a", "name": "x", "content": "ok"}) - session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_b", "name": "y", "content": "ok"}) - + session.messages.extend(_tool_turn("cur", i)) session.messages.append({"role": "user", "content": "new telegram question"}) history = session.get_history(max_messages=100) - assistant_ids = { - tool_call["id"] - for message in history - if message.get("role") == "assistant" - for tool_call in (message.get("tool_calls") or []) - } - orphan_tool_ids = [ - message.get("tool_call_id") - for message in history - if message.get("role") == "tool" and message.get("tool_call_id") not in assistant_ids - ] + _assert_no_orphans(history) - assert orphan_tool_ids == [] + +# --- Positive test: legitimate pairs survive trimming --- + +def test_legitimate_tool_pairs_preserved_after_trim(): + """Complete tool-call groups within the window must not be dropped.""" + session = Session(key="test:positive") + session.messages.append({"role": "user", "content": "hello"}) + for i in range(5): + session.messages.extend(_tool_turn("ok", i)) + session.messages.append({"role": "assistant", "content": "done"}) + + history = session.get_history(max_messages=500) + _assert_no_orphans(history) + tool_ids = [m["tool_call_id"] for m in history if m.get("role") == "tool"] + assert len(tool_ids) == 10 + assert history[0]["role"] == "user" + + +# --- last_consolidated > 0 --- + +def test_orphan_trim_with_last_consolidated(): + """Orphan trimming works correctly when session is partially consolidated.""" + session = Session(key="test:consolidated") + for i in range(10): + session.messages.append({"role": "user", "content": f"old {i}"}) + session.messages.extend(_tool_turn("cons", i)) + session.last_consolidated = 30 + + session.messages.append({"role": "user", "content": "recent"}) + for i in range(15): + session.messages.extend(_tool_turn("new", i)) + session.messages.append({"role": "user", "content": "latest"}) + + history = session.get_history(max_messages=20) + _assert_no_orphans(history) + assert all(m.get("role") != "tool" or m["tool_call_id"].startswith("new_") for m in history) + + +# --- Edge: no tool messages at all --- + +def test_no_tool_messages_unchanged(): + session = Session(key="test:plain") + for i in range(5): + session.messages.append({"role": "user", "content": f"q{i}"}) + session.messages.append({"role": "assistant", "content": f"a{i}"}) + + history = session.get_history(max_messages=6) + assert len(history) == 6 + _assert_no_orphans(history) + + +# --- Edge: all leading messages are orphan tool results --- + +def test_all_orphan_prefix_stripped(): + """If the window starts with orphan tool results and nothing else, they're all dropped.""" + session = Session(key="test:all-orphan") + session.messages.append({"role": "tool", "tool_call_id": "gone_1", "name": "x", "content": "ok"}) + session.messages.append({"role": "tool", "tool_call_id": "gone_2", "name": "y", "content": "ok"}) + session.messages.append({"role": "user", "content": "fresh start"}) + session.messages.append({"role": "assistant", "content": "hi"}) + + history = session.get_history(max_messages=500) + _assert_no_orphans(history) + assert history[0]["role"] == "user" + assert len(history) == 2 + + +# --- Edge: empty session --- + +def test_empty_session_history(): + session = Session(key="test:empty") + history = session.get_history(max_messages=500) + assert history == [] + + +# --- Window cuts mid-group: assistant present but some tool results orphaned --- + +def test_window_cuts_mid_tool_group(): + """If the window starts between an assistant's tool results, the partial group is trimmed.""" + session = Session(key="test:mid-cut") + session.messages.append({"role": "user", "content": "setup"}) + session.messages.append({ + "role": "assistant", "content": None, + "tool_calls": [ + {"id": "split_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, + {"id": "split_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, + ], + }) + session.messages.append({"role": "tool", "tool_call_id": "split_a", "name": "x", "content": "ok"}) + session.messages.append({"role": "tool", "tool_call_id": "split_b", "name": "y", "content": "ok"}) + session.messages.append({"role": "user", "content": "next"}) + session.messages.extend(_tool_turn("intact", 0)) + session.messages.append({"role": "assistant", "content": "final"}) + + # Window of 6 should cut off the "setup" user msg and the assistant with split_a/split_b, + # leaving orphan tool results for split_a at the front. + history = session.get_history(max_messages=6) + _assert_no_orphans(history)