diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index f0a6484..acb6d7f 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -43,23 +43,60 @@ class Session: self.messages.append(msg) self.updated_at = datetime.now() + @staticmethod + def _tool_call_ids(messages: list[dict[str, Any]]) -> set[str]: + ids: set[str] = set() + for message in messages: + if message.get("role") != "assistant": + continue + for tool_call in message.get("tool_calls") or []: + if not isinstance(tool_call, dict): + continue + tool_id = tool_call.get("id") + if tool_id: + ids.add(str(tool_id)) + return ids + + @classmethod + def _trim_orphan_tool_messages(cls, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Drop the oldest prefix that contains tool results without matching tool_calls.""" + trimmed = list(messages) + while trimmed: + tool_call_ids = cls._tool_call_ids(trimmed) + cut_to = None + for index, message in enumerate(trimmed): + if message.get("role") != "tool": + continue + tool_call_id = message.get("tool_call_id") + if tool_call_id and str(tool_call_id) not in tool_call_ids: + cut_to = index + 1 + break + if cut_to is None: + break + trimmed = trimmed[cut_to:] + return trimmed + def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: - """Return unconsolidated messages for LLM input, aligned to a user turn.""" + """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" unconsolidated = self.messages[self.last_consolidated:] sliced = unconsolidated[-max_messages:] - # Drop leading non-user messages to avoid orphaned tool_result blocks - for i, m in enumerate(sliced): - if m.get("role") == "user": + # Drop leading non-user messages to avoid starting mid-turn when possible. + for i, message in enumerate(sliced): + if message.get("role") == "user": sliced = sliced[i:] break + # Some providers reject orphan tool results if the matching assistant + # tool_calls message fell outside the fixed-size history window. + sliced = self._trim_orphan_tool_messages(sliced) + out: list[dict[str, Any]] = [] - for m in sliced: - entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")} - for k in ("tool_calls", "tool_call_id", "name"): - if k in m: - entry[k] = m[k] + for message in sliced: + entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")} + for key in ("tool_calls", "tool_call_id", "name"): + if key in message: + entry[key] = message[key] out.append(entry) return out diff --git a/tests/test_session_manager_history.py b/tests/test_session_manager_history.py new file mode 100644 index 0000000..a0effac --- /dev/null +++ b/tests/test_session_manager_history.py @@ -0,0 +1,52 @@ +from nanobot.session.manager import Session + + +def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls(): + session = Session(key="telegram:test") + session.messages.append({"role": "user", "content": "old turn"}) + + for i in range(20): + session.messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": f"old_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, + {"id": f"old_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, + ], + } + ) + session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_a", "name": "x", "content": "ok"}) + session.messages.append({"role": "tool", "tool_call_id": f"old_{i}_b", "name": "y", "content": "ok"}) + + session.messages.append({"role": "user", "content": "problem turn"}) + for i in range(25): + session.messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"id": f"cur_{i}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}}, + {"id": f"cur_{i}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}}, + ], + } + ) + session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_a", "name": "x", "content": "ok"}) + session.messages.append({"role": "tool", "tool_call_id": f"cur_{i}_b", "name": "y", "content": "ok"}) + + session.messages.append({"role": "user", "content": "new telegram question"}) + + history = session.get_history(max_messages=100) + assistant_ids = { + tool_call["id"] + for message in history + if message.get("role") == "assistant" + for tool_call in (message.get("tool_calls") or []) + } + orphan_tool_ids = [ + message.get("tool_call_id") + for message in history + if message.get("role") == "tool" and message.get("tool_call_id") not in assistant_ids + ] + + assert orphan_tool_ids == []