fix(agent): count all message fields in token estimation

estimate_prompt_tokens() only counted the `content` text field, completely
missing tool_calls JSON (~72% of actual payload), reasoning_content,
tool_call_id, name, and per-message framing overhead. This caused the
memory consolidator to never trigger for tool-heavy sessions (e.g. cron
jobs), leading to context window overflow errors from the LLM provider.

Also adds reasoning_content counting and proper per-message overhead to
estimate_message_tokens() for consistent boundary detection.

Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-22 03:38:58 +00:00
committed by Xubin Ren
parent 48c71bb61e
commit 1c71489121

View File

@@ -115,7 +115,11 @@ def estimate_prompt_tokens(
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
) -> int: ) -> int:
"""Estimate prompt tokens with tiktoken.""" """Estimate prompt tokens with tiktoken.
Counts all fields that providers send to the LLM: content, tool_calls,
reasoning_content, tool_call_id, name, plus per-message framing overhead.
"""
try: try:
enc = tiktoken.get_encoding("cl100k_base") enc = tiktoken.get_encoding("cl100k_base")
parts: list[str] = [] parts: list[str] = []
@@ -129,9 +133,25 @@ def estimate_prompt_tokens(
txt = part.get("text", "") txt = part.get("text", "")
if txt: if txt:
parts.append(txt) parts.append(txt)
tc = msg.get("tool_calls")
if tc:
parts.append(json.dumps(tc, ensure_ascii=False))
rc = msg.get("reasoning_content")
if isinstance(rc, str) and rc:
parts.append(rc)
for key in ("name", "tool_call_id"):
value = msg.get(key)
if isinstance(value, str) and value:
parts.append(value)
if tools: if tools:
parts.append(json.dumps(tools, ensure_ascii=False)) parts.append(json.dumps(tools, ensure_ascii=False))
return len(enc.encode("\n".join(parts)))
per_message_overhead = len(messages) * 4
return len(enc.encode("\n".join(parts))) + per_message_overhead
except Exception: except Exception:
return 0 return 0
@@ -160,14 +180,18 @@ def estimate_message_tokens(message: dict[str, Any]) -> int:
if message.get("tool_calls"): if message.get("tool_calls"):
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
rc = message.get("reasoning_content")
if isinstance(rc, str) and rc:
parts.append(rc)
payload = "\n".join(parts) payload = "\n".join(parts)
if not payload: if not payload:
return 1 return 4
try: try:
enc = tiktoken.get_encoding("cl100k_base") enc = tiktoken.get_encoding("cl100k_base")
return max(1, len(enc.encode(payload))) return max(4, len(enc.encode(payload)) + 4)
except Exception: except Exception:
return max(1, len(payload) // 4) return max(4, len(payload) // 4 + 4)
def estimate_prompt_tokens_chain( def estimate_prompt_tokens_chain(