fix(memory): reserve completion headroom for consolidation

Trigger token consolidation before prompt usage reaches the full context window so response tokens and tokenizer estimation drift still fit safely within the model budget.

Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-23 03:48:12 +00:00
committed by Xubin Ren
parent 8f5c2d1a06
commit aba0b83a77
3 changed files with 16 additions and 3 deletions

View File

@@ -115,6 +115,7 @@ class AgentLoop:
context_window_tokens=context_window_tokens, context_window_tokens=context_window_tokens,
build_messages=self.context.build_messages, build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions, get_tool_definitions=self.tools.get_definitions,
max_completion_tokens=provider.generation.max_tokens,
) )
self._register_default_tools() self._register_default_tools()

View File

@@ -224,6 +224,8 @@ class MemoryConsolidator:
_MAX_CONSOLIDATION_ROUNDS = 5 _MAX_CONSOLIDATION_ROUNDS = 5
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
def __init__( def __init__(
self, self,
workspace: Path, workspace: Path,
@@ -233,12 +235,14 @@ class MemoryConsolidator:
context_window_tokens: int, context_window_tokens: int,
build_messages: Callable[..., list[dict[str, Any]]], build_messages: Callable[..., list[dict[str, Any]]],
get_tool_definitions: Callable[[], list[dict[str, Any]]], get_tool_definitions: Callable[[], list[dict[str, Any]]],
max_completion_tokens: int = 4096,
): ):
self.store = MemoryStore(workspace) self.store = MemoryStore(workspace)
self.provider = provider self.provider = provider
self.model = model self.model = model
self.sessions = sessions self.sessions = sessions
self.context_window_tokens = context_window_tokens self.context_window_tokens = context_window_tokens
self.max_completion_tokens = max_completion_tokens
self._build_messages = build_messages self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
@@ -300,17 +304,22 @@ class MemoryConsolidator:
return True return True
async def maybe_consolidate_by_tokens(self, session: Session) -> None: async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Loop: archive old messages until prompt fits within half the context window.""" """Loop: archive old messages until prompt fits within safe budget.
The budget reserves space for completion tokens and a safety buffer
so the LLM request never exceeds the context window.
"""
if not session.messages or self.context_window_tokens <= 0: if not session.messages or self.context_window_tokens <= 0:
return return
lock = self.get_lock(session.key) lock = self.get_lock(session.key)
async with lock: async with lock:
target = self.context_window_tokens // 2 budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
target = budget // 2
estimated, source = self.estimate_session_prompt_tokens(session) estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0: if estimated <= 0:
return return
if estimated < self.context_window_tokens: if estimated < budget:
logger.debug( logger.debug(
"Token consolidation idle {}: {}/{} via {}", "Token consolidation idle {}: {}/{} via {}",
session.key, session.key,

View File

@@ -9,8 +9,10 @@ from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop: def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
from nanobot.providers.base import GenerationSettings
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings(max_tokens=0)
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
_response = LLMResponse(content="ok", tool_calls=[]) _response = LLMResponse(content="ok", tool_calls=[])
provider.chat_with_retry = AsyncMock(return_value=_response) provider.chat_with_retry = AsyncMock(return_value=_response)
@@ -24,6 +26,7 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -
context_window_tokens=context_window_tokens, context_window_tokens=context_window_tokens,
) )
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
loop.memory_consolidator._SAFETY_BUFFER = 0
return loop return loop