fix(compression): prefer provider prompt token usage
This commit is contained in:
@@ -124,6 +124,8 @@ class AgentLoop:
|
|||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
|
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
|
||||||
|
self._last_turn_prompt_tokens: int = 0
|
||||||
|
self._last_turn_prompt_source: str = "none"
|
||||||
self._processing_lock = asyncio.Lock()
|
self._processing_lock = asyncio.Lock()
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
@@ -324,7 +326,15 @@ class AgentLoop:
|
|||||||
if target_threshold >= start_threshold:
|
if target_threshold >= start_threshold:
|
||||||
target_threshold = max(0, start_threshold - 1)
|
target_threshold = max(0, start_threshold - 1)
|
||||||
|
|
||||||
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
|
# Prefer provider usage prompt tokens from the turn-ending call.
|
||||||
|
# If unavailable, fall back to estimator chain.
|
||||||
|
raw_prompt_tokens = session.metadata.get("_last_prompt_tokens")
|
||||||
|
if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0:
|
||||||
|
current_tokens = int(raw_prompt_tokens)
|
||||||
|
token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt")
|
||||||
|
else:
|
||||||
|
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
|
||||||
|
|
||||||
current_ratio = current_tokens / budget if budget else 0.0
|
current_ratio = current_tokens / budget if budget else 0.0
|
||||||
if current_tokens <= 0:
|
if current_tokens <= 0:
|
||||||
logger.debug("Compression skip {}: token estimate unavailable", session.key)
|
logger.debug("Compression skip {}: token estimate unavailable", session.key)
|
||||||
@@ -569,6 +579,8 @@ class AgentLoop:
|
|||||||
tools_used: list[str] = []
|
tools_used: list[str] = []
|
||||||
total_tokens_this_turn = 0
|
total_tokens_this_turn = 0
|
||||||
token_source = "none"
|
token_source = "none"
|
||||||
|
self._last_turn_prompt_tokens = 0
|
||||||
|
self._last_turn_prompt_source = "none"
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
while iteration < self.max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
@@ -594,19 +606,35 @@ class AgentLoop:
|
|||||||
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
|
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
|
||||||
total_tokens_this_turn = int(t_tokens)
|
total_tokens_this_turn = int(t_tokens)
|
||||||
token_source = "provider_total"
|
token_source = "provider_total"
|
||||||
|
if isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||||
|
self._last_turn_prompt_tokens = int(p_tokens)
|
||||||
|
self._last_turn_prompt_source = "usage_prompt"
|
||||||
|
elif isinstance(c_tokens, (int, float)):
|
||||||
|
prompt_derived = int(t_tokens) - int(c_tokens)
|
||||||
|
if prompt_derived > 0:
|
||||||
|
self._last_turn_prompt_tokens = prompt_derived
|
||||||
|
self._last_turn_prompt_source = "usage_total_minus_completion"
|
||||||
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
|
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
|
||||||
# If we have both prompt and completion tokens, sum them
|
# If we have both prompt and completion tokens, sum them
|
||||||
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
|
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
|
||||||
token_source = "provider_sum"
|
token_source = "provider_sum"
|
||||||
|
if p_tokens > 0:
|
||||||
|
self._last_turn_prompt_tokens = int(p_tokens)
|
||||||
|
self._last_turn_prompt_source = "usage_prompt"
|
||||||
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||||
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
|
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
|
||||||
total_tokens_this_turn = int(p_tokens)
|
total_tokens_this_turn = int(p_tokens)
|
||||||
token_source = "provider_prompt"
|
token_source = "provider_prompt"
|
||||||
|
self._last_turn_prompt_tokens = int(p_tokens)
|
||||||
|
self._last_turn_prompt_source = "usage_prompt"
|
||||||
else:
|
else:
|
||||||
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
|
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
|
||||||
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
|
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
|
||||||
estimated_completion = self._estimate_completion_tokens(response.content or "")
|
estimated_completion = self._estimate_completion_tokens(response.content or "")
|
||||||
total_tokens_this_turn = estimated_prompt + estimated_completion
|
total_tokens_this_turn = estimated_prompt + estimated_completion
|
||||||
|
if estimated_prompt > 0:
|
||||||
|
self._last_turn_prompt_tokens = int(estimated_prompt)
|
||||||
|
self._last_turn_prompt_source = str(prompt_source or "tiktoken")
|
||||||
if total_tokens_this_turn > 0:
|
if total_tokens_this_turn > 0:
|
||||||
token_source = (
|
token_source = (
|
||||||
"tiktoken"
|
"tiktoken"
|
||||||
@@ -779,6 +807,12 @@ class AgentLoop:
|
|||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
|
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
|
||||||
|
if self._last_turn_prompt_tokens > 0:
|
||||||
|
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
|
||||||
|
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
|
||||||
|
else:
|
||||||
|
session.metadata.pop("_last_prompt_tokens", None)
|
||||||
|
session.metadata.pop("_last_prompt_source", None)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background_compression(session.key)
|
self._schedule_background_compression(session.key)
|
||||||
@@ -858,6 +892,13 @@ class AgentLoop:
|
|||||||
if final_content is None:
|
if final_content is None:
|
||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
|
if self._last_turn_prompt_tokens > 0:
|
||||||
|
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
|
||||||
|
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
|
||||||
|
else:
|
||||||
|
session.metadata.pop("_last_prompt_tokens", None)
|
||||||
|
session.metadata.pop("_last_prompt_source", None)
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn)
|
self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background_compression(session.key)
|
self._schedule_background_compression(session.key)
|
||||||
|
|||||||
Reference in New Issue
Block a user