refactor(memory): switch consolidation to token-based context windows
Move consolidation policy into MemoryConsolidator, keep backward compatibility for legacy config, and compress history by token budget instead of message count.
This commit is contained in:
@@ -11,18 +11,12 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import tiktoken # type: ignore
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
tiktoken = None
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
@@ -60,11 +54,8 @@ class AgentLoop:
|
||||
max_iterations: int = 40,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
memory_window: int | None = None, # backward-compat only (unused)
|
||||
reasoning_effort: str | None = None,
|
||||
max_tokens_input: int = 128_000,
|
||||
compression_start_ratio: float = 0.7,
|
||||
compression_target_ratio: float = 0.4,
|
||||
context_window_tokens: int = 65_536,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
@@ -82,18 +73,9 @@ class AgentLoop:
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.temperature = temperature
|
||||
# max_tokens: per-call output token cap (maxTokensOutput in config)
|
||||
self.max_tokens = max_tokens
|
||||
# Keep legacy attribute for older call sites/tests; compression no longer uses it.
|
||||
self.memory_window = memory_window
|
||||
self.reasoning_effort = reasoning_effort
|
||||
# max_tokens_input: model native context window (maxTokensInput in config)
|
||||
self.max_tokens_input = max_tokens_input
|
||||
# Token-based compression watermarks (fractions of available input budget)
|
||||
self.compression_start_ratio = compression_start_ratio
|
||||
self.compression_target_ratio = compression_target_ratio
|
||||
# Reserve tokens for safety margin
|
||||
self._reserve_tokens = 1000
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@@ -123,382 +105,23 @@ class AgentLoop:
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
|
||||
self._last_turn_prompt_tokens: int = 0
|
||||
self._last_turn_prompt_source: str = "none"
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self.memory_consolidator = MemoryConsolidator(
|
||||
workspace=workspace,
|
||||
provider=provider,
|
||||
model=self.model,
|
||||
sessions=self.sessions,
|
||||
context_window_tokens=context_window_tokens,
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
)
|
||||
self._register_default_tools()
|
||||
|
||||
@staticmethod
|
||||
def _estimate_prompt_tokens(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
"""Estimate prompt tokens with tiktoken (fallback only)."""
|
||||
if tiktoken is None:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
parts: list[str] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
txt = part.get("text", "")
|
||||
if txt:
|
||||
parts.append(txt)
|
||||
if tools:
|
||||
parts.append(json.dumps(tools, ensure_ascii=False))
|
||||
return len(enc.encode("\n".join(parts)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _estimate_prompt_tokens_chain(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Unified prompt-token estimation: provider counter -> tiktoken."""
|
||||
provider_counter = getattr(self.provider, "estimate_prompt_tokens", None)
|
||||
if callable(provider_counter):
|
||||
try:
|
||||
tokens, source = provider_counter(messages, tools, self.model)
|
||||
if isinstance(tokens, (int, float)) and tokens > 0:
|
||||
return int(tokens), str(source or "provider_counter")
|
||||
except Exception:
|
||||
logger.debug("Provider token counter failed; fallback to tiktoken")
|
||||
|
||||
estimated = self._estimate_prompt_tokens(messages, tools)
|
||||
if estimated > 0:
|
||||
return int(estimated), "tiktoken"
|
||||
return 0, "none"
|
||||
|
||||
@staticmethod
|
||||
def _estimate_completion_tokens(content: str) -> int:
|
||||
"""Estimate completion tokens with tiktoken (fallback only)."""
|
||||
if tiktoken is None:
|
||||
return 0
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
return len(enc.encode(content or ""))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _get_compressed_until(self, session: Session) -> int:
|
||||
"""Read/normalize compressed boundary and migrate old metadata format."""
|
||||
raw = session.metadata.get("_compressed_until", 0)
|
||||
try:
|
||||
compressed_until = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
compressed_until = 0
|
||||
|
||||
if compressed_until <= 0:
|
||||
ranges = session.metadata.get("_compressed_ranges")
|
||||
if isinstance(ranges, list):
|
||||
inferred = 0
|
||||
for item in ranges:
|
||||
if not isinstance(item, (list, tuple)) or len(item) != 2:
|
||||
continue
|
||||
try:
|
||||
inferred = max(inferred, int(item[1]))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
compressed_until = inferred
|
||||
|
||||
compressed_until = max(0, min(compressed_until, len(session.messages)))
|
||||
session.metadata["_compressed_until"] = compressed_until
|
||||
# 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段
|
||||
session.metadata.pop("_compressed_ranges", None)
|
||||
# 注意:不要删除 _cumulative_tokens,压缩逻辑需要它来跟踪累积 token 计数
|
||||
return compressed_until
|
||||
|
||||
def _set_compressed_until(self, session: Session, idx: int) -> None:
|
||||
"""Persist a contiguous compressed boundary."""
|
||||
session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages)))
|
||||
session.metadata.pop("_compressed_ranges", None)
|
||||
# 注意:不要删除 _cumulative_tokens,压缩逻辑需要它来跟踪累积 token 计数
|
||||
|
||||
@staticmethod
|
||||
def _estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||
"""Rough token estimate for a single persisted message."""
|
||||
content = message.get("content")
|
||||
parts: list[str] = []
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
txt = part.get("text", "")
|
||||
if txt:
|
||||
parts.append(txt)
|
||||
else:
|
||||
parts.append(json.dumps(part, ensure_ascii=False))
|
||||
elif content is not None:
|
||||
parts.append(json.dumps(content, ensure_ascii=False))
|
||||
|
||||
for key in ("name", "tool_call_id"):
|
||||
val = message.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
parts.append(val)
|
||||
if message.get("tool_calls"):
|
||||
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||
|
||||
payload = "\n".join(parts)
|
||||
if not payload:
|
||||
return 1
|
||||
if tiktoken is not None:
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
return max(1, len(enc.encode(payload)))
|
||||
except Exception:
|
||||
pass
|
||||
return max(1, len(payload) // 4)
|
||||
|
||||
def _pick_compression_chunk_by_tokens(
|
||||
self,
|
||||
session: Session,
|
||||
reduction_tokens: int,
|
||||
*,
|
||||
tail_keep: int = 12,
|
||||
) -> tuple[int, int, int] | None:
|
||||
"""
|
||||
Pick one contiguous old chunk so its estimated size is roughly enough
|
||||
to reduce `reduction_tokens`.
|
||||
"""
|
||||
messages = session.messages
|
||||
start = self._get_compressed_until(session)
|
||||
if len(messages) - start <= tail_keep + 2:
|
||||
return None
|
||||
|
||||
end_limit = len(messages) - tail_keep
|
||||
if end_limit - start < 2:
|
||||
return None
|
||||
|
||||
target = max(1, reduction_tokens)
|
||||
end = start
|
||||
collected = 0
|
||||
while end < end_limit and collected < target:
|
||||
collected += self._estimate_message_tokens(messages[end])
|
||||
end += 1
|
||||
|
||||
if end - start < 2:
|
||||
end = min(end_limit, start + 2)
|
||||
collected = sum(self._estimate_message_tokens(m) for m in messages[start:end])
|
||||
if end - start < 2:
|
||||
return None
|
||||
return start, end, collected
|
||||
|
||||
def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||
"""
|
||||
Estimate current full prompt tokens for this session view
|
||||
(system + compressed history view + runtime/user placeholder + tools).
|
||||
"""
|
||||
history = self._build_compressed_history_view(session)
|
||||
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||
probe_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message="[token-probe]",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions())
|
||||
|
||||
async def _maybe_compress_history(
|
||||
self,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
End-of-turn policy:
|
||||
- Estimate current prompt usage from persisted session view.
|
||||
- If above start ratio, perform one best-effort compression chunk.
|
||||
"""
|
||||
if not session.messages:
|
||||
self._set_compressed_until(session, 0)
|
||||
return
|
||||
|
||||
budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens)
|
||||
start_threshold = int(budget * self.compression_start_ratio)
|
||||
target_threshold = int(budget * self.compression_target_ratio)
|
||||
if target_threshold >= start_threshold:
|
||||
target_threshold = max(0, start_threshold - 1)
|
||||
|
||||
# Prefer provider usage prompt tokens from the turn-ending call.
|
||||
# If unavailable, fall back to estimator chain.
|
||||
raw_prompt_tokens = session.metadata.get("_last_prompt_tokens")
|
||||
if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0:
|
||||
current_tokens = int(raw_prompt_tokens)
|
||||
token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt")
|
||||
else:
|
||||
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
|
||||
|
||||
current_ratio = current_tokens / budget if budget else 0.0
|
||||
if current_tokens <= 0:
|
||||
logger.debug("Compression skip {}: token estimate unavailable", session.key)
|
||||
return
|
||||
if current_tokens < start_threshold:
|
||||
logger.debug(
|
||||
"Compression idle {}: {}/{} ({:.1%}) via {}",
|
||||
session.key,
|
||||
current_tokens,
|
||||
budget,
|
||||
current_ratio,
|
||||
token_source,
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Compression trigger {}: {}/{} ({:.1%}) via {}",
|
||||
session.key,
|
||||
current_tokens,
|
||||
budget,
|
||||
current_ratio,
|
||||
token_source,
|
||||
)
|
||||
|
||||
reduction_by_target = max(0, current_tokens - target_threshold)
|
||||
reduction_by_delta = max(1, start_threshold - target_threshold)
|
||||
reduction_need = max(reduction_by_target, reduction_by_delta)
|
||||
|
||||
chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10)
|
||||
if chunk_range is None:
|
||||
logger.info("Compression skipped for {}: no compressible chunk", session.key)
|
||||
return
|
||||
|
||||
start_idx, end_idx, estimated_chunk_tokens = chunk_range
|
||||
chunk = session.messages[start_idx:end_idx]
|
||||
if len(chunk) < 2:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})",
|
||||
session.key,
|
||||
start_idx,
|
||||
end_idx - 1,
|
||||
len(chunk),
|
||||
estimated_chunk_tokens,
|
||||
reduction_need,
|
||||
)
|
||||
success, _ = await self.context.memory.consolidate_chunk(
|
||||
chunk,
|
||||
self.provider,
|
||||
self.model,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("Compression aborted for {}: consolidation failed", session.key)
|
||||
return
|
||||
|
||||
self._set_compressed_until(session, end_idx)
|
||||
self.sessions.save(session)
|
||||
|
||||
after_tokens, after_source = self._estimate_session_prompt_tokens(session)
|
||||
after_ratio = after_tokens / budget if budget else 0.0
|
||||
reduced = max(0, current_tokens - after_tokens)
|
||||
reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0
|
||||
logger.info(
|
||||
"Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})",
|
||||
session.key,
|
||||
after_tokens,
|
||||
budget,
|
||||
after_ratio,
|
||||
after_source,
|
||||
reduced,
|
||||
reduced_ratio,
|
||||
)
|
||||
|
||||
def _schedule_background_compression(self, session_key: str) -> None:
|
||||
"""Schedule best-effort background compression for a session."""
|
||||
existing = self._compression_tasks.get(session_key)
|
||||
if existing is not None and not existing.done():
|
||||
return
|
||||
|
||||
async def _runner() -> None:
|
||||
session = self.sessions.get_or_create(session_key)
|
||||
try:
|
||||
await self._maybe_compress_history(session)
|
||||
except Exception:
|
||||
logger.exception("Background compression failed for {}", session_key)
|
||||
|
||||
task = asyncio.create_task(_runner())
|
||||
self._compression_tasks[session_key] = task
|
||||
|
||||
def _cleanup(t: asyncio.Task) -> None:
|
||||
cur = self._compression_tasks.get(session_key)
|
||||
if cur is t:
|
||||
self._compression_tasks.pop(session_key, None)
|
||||
try:
|
||||
t.result()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
async def wait_for_background_compression(self, timeout_s: float | None = None) -> None:
|
||||
"""Wait for currently scheduled compression tasks."""
|
||||
pending = [t for t in self._compression_tasks.values() if not t.done()]
|
||||
if not pending:
|
||||
return
|
||||
|
||||
logger.info("Waiting for {} background compression task(s)", len(pending))
|
||||
waiter = asyncio.gather(*pending, return_exceptions=True)
|
||||
if timeout_s is None:
|
||||
await waiter
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(waiter, timeout=timeout_s)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Background compression wait timed out after {}s ({} task(s) still running)",
|
||||
timeout_s,
|
||||
len([t for t in self._compression_tasks.values() if not t.done()]),
|
||||
)
|
||||
|
||||
def _build_compressed_history_view(
|
||||
self,
|
||||
session: Session,
|
||||
) -> list[dict]:
|
||||
"""Build non-destructive history view using the compressed boundary."""
|
||||
compressed_until = self._get_compressed_until(session)
|
||||
if compressed_until <= 0:
|
||||
return session.get_history(max_messages=0)
|
||||
|
||||
notice_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"As your assistant, I have compressed earlier context. "
|
||||
"If you need details, please check memory/HISTORY.md."
|
||||
),
|
||||
}
|
||||
|
||||
tail: list[dict[str, Any]] = []
|
||||
for msg in session.messages[compressed_until:]:
|
||||
entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")}
|
||||
for k in ("tool_calls", "tool_call_id", "name"):
|
||||
if k in msg:
|
||||
entry[k] = msg[k]
|
||||
tail.append(entry)
|
||||
|
||||
# Drop leading non-user entries from tail to avoid orphan tool blocks.
|
||||
for i, m in enumerate(tail):
|
||||
if m.get("role") == "user":
|
||||
tail = tail[i:]
|
||||
break
|
||||
else:
|
||||
tail = []
|
||||
|
||||
return [notice_msg, *tail]
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(ValidateDeployJSONTool())
|
||||
self.tools.register(ValidateUsageYAMLTool())
|
||||
self.tools.register(HuggingFaceModelSearchTool())
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
@@ -563,24 +186,12 @@ class AgentLoop:
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], int, str]:
|
||||
"""
|
||||
Run the agent iteration loop.
|
||||
|
||||
Returns:
|
||||
(final_content, tools_used, messages, total_tokens_this_turn, token_source)
|
||||
total_tokens_this_turn: total tokens (prompt + completion) for this turn
|
||||
token_source: provider_total / provider_sum / provider_prompt /
|
||||
provider_counter+tiktoken_completion / tiktoken / none
|
||||
"""
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop."""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
total_tokens_this_turn = 0
|
||||
token_source = "none"
|
||||
self._last_turn_prompt_tokens = 0
|
||||
self._last_turn_prompt_source = "none"
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
@@ -596,63 +207,6 @@ class AgentLoop:
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
)
|
||||
|
||||
# Prefer provider usage from the turn-ending model call; fallback to tiktoken.
|
||||
# Calculate total tokens (prompt + completion) for this turn.
|
||||
usage = response.usage or {}
|
||||
t_tokens = usage.get("total_tokens")
|
||||
p_tokens = usage.get("prompt_tokens")
|
||||
c_tokens = usage.get("completion_tokens")
|
||||
|
||||
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
|
||||
total_tokens_this_turn = int(t_tokens)
|
||||
token_source = "provider_total"
|
||||
if isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||
self._last_turn_prompt_tokens = int(p_tokens)
|
||||
self._last_turn_prompt_source = "usage_prompt"
|
||||
elif isinstance(c_tokens, (int, float)):
|
||||
prompt_derived = int(t_tokens) - int(c_tokens)
|
||||
if prompt_derived > 0:
|
||||
self._last_turn_prompt_tokens = prompt_derived
|
||||
self._last_turn_prompt_source = "usage_total_minus_completion"
|
||||
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
|
||||
# If we have both prompt and completion tokens, sum them
|
||||
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
|
||||
token_source = "provider_sum"
|
||||
if p_tokens > 0:
|
||||
self._last_turn_prompt_tokens = int(p_tokens)
|
||||
self._last_turn_prompt_source = "usage_prompt"
|
||||
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
|
||||
total_tokens_this_turn = int(p_tokens)
|
||||
token_source = "provider_prompt"
|
||||
self._last_turn_prompt_tokens = int(p_tokens)
|
||||
self._last_turn_prompt_source = "usage_prompt"
|
||||
else:
|
||||
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
|
||||
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
|
||||
estimated_completion = self._estimate_completion_tokens(response.content or "")
|
||||
total_tokens_this_turn = estimated_prompt + estimated_completion
|
||||
if estimated_prompt > 0:
|
||||
self._last_turn_prompt_tokens = int(estimated_prompt)
|
||||
self._last_turn_prompt_source = str(prompt_source or "tiktoken")
|
||||
if total_tokens_this_turn > 0:
|
||||
token_source = (
|
||||
"tiktoken"
|
||||
if prompt_source == "tiktoken"
|
||||
else f"{prompt_source}+tiktoken_completion"
|
||||
)
|
||||
if total_tokens_this_turn <= 0:
|
||||
total_tokens_this_turn = 0
|
||||
token_source = "none"
|
||||
|
||||
logger.debug(
|
||||
"Turn token usage: source={}, total={}, prompt={}, completion={}",
|
||||
token_source,
|
||||
total_tokens_this_turn,
|
||||
p_tokens if isinstance(p_tokens, (int, float)) else None,
|
||||
c_tokens if isinstance(c_tokens, (int, float)) else None,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
@@ -707,7 +261,7 @@ class AgentLoop:
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages, total_tokens_this_turn, token_source
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
@@ -732,9 +286,6 @@ class AgentLoop:
|
||||
"""Cancel all active tasks and subagents for the session."""
|
||||
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||
comp = self._compression_tasks.get(msg.session_key)
|
||||
if comp is not None and not comp.done() and comp.cancel():
|
||||
cancelled += 1
|
||||
for t in tasks:
|
||||
try:
|
||||
await t
|
||||
@@ -781,9 +332,6 @@ class AgentLoop:
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
for task in list(self._compression_tasks.values()):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(
|
||||
@@ -800,22 +348,17 @@ class AgentLoop:
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = self._build_compressed_history_view(session)
|
||||
history = session.get_history(max_messages=0)
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
)
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
|
||||
if self._last_turn_prompt_tokens > 0:
|
||||
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
|
||||
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
|
||||
else:
|
||||
session.metadata.pop("_last_prompt_tokens", None)
|
||||
session.metadata.pop("_last_prompt_source", None)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background_compression(session.key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
@@ -829,19 +372,12 @@ class AgentLoop:
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
try:
|
||||
# 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中
|
||||
if session.messages:
|
||||
ok, _ = await self.context.memory.consolidate_chunk(
|
||||
session.messages,
|
||||
self.provider,
|
||||
self.model,
|
||||
if not await self.memory_consolidator.archive_unconsolidated(session):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
if not ok:
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("/new archival failed for {}", session.key)
|
||||
return OutboundMessage(
|
||||
@@ -859,23 +395,20 @@ class AgentLoop:
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
# 正常对话:使用压缩后的历史视图(压缩在回合结束后进行)
|
||||
history = self._build_compressed_history_view(session)
|
||||
history = session.get_history(max_messages=0)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
)
|
||||
# Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:")
|
||||
if session_key and session_key.startswith("cron:"):
|
||||
if initial_messages and initial_messages[0].get("role") == "system":
|
||||
initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}"
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
@@ -885,23 +418,16 @@ class AgentLoop:
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
|
||||
final_content, _, all_msgs, total_tokens_this_turn, token_source = await self._run_agent_loop(
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
initial_messages, on_progress=on_progress or _bus_progress,
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
if self._last_turn_prompt_tokens > 0:
|
||||
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
|
||||
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
|
||||
else:
|
||||
session.metadata.pop("_last_prompt_tokens", None)
|
||||
session.metadata.pop("_last_prompt_source", None)
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background_compression(session.key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
@@ -913,7 +439,7 @@ class AgentLoop:
|
||||
metadata=msg.metadata or {},
|
||||
)
|
||||
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int, total_tokens_this_turn: int = 0) -> None:
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
for m in messages[skip:]:
|
||||
@@ -947,14 +473,6 @@ class AgentLoop:
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
# Update cumulative token count for compression tracking
|
||||
if total_tokens_this_turn > 0:
|
||||
current_cumulative = session.metadata.get("_cumulative_tokens", 0)
|
||||
if isinstance(current_cumulative, (int, float)):
|
||||
session.metadata["_cumulative_tokens"] = int(current_cumulative) + total_tokens_this_turn
|
||||
else:
|
||||
session.metadata["_cumulative_tokens"] = total_tokens_this_turn
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user