fix(agent): count all message fields in token estimation
estimate_prompt_tokens() only counted the `content` text field, completely missing tool_calls JSON (~72% of actual payload), reasoning_content, tool_call_id, name, and per-message framing overhead. This caused the memory consolidator to never trigger for tool-heavy sessions (e.g. cron jobs), leading to context window overflow errors from the LLM provider. Also adds reasoning_content counting and proper per-message overhead to estimate_message_tokens() for consistent boundary detection. Made-with: Cursor
This commit is contained in:
@@ -115,7 +115,11 @@ def estimate_prompt_tokens(
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Estimate prompt tokens with tiktoken."""
|
"""Estimate prompt tokens with tiktoken.
|
||||||
|
|
||||||
|
Counts all fields that providers send to the LLM: content, tool_calls,
|
||||||
|
reasoning_content, tool_call_id, name, plus per-message framing overhead.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
@@ -129,9 +133,25 @@ def estimate_prompt_tokens(
|
|||||||
txt = part.get("text", "")
|
txt = part.get("text", "")
|
||||||
if txt:
|
if txt:
|
||||||
parts.append(txt)
|
parts.append(txt)
|
||||||
|
|
||||||
|
tc = msg.get("tool_calls")
|
||||||
|
if tc:
|
||||||
|
parts.append(json.dumps(tc, ensure_ascii=False))
|
||||||
|
|
||||||
|
rc = msg.get("reasoning_content")
|
||||||
|
if isinstance(rc, str) and rc:
|
||||||
|
parts.append(rc)
|
||||||
|
|
||||||
|
for key in ("name", "tool_call_id"):
|
||||||
|
value = msg.get(key)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
parts.append(value)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
parts.append(json.dumps(tools, ensure_ascii=False))
|
parts.append(json.dumps(tools, ensure_ascii=False))
|
||||||
return len(enc.encode("\n".join(parts)))
|
|
||||||
|
per_message_overhead = len(messages) * 4
|
||||||
|
return len(enc.encode("\n".join(parts))) + per_message_overhead
|
||||||
except Exception:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@@ -160,14 +180,18 @@ def estimate_message_tokens(message: dict[str, Any]) -> int:
|
|||||||
if message.get("tool_calls"):
|
if message.get("tool_calls"):
|
||||||
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||||
|
|
||||||
|
rc = message.get("reasoning_content")
|
||||||
|
if isinstance(rc, str) and rc:
|
||||||
|
parts.append(rc)
|
||||||
|
|
||||||
payload = "\n".join(parts)
|
payload = "\n".join(parts)
|
||||||
if not payload:
|
if not payload:
|
||||||
return 1
|
return 4
|
||||||
try:
|
try:
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
return max(1, len(enc.encode(payload)))
|
return max(4, len(enc.encode(payload)) + 4)
|
||||||
except Exception:
|
except Exception:
|
||||||
return max(1, len(payload) // 4)
|
return max(4, len(payload) // 4 + 4)
|
||||||
|
|
||||||
|
|
||||||
def estimate_prompt_tokens_chain(
|
def estimate_prompt_tokens_chain(
|
||||||
|
|||||||
Reference in New Issue
Block a user