From dbc518098e913d2f382121820dd58bbaf7a04234 Mon Sep 17 00:00:00 2001 From: VITOHJL Date: Sun, 8 Mar 2026 14:20:16 +0800 Subject: [PATCH] refactor: implement token-based context compression mechanism Major changes: - Replace message-count-based memory window with token-budget-based compression - Add max_tokens_input, compression_start_ratio, compression_target_ratio config - Implement _maybe_compress_history() that triggers based on prompt token usage - Use _build_compressed_history_view() to provide compressed history to LLM - Refactor MemoryStore.consolidate() -> consolidate_chunk() for chunk-based compression - Remove last_consolidated from Session, use _compressed_until metadata instead - Add background compression scheduling to avoid blocking message processing Key improvements: - Compression now based on actual token usage, not arbitrary message counts - Better handling of long conversations with large context windows - Non-destructive compression: old messages remain in session, but excluded from prompt - Automatic compression when history exceeds configured token thresholds --- nanobot/agent/loop.py | 521 +++++++++++++++++++++++++++++++++---- nanobot/agent/memory.py | 62 ++--- nanobot/config/schema.py | 25 +- nanobot/session/manager.py | 20 +- 4 files changed, 529 insertions(+), 99 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index ca9a06e..696e2a7 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -5,19 +5,24 @@ from __future__ import annotations import asyncio import json import re -import weakref from contextlib import AsyncExitStack from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger +try: + import tiktoken # type: ignore +except Exception: # pragma: no cover - optional dependency + tiktoken = None + from nanobot.agent.context import ContextBuilder -from nanobot.agent.memory import MemoryStore from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.shell import ExecTool from nanobot.agent.tools.spawn import SpawnTool @@ -55,8 +60,11 @@ class AgentLoop: max_iterations: int = 40, temperature: float = 0.1, max_tokens: int = 4096, - memory_window: int = 100, + memory_window: int | None = None, # backward-compat only (unused) reasoning_effort: str | None = None, + max_tokens_input: int = 128_000, + compression_start_ratio: float = 0.7, + compression_target_ratio: float = 0.4, brave_api_key: str | None = None, web_proxy: str | None = None, exec_config: ExecToolConfig | None = None, @@ -74,9 +82,18 @@ class AgentLoop: self.model = model or provider.get_default_model() self.max_iterations = max_iterations self.temperature = temperature + # max_tokens: per-call output token cap (maxTokensOutput in config) self.max_tokens = max_tokens + # Keep legacy attribute for older call sites/tests; compression no longer uses it. self.memory_window = memory_window self.reasoning_effort = reasoning_effort + # max_tokens_input: model native context window (maxTokensInput in config) + self.max_tokens_input = max_tokens_input + # Token-based compression watermarks (fractions of available input budget) + self.compression_start_ratio = compression_start_ratio + self.compression_target_ratio = compression_target_ratio + # Reserve tokens for safety margin + self._reserve_tokens = 1000 self.brave_api_key = brave_api_key self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -105,18 +122,373 @@ class AgentLoop: self._mcp_stack: AsyncExitStack | None = None self._mcp_connected = False self._mcp_connecting = False - self._consolidating: set[str] = set() # Session keys with consolidation in progress - self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks - self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks + self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task self._processing_lock = asyncio.Lock() self._register_default_tools() + @staticmethod + def _estimate_prompt_tokens( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + ) -> int: + """Estimate prompt tokens with tiktoken (fallback only).""" + if tiktoken is None: + return 0 + + try: + enc = tiktoken.get_encoding("cl100k_base") + parts: list[str] = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + txt = part.get("text", "") + if txt: + parts.append(txt) + if tools: + parts.append(json.dumps(tools, ensure_ascii=False)) + return len(enc.encode("\n".join(parts))) + except Exception: + return 0 + + def _estimate_prompt_tokens_chain( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + ) -> tuple[int, str]: + """Unified prompt-token estimation: provider counter -> tiktoken.""" + provider_counter = getattr(self.provider, "estimate_prompt_tokens", None) + if callable(provider_counter): + try: + tokens, source = provider_counter(messages, tools, self.model) + if isinstance(tokens, (int, float)) and tokens > 0: + return int(tokens), str(source or "provider_counter") + except Exception: + logger.debug("Provider token counter failed; fallback to tiktoken") + + estimated = self._estimate_prompt_tokens(messages, tools) + if estimated > 0: + return int(estimated), "tiktoken" + return 0, "none" + + @staticmethod + def _estimate_completion_tokens(content: str) -> int: + """Estimate completion tokens with tiktoken (fallback only).""" + if tiktoken is None: + return 0 + try: + enc = tiktoken.get_encoding("cl100k_base") + return len(enc.encode(content or "")) + except Exception: + return 0 + + def _get_compressed_until(self, session: Session) -> int: + """Read/normalize compressed boundary and migrate old metadata format.""" + raw = session.metadata.get("_compressed_until", 0) + try: + compressed_until = int(raw) + except (TypeError, ValueError): + compressed_until = 0 + + if compressed_until <= 0: + ranges = session.metadata.get("_compressed_ranges") + if isinstance(ranges, list): + inferred = 0 + for item in ranges: + if not isinstance(item, (list, tuple)) or len(item) != 2: + continue + try: + inferred = max(inferred, int(item[1])) + except (TypeError, ValueError): + continue + compressed_until = inferred + + compressed_until = max(0, min(compressed_until, len(session.messages))) + session.metadata["_compressed_until"] = compressed_until + # 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段 + session.metadata.pop("_compressed_ranges", None) + session.metadata.pop("_cumulative_tokens", None) + return compressed_until + + def _set_compressed_until(self, session: Session, idx: int) -> None: + """Persist a contiguous compressed boundary.""" + session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages))) + session.metadata.pop("_compressed_ranges", None) + session.metadata.pop("_cumulative_tokens", None) + + @staticmethod + def _estimate_message_tokens(message: dict[str, Any]) -> int: + """Rough token estimate for a single persisted message.""" + content = message.get("content") + parts: list[str] = [] + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + txt = part.get("text", "") + if txt: + parts.append(txt) + else: + parts.append(json.dumps(part, ensure_ascii=False)) + elif content is not None: + parts.append(json.dumps(content, ensure_ascii=False)) + + for key in ("name", "tool_call_id"): + val = message.get(key) + if isinstance(val, str) and val: + parts.append(val) + if message.get("tool_calls"): + parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) + + payload = "\n".join(parts) + if not payload: + return 1 + if tiktoken is not None: + try: + enc = tiktoken.get_encoding("cl100k_base") + return max(1, len(enc.encode(payload))) + except Exception: + pass + return max(1, len(payload) // 4) + + def _pick_compression_chunk_by_tokens( + self, + session: Session, + reduction_tokens: int, + *, + tail_keep: int = 12, + ) -> tuple[int, int, int] | None: + """ + Pick one contiguous old chunk so its estimated size is roughly enough + to reduce `reduction_tokens`. + """ + messages = session.messages + start = self._get_compressed_until(session) + if len(messages) - start <= tail_keep + 2: + return None + + end_limit = len(messages) - tail_keep + if end_limit - start < 2: + return None + + target = max(1, reduction_tokens) + end = start + collected = 0 + while end < end_limit and collected < target: + collected += self._estimate_message_tokens(messages[end]) + end += 1 + + if end - start < 2: + end = min(end_limit, start + 2) + collected = sum(self._estimate_message_tokens(m) for m in messages[start:end]) + if end - start < 2: + return None + return start, end, collected + + def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]: + """ + Estimate current full prompt tokens for this session view + (system + compressed history view + runtime/user placeholder + tools). + """ + history = self._build_compressed_history_view(session) + channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) + probe_messages = self.context.build_messages( + history=history, + current_message="[token-probe]", + channel=channel, + chat_id=chat_id, + ) + return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions()) + + async def _maybe_compress_history( + self, + session: Session, + ) -> None: + """ + End-of-turn policy: + - Estimate current prompt usage from persisted session view. + - If above start ratio, perform one best-effort compression chunk. + """ + if not session.messages: + self._set_compressed_until(session, 0) + return + + budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens) + start_threshold = int(budget * self.compression_start_ratio) + target_threshold = int(budget * self.compression_target_ratio) + if target_threshold >= start_threshold: + target_threshold = max(0, start_threshold - 1) + + current_tokens, token_source = self._estimate_session_prompt_tokens(session) + current_ratio = current_tokens / budget if budget else 0.0 + if current_tokens <= 0: + logger.debug("Compression skip {}: token estimate unavailable", session.key) + return + if current_tokens < start_threshold: + logger.debug( + "Compression idle {}: {}/{} ({:.1%}) via {}", + session.key, + current_tokens, + budget, + current_ratio, + token_source, + ) + return + logger.info( + "Compression trigger {}: {}/{} ({:.1%}) via {}", + session.key, + current_tokens, + budget, + current_ratio, + token_source, + ) + + reduction_by_target = max(0, current_tokens - target_threshold) + reduction_by_delta = max(1, start_threshold - target_threshold) + reduction_need = max(reduction_by_target, reduction_by_delta) + + chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10) + if chunk_range is None: + logger.info("Compression skipped for {}: no compressible chunk", session.key) + return + + start_idx, end_idx, estimated_chunk_tokens = chunk_range + chunk = session.messages[start_idx:end_idx] + if len(chunk) < 2: + return + + logger.info( + "Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})", + session.key, + start_idx, + end_idx - 1, + len(chunk), + estimated_chunk_tokens, + reduction_need, + ) + success, _ = await self.context.memory.consolidate_chunk( + chunk, + self.provider, + self.model, + ) + if not success: + logger.warning("Compression aborted for {}: consolidation failed", session.key) + return + + self._set_compressed_until(session, end_idx) + self.sessions.save(session) + + after_tokens, after_source = self._estimate_session_prompt_tokens(session) + after_ratio = after_tokens / budget if budget else 0.0 + reduced = max(0, current_tokens - after_tokens) + reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0 + logger.info( + "Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})", + session.key, + after_tokens, + budget, + after_ratio, + after_source, + reduced, + reduced_ratio, + ) + + def _schedule_background_compression(self, session_key: str) -> None: + """Schedule best-effort background compression for a session.""" + existing = self._compression_tasks.get(session_key) + if existing is not None and not existing.done(): + return + + async def _runner() -> None: + session = self.sessions.get_or_create(session_key) + try: + await self._maybe_compress_history(session) + except Exception: + logger.exception("Background compression failed for {}", session_key) + + task = asyncio.create_task(_runner()) + self._compression_tasks[session_key] = task + + def _cleanup(t: asyncio.Task) -> None: + cur = self._compression_tasks.get(session_key) + if cur is t: + self._compression_tasks.pop(session_key, None) + try: + t.result() + except BaseException: + pass + + task.add_done_callback(_cleanup) + + async def wait_for_background_compression(self, timeout_s: float | None = None) -> None: + """Wait for currently scheduled compression tasks.""" + pending = [t for t in self._compression_tasks.values() if not t.done()] + if not pending: + return + + logger.info("Waiting for {} background compression task(s)", len(pending)) + waiter = asyncio.gather(*pending, return_exceptions=True) + if timeout_s is None: + await waiter + return + + try: + await asyncio.wait_for(waiter, timeout=timeout_s) + except asyncio.TimeoutError: + logger.warning( + "Background compression wait timed out after {}s ({} task(s) still running)", + timeout_s, + len([t for t in self._compression_tasks.values() if not t.done()]), + ) + + def _build_compressed_history_view( + self, + session: Session, + ) -> list[dict]: + """Build non-destructive history view using the compressed boundary.""" + compressed_until = self._get_compressed_until(session) + if compressed_until <= 0: + return session.get_history(max_messages=0) + + notice_msg: dict[str, Any] = { + "role": "assistant", + "content": ( + "As your assistant, I have compressed earlier context. " + "If you need details, please check memory/HISTORY.md." + ), + } + + tail: list[dict[str, Any]] = [] + for msg in session.messages[compressed_until:]: + entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")} + for k in ("tool_calls", "tool_call_id", "name"): + if k in msg: + entry[k] = msg[k] + tail.append(entry) + + # Drop leading non-user entries from tail to avoid orphan tool blocks. + for i, m in enumerate(tail): + if m.get("role") == "user": + tail = tail[i:] + break + else: + tail = [] + + return [notice_msg, *tail] + def _register_default_tools(self) -> None: """Register the default set of tools.""" allowed_dir = self.workspace if self.restrict_to_workspace else None for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) + self.tools.register(ValidateDeployJSONTool()) + self.tools.register(ValidateUsageYAMLTool()) + self.tools.register(HuggingFaceModelSearchTool()) self.tools.register(ExecTool( working_dir=str(self.workspace), timeout=self.exec_config.timeout, @@ -181,25 +553,78 @@ class AgentLoop: self, initial_messages: list[dict], on_progress: Callable[..., Awaitable[None]] | None = None, - ) -> tuple[str | None, list[str], list[dict]]: - """Run the agent iteration loop. Returns (final_content, tools_used, messages).""" + ) -> tuple[str | None, list[str], list[dict], int, str]: + """ + Run the agent iteration loop. + + Returns: + (final_content, tools_used, messages, total_tokens_this_turn, token_source) + total_tokens_this_turn: total tokens (prompt + completion) for this turn + token_source: provider_total / provider_sum / provider_prompt / + provider_counter+tiktoken_completion / tiktoken / none + """ messages = initial_messages iteration = 0 final_content = None tools_used: list[str] = [] + total_tokens_this_turn = 0 + token_source = "none" while iteration < self.max_iterations: iteration += 1 + tool_defs = self.tools.get_definitions() + response = await self.provider.chat( messages=messages, - tools=self.tools.get_definitions(), + tools=tool_defs, model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, reasoning_effort=self.reasoning_effort, ) + # Prefer provider usage from the turn-ending model call; fallback to tiktoken. + # Calculate total tokens (prompt + completion) for this turn. + usage = response.usage or {} + t_tokens = usage.get("total_tokens") + p_tokens = usage.get("prompt_tokens") + c_tokens = usage.get("completion_tokens") + + if isinstance(t_tokens, (int, float)) and t_tokens > 0: + total_tokens_this_turn = int(t_tokens) + token_source = "provider_total" + elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)): + # If we have both prompt and completion tokens, sum them + total_tokens_this_turn = int(p_tokens) + int(c_tokens) + token_source = "provider_sum" + elif isinstance(p_tokens, (int, float)) and p_tokens > 0: + # Fallback: use prompt tokens only (completion might be 0 for tool calls) + total_tokens_this_turn = int(p_tokens) + token_source = "provider_prompt" + else: + # Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken. + estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs) + estimated_completion = self._estimate_completion_tokens(response.content or "") + total_tokens_this_turn = estimated_prompt + estimated_completion + if total_tokens_this_turn > 0: + token_source = ( + "tiktoken" + if prompt_source == "tiktoken" + else f"{prompt_source}+tiktoken_completion" + ) + if total_tokens_this_turn <= 0: + total_tokens_this_turn = 0 + token_source = "none" + + logger.debug( + "Turn token usage: source={}, total={}, prompt={}, completion={}", + token_source, + total_tokens_this_turn, + p_tokens if isinstance(p_tokens, (int, float)) else None, + c_tokens if isinstance(c_tokens, (int, float)) else None, + ) + if response.has_tool_calls: if on_progress: thought = self._strip_think(response.content) @@ -254,7 +679,7 @@ class AgentLoop: "without completing the task. You can try breaking the task into smaller steps." ) - return final_content, tools_used, messages + return final_content, tools_used, messages, total_tokens_this_turn, token_source async def run(self) -> None: """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" @@ -279,6 +704,9 @@ class AgentLoop: """Cancel all active tasks and subagents for the session.""" tasks = self._active_tasks.pop(msg.session_key, []) cancelled = sum(1 for t in tasks if not t.done() and t.cancel()) + comp = self._compression_tasks.get(msg.session_key) + if comp is not None and not comp.done() and comp.cancel(): + cancelled += 1 for t in tasks: try: await t @@ -325,6 +753,9 @@ class AgentLoop: def stop(self) -> None: """Stop the agent loop.""" self._running = False + for task in list(self._compression_tasks.values()): + if not task.done(): + task.cancel() logger.info("Agent loop stopping") async def _process_message( @@ -342,14 +773,15 @@ class AgentLoop: key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) - history = session.get_history(max_messages=self.memory_window) + history = self._build_compressed_history_view(session) messages = self.context.build_messages( history=history, current_message=msg.content, channel=channel, chat_id=chat_id, ) - final_content, _, all_msgs = await self._run_agent_loop(messages) + final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + self._schedule_background_compression(session.key) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -362,27 +794,27 @@ class AgentLoop: # Slash commands cmd = msg.content.strip().lower() if cmd == "/new": - lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) - self._consolidating.add(session.key) try: - async with lock: - snapshot = session.messages[session.last_consolidated:] - if snapshot: - temp = Session(key=session.key) - temp.messages = list(snapshot) - if not await self._consolidate_memory(temp, archive_all=True): - return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="Memory archival failed, session not cleared. Please try again.", - ) + # 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中 + if session.messages: + ok, _ = await self.context.memory.consolidate_chunk( + session.messages, + self.provider, + self.model, + ) + if not ok: + return OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="Memory archival failed, session not cleared. Please try again.", + ) except Exception: logger.exception("/new archival failed for {}", session.key) return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, content="Memory archival failed, session not cleared. Please try again.", ) - finally: - self._consolidating.discard(session.key) session.clear() self.sessions.save(session) @@ -393,36 +825,23 @@ class AgentLoop: return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands") - unconsolidated = len(session.messages) - session.last_consolidated - if (unconsolidated >= self.memory_window and session.key not in self._consolidating): - self._consolidating.add(session.key) - lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) - - async def _consolidate_and_unlock(): - try: - async with lock: - await self._consolidate_memory(session) - finally: - self._consolidating.discard(session.key) - _task = asyncio.current_task() - if _task is not None: - self._consolidation_tasks.discard(_task) - - _task = asyncio.create_task(_consolidate_and_unlock()) - self._consolidation_tasks.add(_task) - self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): if isinstance(message_tool, MessageTool): message_tool.start_turn() - history = session.get_history(max_messages=self.memory_window) + # 正常对话:使用压缩后的历史视图(压缩在回合结束后进行) + history = self._build_compressed_history_view(session) initial_messages = self.context.build_messages( history=history, current_message=msg.content, media=msg.media if msg.media else None, channel=msg.channel, chat_id=msg.chat_id, ) + # Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:") + if session_key and session_key.startswith("cron:"): + if initial_messages and initial_messages[0].get("role") == "system": + initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}" async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: meta = dict(msg.metadata or {}) @@ -432,7 +851,7 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, )) - final_content, _, all_msgs = await self._run_agent_loop( + final_content, _, all_msgs, _, _ = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, ) @@ -441,6 +860,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + self._schedule_background_compression(session.key) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None @@ -487,13 +907,6 @@ class AgentLoop: session.messages.append(entry) session.updated_at = datetime.now() - async def _consolidate_memory(self, session, archive_all: bool = False) -> bool: - """Delegate to MemoryStore.consolidate(). Returns True on success.""" - return await MemoryStore(self.workspace).consolidate( - session, self.provider, self.model, - archive_all=archive_all, memory_window=self.memory_window, - ) - async def process_direct( self, content: str, diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 21fe77d..c8896c8 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -66,36 +66,25 @@ class MemoryStore: long_term = self.read_long_term() return f"## Long-term Memory\n{long_term}" if long_term else "" - async def consolidate( + async def consolidate_chunk( self, - session: Session, + messages: list[dict], provider: LLMProvider, model: str, - *, - archive_all: bool = False, - memory_window: int = 50, - ) -> bool: - """Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call. + ) -> tuple[bool, str | None]: + """Consolidate a chunk of messages into MEMORY.md + HISTORY.md via LLM tool call. - Returns True on success (including no-op), False on failure. + Returns (success, None). + + - success: True on success (including no-op), False on failure. + - The second return value is reserved for future use (e.g. RAG-style summaries) and is + always None in the current implementation. """ - if archive_all: - old_messages = session.messages - keep_count = 0 - logger.info("Memory consolidation (archive_all): {} messages", len(session.messages)) - else: - keep_count = memory_window // 2 - if len(session.messages) <= keep_count: - return True - if len(session.messages) - session.last_consolidated <= 0: - return True - old_messages = session.messages[session.last_consolidated:-keep_count] - if not old_messages: - return True - logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count) + if not messages: + return True, None lines = [] - for m in old_messages: + for m in messages: if not m.get("content"): continue tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" @@ -113,7 +102,19 @@ class MemoryStore: try: response = await provider.chat( messages=[ - {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, + { + "role": "system", + "content": ( + "You are a memory consolidation agent.\n" + "Your job is to:\n" + "1) Append a concise but grep-friendly entry to HISTORY.md summarizing key events, decisions and topics.\n" + " - Write 1 paragraph of 2–5 sentences that starts with [YYYY-MM-DD HH:MM].\n" + " - Include concrete names, IDs and numbers so it is easy to search with grep.\n" + "2) Update long-term MEMORY.md with stable facts and user preferences as markdown, including all existing facts plus new ones.\n" + "3) Optionally return a short context_summary (1–3 sentences) that will replace the raw messages in future dialogue history.\n\n" + "Always call the save_memory tool with history_entry, memory_update and (optionally) context_summary." + ), + }, {"role": "user", "content": prompt}, ], tools=_SAVE_MEMORY_TOOL, @@ -122,7 +123,7 @@ class MemoryStore: if not response.has_tool_calls: logger.warning("Memory consolidation: LLM did not call save_memory, skipping") - return False + return False, None args = response.tool_calls[0].arguments # Some providers return arguments as a JSON string instead of dict @@ -134,10 +135,10 @@ class MemoryStore: args = args[0] else: logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list") - return False + return False, None if not isinstance(args, dict): logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) - return False + return False, None if entry := args.get("history_entry"): if not isinstance(entry, str): @@ -149,9 +150,8 @@ class MemoryStore: if update != current_memory: self.write_long_term(update) - session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count - logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) - return True + logger.info("Memory consolidation done for {} messages", len(messages)) + return True, None except Exception: logger.exception("Memory consolidation failed") - return False + return False, None diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 803cb61..1ebde20 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -189,11 +189,22 @@ class SlackConfig(Base): class QQConfig(Base): - """QQ channel configuration using botpy SDK.""" + """QQ channel configuration. + + Supports two implementations: + 1. Official botpy SDK: requires app_id and secret + 2. OneBot protocol: requires api_url (and optionally ws_reverse_url, bot_qq, access_token) + """ enabled: bool = False + # Official botpy SDK fields app_id: str = "" # 机器人 ID (AppID) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com + # OneBot protocol fields + api_url: str = "" # OneBot HTTP API URL (e.g. "http://localhost:5700") + ws_reverse_url: str = "" # OneBot WebSocket reverse URL (e.g. "ws://localhost:8080/ws/reverse") + bot_qq: int | None = None # Bot's QQ number (for filtering self messages) + access_token: str = "" # Optional access token for OneBot API allow_from: list[str] = Field( default_factory=list ) # Allowed user openids (empty = public access) @@ -226,10 +237,18 @@ class AgentDefaults(Base): provider: str = ( "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection ) - max_tokens: int = 8192 + # 原生上下文最大窗口(通常对应模型的 max_input_tokens / max_context_tokens) + # 默认按照主流大模型(如 GPT-4o、Claude 3.x 等)的 128k 上下文给一个宽松上限,实际应根据所选模型文档手动调整。 + max_tokens_input: int = 128_000 + # 默认单次回复的最大输出 token 上限(调用时可按需要再做截断或比例分配) + # 8192 足以覆盖大多数实际对话/工具使用场景,同样可按需手动调整。 + max_tokens_output: int = 8192 + # 会话历史压缩触发比例:当估算的输入 token 使用量 >= maxTokensInput * compressionStartRatio 时开始压缩。 + compression_start_ratio: float = 0.7 + # 会话历史压缩目标比例:每轮压缩后尽量把估算的输入 token 使用量压到 maxTokensInput * compressionTargetRatio 附近。 + compression_target_ratio: float = 0.4 temperature: float = 0.1 max_tool_iterations: int = 40 - memory_window: int = 100 reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index f0a6484..1cb8a51 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -9,7 +9,6 @@ from typing import Any from loguru import logger -from nanobot.config.paths import get_legacy_sessions_dir from nanobot.utils.helpers import ensure_dir, safe_filename @@ -30,7 +29,6 @@ class Session: created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) - last_consolidated: int = 0 # Number of messages already consolidated to files def add_message(self, role: str, content: str, **kwargs: Any) -> None: """Add a message to the session.""" @@ -44,9 +42,13 @@ class Session: self.updated_at = datetime.now() def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: - """Return unconsolidated messages for LLM input, aligned to a user turn.""" - unconsolidated = self.messages[self.last_consolidated:] - sliced = unconsolidated[-max_messages:] + """ + Return messages for LLM input, aligned to a user turn. + + - max_messages > 0 时只保留最近 max_messages 条; + - max_messages <= 0 时不做条数截断,返回全部消息。 + """ + sliced = self.messages if max_messages <= 0 else self.messages[-max_messages:] # Drop leading non-user messages to avoid orphaned tool_result blocks for i, m in enumerate(sliced): @@ -66,7 +68,7 @@ class Session: def clear(self) -> None: """Clear all messages and reset session to initial state.""" self.messages = [] - self.last_consolidated = 0 + self.metadata = {} self.updated_at = datetime.now() @@ -80,7 +82,7 @@ class SessionManager: def __init__(self, workspace: Path): self.workspace = workspace self.sessions_dir = ensure_dir(self.workspace / "sessions") - self.legacy_sessions_dir = get_legacy_sessions_dir() + self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions" self._cache: dict[str, Session] = {} def _get_session_path(self, key: str) -> Path: @@ -132,7 +134,6 @@ class SessionManager: messages = [] metadata = {} created_at = None - last_consolidated = 0 with open(path, encoding="utf-8") as f: for line in f: @@ -145,7 +146,6 @@ class SessionManager: if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None - last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) @@ -154,7 +154,6 @@ class SessionManager: messages=messages, created_at=created_at or datetime.now(), metadata=metadata, - last_consolidated=last_consolidated ) except Exception as e: logger.warning("Failed to load session {}: {}", key, e) @@ -171,7 +170,6 @@ class SessionManager: "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), "metadata": session.metadata, - "last_consolidated": session.last_consolidated } f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n") for msg in session.messages: