fix: keep truncated session history tool-call consistent
This commit is contained in:
@@ -44,37 +44,27 @@ class Session:
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
@staticmethod
|
||||
def _tool_call_ids(messages: list[dict[str, Any]]) -> set[str]:
|
||||
ids: set[str] = set()
|
||||
for message in messages:
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
if tool_id:
|
||||
ids.add(str(tool_id))
|
||||
return ids
|
||||
|
||||
@classmethod
|
||||
def _trim_orphan_tool_messages(cls, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Drop the oldest prefix that contains tool results without matching tool_calls."""
|
||||
trimmed = list(messages)
|
||||
while trimmed:
|
||||
tool_call_ids = cls._tool_call_ids(trimmed)
|
||||
cut_to = None
|
||||
for index, message in enumerate(trimmed):
|
||||
if message.get("role") != "tool":
|
||||
continue
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if tool_call_id and str(tool_call_id) not in tool_call_ids:
|
||||
cut_to = index + 1
|
||||
break
|
||||
if cut_to is None:
|
||||
break
|
||||
trimmed = trimmed[cut_to:]
|
||||
return trimmed
|
||||
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
|
||||
"""Find first index where every tool result has a matching assistant tool_call."""
|
||||
declared: set[str] = set()
|
||||
start = 0
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
elif role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
start = i + 1
|
||||
declared.clear()
|
||||
for prev in messages[start:i + 1]:
|
||||
if prev.get("role") == "assistant":
|
||||
for tc in prev.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
return start
|
||||
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||
@@ -89,7 +79,9 @@ class Session:
|
||||
|
||||
# Some providers reject orphan tool results if the matching assistant
|
||||
# tool_calls message fell outside the fixed-size history window.
|
||||
sliced = self._trim_orphan_tool_messages(sliced)
|
||||
start = self._find_legal_start(sliced)
|
||||
if start:
|
||||
sliced = sliced[start:]
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for message in sliced:
|
||||
|
||||
@@ -1,52 +1,146 @@
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
def _assert_no_orphans(history: list[dict]) -> None:
|
||||
"""Assert every tool result in history has a matching assistant tool_call."""
|
||||
declared = {
|
||||
tc["id"]
|
||||
for m in history if m.get("role") == "assistant"
|
||||
for tc in (m.get("tool_calls") or [])
|
||||
}
|
||||
orphans = [
|
||||
m.get("tool_call_id") for m in history
|
||||
if m.get("role") == "tool" and m.get("tool_call_id") not in declared
|
||||
]
|
||||
assert orphans == [], f"orphan tool_call_ids: {orphans}"
|
||||
|
||||
|
||||
def _tool_turn(prefix: str, idx: int) -> list[dict]:
|
||||
"""Helper: one assistant with 2 tool_calls + 2 tool results."""
|
||||
return [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{"id": f"{prefix}_{idx}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||
{"id": f"{prefix}_{idx}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_a", "name": "x", "content": "ok"},
|
||||
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_b", "name": "y", "content": "ok"},
|
||||
]
|
||||
|
||||
|
||||
# --- Original regression test (from PR 2075) ---
|
||||
|
||||
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
|
||||
session = Session(key="telegram:test")
|
||||
session.messages.append({"role": "user", "content": "old turn"})
|
||||
|
||||
for i in range(20):
|
||||
session.messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{"id": f"old_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||
{"id": f"old_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_a", "name": "x", "content": "ok"})
|
||||
session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_b", "name": "y", "content": "ok"})
|
||||
|
||||
session.messages.extend(_tool_turn("old", i))
|
||||
session.messages.append({"role": "user", "content": "problem turn"})
|
||||
for i in range(25):
|
||||
session.messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{"id": f"cur_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||
{"id": f"cur_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_a", "name": "x", "content": "ok"})
|
||||
session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_b", "name": "y", "content": "ok"})
|
||||
|
||||
session.messages.extend(_tool_turn("cur", i))
|
||||
session.messages.append({"role": "user", "content": "new telegram question"})
|
||||
|
||||
history = session.get_history(max_messages=100)
|
||||
assistant_ids = {
|
||||
tool_call["id"]
|
||||
for message in history
|
||||
if message.get("role") == "assistant"
|
||||
for tool_call in (message.get("tool_calls") or [])
|
||||
}
|
||||
orphan_tool_ids = [
|
||||
message.get("tool_call_id")
|
||||
for message in history
|
||||
if message.get("role") == "tool" and message.get("tool_call_id") not in assistant_ids
|
||||
]
|
||||
_assert_no_orphans(history)
|
||||
|
||||
assert orphan_tool_ids == []
|
||||
|
||||
# --- Positive test: legitimate pairs survive trimming ---
|
||||
|
||||
def test_legitimate_tool_pairs_preserved_after_trim():
|
||||
"""Complete tool-call groups within the window must not be dropped."""
|
||||
session = Session(key="test:positive")
|
||||
session.messages.append({"role": "user", "content": "hello"})
|
||||
for i in range(5):
|
||||
session.messages.extend(_tool_turn("ok", i))
|
||||
session.messages.append({"role": "assistant", "content": "done"})
|
||||
|
||||
history = session.get_history(max_messages=500)
|
||||
_assert_no_orphans(history)
|
||||
tool_ids = [m["tool_call_id"] for m in history if m.get("role") == "tool"]
|
||||
assert len(tool_ids) == 10
|
||||
assert history[0]["role"] == "user"
|
||||
|
||||
|
||||
# --- last_consolidated > 0 ---
|
||||
|
||||
def test_orphan_trim_with_last_consolidated():
|
||||
"""Orphan trimming works correctly when session is partially consolidated."""
|
||||
session = Session(key="test:consolidated")
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"old {i}"})
|
||||
session.messages.extend(_tool_turn("cons", i))
|
||||
session.last_consolidated = 30
|
||||
|
||||
session.messages.append({"role": "user", "content": "recent"})
|
||||
for i in range(15):
|
||||
session.messages.extend(_tool_turn("new", i))
|
||||
session.messages.append({"role": "user", "content": "latest"})
|
||||
|
||||
history = session.get_history(max_messages=20)
|
||||
_assert_no_orphans(history)
|
||||
assert all(m.get("role") != "tool" or m["tool_call_id"].startswith("new_") for m in history)
|
||||
|
||||
|
||||
# --- Edge: no tool messages at all ---
|
||||
|
||||
def test_no_tool_messages_unchanged():
|
||||
session = Session(key="test:plain")
|
||||
for i in range(5):
|
||||
session.messages.append({"role": "user", "content": f"q{i}"})
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||
|
||||
history = session.get_history(max_messages=6)
|
||||
assert len(history) == 6
|
||||
_assert_no_orphans(history)
|
||||
|
||||
|
||||
# --- Edge: all leading messages are orphan tool results ---
|
||||
|
||||
def test_all_orphan_prefix_stripped():
|
||||
"""If the window starts with orphan tool results and nothing else, they're all dropped."""
|
||||
session = Session(key="test:all-orphan")
|
||||
session.messages.append({"role": "tool", "tool_call_id": "gone_1", "name": "x", "content": "ok"})
|
||||
session.messages.append({"role": "tool", "tool_call_id": "gone_2", "name": "y", "content": "ok"})
|
||||
session.messages.append({"role": "user", "content": "fresh start"})
|
||||
session.messages.append({"role": "assistant", "content": "hi"})
|
||||
|
||||
history = session.get_history(max_messages=500)
|
||||
_assert_no_orphans(history)
|
||||
assert history[0]["role"] == "user"
|
||||
assert len(history) == 2
|
||||
|
||||
|
||||
# --- Edge: empty session ---
|
||||
|
||||
def test_empty_session_history():
|
||||
session = Session(key="test:empty")
|
||||
history = session.get_history(max_messages=500)
|
||||
assert history == []
|
||||
|
||||
|
||||
# --- Window cuts mid-group: assistant present but some tool results orphaned ---
|
||||
|
||||
def test_window_cuts_mid_tool_group():
|
||||
"""If the window starts between an assistant's tool results, the partial group is trimmed."""
|
||||
session = Session(key="test:mid-cut")
|
||||
session.messages.append({"role": "user", "content": "setup"})
|
||||
session.messages.append({
|
||||
"role": "assistant", "content": None,
|
||||
"tool_calls": [
|
||||
{"id": "split_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||
{"id": "split_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||
],
|
||||
})
|
||||
session.messages.append({"role": "tool", "tool_call_id": "split_a", "name": "x", "content": "ok"})
|
||||
session.messages.append({"role": "tool", "tool_call_id": "split_b", "name": "y", "content": "ok"})
|
||||
session.messages.append({"role": "user", "content": "next"})
|
||||
session.messages.extend(_tool_turn("intact", 0))
|
||||
session.messages.append({"role": "assistant", "content": "final"})
|
||||
|
||||
# Window of 6 should cut off the "setup" user msg and the assistant with split_a/split_b,
|
||||
# leaving orphan tool results for split_a at the front.
|
||||
history = session.get_history(max_messages=6)
|
||||
_assert_no_orphans(history)
|
||||
|
||||
Reference in New Issue
Block a user