fix(compression): prefer provider prompt token usage
This commit is contained in:
@@ -124,6 +124,8 @@ class AgentLoop:
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
|
||||
self._last_turn_prompt_tokens: int = 0
|
||||
self._last_turn_prompt_source: str = "none"
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self._register_default_tools()
|
||||
|
||||
@@ -324,7 +326,15 @@ class AgentLoop:
|
||||
if target_threshold >= start_threshold:
|
||||
target_threshold = max(0, start_threshold - 1)
|
||||
|
||||
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
|
||||
# Prefer provider usage prompt tokens from the turn-ending call.
|
||||
# If unavailable, fall back to estimator chain.
|
||||
raw_prompt_tokens = session.metadata.get("_last_prompt_tokens")
|
||||
if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0:
|
||||
current_tokens = int(raw_prompt_tokens)
|
||||
token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt")
|
||||
else:
|
||||
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
|
||||
|
||||
current_ratio = current_tokens / budget if budget else 0.0
|
||||
if current_tokens <= 0:
|
||||
logger.debug("Compression skip {}: token estimate unavailable", session.key)
|
||||
@@ -569,6 +579,8 @@ class AgentLoop:
|
||||
tools_used: list[str] = []
|
||||
total_tokens_this_turn = 0
|
||||
token_source = "none"
|
||||
self._last_turn_prompt_tokens = 0
|
||||
self._last_turn_prompt_source = "none"
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
@@ -594,19 +606,35 @@ class AgentLoop:
|
||||
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
|
||||
total_tokens_this_turn = int(t_tokens)
|
||||
token_source = "provider_total"
|
||||
if isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||
self._last_turn_prompt_tokens = int(p_tokens)
|
||||
self._last_turn_prompt_source = "usage_prompt"
|
||||
elif isinstance(c_tokens, (int, float)):
|
||||
prompt_derived = int(t_tokens) - int(c_tokens)
|
||||
if prompt_derived > 0:
|
||||
self._last_turn_prompt_tokens = prompt_derived
|
||||
self._last_turn_prompt_source = "usage_total_minus_completion"
|
||||
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
|
||||
# If we have both prompt and completion tokens, sum them
|
||||
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
|
||||
token_source = "provider_sum"
|
||||
if p_tokens > 0:
|
||||
self._last_turn_prompt_tokens = int(p_tokens)
|
||||
self._last_turn_prompt_source = "usage_prompt"
|
||||
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
|
||||
total_tokens_this_turn = int(p_tokens)
|
||||
token_source = "provider_prompt"
|
||||
self._last_turn_prompt_tokens = int(p_tokens)
|
||||
self._last_turn_prompt_source = "usage_prompt"
|
||||
else:
|
||||
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
|
||||
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
|
||||
estimated_completion = self._estimate_completion_tokens(response.content or "")
|
||||
total_tokens_this_turn = estimated_prompt + estimated_completion
|
||||
if estimated_prompt > 0:
|
||||
self._last_turn_prompt_tokens = int(estimated_prompt)
|
||||
self._last_turn_prompt_source = str(prompt_source or "tiktoken")
|
||||
if total_tokens_this_turn > 0:
|
||||
token_source = (
|
||||
"tiktoken"
|
||||
@@ -779,6 +807,12 @@ class AgentLoop:
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
)
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
|
||||
if self._last_turn_prompt_tokens > 0:
|
||||
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
|
||||
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
|
||||
else:
|
||||
session.metadata.pop("_last_prompt_tokens", None)
|
||||
session.metadata.pop("_last_prompt_source", None)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background_compression(session.key)
|
||||
@@ -858,6 +892,13 @@ class AgentLoop:
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
if self._last_turn_prompt_tokens > 0:
|
||||
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
|
||||
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
|
||||
else:
|
||||
session.metadata.pop("_last_prompt_tokens", None)
|
||||
session.metadata.pop("_last_prompt_source", None)
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background_compression(session.key)
|
||||
|
||||
Reference in New Issue
Block a user