refactor(memory): switch consolidation to token-based context windows

Move consolidation policy into MemoryConsolidator, keep backward compatibility for legacy config, and compress history by token budget instead of message count.
This commit is contained in:
Re-bin
2026-03-10 19:55:06 +00:00
parent 4784eb4128
commit 62ccda43b9
13 changed files with 709 additions and 911 deletions

View File

@@ -11,18 +11,12 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
try:
import tiktoken # type: ignore
except Exception: # pragma: no cover - optional dependency
tiktoken = None
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
@@ -60,11 +54,8 @@ class AgentLoop:
max_iterations: int = 40,
temperature: float = 0.1,
max_tokens: int = 4096,
memory_window: int | None = None, # backward-compat only (unused)
reasoning_effort: str | None = None,
max_tokens_input: int = 128_000,
compression_start_ratio: float = 0.7,
compression_target_ratio: float = 0.4,
context_window_tokens: int = 65_536,
brave_api_key: str | None = None,
web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None,
@@ -82,18 +73,9 @@ class AgentLoop:
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
self.temperature = temperature
# max_tokens: per-call output token cap (maxTokensOutput in config)
self.max_tokens = max_tokens
# Keep legacy attribute for older call sites/tests; compression no longer uses it.
self.memory_window = memory_window
self.reasoning_effort = reasoning_effort
# max_tokens_input: model native context window (maxTokensInput in config)
self.max_tokens_input = max_tokens_input
# Token-based compression watermarks (fractions of available input budget)
self.compression_start_ratio = compression_start_ratio
self.compression_target_ratio = compression_target_ratio
# Reserve tokens for safety margin
self._reserve_tokens = 1000
self.context_window_tokens = context_window_tokens
self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@@ -123,382 +105,23 @@ class AgentLoop:
self._mcp_connected = False
self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
self._last_turn_prompt_tokens: int = 0
self._last_turn_prompt_source: str = "none"
self._processing_lock = asyncio.Lock()
self.memory_consolidator = MemoryConsolidator(
workspace=workspace,
provider=provider,
model=self.model,
sessions=self.sessions,
context_window_tokens=context_window_tokens,
build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions,
)
self._register_default_tools()
@staticmethod
def _estimate_prompt_tokens(
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> int:
"""Estimate prompt tokens with tiktoken (fallback only)."""
if tiktoken is None:
return 0
try:
enc = tiktoken.get_encoding("cl100k_base")
parts: list[str] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
txt = part.get("text", "")
if txt:
parts.append(txt)
if tools:
parts.append(json.dumps(tools, ensure_ascii=False))
return len(enc.encode("\n".join(parts)))
except Exception:
return 0
def _estimate_prompt_tokens_chain(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> tuple[int, str]:
"""Unified prompt-token estimation: provider counter -> tiktoken."""
provider_counter = getattr(self.provider, "estimate_prompt_tokens", None)
if callable(provider_counter):
try:
tokens, source = provider_counter(messages, tools, self.model)
if isinstance(tokens, (int, float)) and tokens > 0:
return int(tokens), str(source or "provider_counter")
except Exception:
logger.debug("Provider token counter failed; fallback to tiktoken")
estimated = self._estimate_prompt_tokens(messages, tools)
if estimated > 0:
return int(estimated), "tiktoken"
return 0, "none"
@staticmethod
def _estimate_completion_tokens(content: str) -> int:
"""Estimate completion tokens with tiktoken (fallback only)."""
if tiktoken is None:
return 0
try:
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode(content or ""))
except Exception:
return 0
def _get_compressed_until(self, session: Session) -> int:
"""Read/normalize compressed boundary and migrate old metadata format."""
raw = session.metadata.get("_compressed_until", 0)
try:
compressed_until = int(raw)
except (TypeError, ValueError):
compressed_until = 0
if compressed_until <= 0:
ranges = session.metadata.get("_compressed_ranges")
if isinstance(ranges, list):
inferred = 0
for item in ranges:
if not isinstance(item, (list, tuple)) or len(item) != 2:
continue
try:
inferred = max(inferred, int(item[1]))
except (TypeError, ValueError):
continue
compressed_until = inferred
compressed_until = max(0, min(compressed_until, len(session.messages)))
session.metadata["_compressed_until"] = compressed_until
# 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段
session.metadata.pop("_compressed_ranges", None)
# 注意:不要删除 _cumulative_tokens压缩逻辑需要它来跟踪累积 token 计数
return compressed_until
def _set_compressed_until(self, session: Session, idx: int) -> None:
"""Persist a contiguous compressed boundary."""
session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages)))
session.metadata.pop("_compressed_ranges", None)
# 注意:不要删除 _cumulative_tokens压缩逻辑需要它来跟踪累积 token 计数
@staticmethod
def _estimate_message_tokens(message: dict[str, Any]) -> int:
"""Rough token estimate for a single persisted message."""
content = message.get("content")
parts: list[str] = []
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
txt = part.get("text", "")
if txt:
parts.append(txt)
else:
parts.append(json.dumps(part, ensure_ascii=False))
elif content is not None:
parts.append(json.dumps(content, ensure_ascii=False))
for key in ("name", "tool_call_id"):
val = message.get(key)
if isinstance(val, str) and val:
parts.append(val)
if message.get("tool_calls"):
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
payload = "\n".join(parts)
if not payload:
return 1
if tiktoken is not None:
try:
enc = tiktoken.get_encoding("cl100k_base")
return max(1, len(enc.encode(payload)))
except Exception:
pass
return max(1, len(payload) // 4)
def _pick_compression_chunk_by_tokens(
self,
session: Session,
reduction_tokens: int,
*,
tail_keep: int = 12,
) -> tuple[int, int, int] | None:
"""
Pick one contiguous old chunk so its estimated size is roughly enough
to reduce `reduction_tokens`.
"""
messages = session.messages
start = self._get_compressed_until(session)
if len(messages) - start <= tail_keep + 2:
return None
end_limit = len(messages) - tail_keep
if end_limit - start < 2:
return None
target = max(1, reduction_tokens)
end = start
collected = 0
while end < end_limit and collected < target:
collected += self._estimate_message_tokens(messages[end])
end += 1
if end - start < 2:
end = min(end_limit, start + 2)
collected = sum(self._estimate_message_tokens(m) for m in messages[start:end])
if end - start < 2:
return None
return start, end, collected
def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
"""
Estimate current full prompt tokens for this session view
(system + compressed history view + runtime/user placeholder + tools).
"""
history = self._build_compressed_history_view(session)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
probe_messages = self.context.build_messages(
history=history,
current_message="[token-probe]",
channel=channel,
chat_id=chat_id,
)
return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions())
async def _maybe_compress_history(
self,
session: Session,
) -> None:
"""
End-of-turn policy:
- Estimate current prompt usage from persisted session view.
- If above start ratio, perform one best-effort compression chunk.
"""
if not session.messages:
self._set_compressed_until(session, 0)
return
budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens)
start_threshold = int(budget * self.compression_start_ratio)
target_threshold = int(budget * self.compression_target_ratio)
if target_threshold >= start_threshold:
target_threshold = max(0, start_threshold - 1)
# Prefer provider usage prompt tokens from the turn-ending call.
# If unavailable, fall back to estimator chain.
raw_prompt_tokens = session.metadata.get("_last_prompt_tokens")
if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0:
current_tokens = int(raw_prompt_tokens)
token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt")
else:
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
current_ratio = current_tokens / budget if budget else 0.0
if current_tokens <= 0:
logger.debug("Compression skip {}: token estimate unavailable", session.key)
return
if current_tokens < start_threshold:
logger.debug(
"Compression idle {}: {}/{} ({:.1%}) via {}",
session.key,
current_tokens,
budget,
current_ratio,
token_source,
)
return
logger.info(
"Compression trigger {}: {}/{} ({:.1%}) via {}",
session.key,
current_tokens,
budget,
current_ratio,
token_source,
)
reduction_by_target = max(0, current_tokens - target_threshold)
reduction_by_delta = max(1, start_threshold - target_threshold)
reduction_need = max(reduction_by_target, reduction_by_delta)
chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10)
if chunk_range is None:
logger.info("Compression skipped for {}: no compressible chunk", session.key)
return
start_idx, end_idx, estimated_chunk_tokens = chunk_range
chunk = session.messages[start_idx:end_idx]
if len(chunk) < 2:
return
logger.info(
"Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})",
session.key,
start_idx,
end_idx - 1,
len(chunk),
estimated_chunk_tokens,
reduction_need,
)
success, _ = await self.context.memory.consolidate_chunk(
chunk,
self.provider,
self.model,
)
if not success:
logger.warning("Compression aborted for {}: consolidation failed", session.key)
return
self._set_compressed_until(session, end_idx)
self.sessions.save(session)
after_tokens, after_source = self._estimate_session_prompt_tokens(session)
after_ratio = after_tokens / budget if budget else 0.0
reduced = max(0, current_tokens - after_tokens)
reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0
logger.info(
"Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})",
session.key,
after_tokens,
budget,
after_ratio,
after_source,
reduced,
reduced_ratio,
)
def _schedule_background_compression(self, session_key: str) -> None:
"""Schedule best-effort background compression for a session."""
existing = self._compression_tasks.get(session_key)
if existing is not None and not existing.done():
return
async def _runner() -> None:
session = self.sessions.get_or_create(session_key)
try:
await self._maybe_compress_history(session)
except Exception:
logger.exception("Background compression failed for {}", session_key)
task = asyncio.create_task(_runner())
self._compression_tasks[session_key] = task
def _cleanup(t: asyncio.Task) -> None:
cur = self._compression_tasks.get(session_key)
if cur is t:
self._compression_tasks.pop(session_key, None)
try:
t.result()
except BaseException:
pass
task.add_done_callback(_cleanup)
async def wait_for_background_compression(self, timeout_s: float | None = None) -> None:
"""Wait for currently scheduled compression tasks."""
pending = [t for t in self._compression_tasks.values() if not t.done()]
if not pending:
return
logger.info("Waiting for {} background compression task(s)", len(pending))
waiter = asyncio.gather(*pending, return_exceptions=True)
if timeout_s is None:
await waiter
return
try:
await asyncio.wait_for(waiter, timeout=timeout_s)
except asyncio.TimeoutError:
logger.warning(
"Background compression wait timed out after {}s ({} task(s) still running)",
timeout_s,
len([t for t in self._compression_tasks.values() if not t.done()]),
)
def _build_compressed_history_view(
self,
session: Session,
) -> list[dict]:
"""Build non-destructive history view using the compressed boundary."""
compressed_until = self._get_compressed_until(session)
if compressed_until <= 0:
return session.get_history(max_messages=0)
notice_msg: dict[str, Any] = {
"role": "assistant",
"content": (
"As your assistant, I have compressed earlier context. "
"If you need details, please check memory/HISTORY.md."
),
}
tail: list[dict[str, Any]] = []
for msg in session.messages[compressed_until:]:
entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")}
for k in ("tool_calls", "tool_call_id", "name"):
if k in msg:
entry[k] = msg[k]
tail.append(entry)
# Drop leading non-user entries from tail to avoid orphan tool blocks.
for i, m in enumerate(tail):
if m.get("role") == "user":
tail = tail[i:]
break
else:
tail = []
return [notice_msg, *tail]
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = self.workspace if self.restrict_to_workspace else None
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(ValidateDeployJSONTool())
self.tools.register(ValidateUsageYAMLTool())
self.tools.register(HuggingFaceModelSearchTool())
self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
@@ -563,24 +186,12 @@ class AgentLoop:
self,
initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None,
) -> tuple[str | None, list[str], list[dict], int, str]:
"""
Run the agent iteration loop.
Returns:
(final_content, tools_used, messages, total_tokens_this_turn, token_source)
total_tokens_this_turn: total tokens (prompt + completion) for this turn
token_source: provider_total / provider_sum / provider_prompt /
provider_counter+tiktoken_completion / tiktoken / none
"""
) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop."""
messages = initial_messages
iteration = 0
final_content = None
tools_used: list[str] = []
total_tokens_this_turn = 0
token_source = "none"
self._last_turn_prompt_tokens = 0
self._last_turn_prompt_source = "none"
while iteration < self.max_iterations:
iteration += 1
@@ -596,63 +207,6 @@ class AgentLoop:
reasoning_effort=self.reasoning_effort,
)
# Prefer provider usage from the turn-ending model call; fallback to tiktoken.
# Calculate total tokens (prompt + completion) for this turn.
usage = response.usage or {}
t_tokens = usage.get("total_tokens")
p_tokens = usage.get("prompt_tokens")
c_tokens = usage.get("completion_tokens")
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
total_tokens_this_turn = int(t_tokens)
token_source = "provider_total"
if isinstance(p_tokens, (int, float)) and p_tokens > 0:
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
elif isinstance(c_tokens, (int, float)):
prompt_derived = int(t_tokens) - int(c_tokens)
if prompt_derived > 0:
self._last_turn_prompt_tokens = prompt_derived
self._last_turn_prompt_source = "usage_total_minus_completion"
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
# If we have both prompt and completion tokens, sum them
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
token_source = "provider_sum"
if p_tokens > 0:
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
total_tokens_this_turn = int(p_tokens)
token_source = "provider_prompt"
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
else:
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
estimated_completion = self._estimate_completion_tokens(response.content or "")
total_tokens_this_turn = estimated_prompt + estimated_completion
if estimated_prompt > 0:
self._last_turn_prompt_tokens = int(estimated_prompt)
self._last_turn_prompt_source = str(prompt_source or "tiktoken")
if total_tokens_this_turn > 0:
token_source = (
"tiktoken"
if prompt_source == "tiktoken"
else f"{prompt_source}+tiktoken_completion"
)
if total_tokens_this_turn <= 0:
total_tokens_this_turn = 0
token_source = "none"
logger.debug(
"Turn token usage: source={}, total={}, prompt={}, completion={}",
token_source,
total_tokens_this_turn,
p_tokens if isinstance(p_tokens, (int, float)) else None,
c_tokens if isinstance(c_tokens, (int, float)) else None,
)
if response.has_tool_calls:
if on_progress:
thought = self._strip_think(response.content)
@@ -707,7 +261,7 @@ class AgentLoop:
"without completing the task. You can try breaking the task into smaller steps."
)
return final_content, tools_used, messages, total_tokens_this_turn, token_source
return final_content, tools_used, messages
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
@@ -732,9 +286,6 @@ class AgentLoop:
"""Cancel all active tasks and subagents for the session."""
tasks = self._active_tasks.pop(msg.session_key, [])
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
comp = self._compression_tasks.get(msg.session_key)
if comp is not None and not comp.done() and comp.cancel():
cancelled += 1
for t in tasks:
try:
await t
@@ -781,9 +332,6 @@ class AgentLoop:
def stop(self) -> None:
"""Stop the agent loop."""
self._running = False
for task in list(self._compression_tasks.values()):
if not task.done():
task.cancel()
logger.info("Agent loop stopping")
async def _process_message(
@@ -800,22 +348,17 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = self._build_compressed_history_view(session)
history = session.get_history(max_messages=0)
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
)
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
if self._last_turn_prompt_tokens > 0:
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
else:
session.metadata.pop("_last_prompt_tokens", None)
session.metadata.pop("_last_prompt_source", None)
final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
self._schedule_background_compression(session.key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.")
@@ -829,19 +372,12 @@ class AgentLoop:
cmd = msg.content.strip().lower()
if cmd == "/new":
try:
# 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中
if session.messages:
ok, _ = await self.context.memory.consolidate_chunk(
session.messages,
self.provider,
self.model,
if not await self.memory_consolidator.archive_unconsolidated(session):
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
if not ok:
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
)
except Exception:
logger.exception("/new archival failed for {}", session.key)
return OutboundMessage(
@@ -859,23 +395,20 @@ class AgentLoop:
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
# 正常对话:使用压缩后的历史视图(压缩在回合结束后进行)
history = self._build_compressed_history_view(session)
history = session.get_history(max_messages=0)
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id,
)
# Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:")
if session_key and session_key.startswith("cron:"):
if initial_messages and initial_messages[0].get("role") == "system":
initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}"
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {})
@@ -885,23 +418,16 @@ class AgentLoop:
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
final_content, _, all_msgs, total_tokens_this_turn, token_source = await self._run_agent_loop(
final_content, _, all_msgs = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress,
)
if final_content is None:
final_content = "I've completed processing but have no response to give."
if self._last_turn_prompt_tokens > 0:
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
else:
session.metadata.pop("_last_prompt_tokens", None)
session.metadata.pop("_last_prompt_source", None)
self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
self._schedule_background_compression(session.key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
@@ -913,7 +439,7 @@ class AgentLoop:
metadata=msg.metadata or {},
)
def _save_turn(self, session: Session, messages: list[dict], skip: int, total_tokens_this_turn: int = 0) -> None:
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
for m in messages[skip:]:
@@ -947,14 +473,6 @@ class AgentLoop:
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
session.updated_at = datetime.now()
# Update cumulative token count for compression tracking
if total_tokens_this_turn > 0:
current_cumulative = session.metadata.get("_cumulative_tokens", 0)
if isinstance(current_cumulative, (int, float)):
session.metadata["_cumulative_tokens"] = int(current_cumulative) + total_tokens_this_turn
else:
session.metadata["_cumulative_tokens"] = total_tokens_this_turn
async def process_direct(
self,