Fix orphan tool results in truncated session history

This commit is contained in:
rise
2026-03-16 08:56:39 +08:00
committed by Xubin Ren
parent 94b5956309
commit db276bdf2b
2 changed files with 98 additions and 9 deletions

View File

@@ -43,23 +43,60 @@ class Session:
self.messages.append(msg)
self.updated_at = datetime.now()
@staticmethod
def _tool_call_ids(messages: list[dict[str, Any]]) -> set[str]:
ids: set[str] = set()
for message in messages:
if message.get("role") != "assistant":
continue
for tool_call in message.get("tool_calls") or []:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
if tool_id:
ids.add(str(tool_id))
return ids
@classmethod
def _trim_orphan_tool_messages(cls, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Drop the oldest prefix that contains tool results without matching tool_calls."""
trimmed = list(messages)
while trimmed:
tool_call_ids = cls._tool_call_ids(trimmed)
cut_to = None
for index, message in enumerate(trimmed):
if message.get("role") != "tool":
continue
tool_call_id = message.get("tool_call_id")
if tool_call_id and str(tool_call_id) not in tool_call_ids:
cut_to = index + 1
break
if cut_to is None:
break
trimmed = trimmed[cut_to:]
return trimmed
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
# Drop leading non-user messages to avoid orphaned tool_result blocks
for i, m in enumerate(sliced):
if m.get("role") == "user":
# Drop leading non-user messages to avoid starting mid-turn when possible.
for i, message in enumerate(sliced):
if message.get("role") == "user":
sliced = sliced[i:]
break
# Some providers reject orphan tool results if the matching assistant
# tool_calls message fell outside the fixed-size history window.
sliced = self._trim_orphan_tool_messages(sliced)
out: list[dict[str, Any]] = []
for m in sliced:
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
for k in ("tool_calls", "tool_call_id", "name"):
if k in m:
entry[k] = m[k]
for message in sliced:
entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")}
for key in ("tool_calls", "tool_call_id", "name"):
if key in message:
entry[key] = message[key]
out.append(entry)
return out

View File

@@ -0,0 +1,52 @@
from nanobot.session.manager import Session
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
session = Session(key="telegram:test")
session.messages.append({"role": "user", "content": "old turn"})
for i in range(20):
session.messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": f"old_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": f"old_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
}
)
session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_a", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_b", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "problem turn"})
for i in range(25):
session.messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": f"cur_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": f"cur_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
}
)
session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_a", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_b", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "new telegram question"})
history = session.get_history(max_messages=100)
assistant_ids = {
tool_call["id"]
for message in history
if message.get("role") == "assistant"
for tool_call in (message.get("tool_calls") or [])
}
orphan_tool_ids = [
message.get("tool_call_id")
for message in history
if message.get("role") == "tool" and message.get("tool_call_id") not in assistant_ids
]
assert orphan_tool_ids == []