fix(compression): prefer provider prompt token usage

This commit is contained in:
VITOHJL
2026-03-08 17:25:59 +08:00
parent 1b16d48390
commit 274edc5451

View File

@@ -124,6 +124,8 @@ class AgentLoop:
self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
self._last_turn_prompt_tokens: int = 0
self._last_turn_prompt_source: str = "none"
self._processing_lock = asyncio.Lock()
self._register_default_tools()
@@ -324,7 +326,15 @@ class AgentLoop:
if target_threshold >= start_threshold:
target_threshold = max(0, start_threshold - 1)
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
# Prefer provider usage prompt tokens from the turn-ending call.
# If unavailable, fall back to estimator chain.
raw_prompt_tokens = session.metadata.get("_last_prompt_tokens")
if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0:
current_tokens = int(raw_prompt_tokens)
token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt")
else:
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
current_ratio = current_tokens / budget if budget else 0.0
if current_tokens <= 0:
logger.debug("Compression skip {}: token estimate unavailable", session.key)
@@ -569,6 +579,8 @@ class AgentLoop:
tools_used: list[str] = []
total_tokens_this_turn = 0
token_source = "none"
self._last_turn_prompt_tokens = 0
self._last_turn_prompt_source = "none"
while iteration < self.max_iterations:
iteration += 1
@@ -594,19 +606,35 @@ class AgentLoop:
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
total_tokens_this_turn = int(t_tokens)
token_source = "provider_total"
if isinstance(p_tokens, (int, float)) and p_tokens > 0:
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
elif isinstance(c_tokens, (int, float)):
prompt_derived = int(t_tokens) - int(c_tokens)
if prompt_derived > 0:
self._last_turn_prompt_tokens = prompt_derived
self._last_turn_prompt_source = "usage_total_minus_completion"
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
# If we have both prompt and completion tokens, sum them
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
token_source = "provider_sum"
if p_tokens > 0:
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
total_tokens_this_turn = int(p_tokens)
token_source = "provider_prompt"
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
else:
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
estimated_completion = self._estimate_completion_tokens(response.content or "")
total_tokens_this_turn = estimated_prompt + estimated_completion
if estimated_prompt > 0:
self._last_turn_prompt_tokens = int(estimated_prompt)
self._last_turn_prompt_source = str(prompt_source or "tiktoken")
if total_tokens_this_turn > 0:
token_source = (
"tiktoken"
@@ -779,6 +807,12 @@ class AgentLoop:
current_message=msg.content, channel=channel, chat_id=chat_id,
)
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
if self._last_turn_prompt_tokens > 0:
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
else:
session.metadata.pop("_last_prompt_tokens", None)
session.metadata.pop("_last_prompt_source", None)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
self._schedule_background_compression(session.key)
@@ -858,6 +892,13 @@ class AgentLoop:
if final_content is None:
final_content = "I've completed processing but have no response to give."
if self._last_turn_prompt_tokens > 0:
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
else:
session.metadata.pop("_last_prompt_tokens", None)
session.metadata.pop("_last_prompt_source", None)
self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn)
self.sessions.save(session)
self._schedule_background_compression(session.key)