Fix orphan tool results in truncated session history
This commit is contained in:
@@ -43,23 +43,60 @@ class Session:
|
|||||||
self.messages.append(msg)
|
self.messages.append(msg)
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_call_ids(messages: list[dict[str, Any]]) -> set[str]:
|
||||||
|
ids: set[str] = set()
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") != "assistant":
|
||||||
|
continue
|
||||||
|
for tool_call in message.get("tool_calls") or []:
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
continue
|
||||||
|
tool_id = tool_call.get("id")
|
||||||
|
if tool_id:
|
||||||
|
ids.add(str(tool_id))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _trim_orphan_tool_messages(cls, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Drop the oldest prefix that contains tool results without matching tool_calls."""
|
||||||
|
trimmed = list(messages)
|
||||||
|
while trimmed:
|
||||||
|
tool_call_ids = cls._tool_call_ids(trimmed)
|
||||||
|
cut_to = None
|
||||||
|
for index, message in enumerate(trimmed):
|
||||||
|
if message.get("role") != "tool":
|
||||||
|
continue
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
if tool_call_id and str(tool_call_id) not in tool_call_ids:
|
||||||
|
cut_to = index + 1
|
||||||
|
break
|
||||||
|
if cut_to is None:
|
||||||
|
break
|
||||||
|
trimmed = trimmed[cut_to:]
|
||||||
|
return trimmed
|
||||||
|
|
||||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||||
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||||
unconsolidated = self.messages[self.last_consolidated:]
|
unconsolidated = self.messages[self.last_consolidated:]
|
||||||
sliced = unconsolidated[-max_messages:]
|
sliced = unconsolidated[-max_messages:]
|
||||||
|
|
||||||
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
# Drop leading non-user messages to avoid starting mid-turn when possible.
|
||||||
for i, m in enumerate(sliced):
|
for i, message in enumerate(sliced):
|
||||||
if m.get("role") == "user":
|
if message.get("role") == "user":
|
||||||
sliced = sliced[i:]
|
sliced = sliced[i:]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Some providers reject orphan tool results if the matching assistant
|
||||||
|
# tool_calls message fell outside the fixed-size history window.
|
||||||
|
sliced = self._trim_orphan_tool_messages(sliced)
|
||||||
|
|
||||||
out: list[dict[str, Any]] = []
|
out: list[dict[str, Any]] = []
|
||||||
for m in sliced:
|
for message in sliced:
|
||||||
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
|
entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")}
|
||||||
for k in ("tool_calls", "tool_call_id", "name"):
|
for key in ("tool_calls", "tool_call_id", "name"):
|
||||||
if k in m:
|
if key in message:
|
||||||
entry[k] = m[k]
|
entry[key] = message[key]
|
||||||
out.append(entry)
|
out.append(entry)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
52
tests/test_session_manager_history.py
Normal file
52
tests/test_session_manager_history.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from nanobot.session.manager import Session
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
|
||||||
|
session = Session(key="telegram:test")
|
||||||
|
session.messages.append({"role": "user", "content": "old turn"})
|
||||||
|
|
||||||
|
for i in range(20):
|
||||||
|
session.messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": f"old_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||||
|
{"id": f"old_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_a", "name": "x", "content": "ok"})
|
||||||
|
session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_b", "name": "y", "content": "ok"})
|
||||||
|
|
||||||
|
session.messages.append({"role": "user", "content": "problem turn"})
|
||||||
|
for i in range(25):
|
||||||
|
session.messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": f"cur_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
|
||||||
|
{"id": f"cur_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_a", "name": "x", "content": "ok"})
|
||||||
|
session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_b", "name": "y", "content": "ok"})
|
||||||
|
|
||||||
|
session.messages.append({"role": "user", "content": "new telegram question"})
|
||||||
|
|
||||||
|
history = session.get_history(max_messages=100)
|
||||||
|
assistant_ids = {
|
||||||
|
tool_call["id"]
|
||||||
|
for message in history
|
||||||
|
if message.get("role") == "assistant"
|
||||||
|
for tool_call in (message.get("tool_calls") or [])
|
||||||
|
}
|
||||||
|
orphan_tool_ids = [
|
||||||
|
message.get("tool_call_id")
|
||||||
|
for message in history
|
||||||
|
if message.get("role") == "tool" and message.get("tool_call_id") not in assistant_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
assert orphan_tool_ids == []
|
||||||
Reference in New Issue
Block a user