refactor: implement token-based context compression mechanism
Major changes: - Replace message-count-based memory window with token-budget-based compression - Add max_tokens_input, compression_start_ratio, compression_target_ratio config - Implement _maybe_compress_history() that triggers based on prompt token usage - Use _build_compressed_history_view() to provide compressed history to LLM - Refactor MemoryStore.consolidate() -> consolidate_chunk() for chunk-based compression - Remove last_consolidated from Session, use _compressed_until metadata instead - Add background compression scheduling to avoid blocking message processing Key improvements: - Compression now based on actual token usage, not arbitrary message counts - Better handling of long conversations with large context windows - Non-destructive compression: old messages remain in session, but excluded from prompt - Automatic compression when history exceeds configured token thresholds
This commit is contained in:
@@ -5,19 +5,24 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import weakref
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import tiktoken # type: ignore
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
tiktoken = None
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
@@ -55,8 +60,11 @@ class AgentLoop:
|
||||
max_iterations: int = 40,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
memory_window: int = 100,
|
||||
memory_window: int | None = None, # backward-compat only (unused)
|
||||
reasoning_effort: str | None = None,
|
||||
max_tokens_input: int = 128_000,
|
||||
compression_start_ratio: float = 0.7,
|
||||
compression_target_ratio: float = 0.4,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
@@ -74,9 +82,18 @@ class AgentLoop:
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.temperature = temperature
|
||||
# max_tokens: per-call output token cap (maxTokensOutput in config)
|
||||
self.max_tokens = max_tokens
|
||||
# Keep legacy attribute for older call sites/tests; compression no longer uses it.
|
||||
self.memory_window = memory_window
|
||||
self.reasoning_effort = reasoning_effort
|
||||
# max_tokens_input: model native context window (maxTokensInput in config)
|
||||
self.max_tokens_input = max_tokens_input
|
||||
# Token-based compression watermarks (fractions of available input budget)
|
||||
self.compression_start_ratio = compression_start_ratio
|
||||
self.compression_target_ratio = compression_target_ratio
|
||||
# Reserve tokens for safety margin
|
||||
self._reserve_tokens = 1000
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@@ -105,18 +122,373 @@ class AgentLoop:
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self._register_default_tools()
|
||||
|
||||
@staticmethod
|
||||
def _estimate_prompt_tokens(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
"""Estimate prompt tokens with tiktoken (fallback only)."""
|
||||
if tiktoken is None:
|
||||
return 0
|
||||
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
parts: list[str] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
txt = part.get("text", "")
|
||||
if txt:
|
||||
parts.append(txt)
|
||||
if tools:
|
||||
parts.append(json.dumps(tools, ensure_ascii=False))
|
||||
return len(enc.encode("\n".join(parts)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _estimate_prompt_tokens_chain(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Unified prompt-token estimation: provider counter -> tiktoken."""
|
||||
provider_counter = getattr(self.provider, "estimate_prompt_tokens", None)
|
||||
if callable(provider_counter):
|
||||
try:
|
||||
tokens, source = provider_counter(messages, tools, self.model)
|
||||
if isinstance(tokens, (int, float)) and tokens > 0:
|
||||
return int(tokens), str(source or "provider_counter")
|
||||
except Exception:
|
||||
logger.debug("Provider token counter failed; fallback to tiktoken")
|
||||
|
||||
estimated = self._estimate_prompt_tokens(messages, tools)
|
||||
if estimated > 0:
|
||||
return int(estimated), "tiktoken"
|
||||
return 0, "none"
|
||||
|
||||
@staticmethod
|
||||
def _estimate_completion_tokens(content: str) -> int:
|
||||
"""Estimate completion tokens with tiktoken (fallback only)."""
|
||||
if tiktoken is None:
|
||||
return 0
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
return len(enc.encode(content or ""))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _get_compressed_until(self, session: Session) -> int:
|
||||
"""Read/normalize compressed boundary and migrate old metadata format."""
|
||||
raw = session.metadata.get("_compressed_until", 0)
|
||||
try:
|
||||
compressed_until = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
compressed_until = 0
|
||||
|
||||
if compressed_until <= 0:
|
||||
ranges = session.metadata.get("_compressed_ranges")
|
||||
if isinstance(ranges, list):
|
||||
inferred = 0
|
||||
for item in ranges:
|
||||
if not isinstance(item, (list, tuple)) or len(item) != 2:
|
||||
continue
|
||||
try:
|
||||
inferred = max(inferred, int(item[1]))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
compressed_until = inferred
|
||||
|
||||
compressed_until = max(0, min(compressed_until, len(session.messages)))
|
||||
session.metadata["_compressed_until"] = compressed_until
|
||||
# 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段
|
||||
session.metadata.pop("_compressed_ranges", None)
|
||||
session.metadata.pop("_cumulative_tokens", None)
|
||||
return compressed_until
|
||||
|
||||
def _set_compressed_until(self, session: Session, idx: int) -> None:
|
||||
"""Persist a contiguous compressed boundary."""
|
||||
session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages)))
|
||||
session.metadata.pop("_compressed_ranges", None)
|
||||
session.metadata.pop("_cumulative_tokens", None)
|
||||
|
||||
@staticmethod
|
||||
def _estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||
"""Rough token estimate for a single persisted message."""
|
||||
content = message.get("content")
|
||||
parts: list[str] = []
|
||||
if isinstance(content, str):
|
||||
parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
txt = part.get("text", "")
|
||||
if txt:
|
||||
parts.append(txt)
|
||||
else:
|
||||
parts.append(json.dumps(part, ensure_ascii=False))
|
||||
elif content is not None:
|
||||
parts.append(json.dumps(content, ensure_ascii=False))
|
||||
|
||||
for key in ("name", "tool_call_id"):
|
||||
val = message.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
parts.append(val)
|
||||
if message.get("tool_calls"):
|
||||
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||
|
||||
payload = "\n".join(parts)
|
||||
if not payload:
|
||||
return 1
|
||||
if tiktoken is not None:
|
||||
try:
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
return max(1, len(enc.encode(payload)))
|
||||
except Exception:
|
||||
pass
|
||||
return max(1, len(payload) // 4)
|
||||
|
||||
def _pick_compression_chunk_by_tokens(
|
||||
self,
|
||||
session: Session,
|
||||
reduction_tokens: int,
|
||||
*,
|
||||
tail_keep: int = 12,
|
||||
) -> tuple[int, int, int] | None:
|
||||
"""
|
||||
Pick one contiguous old chunk so its estimated size is roughly enough
|
||||
to reduce `reduction_tokens`.
|
||||
"""
|
||||
messages = session.messages
|
||||
start = self._get_compressed_until(session)
|
||||
if len(messages) - start <= tail_keep + 2:
|
||||
return None
|
||||
|
||||
end_limit = len(messages) - tail_keep
|
||||
if end_limit - start < 2:
|
||||
return None
|
||||
|
||||
target = max(1, reduction_tokens)
|
||||
end = start
|
||||
collected = 0
|
||||
while end < end_limit and collected < target:
|
||||
collected += self._estimate_message_tokens(messages[end])
|
||||
end += 1
|
||||
|
||||
if end - start < 2:
|
||||
end = min(end_limit, start + 2)
|
||||
collected = sum(self._estimate_message_tokens(m) for m in messages[start:end])
|
||||
if end - start < 2:
|
||||
return None
|
||||
return start, end, collected
|
||||
|
||||
def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||
"""
|
||||
Estimate current full prompt tokens for this session view
|
||||
(system + compressed history view + runtime/user placeholder + tools).
|
||||
"""
|
||||
history = self._build_compressed_history_view(session)
|
||||
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||
probe_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message="[token-probe]",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
)
|
||||
return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions())
|
||||
|
||||
async def _maybe_compress_history(
|
||||
self,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
End-of-turn policy:
|
||||
- Estimate current prompt usage from persisted session view.
|
||||
- If above start ratio, perform one best-effort compression chunk.
|
||||
"""
|
||||
if not session.messages:
|
||||
self._set_compressed_until(session, 0)
|
||||
return
|
||||
|
||||
budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens)
|
||||
start_threshold = int(budget * self.compression_start_ratio)
|
||||
target_threshold = int(budget * self.compression_target_ratio)
|
||||
if target_threshold >= start_threshold:
|
||||
target_threshold = max(0, start_threshold - 1)
|
||||
|
||||
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
|
||||
current_ratio = current_tokens / budget if budget else 0.0
|
||||
if current_tokens <= 0:
|
||||
logger.debug("Compression skip {}: token estimate unavailable", session.key)
|
||||
return
|
||||
if current_tokens < start_threshold:
|
||||
logger.debug(
|
||||
"Compression idle {}: {}/{} ({:.1%}) via {}",
|
||||
session.key,
|
||||
current_tokens,
|
||||
budget,
|
||||
current_ratio,
|
||||
token_source,
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"Compression trigger {}: {}/{} ({:.1%}) via {}",
|
||||
session.key,
|
||||
current_tokens,
|
||||
budget,
|
||||
current_ratio,
|
||||
token_source,
|
||||
)
|
||||
|
||||
reduction_by_target = max(0, current_tokens - target_threshold)
|
||||
reduction_by_delta = max(1, start_threshold - target_threshold)
|
||||
reduction_need = max(reduction_by_target, reduction_by_delta)
|
||||
|
||||
chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10)
|
||||
if chunk_range is None:
|
||||
logger.info("Compression skipped for {}: no compressible chunk", session.key)
|
||||
return
|
||||
|
||||
start_idx, end_idx, estimated_chunk_tokens = chunk_range
|
||||
chunk = session.messages[start_idx:end_idx]
|
||||
if len(chunk) < 2:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})",
|
||||
session.key,
|
||||
start_idx,
|
||||
end_idx - 1,
|
||||
len(chunk),
|
||||
estimated_chunk_tokens,
|
||||
reduction_need,
|
||||
)
|
||||
success, _ = await self.context.memory.consolidate_chunk(
|
||||
chunk,
|
||||
self.provider,
|
||||
self.model,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("Compression aborted for {}: consolidation failed", session.key)
|
||||
return
|
||||
|
||||
self._set_compressed_until(session, end_idx)
|
||||
self.sessions.save(session)
|
||||
|
||||
after_tokens, after_source = self._estimate_session_prompt_tokens(session)
|
||||
after_ratio = after_tokens / budget if budget else 0.0
|
||||
reduced = max(0, current_tokens - after_tokens)
|
||||
reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0
|
||||
logger.info(
|
||||
"Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})",
|
||||
session.key,
|
||||
after_tokens,
|
||||
budget,
|
||||
after_ratio,
|
||||
after_source,
|
||||
reduced,
|
||||
reduced_ratio,
|
||||
)
|
||||
|
||||
def _schedule_background_compression(self, session_key: str) -> None:
|
||||
"""Schedule best-effort background compression for a session."""
|
||||
existing = self._compression_tasks.get(session_key)
|
||||
if existing is not None and not existing.done():
|
||||
return
|
||||
|
||||
async def _runner() -> None:
|
||||
session = self.sessions.get_or_create(session_key)
|
||||
try:
|
||||
await self._maybe_compress_history(session)
|
||||
except Exception:
|
||||
logger.exception("Background compression failed for {}", session_key)
|
||||
|
||||
task = asyncio.create_task(_runner())
|
||||
self._compression_tasks[session_key] = task
|
||||
|
||||
def _cleanup(t: asyncio.Task) -> None:
|
||||
cur = self._compression_tasks.get(session_key)
|
||||
if cur is t:
|
||||
self._compression_tasks.pop(session_key, None)
|
||||
try:
|
||||
t.result()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
task.add_done_callback(_cleanup)
|
||||
|
||||
async def wait_for_background_compression(self, timeout_s: float | None = None) -> None:
|
||||
"""Wait for currently scheduled compression tasks."""
|
||||
pending = [t for t in self._compression_tasks.values() if not t.done()]
|
||||
if not pending:
|
||||
return
|
||||
|
||||
logger.info("Waiting for {} background compression task(s)", len(pending))
|
||||
waiter = asyncio.gather(*pending, return_exceptions=True)
|
||||
if timeout_s is None:
|
||||
await waiter
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(waiter, timeout=timeout_s)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Background compression wait timed out after {}s ({} task(s) still running)",
|
||||
timeout_s,
|
||||
len([t for t in self._compression_tasks.values() if not t.done()]),
|
||||
)
|
||||
|
||||
def _build_compressed_history_view(
|
||||
self,
|
||||
session: Session,
|
||||
) -> list[dict]:
|
||||
"""Build non-destructive history view using the compressed boundary."""
|
||||
compressed_until = self._get_compressed_until(session)
|
||||
if compressed_until <= 0:
|
||||
return session.get_history(max_messages=0)
|
||||
|
||||
notice_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"As your assistant, I have compressed earlier context. "
|
||||
"If you need details, please check memory/HISTORY.md."
|
||||
),
|
||||
}
|
||||
|
||||
tail: list[dict[str, Any]] = []
|
||||
for msg in session.messages[compressed_until:]:
|
||||
entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")}
|
||||
for k in ("tool_calls", "tool_call_id", "name"):
|
||||
if k in msg:
|
||||
entry[k] = msg[k]
|
||||
tail.append(entry)
|
||||
|
||||
# Drop leading non-user entries from tail to avoid orphan tool blocks.
|
||||
for i, m in enumerate(tail):
|
||||
if m.get("role") == "user":
|
||||
tail = tail[i:]
|
||||
break
|
||||
else:
|
||||
tail = []
|
||||
|
||||
return [notice_msg, *tail]
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(ValidateDeployJSONTool())
|
||||
self.tools.register(ValidateUsageYAMLTool())
|
||||
self.tools.register(HuggingFaceModelSearchTool())
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
@@ -181,25 +553,78 @@ class AgentLoop:
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
||||
) -> tuple[str | None, list[str], list[dict], int, str]:
|
||||
"""
|
||||
Run the agent iteration loop.
|
||||
|
||||
Returns:
|
||||
(final_content, tools_used, messages, total_tokens_this_turn, token_source)
|
||||
total_tokens_this_turn: total tokens (prompt + completion) for this turn
|
||||
token_source: provider_total / provider_sum / provider_prompt /
|
||||
provider_counter+tiktoken_completion / tiktoken / none
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
total_tokens_this_turn = 0
|
||||
token_source = "none"
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
tool_defs = self.tools.get_definitions()
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
)
|
||||
|
||||
# Prefer provider usage from the turn-ending model call; fallback to tiktoken.
|
||||
# Calculate total tokens (prompt + completion) for this turn.
|
||||
usage = response.usage or {}
|
||||
t_tokens = usage.get("total_tokens")
|
||||
p_tokens = usage.get("prompt_tokens")
|
||||
c_tokens = usage.get("completion_tokens")
|
||||
|
||||
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
|
||||
total_tokens_this_turn = int(t_tokens)
|
||||
token_source = "provider_total"
|
||||
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
|
||||
# If we have both prompt and completion tokens, sum them
|
||||
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
|
||||
token_source = "provider_sum"
|
||||
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
|
||||
total_tokens_this_turn = int(p_tokens)
|
||||
token_source = "provider_prompt"
|
||||
else:
|
||||
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
|
||||
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
|
||||
estimated_completion = self._estimate_completion_tokens(response.content or "")
|
||||
total_tokens_this_turn = estimated_prompt + estimated_completion
|
||||
if total_tokens_this_turn > 0:
|
||||
token_source = (
|
||||
"tiktoken"
|
||||
if prompt_source == "tiktoken"
|
||||
else f"{prompt_source}+tiktoken_completion"
|
||||
)
|
||||
if total_tokens_this_turn <= 0:
|
||||
total_tokens_this_turn = 0
|
||||
token_source = "none"
|
||||
|
||||
logger.debug(
|
||||
"Turn token usage: source={}, total={}, prompt={}, completion={}",
|
||||
token_source,
|
||||
total_tokens_this_turn,
|
||||
p_tokens if isinstance(p_tokens, (int, float)) else None,
|
||||
c_tokens if isinstance(c_tokens, (int, float)) else None,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
@@ -254,7 +679,7 @@ class AgentLoop:
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
return final_content, tools_used, messages, total_tokens_this_turn, token_source
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
@@ -279,6 +704,9 @@ class AgentLoop:
|
||||
"""Cancel all active tasks and subagents for the session."""
|
||||
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||
comp = self._compression_tasks.get(msg.session_key)
|
||||
if comp is not None and not comp.done() and comp.cancel():
|
||||
cancelled += 1
|
||||
for t in tasks:
|
||||
try:
|
||||
await t
|
||||
@@ -325,6 +753,9 @@ class AgentLoop:
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
for task in list(self._compression_tasks.values()):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(
|
||||
@@ -342,14 +773,15 @@ class AgentLoop:
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
history = self._build_compressed_history_view(session)
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background_compression(session.key)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
@@ -362,27 +794,27 @@ class AgentLoop:
|
||||
# Slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
self._consolidating.add(session.key)
|
||||
try:
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if snapshot:
|
||||
temp = Session(key=session.key)
|
||||
temp.messages = list(snapshot)
|
||||
if not await self._consolidate_memory(temp, archive_all=True):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
# 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中
|
||||
if session.messages:
|
||||
ok, _ = await self.context.memory.consolidate_chunk(
|
||||
session.messages,
|
||||
self.provider,
|
||||
self.model,
|
||||
)
|
||||
if not ok:
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("/new archival failed for {}", session.key)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
@@ -393,36 +825,23 @@ class AgentLoop:
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||
|
||||
unconsolidated = len(session.messages) - session.last_consolidated
|
||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||
self._consolidating.add(session.key)
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
|
||||
async def _consolidate_and_unlock():
|
||||
try:
|
||||
async with lock:
|
||||
await self._consolidate_memory(session)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
_task = asyncio.current_task()
|
||||
if _task is not None:
|
||||
self._consolidation_tasks.discard(_task)
|
||||
|
||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
||||
self._consolidation_tasks.add(_task)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
# 正常对话:使用压缩后的历史视图(压缩在回合结束后进行)
|
||||
history = self._build_compressed_history_view(session)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
)
|
||||
# Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:")
|
||||
if session_key and session_key.startswith("cron:"):
|
||||
if initial_messages and initial_messages[0].get("role") == "system":
|
||||
initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}"
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
@@ -432,7 +851,7 @@ class AgentLoop:
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||
initial_messages, on_progress=on_progress or _bus_progress,
|
||||
)
|
||||
|
||||
@@ -441,6 +860,7 @@ class AgentLoop:
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
self._schedule_background_compression(session.key)
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
@@ -487,13 +907,6 @@ class AgentLoop:
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
||||
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
||||
return await MemoryStore(self.workspace).consolidate(
|
||||
session, self.provider, self.model,
|
||||
archive_all=archive_all, memory_window=self.memory_window,
|
||||
)
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
|
||||
@@ -66,36 +66,25 @@ class MemoryStore:
|
||||
long_term = self.read_long_term()
|
||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||
|
||||
async def consolidate(
|
||||
async def consolidate_chunk(
|
||||
self,
|
||||
session: Session,
|
||||
messages: list[dict],
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
*,
|
||||
archive_all: bool = False,
|
||||
memory_window: int = 50,
|
||||
) -> bool:
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Consolidate a chunk of messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||
|
||||
Returns True on success (including no-op), False on failure.
|
||||
Returns (success, None).
|
||||
|
||||
- success: True on success (including no-op), False on failure.
|
||||
- The second return value is reserved for future use (e.g. RAG-style summaries) and is
|
||||
always None in the current implementation.
|
||||
"""
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
keep_count = 0
|
||||
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
||||
else:
|
||||
keep_count = memory_window // 2
|
||||
if len(session.messages) <= keep_count:
|
||||
return True
|
||||
if len(session.messages) - session.last_consolidated <= 0:
|
||||
return True
|
||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||
if not old_messages:
|
||||
return True
|
||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
||||
if not messages:
|
||||
return True, None
|
||||
|
||||
lines = []
|
||||
for m in old_messages:
|
||||
for m in messages:
|
||||
if not m.get("content"):
|
||||
continue
|
||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||
@@ -113,7 +102,19 @@ class MemoryStore:
|
||||
try:
|
||||
response = await provider.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a memory consolidation agent.\n"
|
||||
"Your job is to:\n"
|
||||
"1) Append a concise but grep-friendly entry to HISTORY.md summarizing key events, decisions and topics.\n"
|
||||
" - Write 1 paragraph of 2–5 sentences that starts with [YYYY-MM-DD HH:MM].\n"
|
||||
" - Include concrete names, IDs and numbers so it is easy to search with grep.\n"
|
||||
"2) Update long-term MEMORY.md with stable facts and user preferences as markdown, including all existing facts plus new ones.\n"
|
||||
"3) Optionally return a short context_summary (1–3 sentences) that will replace the raw messages in future dialogue history.\n\n"
|
||||
"Always call the save_memory tool with history_entry, memory_update and (optionally) context_summary."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
tools=_SAVE_MEMORY_TOOL,
|
||||
@@ -122,7 +123,7 @@ class MemoryStore:
|
||||
|
||||
if not response.has_tool_calls:
|
||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
args = response.tool_calls[0].arguments
|
||||
# Some providers return arguments as a JSON string instead of dict
|
||||
@@ -134,10 +135,10 @@ class MemoryStore:
|
||||
args = args[0]
|
||||
else:
|
||||
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
||||
return False
|
||||
return False, None
|
||||
if not isinstance(args, dict):
|
||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||
return False
|
||||
return False, None
|
||||
|
||||
if entry := args.get("history_entry"):
|
||||
if not isinstance(entry, str):
|
||||
@@ -149,9 +150,8 @@ class MemoryStore:
|
||||
if update != current_memory:
|
||||
self.write_long_term(update)
|
||||
|
||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||
return True
|
||||
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||
return True, None
|
||||
except Exception:
|
||||
logger.exception("Memory consolidation failed")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
@@ -189,11 +189,22 @@ class SlackConfig(Base):
|
||||
|
||||
|
||||
class QQConfig(Base):
|
||||
"""QQ channel configuration using botpy SDK."""
|
||||
"""QQ channel configuration.
|
||||
|
||||
Supports two implementations:
|
||||
1. Official botpy SDK: requires app_id and secret
|
||||
2. OneBot protocol: requires api_url (and optionally ws_reverse_url, bot_qq, access_token)
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
# Official botpy SDK fields
|
||||
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
|
||||
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
||||
# OneBot protocol fields
|
||||
api_url: str = "" # OneBot HTTP API URL (e.g. "http://localhost:5700")
|
||||
ws_reverse_url: str = "" # OneBot WebSocket reverse URL (e.g. "ws://localhost:8080/ws/reverse")
|
||||
bot_qq: int | None = None # Bot's QQ number (for filtering self messages)
|
||||
access_token: str = "" # Optional access token for OneBot API
|
||||
allow_from: list[str] = Field(
|
||||
default_factory=list
|
||||
) # Allowed user openids (empty = public access)
|
||||
@@ -226,10 +237,18 @@ class AgentDefaults(Base):
|
||||
provider: str = (
|
||||
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||
)
|
||||
max_tokens: int = 8192
|
||||
# 原生上下文最大窗口(通常对应模型的 max_input_tokens / max_context_tokens)
|
||||
# 默认按照主流大模型(如 GPT-4o、Claude 3.x 等)的 128k 上下文给一个宽松上限,实际应根据所选模型文档手动调整。
|
||||
max_tokens_input: int = 128_000
|
||||
# 默认单次回复的最大输出 token 上限(调用时可按需要再做截断或比例分配)
|
||||
# 8192 足以覆盖大多数实际对话/工具使用场景,同样可按需手动调整。
|
||||
max_tokens_output: int = 8192
|
||||
# 会话历史压缩触发比例:当估算的输入 token 使用量 >= maxTokensInput * compressionStartRatio 时开始压缩。
|
||||
compression_start_ratio: float = 0.7
|
||||
# 会话历史压缩目标比例:每轮压缩后尽量把估算的输入 token 使用量压到 maxTokensInput * compressionTargetRatio 附近。
|
||||
compression_target_ratio: float = 0.4
|
||||
temperature: float = 0.1
|
||||
max_tool_iterations: int = 40
|
||||
memory_window: int = 100
|
||||
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_legacy_sessions_dir
|
||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||
|
||||
|
||||
@@ -30,7 +29,6 @@ class Session:
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||
"""Add a message to the session."""
|
||||
@@ -44,9 +42,13 @@ class Session:
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
"""
|
||||
Return messages for LLM input, aligned to a user turn.
|
||||
|
||||
- max_messages > 0 时只保留最近 max_messages 条;
|
||||
- max_messages <= 0 时不做条数截断,返回全部消息。
|
||||
"""
|
||||
sliced = self.messages if max_messages <= 0 else self.messages[-max_messages:]
|
||||
|
||||
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
||||
for i, m in enumerate(sliced):
|
||||
@@ -66,7 +68,7 @@ class Session:
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages and reset session to initial state."""
|
||||
self.messages = []
|
||||
self.last_consolidated = 0
|
||||
self.metadata = {}
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
|
||||
@@ -80,7 +82,7 @@ class SessionManager:
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
||||
self.legacy_sessions_dir = get_legacy_sessions_dir()
|
||||
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
|
||||
self._cache: dict[str, Session] = {}
|
||||
|
||||
def _get_session_path(self, key: str) -> Path:
|
||||
@@ -132,7 +134,6 @@ class SessionManager:
|
||||
messages = []
|
||||
metadata = {}
|
||||
created_at = None
|
||||
last_consolidated = 0
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
@@ -145,7 +146,6 @@ class SessionManager:
|
||||
if data.get("_type") == "metadata":
|
||||
metadata = data.get("metadata", {})
|
||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||
last_consolidated = data.get("last_consolidated", 0)
|
||||
else:
|
||||
messages.append(data)
|
||||
|
||||
@@ -154,7 +154,6 @@ class SessionManager:
|
||||
messages=messages,
|
||||
created_at=created_at or datetime.now(),
|
||||
metadata=metadata,
|
||||
last_consolidated=last_consolidated
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load session {}: {}", key, e)
|
||||
@@ -171,7 +170,6 @@ class SessionManager:
|
||||
"created_at": session.created_at.isoformat(),
|
||||
"updated_at": session.updated_at.isoformat(),
|
||||
"metadata": session.metadata,
|
||||
"last_consolidated": session.last_consolidated
|
||||
}
|
||||
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
||||
for msg in session.messages:
|
||||
|
||||
Reference in New Issue
Block a user