From dbc518098e913d2f382121820dd58bbaf7a04234 Mon Sep 17 00:00:00 2001 From: VITOHJL Date: Sun, 8 Mar 2026 14:20:16 +0800 Subject: [PATCH 1/7] refactor: implement token-based context compression mechanism Major changes: - Replace message-count-based memory window with token-budget-based compression - Add max_tokens_input, compression_start_ratio, compression_target_ratio config - Implement _maybe_compress_history() that triggers based on prompt token usage - Use _build_compressed_history_view() to provide compressed history to LLM - Refactor MemoryStore.consolidate() -> consolidate_chunk() for chunk-based compression - Remove last_consolidated from Session, use _compressed_until metadata instead - Add background compression scheduling to avoid blocking message processing Key improvements: - Compression now based on actual token usage, not arbitrary message counts - Better handling of long conversations with large context windows - Non-destructive compression: old messages remain in session, but excluded from prompt - Automatic compression when history exceeds configured token thresholds --- nanobot/agent/loop.py | 521 +++++++++++++++++++++++++++++++++---- nanobot/agent/memory.py | 62 ++--- nanobot/config/schema.py | 25 +- nanobot/session/manager.py | 20 +- 4 files changed, 529 insertions(+), 99 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index ca9a06e..696e2a7 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -5,19 +5,24 @@ from __future__ import annotations import asyncio import json import re -import weakref from contextlib import AsyncExitStack from pathlib import Path from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger +try: + import tiktoken # type: ignore +except Exception: # pragma: no cover - optional dependency + tiktoken = None + from nanobot.agent.context import ContextBuilder -from nanobot.agent.memory import MemoryStore from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.shell import ExecTool from nanobot.agent.tools.spawn import SpawnTool @@ -55,8 +60,11 @@ class AgentLoop: max_iterations: int = 40, temperature: float = 0.1, max_tokens: int = 4096, - memory_window: int = 100, + memory_window: int | None = None, # backward-compat only (unused) reasoning_effort: str | None = None, + max_tokens_input: int = 128_000, + compression_start_ratio: float = 0.7, + compression_target_ratio: float = 0.4, brave_api_key: str | None = None, web_proxy: str | None = None, exec_config: ExecToolConfig | None = None, @@ -74,9 +82,18 @@ class AgentLoop: self.model = model or provider.get_default_model() self.max_iterations = max_iterations self.temperature = temperature + # max_tokens: per-call output token cap (maxTokensOutput in config) self.max_tokens = max_tokens + # Keep legacy attribute for older call sites/tests; compression no longer uses it. self.memory_window = memory_window self.reasoning_effort = reasoning_effort + # max_tokens_input: model native context window (maxTokensInput in config) + self.max_tokens_input = max_tokens_input + # Token-based compression watermarks (fractions of available input budget) + self.compression_start_ratio = compression_start_ratio + self.compression_target_ratio = compression_target_ratio + # Reserve tokens for safety margin + self._reserve_tokens = 1000 self.brave_api_key = brave_api_key self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -105,18 +122,373 @@ class AgentLoop: self._mcp_stack: AsyncExitStack | None = None self._mcp_connected = False self._mcp_connecting = False - self._consolidating: set[str] = set() # Session keys with consolidation in progress - self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks - self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks + self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task self._processing_lock = asyncio.Lock() self._register_default_tools() + @staticmethod + def _estimate_prompt_tokens( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + ) -> int: + """Estimate prompt tokens with tiktoken (fallback only).""" + if tiktoken is None: + return 0 + + try: + enc = tiktoken.get_encoding("cl100k_base") + parts: list[str] = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + txt = part.get("text", "") + if txt: + parts.append(txt) + if tools: + parts.append(json.dumps(tools, ensure_ascii=False)) + return len(enc.encode("\n".join(parts))) + except Exception: + return 0 + + def _estimate_prompt_tokens_chain( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + ) -> tuple[int, str]: + """Unified prompt-token estimation: provider counter -> tiktoken.""" + provider_counter = getattr(self.provider, "estimate_prompt_tokens", None) + if callable(provider_counter): + try: + tokens, source = provider_counter(messages, tools, self.model) + if isinstance(tokens, (int, float)) and tokens > 0: + return int(tokens), str(source or "provider_counter") + except Exception: + logger.debug("Provider token counter failed; fallback to tiktoken") + + estimated = self._estimate_prompt_tokens(messages, tools) + if estimated > 0: + return int(estimated), "tiktoken" + return 0, "none" + + @staticmethod + def _estimate_completion_tokens(content: str) -> int: + """Estimate completion tokens with tiktoken (fallback only).""" + if tiktoken is None: + return 0 + try: + enc = tiktoken.get_encoding("cl100k_base") + return len(enc.encode(content or "")) + except Exception: + return 0 + + def _get_compressed_until(self, session: Session) -> int: + """Read/normalize compressed boundary and migrate old metadata format.""" + raw = session.metadata.get("_compressed_until", 0) + try: + compressed_until = int(raw) + except (TypeError, ValueError): + compressed_until = 0 + + if compressed_until <= 0: + ranges = session.metadata.get("_compressed_ranges") + if isinstance(ranges, list): + inferred = 0 + for item in ranges: + if not isinstance(item, (list, tuple)) or len(item) != 2: + continue + try: + inferred = max(inferred, int(item[1])) + except (TypeError, ValueError): + continue + compressed_until = inferred + + compressed_until = max(0, min(compressed_until, len(session.messages))) + session.metadata["_compressed_until"] = compressed_until + # 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段 + session.metadata.pop("_compressed_ranges", None) + session.metadata.pop("_cumulative_tokens", None) + return compressed_until + + def _set_compressed_until(self, session: Session, idx: int) -> None: + """Persist a contiguous compressed boundary.""" + session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages))) + session.metadata.pop("_compressed_ranges", None) + session.metadata.pop("_cumulative_tokens", None) + + @staticmethod + def _estimate_message_tokens(message: dict[str, Any]) -> int: + """Rough token estimate for a single persisted message.""" + content = message.get("content") + parts: list[str] = [] + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + txt = part.get("text", "") + if txt: + parts.append(txt) + else: + parts.append(json.dumps(part, ensure_ascii=False)) + elif content is not None: + parts.append(json.dumps(content, ensure_ascii=False)) + + for key in ("name", "tool_call_id"): + val = message.get(key) + if isinstance(val, str) and val: + parts.append(val) + if message.get("tool_calls"): + parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) + + payload = "\n".join(parts) + if not payload: + return 1 + if tiktoken is not None: + try: + enc = tiktoken.get_encoding("cl100k_base") + return max(1, len(enc.encode(payload))) + except Exception: + pass + return max(1, len(payload) // 4) + + def _pick_compression_chunk_by_tokens( + self, + session: Session, + reduction_tokens: int, + *, + tail_keep: int = 12, + ) -> tuple[int, int, int] | None: + """ + Pick one contiguous old chunk so its estimated size is roughly enough + to reduce `reduction_tokens`. + """ + messages = session.messages + start = self._get_compressed_until(session) + if len(messages) - start <= tail_keep + 2: + return None + + end_limit = len(messages) - tail_keep + if end_limit - start < 2: + return None + + target = max(1, reduction_tokens) + end = start + collected = 0 + while end < end_limit and collected < target: + collected += self._estimate_message_tokens(messages[end]) + end += 1 + + if end - start < 2: + end = min(end_limit, start + 2) + collected = sum(self._estimate_message_tokens(m) for m in messages[start:end]) + if end - start < 2: + return None + return start, end, collected + + def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]: + """ + Estimate current full prompt tokens for this session view + (system + compressed history view + runtime/user placeholder + tools). + """ + history = self._build_compressed_history_view(session) + channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) + probe_messages = self.context.build_messages( + history=history, + current_message="[token-probe]", + channel=channel, + chat_id=chat_id, + ) + return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions()) + + async def _maybe_compress_history( + self, + session: Session, + ) -> None: + """ + End-of-turn policy: + - Estimate current prompt usage from persisted session view. + - If above start ratio, perform one best-effort compression chunk. + """ + if not session.messages: + self._set_compressed_until(session, 0) + return + + budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens) + start_threshold = int(budget * self.compression_start_ratio) + target_threshold = int(budget * self.compression_target_ratio) + if target_threshold >= start_threshold: + target_threshold = max(0, start_threshold - 1) + + current_tokens, token_source = self._estimate_session_prompt_tokens(session) + current_ratio = current_tokens / budget if budget else 0.0 + if current_tokens <= 0: + logger.debug("Compression skip {}: token estimate unavailable", session.key) + return + if current_tokens < start_threshold: + logger.debug( + "Compression idle {}: {}/{} ({:.1%}) via {}", + session.key, + current_tokens, + budget, + current_ratio, + token_source, + ) + return + logger.info( + "Compression trigger {}: {}/{} ({:.1%}) via {}", + session.key, + current_tokens, + budget, + current_ratio, + token_source, + ) + + reduction_by_target = max(0, current_tokens - target_threshold) + reduction_by_delta = max(1, start_threshold - target_threshold) + reduction_need = max(reduction_by_target, reduction_by_delta) + + chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10) + if chunk_range is None: + logger.info("Compression skipped for {}: no compressible chunk", session.key) + return + + start_idx, end_idx, estimated_chunk_tokens = chunk_range + chunk = session.messages[start_idx:end_idx] + if len(chunk) < 2: + return + + logger.info( + "Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})", + session.key, + start_idx, + end_idx - 1, + len(chunk), + estimated_chunk_tokens, + reduction_need, + ) + success, _ = await self.context.memory.consolidate_chunk( + chunk, + self.provider, + self.model, + ) + if not success: + logger.warning("Compression aborted for {}: consolidation failed", session.key) + return + + self._set_compressed_until(session, end_idx) + self.sessions.save(session) + + after_tokens, after_source = self._estimate_session_prompt_tokens(session) + after_ratio = after_tokens / budget if budget else 0.0 + reduced = max(0, current_tokens - after_tokens) + reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0 + logger.info( + "Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})", + session.key, + after_tokens, + budget, + after_ratio, + after_source, + reduced, + reduced_ratio, + ) + + def _schedule_background_compression(self, session_key: str) -> None: + """Schedule best-effort background compression for a session.""" + existing = self._compression_tasks.get(session_key) + if existing is not None and not existing.done(): + return + + async def _runner() -> None: + session = self.sessions.get_or_create(session_key) + try: + await self._maybe_compress_history(session) + except Exception: + logger.exception("Background compression failed for {}", session_key) + + task = asyncio.create_task(_runner()) + self._compression_tasks[session_key] = task + + def _cleanup(t: asyncio.Task) -> None: + cur = self._compression_tasks.get(session_key) + if cur is t: + self._compression_tasks.pop(session_key, None) + try: + t.result() + except BaseException: + pass + + task.add_done_callback(_cleanup) + + async def wait_for_background_compression(self, timeout_s: float | None = None) -> None: + """Wait for currently scheduled compression tasks.""" + pending = [t for t in self._compression_tasks.values() if not t.done()] + if not pending: + return + + logger.info("Waiting for {} background compression task(s)", len(pending)) + waiter = asyncio.gather(*pending, return_exceptions=True) + if timeout_s is None: + await waiter + return + + try: + await asyncio.wait_for(waiter, timeout=timeout_s) + except asyncio.TimeoutError: + logger.warning( + "Background compression wait timed out after {}s ({} task(s) still running)", + timeout_s, + len([t for t in self._compression_tasks.values() if not t.done()]), + ) + + def _build_compressed_history_view( + self, + session: Session, + ) -> list[dict]: + """Build non-destructive history view using the compressed boundary.""" + compressed_until = self._get_compressed_until(session) + if compressed_until <= 0: + return session.get_history(max_messages=0) + + notice_msg: dict[str, Any] = { + "role": "assistant", + "content": ( + "As your assistant, I have compressed earlier context. " + "If you need details, please check memory/HISTORY.md." + ), + } + + tail: list[dict[str, Any]] = [] + for msg in session.messages[compressed_until:]: + entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")} + for k in ("tool_calls", "tool_call_id", "name"): + if k in msg: + entry[k] = msg[k] + tail.append(entry) + + # Drop leading non-user entries from tail to avoid orphan tool blocks. + for i, m in enumerate(tail): + if m.get("role") == "user": + tail = tail[i:] + break + else: + tail = [] + + return [notice_msg, *tail] + def _register_default_tools(self) -> None: """Register the default set of tools.""" allowed_dir = self.workspace if self.restrict_to_workspace else None for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) + self.tools.register(ValidateDeployJSONTool()) + self.tools.register(ValidateUsageYAMLTool()) + self.tools.register(HuggingFaceModelSearchTool()) self.tools.register(ExecTool( working_dir=str(self.workspace), timeout=self.exec_config.timeout, @@ -181,25 +553,78 @@ class AgentLoop: self, initial_messages: list[dict], on_progress: Callable[..., Awaitable[None]] | None = None, - ) -> tuple[str | None, list[str], list[dict]]: - """Run the agent iteration loop. Returns (final_content, tools_used, messages).""" + ) -> tuple[str | None, list[str], list[dict], int, str]: + """ + Run the agent iteration loop. + + Returns: + (final_content, tools_used, messages, total_tokens_this_turn, token_source) + total_tokens_this_turn: total tokens (prompt + completion) for this turn + token_source: provider_total / provider_sum / provider_prompt / + provider_counter+tiktoken_completion / tiktoken / none + """ messages = initial_messages iteration = 0 final_content = None tools_used: list[str] = [] + total_tokens_this_turn = 0 + token_source = "none" while iteration < self.max_iterations: iteration += 1 + tool_defs = self.tools.get_definitions() + response = await self.provider.chat( messages=messages, - tools=self.tools.get_definitions(), + tools=tool_defs, model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, reasoning_effort=self.reasoning_effort, ) + # Prefer provider usage from the turn-ending model call; fallback to tiktoken. + # Calculate total tokens (prompt + completion) for this turn. + usage = response.usage or {} + t_tokens = usage.get("total_tokens") + p_tokens = usage.get("prompt_tokens") + c_tokens = usage.get("completion_tokens") + + if isinstance(t_tokens, (int, float)) and t_tokens > 0: + total_tokens_this_turn = int(t_tokens) + token_source = "provider_total" + elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)): + # If we have both prompt and completion tokens, sum them + total_tokens_this_turn = int(p_tokens) + int(c_tokens) + token_source = "provider_sum" + elif isinstance(p_tokens, (int, float)) and p_tokens > 0: + # Fallback: use prompt tokens only (completion might be 0 for tool calls) + total_tokens_this_turn = int(p_tokens) + token_source = "provider_prompt" + else: + # Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken. + estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs) + estimated_completion = self._estimate_completion_tokens(response.content or "") + total_tokens_this_turn = estimated_prompt + estimated_completion + if total_tokens_this_turn > 0: + token_source = ( + "tiktoken" + if prompt_source == "tiktoken" + else f"{prompt_source}+tiktoken_completion" + ) + if total_tokens_this_turn <= 0: + total_tokens_this_turn = 0 + token_source = "none" + + logger.debug( + "Turn token usage: source={}, total={}, prompt={}, completion={}", + token_source, + total_tokens_this_turn, + p_tokens if isinstance(p_tokens, (int, float)) else None, + c_tokens if isinstance(c_tokens, (int, float)) else None, + ) + if response.has_tool_calls: if on_progress: thought = self._strip_think(response.content) @@ -254,7 +679,7 @@ class AgentLoop: "without completing the task. You can try breaking the task into smaller steps." ) - return final_content, tools_used, messages + return final_content, tools_used, messages, total_tokens_this_turn, token_source async def run(self) -> None: """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" @@ -279,6 +704,9 @@ class AgentLoop: """Cancel all active tasks and subagents for the session.""" tasks = self._active_tasks.pop(msg.session_key, []) cancelled = sum(1 for t in tasks if not t.done() and t.cancel()) + comp = self._compression_tasks.get(msg.session_key) + if comp is not None and not comp.done() and comp.cancel(): + cancelled += 1 for t in tasks: try: await t @@ -325,6 +753,9 @@ class AgentLoop: def stop(self) -> None: """Stop the agent loop.""" self._running = False + for task in list(self._compression_tasks.values()): + if not task.done(): + task.cancel() logger.info("Agent loop stopping") async def _process_message( @@ -342,14 +773,15 @@ class AgentLoop: key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) - history = session.get_history(max_messages=self.memory_window) + history = self._build_compressed_history_view(session) messages = self.context.build_messages( history=history, current_message=msg.content, channel=channel, chat_id=chat_id, ) - final_content, _, all_msgs = await self._run_agent_loop(messages) + final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + self._schedule_background_compression(session.key) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -362,27 +794,27 @@ class AgentLoop: # Slash commands cmd = msg.content.strip().lower() if cmd == "/new": - lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) - self._consolidating.add(session.key) try: - async with lock: - snapshot = session.messages[session.last_consolidated:] - if snapshot: - temp = Session(key=session.key) - temp.messages = list(snapshot) - if not await self._consolidate_memory(temp, archive_all=True): - return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="Memory archival failed, session not cleared. Please try again.", - ) + # 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中 + if session.messages: + ok, _ = await self.context.memory.consolidate_chunk( + session.messages, + self.provider, + self.model, + ) + if not ok: + return OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="Memory archival failed, session not cleared. Please try again.", + ) except Exception: logger.exception("/new archival failed for {}", session.key) return OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, + channel=msg.channel, + chat_id=msg.chat_id, content="Memory archival failed, session not cleared. Please try again.", ) - finally: - self._consolidating.discard(session.key) session.clear() self.sessions.save(session) @@ -393,36 +825,23 @@ class AgentLoop: return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands") - unconsolidated = len(session.messages) - session.last_consolidated - if (unconsolidated >= self.memory_window and session.key not in self._consolidating): - self._consolidating.add(session.key) - lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock()) - - async def _consolidate_and_unlock(): - try: - async with lock: - await self._consolidate_memory(session) - finally: - self._consolidating.discard(session.key) - _task = asyncio.current_task() - if _task is not None: - self._consolidation_tasks.discard(_task) - - _task = asyncio.create_task(_consolidate_and_unlock()) - self._consolidation_tasks.add(_task) - self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): if isinstance(message_tool, MessageTool): message_tool.start_turn() - history = session.get_history(max_messages=self.memory_window) + # 正常对话:使用压缩后的历史视图(压缩在回合结束后进行) + history = self._build_compressed_history_view(session) initial_messages = self.context.build_messages( history=history, current_message=msg.content, media=msg.media if msg.media else None, channel=msg.channel, chat_id=msg.chat_id, ) + # Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:") + if session_key and session_key.startswith("cron:"): + if initial_messages and initial_messages[0].get("role") == "system": + initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}" async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: meta = dict(msg.metadata or {}) @@ -432,7 +851,7 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, )) - final_content, _, all_msgs = await self._run_agent_loop( + final_content, _, all_msgs, _, _ = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, ) @@ -441,6 +860,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) + self._schedule_background_compression(session.key) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None @@ -487,13 +907,6 @@ class AgentLoop: session.messages.append(entry) session.updated_at = datetime.now() - async def _consolidate_memory(self, session, archive_all: bool = False) -> bool: - """Delegate to MemoryStore.consolidate(). Returns True on success.""" - return await MemoryStore(self.workspace).consolidate( - session, self.provider, self.model, - archive_all=archive_all, memory_window=self.memory_window, - ) - async def process_direct( self, content: str, diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 21fe77d..c8896c8 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -66,36 +66,25 @@ class MemoryStore: long_term = self.read_long_term() return f"## Long-term Memory\n{long_term}" if long_term else "" - async def consolidate( + async def consolidate_chunk( self, - session: Session, + messages: list[dict], provider: LLMProvider, model: str, - *, - archive_all: bool = False, - memory_window: int = 50, - ) -> bool: - """Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call. + ) -> tuple[bool, str | None]: + """Consolidate a chunk of messages into MEMORY.md + HISTORY.md via LLM tool call. - Returns True on success (including no-op), False on failure. + Returns (success, None). + + - success: True on success (including no-op), False on failure. + - The second return value is reserved for future use (e.g. RAG-style summaries) and is + always None in the current implementation. """ - if archive_all: - old_messages = session.messages - keep_count = 0 - logger.info("Memory consolidation (archive_all): {} messages", len(session.messages)) - else: - keep_count = memory_window // 2 - if len(session.messages) <= keep_count: - return True - if len(session.messages) - session.last_consolidated <= 0: - return True - old_messages = session.messages[session.last_consolidated:-keep_count] - if not old_messages: - return True - logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count) + if not messages: + return True, None lines = [] - for m in old_messages: + for m in messages: if not m.get("content"): continue tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" @@ -113,7 +102,19 @@ class MemoryStore: try: response = await provider.chat( messages=[ - {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, + { + "role": "system", + "content": ( + "You are a memory consolidation agent.\n" + "Your job is to:\n" + "1) Append a concise but grep-friendly entry to HISTORY.md summarizing key events, decisions and topics.\n" + " - Write 1 paragraph of 2–5 sentences that starts with [YYYY-MM-DD HH:MM].\n" + " - Include concrete names, IDs and numbers so it is easy to search with grep.\n" + "2) Update long-term MEMORY.md with stable facts and user preferences as markdown, including all existing facts plus new ones.\n" + "3) Optionally return a short context_summary (1–3 sentences) that will replace the raw messages in future dialogue history.\n\n" + "Always call the save_memory tool with history_entry, memory_update and (optionally) context_summary." + ), + }, {"role": "user", "content": prompt}, ], tools=_SAVE_MEMORY_TOOL, @@ -122,7 +123,7 @@ class MemoryStore: if not response.has_tool_calls: logger.warning("Memory consolidation: LLM did not call save_memory, skipping") - return False + return False, None args = response.tool_calls[0].arguments # Some providers return arguments as a JSON string instead of dict @@ -134,10 +135,10 @@ class MemoryStore: args = args[0] else: logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list") - return False + return False, None if not isinstance(args, dict): logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) - return False + return False, None if entry := args.get("history_entry"): if not isinstance(entry, str): @@ -149,9 +150,8 @@ class MemoryStore: if update != current_memory: self.write_long_term(update) - session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count - logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated) - return True + logger.info("Memory consolidation done for {} messages", len(messages)) + return True, None except Exception: logger.exception("Memory consolidation failed") - return False + return False, None diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 803cb61..1ebde20 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -189,11 +189,22 @@ class SlackConfig(Base): class QQConfig(Base): - """QQ channel configuration using botpy SDK.""" + """QQ channel configuration. + + Supports two implementations: + 1. Official botpy SDK: requires app_id and secret + 2. OneBot protocol: requires api_url (and optionally ws_reverse_url, bot_qq, access_token) + """ enabled: bool = False + # Official botpy SDK fields app_id: str = "" # 机器人 ID (AppID) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com + # OneBot protocol fields + api_url: str = "" # OneBot HTTP API URL (e.g. "http://localhost:5700") + ws_reverse_url: str = "" # OneBot WebSocket reverse URL (e.g. "ws://localhost:8080/ws/reverse") + bot_qq: int | None = None # Bot's QQ number (for filtering self messages) + access_token: str = "" # Optional access token for OneBot API allow_from: list[str] = Field( default_factory=list ) # Allowed user openids (empty = public access) @@ -226,10 +237,18 @@ class AgentDefaults(Base): provider: str = ( "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection ) - max_tokens: int = 8192 + # 原生上下文最大窗口(通常对应模型的 max_input_tokens / max_context_tokens) + # 默认按照主流大模型(如 GPT-4o、Claude 3.x 等)的 128k 上下文给一个宽松上限,实际应根据所选模型文档手动调整。 + max_tokens_input: int = 128_000 + # 默认单次回复的最大输出 token 上限(调用时可按需要再做截断或比例分配) + # 8192 足以覆盖大多数实际对话/工具使用场景,同样可按需手动调整。 + max_tokens_output: int = 8192 + # 会话历史压缩触发比例:当估算的输入 token 使用量 >= maxTokensInput * compressionStartRatio 时开始压缩。 + compression_start_ratio: float = 0.7 + # 会话历史压缩目标比例:每轮压缩后尽量把估算的输入 token 使用量压到 maxTokensInput * compressionTargetRatio 附近。 + compression_target_ratio: float = 0.4 temperature: float = 0.1 max_tool_iterations: int = 40 - memory_window: int = 100 reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index f0a6484..1cb8a51 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -9,7 +9,6 @@ from typing import Any from loguru import logger -from nanobot.config.paths import get_legacy_sessions_dir from nanobot.utils.helpers import ensure_dir, safe_filename @@ -30,7 +29,6 @@ class Session: created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) - last_consolidated: int = 0 # Number of messages already consolidated to files def add_message(self, role: str, content: str, **kwargs: Any) -> None: """Add a message to the session.""" @@ -44,9 +42,13 @@ class Session: self.updated_at = datetime.now() def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: - """Return unconsolidated messages for LLM input, aligned to a user turn.""" - unconsolidated = self.messages[self.last_consolidated:] - sliced = unconsolidated[-max_messages:] + """ + Return messages for LLM input, aligned to a user turn. + + - max_messages > 0 时只保留最近 max_messages 条; + - max_messages <= 0 时不做条数截断,返回全部消息。 + """ + sliced = self.messages if max_messages <= 0 else self.messages[-max_messages:] # Drop leading non-user messages to avoid orphaned tool_result blocks for i, m in enumerate(sliced): @@ -66,7 +68,7 @@ class Session: def clear(self) -> None: """Clear all messages and reset session to initial state.""" self.messages = [] - self.last_consolidated = 0 + self.metadata = {} self.updated_at = datetime.now() @@ -80,7 +82,7 @@ class SessionManager: def __init__(self, workspace: Path): self.workspace = workspace self.sessions_dir = ensure_dir(self.workspace / "sessions") - self.legacy_sessions_dir = get_legacy_sessions_dir() + self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions" self._cache: dict[str, Session] = {} def _get_session_path(self, key: str) -> Path: @@ -132,7 +134,6 @@ class SessionManager: messages = [] metadata = {} created_at = None - last_consolidated = 0 with open(path, encoding="utf-8") as f: for line in f: @@ -145,7 +146,6 @@ class SessionManager: if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None - last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) @@ -154,7 +154,6 @@ class SessionManager: messages=messages, created_at=created_at or datetime.now(), metadata=metadata, - last_consolidated=last_consolidated ) except Exception as e: logger.warning("Failed to load session {}: {}", key, e) @@ -171,7 +170,6 @@ class SessionManager: "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), "metadata": session.metadata, - "last_consolidated": session.last_consolidated } f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n") for msg in session.messages: From 2dcb4de422ddec8c0f114dc6b0fdce06b9388b8f Mon Sep 17 00:00:00 2001 From: VITOHJL Date: Sun, 8 Mar 2026 15:04:38 +0800 Subject: [PATCH 2/7] fix(commands): update AgentLoop calls to use token-based compression parameters --- nanobot/cli/commands.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 2c8d6d3..cf29cc5 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -330,8 +330,10 @@ def gateway( temperature=config.agents.defaults.temperature, max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, - memory_window=config.agents.defaults.memory_window, reasoning_effort=config.agents.defaults.reasoning_effort, + max_tokens_input=config.agents.defaults.max_tokens_input, + compression_start_ratio=config.agents.defaults.compression_start_ratio, + compression_target_ratio=config.agents.defaults.compression_target_ratio, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, @@ -515,8 +517,10 @@ def agent( temperature=config.agents.defaults.temperature, max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, - memory_window=config.agents.defaults.memory_window, reasoning_effort=config.agents.defaults.reasoning_effort, + max_tokens_input=config.agents.defaults.max_tokens_input, + compression_start_ratio=config.agents.defaults.compression_start_ratio, + compression_target_ratio=config.agents.defaults.compression_target_ratio, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, From 2706d3c317be7325795e9dac74d07512e57112f4 Mon Sep 17 00:00:00 2001 From: VITOHJL Date: Sun, 8 Mar 2026 15:20:34 +0800 Subject: [PATCH 3/7] fix(commands): use max_tokens_output instead of max_tokens from AgentDefaults --- nanobot/cli/commands.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index cf29cc5..18c9d56 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -328,7 +328,7 @@ def gateway( workspace=config.workspace_path, model=config.agents.defaults.model, temperature=config.agents.defaults.temperature, - max_tokens=config.agents.defaults.max_tokens, + max_tokens=config.agents.defaults.max_tokens_output, max_iterations=config.agents.defaults.max_tool_iterations, reasoning_effort=config.agents.defaults.reasoning_effort, max_tokens_input=config.agents.defaults.max_tokens_input, @@ -515,7 +515,7 @@ def agent( workspace=config.workspace_path, model=config.agents.defaults.model, temperature=config.agents.defaults.temperature, - max_tokens=config.agents.defaults.max_tokens, + max_tokens=config.agents.defaults.max_tokens_output, max_iterations=config.agents.defaults.max_tool_iterations, reasoning_effort=config.agents.defaults.reasoning_effort, max_tokens_input=config.agents.defaults.max_tokens_input, From a984e0df3752f6a8883a0e9b6d8efee4abd7f9dd Mon Sep 17 00:00:00 2001 From: VITOHJL Date: Sun, 8 Mar 2026 15:23:55 +0800 Subject: [PATCH 4/7] feat(loop): add history message count logging in compression --- nanobot/agent/loop.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 696e2a7..5d316ea 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -362,6 +362,7 @@ class AgentLoop: if len(chunk) < 2: return + before_msg_count = len(session.messages) logger.info( "Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})", session.key, @@ -383,12 +384,13 @@ class AgentLoop: self._set_compressed_until(session, end_idx) self.sessions.save(session) + after_msg_count = len(session.messages) after_tokens, after_source = self._estimate_session_prompt_tokens(session) after_ratio = after_tokens / budget if budget else 0.0 reduced = max(0, current_tokens - after_tokens) reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0 logger.info( - "Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})", + "Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%}), history: {} -> {}", session.key, after_tokens, budget, @@ -396,6 +398,8 @@ class AgentLoop: after_source, reduced, reduced_ratio, + before_msg_count, + after_msg_count, ) def _schedule_background_compression(self, session_key: str) -> None: From 1b16d48390b3fded3438f4fdbc3f0ae0a0379878 Mon Sep 17 00:00:00 2001 From: VITOHJL Date: Sun, 8 Mar 2026 15:26:49 +0800 Subject: [PATCH 5/7] fix(loop): update _cumulative_tokens in _save_turn and preserve it in compression methods --- nanobot/agent/loop.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 5d316ea..5e01b79 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -211,14 +211,14 @@ class AgentLoop: session.metadata["_compressed_until"] = compressed_until # 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段 session.metadata.pop("_compressed_ranges", None) - session.metadata.pop("_cumulative_tokens", None) + # 注意:不要删除 _cumulative_tokens,压缩逻辑需要它来跟踪累积 token 计数 return compressed_until def _set_compressed_until(self, session: Session, idx: int) -> None: """Persist a contiguous compressed boundary.""" session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages))) session.metadata.pop("_compressed_ranges", None) - session.metadata.pop("_cumulative_tokens", None) + # 注意:不要删除 _cumulative_tokens,压缩逻辑需要它来跟踪累积 token 计数 @staticmethod def _estimate_message_tokens(message: dict[str, Any]) -> int: @@ -362,7 +362,6 @@ class AgentLoop: if len(chunk) < 2: return - before_msg_count = len(session.messages) logger.info( "Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})", session.key, @@ -384,13 +383,12 @@ class AgentLoop: self._set_compressed_until(session, end_idx) self.sessions.save(session) - after_msg_count = len(session.messages) after_tokens, after_source = self._estimate_session_prompt_tokens(session) after_ratio = after_tokens / budget if budget else 0.0 reduced = max(0, current_tokens - after_tokens) reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0 logger.info( - "Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%}), history: {} -> {}", + "Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})", session.key, after_tokens, budget, @@ -398,8 +396,6 @@ class AgentLoop: after_source, reduced, reduced_ratio, - before_msg_count, - after_msg_count, ) def _schedule_background_compression(self, session_key: str) -> None: @@ -855,14 +851,14 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, )) - final_content, _, all_msgs, _, _ = await self._run_agent_loop( + final_content, _, all_msgs, total_tokens_this_turn, token_source = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, ) if final_content is None: final_content = "I've completed processing but have no response to give." - self._save_turn(session, all_msgs, 1 + len(history)) + self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn) self.sessions.save(session) self._schedule_background_compression(session.key) @@ -876,7 +872,7 @@ class AgentLoop: metadata=msg.metadata or {}, ) - def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: + def _save_turn(self, session: Session, messages: list[dict], skip: int, total_tokens_this_turn: int = 0) -> None: """Save new-turn messages into session, truncating large tool results.""" from datetime import datetime for m in messages[skip:]: @@ -910,6 +906,14 @@ class AgentLoop: entry.setdefault("timestamp", datetime.now().isoformat()) session.messages.append(entry) session.updated_at = datetime.now() + + # Update cumulative token count for compression tracking + if total_tokens_this_turn > 0: + current_cumulative = session.metadata.get("_cumulative_tokens", 0) + if isinstance(current_cumulative, (int, float)): + session.metadata["_cumulative_tokens"] = int(current_cumulative) + total_tokens_this_turn + else: + session.metadata["_cumulative_tokens"] = total_tokens_this_turn async def process_direct( self, From 274edc5451c1d0f79eda80c76127f497ec6923e9 Mon Sep 17 00:00:00 2001 From: VITOHJL Date: Sun, 8 Mar 2026 17:25:59 +0800 Subject: [PATCH 6/7] fix(compression): prefer provider prompt token usage --- nanobot/agent/loop.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 5e01b79..4f6a051 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -124,6 +124,8 @@ class AgentLoop: self._mcp_connecting = False self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task + self._last_turn_prompt_tokens: int = 0 + self._last_turn_prompt_source: str = "none" self._processing_lock = asyncio.Lock() self._register_default_tools() @@ -324,7 +326,15 @@ class AgentLoop: if target_threshold >= start_threshold: target_threshold = max(0, start_threshold - 1) - current_tokens, token_source = self._estimate_session_prompt_tokens(session) + # Prefer provider usage prompt tokens from the turn-ending call. + # If unavailable, fall back to estimator chain. + raw_prompt_tokens = session.metadata.get("_last_prompt_tokens") + if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0: + current_tokens = int(raw_prompt_tokens) + token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt") + else: + current_tokens, token_source = self._estimate_session_prompt_tokens(session) + current_ratio = current_tokens / budget if budget else 0.0 if current_tokens <= 0: logger.debug("Compression skip {}: token estimate unavailable", session.key) @@ -569,6 +579,8 @@ class AgentLoop: tools_used: list[str] = [] total_tokens_this_turn = 0 token_source = "none" + self._last_turn_prompt_tokens = 0 + self._last_turn_prompt_source = "none" while iteration < self.max_iterations: iteration += 1 @@ -594,19 +606,35 @@ class AgentLoop: if isinstance(t_tokens, (int, float)) and t_tokens > 0: total_tokens_this_turn = int(t_tokens) token_source = "provider_total" + if isinstance(p_tokens, (int, float)) and p_tokens > 0: + self._last_turn_prompt_tokens = int(p_tokens) + self._last_turn_prompt_source = "usage_prompt" + elif isinstance(c_tokens, (int, float)): + prompt_derived = int(t_tokens) - int(c_tokens) + if prompt_derived > 0: + self._last_turn_prompt_tokens = prompt_derived + self._last_turn_prompt_source = "usage_total_minus_completion" elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)): # If we have both prompt and completion tokens, sum them total_tokens_this_turn = int(p_tokens) + int(c_tokens) token_source = "provider_sum" + if p_tokens > 0: + self._last_turn_prompt_tokens = int(p_tokens) + self._last_turn_prompt_source = "usage_prompt" elif isinstance(p_tokens, (int, float)) and p_tokens > 0: # Fallback: use prompt tokens only (completion might be 0 for tool calls) total_tokens_this_turn = int(p_tokens) token_source = "provider_prompt" + self._last_turn_prompt_tokens = int(p_tokens) + self._last_turn_prompt_source = "usage_prompt" else: # Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken. estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs) estimated_completion = self._estimate_completion_tokens(response.content or "") total_tokens_this_turn = estimated_prompt + estimated_completion + if estimated_prompt > 0: + self._last_turn_prompt_tokens = int(estimated_prompt) + self._last_turn_prompt_source = str(prompt_source or "tiktoken") if total_tokens_this_turn > 0: token_source = ( "tiktoken" @@ -779,6 +807,12 @@ class AgentLoop: current_message=msg.content, channel=channel, chat_id=chat_id, ) final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages) + if self._last_turn_prompt_tokens > 0: + session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens + session.metadata["_last_prompt_source"] = self._last_turn_prompt_source + else: + session.metadata.pop("_last_prompt_tokens", None) + session.metadata.pop("_last_prompt_source", None) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) self._schedule_background_compression(session.key) @@ -858,6 +892,13 @@ class AgentLoop: if final_content is None: final_content = "I've completed processing but have no response to give." + if self._last_turn_prompt_tokens > 0: + session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens + session.metadata["_last_prompt_source"] = self._last_turn_prompt_source + else: + session.metadata.pop("_last_prompt_tokens", None) + session.metadata.pop("_last_prompt_source", None) + self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn) self.sessions.save(session) self._schedule_background_compression(session.key) From 62ccda43b980d53c5ac7a79adf8edf43294f1fdb Mon Sep 17 00:00:00 2001 From: Re-bin Date: Tue, 10 Mar 2026 19:55:06 +0000 Subject: [PATCH 7/7] refactor(memory): switch consolidation to token-based context windows Move consolidation policy into MemoryConsolidator, keep backward compatibility for legacy config, and compress history by token budget instead of message count. --- nanobot/agent/loop.py | 544 ++--------------------- nanobot/agent/memory.py | 243 +++++++--- nanobot/cli/commands.py | 26 +- nanobot/config/schema.py | 32 +- nanobot/session/manager.py | 20 +- nanobot/utils/helpers.py | 85 ++++ pyproject.toml | 1 + tests/test_commands.py | 33 ++ tests/test_config_migration.py | 88 ++++ tests/test_consolidate_offset.py | 297 ++----------- tests/test_loop_consolidation_tokens.py | 190 ++++++++ tests/test_memory_consolidation_types.py | 51 +-- tests/test_message_tool_suppress.py | 10 +- 13 files changed, 709 insertions(+), 911 deletions(-) create mode 100644 tests/test_config_migration.py create mode 100644 tests/test_loop_consolidation_tokens.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index ba35a23..8605a09 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -11,18 +11,12 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger -try: - import tiktoken # type: ignore -except Exception: # pragma: no cover - optional dependency - tiktoken = None - from nanobot.agent.context import ContextBuilder +from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool -from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool from nanobot.agent.tools.message import MessageTool -from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.shell import ExecTool from nanobot.agent.tools.spawn import SpawnTool @@ -60,11 +54,8 @@ class AgentLoop: max_iterations: int = 40, temperature: float = 0.1, max_tokens: int = 4096, - memory_window: int | None = None, # backward-compat only (unused) reasoning_effort: str | None = None, - max_tokens_input: int = 128_000, - compression_start_ratio: float = 0.7, - compression_target_ratio: float = 0.4, + context_window_tokens: int = 65_536, brave_api_key: str | None = None, web_proxy: str | None = None, exec_config: ExecToolConfig | None = None, @@ -82,18 +73,9 @@ class AgentLoop: self.model = model or provider.get_default_model() self.max_iterations = max_iterations self.temperature = temperature - # max_tokens: per-call output token cap (maxTokensOutput in config) self.max_tokens = max_tokens - # Keep legacy attribute for older call sites/tests; compression no longer uses it. - self.memory_window = memory_window self.reasoning_effort = reasoning_effort - # max_tokens_input: model native context window (maxTokensInput in config) - self.max_tokens_input = max_tokens_input - # Token-based compression watermarks (fractions of available input budget) - self.compression_start_ratio = compression_start_ratio - self.compression_target_ratio = compression_target_ratio - # Reserve tokens for safety margin - self._reserve_tokens = 1000 + self.context_window_tokens = context_window_tokens self.brave_api_key = brave_api_key self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -123,382 +105,23 @@ class AgentLoop: self._mcp_connected = False self._mcp_connecting = False self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks - self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task - self._last_turn_prompt_tokens: int = 0 - self._last_turn_prompt_source: str = "none" self._processing_lock = asyncio.Lock() + self.memory_consolidator = MemoryConsolidator( + workspace=workspace, + provider=provider, + model=self.model, + sessions=self.sessions, + context_window_tokens=context_window_tokens, + build_messages=self.context.build_messages, + get_tool_definitions=self.tools.get_definitions, + ) self._register_default_tools() - @staticmethod - def _estimate_prompt_tokens( - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - ) -> int: - """Estimate prompt tokens with tiktoken (fallback only).""" - if tiktoken is None: - return 0 - - try: - enc = tiktoken.get_encoding("cl100k_base") - parts: list[str] = [] - for msg in messages: - content = msg.get("content") - if isinstance(content, str): - parts.append(content) - elif isinstance(content, list): - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - txt = part.get("text", "") - if txt: - parts.append(txt) - if tools: - parts.append(json.dumps(tools, ensure_ascii=False)) - return len(enc.encode("\n".join(parts))) - except Exception: - return 0 - - def _estimate_prompt_tokens_chain( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - ) -> tuple[int, str]: - """Unified prompt-token estimation: provider counter -> tiktoken.""" - provider_counter = getattr(self.provider, "estimate_prompt_tokens", None) - if callable(provider_counter): - try: - tokens, source = provider_counter(messages, tools, self.model) - if isinstance(tokens, (int, float)) and tokens > 0: - return int(tokens), str(source or "provider_counter") - except Exception: - logger.debug("Provider token counter failed; fallback to tiktoken") - - estimated = self._estimate_prompt_tokens(messages, tools) - if estimated > 0: - return int(estimated), "tiktoken" - return 0, "none" - - @staticmethod - def _estimate_completion_tokens(content: str) -> int: - """Estimate completion tokens with tiktoken (fallback only).""" - if tiktoken is None: - return 0 - try: - enc = tiktoken.get_encoding("cl100k_base") - return len(enc.encode(content or "")) - except Exception: - return 0 - - def _get_compressed_until(self, session: Session) -> int: - """Read/normalize compressed boundary and migrate old metadata format.""" - raw = session.metadata.get("_compressed_until", 0) - try: - compressed_until = int(raw) - except (TypeError, ValueError): - compressed_until = 0 - - if compressed_until <= 0: - ranges = session.metadata.get("_compressed_ranges") - if isinstance(ranges, list): - inferred = 0 - for item in ranges: - if not isinstance(item, (list, tuple)) or len(item) != 2: - continue - try: - inferred = max(inferred, int(item[1])) - except (TypeError, ValueError): - continue - compressed_until = inferred - - compressed_until = max(0, min(compressed_until, len(session.messages))) - session.metadata["_compressed_until"] = compressed_until - # 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段 - session.metadata.pop("_compressed_ranges", None) - # 注意:不要删除 _cumulative_tokens,压缩逻辑需要它来跟踪累积 token 计数 - return compressed_until - - def _set_compressed_until(self, session: Session, idx: int) -> None: - """Persist a contiguous compressed boundary.""" - session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages))) - session.metadata.pop("_compressed_ranges", None) - # 注意:不要删除 _cumulative_tokens,压缩逻辑需要它来跟踪累积 token 计数 - - @staticmethod - def _estimate_message_tokens(message: dict[str, Any]) -> int: - """Rough token estimate for a single persisted message.""" - content = message.get("content") - parts: list[str] = [] - if isinstance(content, str): - parts.append(content) - elif isinstance(content, list): - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - txt = part.get("text", "") - if txt: - parts.append(txt) - else: - parts.append(json.dumps(part, ensure_ascii=False)) - elif content is not None: - parts.append(json.dumps(content, ensure_ascii=False)) - - for key in ("name", "tool_call_id"): - val = message.get(key) - if isinstance(val, str) and val: - parts.append(val) - if message.get("tool_calls"): - parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) - - payload = "\n".join(parts) - if not payload: - return 1 - if tiktoken is not None: - try: - enc = tiktoken.get_encoding("cl100k_base") - return max(1, len(enc.encode(payload))) - except Exception: - pass - return max(1, len(payload) // 4) - - def _pick_compression_chunk_by_tokens( - self, - session: Session, - reduction_tokens: int, - *, - tail_keep: int = 12, - ) -> tuple[int, int, int] | None: - """ - Pick one contiguous old chunk so its estimated size is roughly enough - to reduce `reduction_tokens`. - """ - messages = session.messages - start = self._get_compressed_until(session) - if len(messages) - start <= tail_keep + 2: - return None - - end_limit = len(messages) - tail_keep - if end_limit - start < 2: - return None - - target = max(1, reduction_tokens) - end = start - collected = 0 - while end < end_limit and collected < target: - collected += self._estimate_message_tokens(messages[end]) - end += 1 - - if end - start < 2: - end = min(end_limit, start + 2) - collected = sum(self._estimate_message_tokens(m) for m in messages[start:end]) - if end - start < 2: - return None - return start, end, collected - - def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]: - """ - Estimate current full prompt tokens for this session view - (system + compressed history view + runtime/user placeholder + tools). - """ - history = self._build_compressed_history_view(session) - channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) - probe_messages = self.context.build_messages( - history=history, - current_message="[token-probe]", - channel=channel, - chat_id=chat_id, - ) - return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions()) - - async def _maybe_compress_history( - self, - session: Session, - ) -> None: - """ - End-of-turn policy: - - Estimate current prompt usage from persisted session view. - - If above start ratio, perform one best-effort compression chunk. - """ - if not session.messages: - self._set_compressed_until(session, 0) - return - - budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens) - start_threshold = int(budget * self.compression_start_ratio) - target_threshold = int(budget * self.compression_target_ratio) - if target_threshold >= start_threshold: - target_threshold = max(0, start_threshold - 1) - - # Prefer provider usage prompt tokens from the turn-ending call. - # If unavailable, fall back to estimator chain. - raw_prompt_tokens = session.metadata.get("_last_prompt_tokens") - if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0: - current_tokens = int(raw_prompt_tokens) - token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt") - else: - current_tokens, token_source = self._estimate_session_prompt_tokens(session) - - current_ratio = current_tokens / budget if budget else 0.0 - if current_tokens <= 0: - logger.debug("Compression skip {}: token estimate unavailable", session.key) - return - if current_tokens < start_threshold: - logger.debug( - "Compression idle {}: {}/{} ({:.1%}) via {}", - session.key, - current_tokens, - budget, - current_ratio, - token_source, - ) - return - logger.info( - "Compression trigger {}: {}/{} ({:.1%}) via {}", - session.key, - current_tokens, - budget, - current_ratio, - token_source, - ) - - reduction_by_target = max(0, current_tokens - target_threshold) - reduction_by_delta = max(1, start_threshold - target_threshold) - reduction_need = max(reduction_by_target, reduction_by_delta) - - chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10) - if chunk_range is None: - logger.info("Compression skipped for {}: no compressible chunk", session.key) - return - - start_idx, end_idx, estimated_chunk_tokens = chunk_range - chunk = session.messages[start_idx:end_idx] - if len(chunk) < 2: - return - - logger.info( - "Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})", - session.key, - start_idx, - end_idx - 1, - len(chunk), - estimated_chunk_tokens, - reduction_need, - ) - success, _ = await self.context.memory.consolidate_chunk( - chunk, - self.provider, - self.model, - ) - if not success: - logger.warning("Compression aborted for {}: consolidation failed", session.key) - return - - self._set_compressed_until(session, end_idx) - self.sessions.save(session) - - after_tokens, after_source = self._estimate_session_prompt_tokens(session) - after_ratio = after_tokens / budget if budget else 0.0 - reduced = max(0, current_tokens - after_tokens) - reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0 - logger.info( - "Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})", - session.key, - after_tokens, - budget, - after_ratio, - after_source, - reduced, - reduced_ratio, - ) - - def _schedule_background_compression(self, session_key: str) -> None: - """Schedule best-effort background compression for a session.""" - existing = self._compression_tasks.get(session_key) - if existing is not None and not existing.done(): - return - - async def _runner() -> None: - session = self.sessions.get_or_create(session_key) - try: - await self._maybe_compress_history(session) - except Exception: - logger.exception("Background compression failed for {}", session_key) - - task = asyncio.create_task(_runner()) - self._compression_tasks[session_key] = task - - def _cleanup(t: asyncio.Task) -> None: - cur = self._compression_tasks.get(session_key) - if cur is t: - self._compression_tasks.pop(session_key, None) - try: - t.result() - except BaseException: - pass - - task.add_done_callback(_cleanup) - - async def wait_for_background_compression(self, timeout_s: float | None = None) -> None: - """Wait for currently scheduled compression tasks.""" - pending = [t for t in self._compression_tasks.values() if not t.done()] - if not pending: - return - - logger.info("Waiting for {} background compression task(s)", len(pending)) - waiter = asyncio.gather(*pending, return_exceptions=True) - if timeout_s is None: - await waiter - return - - try: - await asyncio.wait_for(waiter, timeout=timeout_s) - except asyncio.TimeoutError: - logger.warning( - "Background compression wait timed out after {}s ({} task(s) still running)", - timeout_s, - len([t for t in self._compression_tasks.values() if not t.done()]), - ) - - def _build_compressed_history_view( - self, - session: Session, - ) -> list[dict]: - """Build non-destructive history view using the compressed boundary.""" - compressed_until = self._get_compressed_until(session) - if compressed_until <= 0: - return session.get_history(max_messages=0) - - notice_msg: dict[str, Any] = { - "role": "assistant", - "content": ( - "As your assistant, I have compressed earlier context. " - "If you need details, please check memory/HISTORY.md." - ), - } - - tail: list[dict[str, Any]] = [] - for msg in session.messages[compressed_until:]: - entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")} - for k in ("tool_calls", "tool_call_id", "name"): - if k in msg: - entry[k] = msg[k] - tail.append(entry) - - # Drop leading non-user entries from tail to avoid orphan tool blocks. - for i, m in enumerate(tail): - if m.get("role") == "user": - tail = tail[i:] - break - else: - tail = [] - - return [notice_msg, *tail] - def _register_default_tools(self) -> None: """Register the default set of tools.""" allowed_dir = self.workspace if self.restrict_to_workspace else None for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) - self.tools.register(ValidateDeployJSONTool()) - self.tools.register(ValidateUsageYAMLTool()) - self.tools.register(HuggingFaceModelSearchTool()) self.tools.register(ExecTool( working_dir=str(self.workspace), timeout=self.exec_config.timeout, @@ -563,24 +186,12 @@ class AgentLoop: self, initial_messages: list[dict], on_progress: Callable[..., Awaitable[None]] | None = None, - ) -> tuple[str | None, list[str], list[dict], int, str]: - """ - Run the agent iteration loop. - - Returns: - (final_content, tools_used, messages, total_tokens_this_turn, token_source) - total_tokens_this_turn: total tokens (prompt + completion) for this turn - token_source: provider_total / provider_sum / provider_prompt / - provider_counter+tiktoken_completion / tiktoken / none - """ + ) -> tuple[str | None, list[str], list[dict]]: + """Run the agent iteration loop.""" messages = initial_messages iteration = 0 final_content = None tools_used: list[str] = [] - total_tokens_this_turn = 0 - token_source = "none" - self._last_turn_prompt_tokens = 0 - self._last_turn_prompt_source = "none" while iteration < self.max_iterations: iteration += 1 @@ -596,63 +207,6 @@ class AgentLoop: reasoning_effort=self.reasoning_effort, ) - # Prefer provider usage from the turn-ending model call; fallback to tiktoken. - # Calculate total tokens (prompt + completion) for this turn. - usage = response.usage or {} - t_tokens = usage.get("total_tokens") - p_tokens = usage.get("prompt_tokens") - c_tokens = usage.get("completion_tokens") - - if isinstance(t_tokens, (int, float)) and t_tokens > 0: - total_tokens_this_turn = int(t_tokens) - token_source = "provider_total" - if isinstance(p_tokens, (int, float)) and p_tokens > 0: - self._last_turn_prompt_tokens = int(p_tokens) - self._last_turn_prompt_source = "usage_prompt" - elif isinstance(c_tokens, (int, float)): - prompt_derived = int(t_tokens) - int(c_tokens) - if prompt_derived > 0: - self._last_turn_prompt_tokens = prompt_derived - self._last_turn_prompt_source = "usage_total_minus_completion" - elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)): - # If we have both prompt and completion tokens, sum them - total_tokens_this_turn = int(p_tokens) + int(c_tokens) - token_source = "provider_sum" - if p_tokens > 0: - self._last_turn_prompt_tokens = int(p_tokens) - self._last_turn_prompt_source = "usage_prompt" - elif isinstance(p_tokens, (int, float)) and p_tokens > 0: - # Fallback: use prompt tokens only (completion might be 0 for tool calls) - total_tokens_this_turn = int(p_tokens) - token_source = "provider_prompt" - self._last_turn_prompt_tokens = int(p_tokens) - self._last_turn_prompt_source = "usage_prompt" - else: - # Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken. - estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs) - estimated_completion = self._estimate_completion_tokens(response.content or "") - total_tokens_this_turn = estimated_prompt + estimated_completion - if estimated_prompt > 0: - self._last_turn_prompt_tokens = int(estimated_prompt) - self._last_turn_prompt_source = str(prompt_source or "tiktoken") - if total_tokens_this_turn > 0: - token_source = ( - "tiktoken" - if prompt_source == "tiktoken" - else f"{prompt_source}+tiktoken_completion" - ) - if total_tokens_this_turn <= 0: - total_tokens_this_turn = 0 - token_source = "none" - - logger.debug( - "Turn token usage: source={}, total={}, prompt={}, completion={}", - token_source, - total_tokens_this_turn, - p_tokens if isinstance(p_tokens, (int, float)) else None, - c_tokens if isinstance(c_tokens, (int, float)) else None, - ) - if response.has_tool_calls: if on_progress: thought = self._strip_think(response.content) @@ -707,7 +261,7 @@ class AgentLoop: "without completing the task. You can try breaking the task into smaller steps." ) - return final_content, tools_used, messages, total_tokens_this_turn, token_source + return final_content, tools_used, messages async def run(self) -> None: """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" @@ -732,9 +286,6 @@ class AgentLoop: """Cancel all active tasks and subagents for the session.""" tasks = self._active_tasks.pop(msg.session_key, []) cancelled = sum(1 for t in tasks if not t.done() and t.cancel()) - comp = self._compression_tasks.get(msg.session_key) - if comp is not None and not comp.done() and comp.cancel(): - cancelled += 1 for t in tasks: try: await t @@ -781,9 +332,6 @@ class AgentLoop: def stop(self) -> None: """Stop the agent loop.""" self._running = False - for task in list(self._compression_tasks.values()): - if not task.done(): - task.cancel() logger.info("Agent loop stopping") async def _process_message( @@ -800,22 +348,17 @@ class AgentLoop: logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) + await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) - history = self._build_compressed_history_view(session) + history = session.get_history(max_messages=0) messages = self.context.build_messages( history=history, current_message=msg.content, channel=channel, chat_id=chat_id, ) - final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages) - if self._last_turn_prompt_tokens > 0: - session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens - session.metadata["_last_prompt_source"] = self._last_turn_prompt_source - else: - session.metadata.pop("_last_prompt_tokens", None) - session.metadata.pop("_last_prompt_source", None) + final_content, _, all_msgs = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - self._schedule_background_compression(session.key) + await self.memory_consolidator.maybe_consolidate_by_tokens(session) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -829,19 +372,12 @@ class AgentLoop: cmd = msg.content.strip().lower() if cmd == "/new": try: - # 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中 - if session.messages: - ok, _ = await self.context.memory.consolidate_chunk( - session.messages, - self.provider, - self.model, + if not await self.memory_consolidator.archive_unconsolidated(session): + return OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="Memory archival failed, session not cleared. Please try again.", ) - if not ok: - return OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content="Memory archival failed, session not cleared. Please try again.", - ) except Exception: logger.exception("/new archival failed for {}", session.key) return OutboundMessage( @@ -859,23 +395,20 @@ class AgentLoop: return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands") + await self.memory_consolidator.maybe_consolidate_by_tokens(session) + self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): if isinstance(message_tool, MessageTool): message_tool.start_turn() - # 正常对话:使用压缩后的历史视图(压缩在回合结束后进行) - history = self._build_compressed_history_view(session) + history = session.get_history(max_messages=0) initial_messages = self.context.build_messages( history=history, current_message=msg.content, media=msg.media if msg.media else None, channel=msg.channel, chat_id=msg.chat_id, ) - # Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:") - if session_key and session_key.startswith("cron:"): - if initial_messages and initial_messages[0].get("role") == "system": - initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}" async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: meta = dict(msg.metadata or {}) @@ -885,23 +418,16 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, )) - final_content, _, all_msgs, total_tokens_this_turn, token_source = await self._run_agent_loop( + final_content, _, all_msgs = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, ) if final_content is None: final_content = "I've completed processing but have no response to give." - if self._last_turn_prompt_tokens > 0: - session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens - session.metadata["_last_prompt_source"] = self._last_turn_prompt_source - else: - session.metadata.pop("_last_prompt_tokens", None) - session.metadata.pop("_last_prompt_source", None) - - self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn) + self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - self._schedule_background_compression(session.key) + await self.memory_consolidator.maybe_consolidate_by_tokens(session) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None @@ -913,7 +439,7 @@ class AgentLoop: metadata=msg.metadata or {}, ) - def _save_turn(self, session: Session, messages: list[dict], skip: int, total_tokens_this_turn: int = 0) -> None: + def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None: """Save new-turn messages into session, truncating large tool results.""" from datetime import datetime for m in messages[skip:]: @@ -947,14 +473,6 @@ class AgentLoop: entry.setdefault("timestamp", datetime.now().isoformat()) session.messages.append(entry) session.updated_at = datetime.now() - - # Update cumulative token count for compression tracking - if total_tokens_this_turn > 0: - current_cumulative = session.metadata.get("_cumulative_tokens", 0) - if isinstance(current_cumulative, (int, float)): - session.metadata["_cumulative_tokens"] = int(current_cumulative) + total_tokens_this_turn - else: - session.metadata["_cumulative_tokens"] = total_tokens_this_turn async def process_direct( self, diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index e29788a..cd5f54f 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -2,17 +2,19 @@ from __future__ import annotations +import asyncio import json +import weakref from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable from loguru import logger -from nanobot.utils.helpers import ensure_dir +from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain if TYPE_CHECKING: from nanobot.providers.base import LLMProvider - from nanobot.session.manager import Session + from nanobot.session.manager import Session, SessionManager _SAVE_MEMORY_TOOL = [ @@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [ "properties": { "history_entry": { "type": "string", - "description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. " + "description": "A paragraph summarizing key events/decisions/topics. " "Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.", }, "memory_update": { @@ -42,6 +44,20 @@ _SAVE_MEMORY_TOOL = [ ] +def _ensure_text(value: Any) -> str: + """Normalize tool-call payload values to text for file storage.""" + return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False) + + +def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None: + """Normalize provider tool-call arguments to the expected dict shape.""" + if isinstance(args, str): + args = json.loads(args) + if isinstance(args, list): + return args[0] if args and isinstance(args[0], dict) else None + return args if isinstance(args, dict) else None + + class MemoryStore: """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" @@ -66,29 +82,27 @@ class MemoryStore: long_term = self.read_long_term() return f"## Long-term Memory\n{long_term}" if long_term else "" - async def consolidate_chunk( + @staticmethod + def _format_messages(messages: list[dict]) -> str: + lines = [] + for message in messages: + if not message.get("content"): + continue + tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else "" + lines.append( + f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}" + ) + return "\n".join(lines) + + async def consolidate( self, messages: list[dict], provider: LLMProvider, model: str, - ) -> tuple[bool, str | None]: - """Consolidate a chunk of messages into MEMORY.md + HISTORY.md via LLM tool call. - - Returns (success, None). - - - success: True on success (including no-op), False on failure. - - The second return value is reserved for future use (e.g. RAG-style summaries) and is - always None in the current implementation. - """ + ) -> bool: + """Consolidate the provided message chunk into MEMORY.md + HISTORY.md.""" if not messages: - return True, None - - lines = [] - for m in messages: - if not m.get("content"): - continue - tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else "" - lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}") + return True current_memory = self.read_long_term() prompt = f"""Process this conversation and call the save_memory tool with your consolidation. @@ -97,24 +111,12 @@ class MemoryStore: {current_memory or "(empty)"} ## Conversation to Process -{chr(10).join(lines)}""" +{self._format_messages(messages)}""" try: response = await provider.chat_with_retry( messages=[ - { - "role": "system", - "content": ( - "You are a memory consolidation agent.\n" - "Your job is to:\n" - "1) Append a concise but grep-friendly entry to HISTORY.md summarizing key events, decisions and topics.\n" - " - Write 1 paragraph of 2–5 sentences that starts with [YYYY-MM-DD HH:MM].\n" - " - Include concrete names, IDs and numbers so it is easy to search with grep.\n" - "2) Update long-term MEMORY.md with stable facts and user preferences as markdown, including all existing facts plus new ones.\n" - "3) Optionally return a short context_summary (1–3 sentences) that will replace the raw messages in future dialogue history.\n\n" - "Always call the save_memory tool with history_entry, memory_update and (optionally) context_summary." - ), - }, + {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, {"role": "user", "content": prompt}, ], tools=_SAVE_MEMORY_TOOL, @@ -123,35 +125,160 @@ class MemoryStore: if not response.has_tool_calls: logger.warning("Memory consolidation: LLM did not call save_memory, skipping") - return False, None + return False - args = response.tool_calls[0].arguments - # Some providers return arguments as a JSON string instead of dict - if isinstance(args, str): - args = json.loads(args) - # Some providers return arguments as a list (handle edge case) - if isinstance(args, list): - if args and isinstance(args[0], dict): - args = args[0] - else: - logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list") - return False, None - if not isinstance(args, dict): - logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__) - return False, None + args = _normalize_save_memory_args(response.tool_calls[0].arguments) + if args is None: + logger.warning("Memory consolidation: unexpected save_memory arguments") + return False if entry := args.get("history_entry"): - if not isinstance(entry, str): - entry = json.dumps(entry, ensure_ascii=False) - self.append_history(entry) + self.append_history(_ensure_text(entry)) if update := args.get("memory_update"): - if not isinstance(update, str): - update = json.dumps(update, ensure_ascii=False) + update = _ensure_text(update) if update != current_memory: self.write_long_term(update) logger.info("Memory consolidation done for {} messages", len(messages)) - return True, None + return True except Exception: logger.exception("Memory consolidation failed") - return False, None + return False + + +class MemoryConsolidator: + """Owns consolidation policy, locking, and session offset updates.""" + + _MAX_CONSOLIDATION_ROUNDS = 5 + + def __init__( + self, + workspace: Path, + provider: LLMProvider, + model: str, + sessions: SessionManager, + context_window_tokens: int, + build_messages: Callable[..., list[dict[str, Any]]], + get_tool_definitions: Callable[[], list[dict[str, Any]]], + ): + self.store = MemoryStore(workspace) + self.provider = provider + self.model = model + self.sessions = sessions + self.context_window_tokens = context_window_tokens + self._build_messages = build_messages + self._get_tool_definitions = get_tool_definitions + self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() + + def get_lock(self, session_key: str) -> asyncio.Lock: + """Return the shared consolidation lock for one session.""" + return self._locks.setdefault(session_key, asyncio.Lock()) + + async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: + """Archive a selected message chunk into persistent memory.""" + return await self.store.consolidate(messages, self.provider, self.model) + + def pick_consolidation_boundary( + self, + session: Session, + tokens_to_remove: int, + ) -> tuple[int, int] | None: + """Pick a user-turn boundary that removes enough old prompt tokens.""" + start = session.last_consolidated + if start >= len(session.messages) or tokens_to_remove <= 0: + return None + + removed_tokens = 0 + last_boundary: tuple[int, int] | None = None + for idx in range(start, len(session.messages)): + message = session.messages[idx] + if idx > start and message.get("role") == "user": + last_boundary = (idx, removed_tokens) + if removed_tokens >= tokens_to_remove: + return last_boundary + removed_tokens += estimate_message_tokens(message) + + return last_boundary + + def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]: + """Estimate current prompt size for the normal session history view.""" + history = session.get_history(max_messages=0) + channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) + probe_messages = self._build_messages( + history=history, + current_message="[token-probe]", + channel=channel, + chat_id=chat_id, + ) + return estimate_prompt_tokens_chain( + self.provider, + self.model, + probe_messages, + self._get_tool_definitions(), + ) + + async def archive_unconsolidated(self, session: Session) -> bool: + """Archive the full unconsolidated tail for /new-style session rollover.""" + lock = self.get_lock(session.key) + async with lock: + snapshot = session.messages[session.last_consolidated:] + if not snapshot: + return True + return await self.consolidate_messages(snapshot) + + async def maybe_consolidate_by_tokens(self, session: Session) -> None: + """Loop: archive old messages until prompt fits within half the context window.""" + if not session.messages or self.context_window_tokens <= 0: + return + + lock = self.get_lock(session.key) + async with lock: + target = self.context_window_tokens // 2 + estimated, source = self.estimate_session_prompt_tokens(session) + if estimated <= 0: + return + if estimated < self.context_window_tokens: + logger.debug( + "Token consolidation idle {}: {}/{} via {}", + session.key, + estimated, + self.context_window_tokens, + source, + ) + return + + for round_num in range(self._MAX_CONSOLIDATION_ROUNDS): + if estimated <= target: + return + + boundary = self.pick_consolidation_boundary(session, max(1, estimated - target)) + if boundary is None: + logger.debug( + "Token consolidation: no safe boundary for {} (round {})", + session.key, + round_num, + ) + return + + end_idx = boundary[0] + chunk = session.messages[session.last_consolidated:end_idx] + if not chunk: + return + + logger.info( + "Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs", + round_num, + session.key, + estimated, + self.context_window_tokens, + source, + len(chunk), + ) + if not await self.consolidate_messages(chunk): + return + session.last_consolidated = end_idx + self.sessions.save(session) + + estimated, source = self.estimate_session_prompt_tokens(session) + if estimated <= 0: + return diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 36e2a53..cf69450 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -191,6 +191,8 @@ def onboard(): save_config(Config()) console.print(f"[green]✓[/green] Created config at {config_path}") + console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]") + # Create workspace workspace = get_workspace_path() @@ -283,6 +285,16 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None return loaded +def _print_deprecated_memory_window_notice(config: Config) -> None: + """Warn when running with old memoryWindow-only config.""" + if config.agents.defaults.should_warn_deprecated_memory_window: + console.print( + "[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without " + "`contextWindowTokens`. `memoryWindow` is ignored; run " + "[cyan]nanobot onboard[/cyan] to refresh your config template." + ) + + # ============================================================================ # Gateway / Server # ============================================================================ @@ -310,6 +322,7 @@ def gateway( logging.basicConfig(level=logging.DEBUG) config = _load_runtime_config(config, workspace) + _print_deprecated_memory_window_notice(config) port = port if port is not None else config.gateway.port console.print(f"{__logo__} Starting nanobot gateway on port {port}...") @@ -329,12 +342,10 @@ def gateway( workspace=config.workspace_path, model=config.agents.defaults.model, temperature=config.agents.defaults.temperature, - max_tokens=config.agents.defaults.max_tokens_output, + max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, reasoning_effort=config.agents.defaults.reasoning_effort, - max_tokens_input=config.agents.defaults.max_tokens_input, - compression_start_ratio=config.agents.defaults.compression_start_ratio, - compression_target_ratio=config.agents.defaults.compression_target_ratio, + context_window_tokens=config.agents.defaults.context_window_tokens, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, @@ -496,6 +507,7 @@ def agent( from nanobot.cron.service import CronService config = _load_runtime_config(config, workspace) + _print_deprecated_memory_window_notice(config) sync_workspace_templates(config.workspace_path) bus = MessageBus() @@ -516,12 +528,10 @@ def agent( workspace=config.workspace_path, model=config.agents.defaults.model, temperature=config.agents.defaults.temperature, - max_tokens=config.agents.defaults.max_tokens_output, + max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, reasoning_effort=config.agents.defaults.reasoning_effort, - max_tokens_input=config.agents.defaults.max_tokens_input, - compression_start_ratio=config.agents.defaults.compression_start_ratio, - compression_target_ratio=config.agents.defaults.compression_target_ratio, + context_window_tokens=config.agents.defaults.context_window_tokens, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 0e41d12..a2de239 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -190,22 +190,11 @@ class SlackConfig(Base): class QQConfig(Base): - """QQ channel configuration. - - Supports two implementations: - 1. Official botpy SDK: requires app_id and secret - 2. OneBot protocol: requires api_url (and optionally ws_reverse_url, bot_qq, access_token) - """ + """QQ channel configuration using botpy SDK.""" enabled: bool = False - # Official botpy SDK fields app_id: str = "" # 机器人 ID (AppID) from q.qq.com secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com - # OneBot protocol fields - api_url: str = "" # OneBot HTTP API URL (e.g. "http://localhost:5700") - ws_reverse_url: str = "" # OneBot WebSocket reverse URL (e.g. "ws://localhost:8080/ws/reverse") - bot_qq: int | None = None # Bot's QQ number (for filtering self messages) - access_token: str = "" # Optional access token for OneBot API allow_from: list[str] = Field( default_factory=list ) # Allowed user openids (empty = public access) @@ -238,20 +227,19 @@ class AgentDefaults(Base): provider: str = ( "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection ) - # 原生上下文最大窗口(通常对应模型的 max_input_tokens / max_context_tokens) - # 默认按照主流大模型(如 GPT-4o、Claude 3.x 等)的 128k 上下文给一个宽松上限,实际应根据所选模型文档手动调整。 - max_tokens_input: int = 128_000 - # 默认单次回复的最大输出 token 上限(调用时可按需要再做截断或比例分配) - # 8192 足以覆盖大多数实际对话/工具使用场景,同样可按需手动调整。 - max_tokens_output: int = 8192 - # 会话历史压缩触发比例:当估算的输入 token 使用量 >= maxTokensInput * compressionStartRatio 时开始压缩。 - compression_start_ratio: float = 0.7 - # 会话历史压缩目标比例:每轮压缩后尽量把估算的输入 token 使用量压到 maxTokensInput * compressionTargetRatio 附近。 - compression_target_ratio: float = 0.4 + max_tokens: int = 8192 + context_window_tokens: int = 65_536 temperature: float = 0.1 max_tool_iterations: int = 40 + # Deprecated compatibility field: accepted from old configs but ignored at runtime. + memory_window: int | None = Field(default=None, exclude=True) reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode + @property + def should_warn_deprecated_memory_window(self) -> bool: + """Return True when old memoryWindow is present without contextWindowTokens.""" + return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set + class AgentsConfig(Base): """Agent configuration.""" diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index 1cb8a51..f0a6484 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -9,6 +9,7 @@ from typing import Any from loguru import logger +from nanobot.config.paths import get_legacy_sessions_dir from nanobot.utils.helpers import ensure_dir, safe_filename @@ -29,6 +30,7 @@ class Session: created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) + last_consolidated: int = 0 # Number of messages already consolidated to files def add_message(self, role: str, content: str, **kwargs: Any) -> None: """Add a message to the session.""" @@ -42,13 +44,9 @@ class Session: self.updated_at = datetime.now() def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: - """ - Return messages for LLM input, aligned to a user turn. - - - max_messages > 0 时只保留最近 max_messages 条; - - max_messages <= 0 时不做条数截断,返回全部消息。 - """ - sliced = self.messages if max_messages <= 0 else self.messages[-max_messages:] + """Return unconsolidated messages for LLM input, aligned to a user turn.""" + unconsolidated = self.messages[self.last_consolidated:] + sliced = unconsolidated[-max_messages:] # Drop leading non-user messages to avoid orphaned tool_result blocks for i, m in enumerate(sliced): @@ -68,7 +66,7 @@ class Session: def clear(self) -> None: """Clear all messages and reset session to initial state.""" self.messages = [] - self.metadata = {} + self.last_consolidated = 0 self.updated_at = datetime.now() @@ -82,7 +80,7 @@ class SessionManager: def __init__(self, workspace: Path): self.workspace = workspace self.sessions_dir = ensure_dir(self.workspace / "sessions") - self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions" + self.legacy_sessions_dir = get_legacy_sessions_dir() self._cache: dict[str, Session] = {} def _get_session_path(self, key: str) -> Path: @@ -134,6 +132,7 @@ class SessionManager: messages = [] metadata = {} created_at = None + last_consolidated = 0 with open(path, encoding="utf-8") as f: for line in f: @@ -146,6 +145,7 @@ class SessionManager: if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None + last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) @@ -154,6 +154,7 @@ class SessionManager: messages=messages, created_at=created_at or datetime.now(), metadata=metadata, + last_consolidated=last_consolidated ) except Exception as e: logger.warning("Failed to load session {}: {}", key, e) @@ -170,6 +171,7 @@ class SessionManager: "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), "metadata": session.metadata, + "last_consolidated": session.last_consolidated } f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n") for msg in session.messages: diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 57c60dc..9242ba6 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -1,8 +1,12 @@ """Utility functions for nanobot.""" +import json import re from datetime import datetime from pathlib import Path +from typing import Any + +import tiktoken def detect_image_mime(data: bytes) -> str | None: @@ -68,6 +72,87 @@ def split_message(content: str, max_len: int = 2000) -> list[str]: return chunks +def estimate_prompt_tokens( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, +) -> int: + """Estimate prompt tokens with tiktoken.""" + try: + enc = tiktoken.get_encoding("cl100k_base") + parts: list[str] = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + txt = part.get("text", "") + if txt: + parts.append(txt) + if tools: + parts.append(json.dumps(tools, ensure_ascii=False)) + return len(enc.encode("\n".join(parts))) + except Exception: + return 0 + + +def estimate_message_tokens(message: dict[str, Any]) -> int: + """Estimate prompt tokens contributed by one persisted message.""" + content = message.get("content") + parts: list[str] = [] + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text", "") + if text: + parts.append(text) + else: + parts.append(json.dumps(part, ensure_ascii=False)) + elif content is not None: + parts.append(json.dumps(content, ensure_ascii=False)) + + for key in ("name", "tool_call_id"): + value = message.get(key) + if isinstance(value, str) and value: + parts.append(value) + if message.get("tool_calls"): + parts.append(json.dumps(message["tool_calls"], ensure_ascii=False)) + + payload = "\n".join(parts) + if not payload: + return 1 + try: + enc = tiktoken.get_encoding("cl100k_base") + return max(1, len(enc.encode(payload))) + except Exception: + return max(1, len(payload) // 4) + + +def estimate_prompt_tokens_chain( + provider: Any, + model: str | None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, +) -> tuple[int, str]: + """Estimate prompt tokens via provider counter first, then tiktoken fallback.""" + provider_counter = getattr(provider, "estimate_prompt_tokens", None) + if callable(provider_counter): + try: + tokens, source = provider_counter(messages, tools, model) + if isinstance(tokens, (int, float)) and tokens > 0: + return int(tokens), str(source or "provider_counter") + except Exception: + pass + + estimated = estimate_prompt_tokens(messages, tools) + if estimated > 0: + return int(estimated), "tiktoken" + return 0, "none" + + def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: """Sync bundled templates to workspace. Only creates missing files.""" from importlib.resources import files as pkg_files diff --git a/pyproject.toml b/pyproject.toml index 62cf616..0344348 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "json-repair>=0.57.0,<1.0.0", "chardet>=3.0.2,<6.0.0", "openai>=2.8.0", + "tiktoken>=0.12.0,<1.0.0", ] [project.optional-dependencies] diff --git a/tests/test_commands.py b/tests/test_commands.py index 5e3760a..1375a3a 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -267,6 +267,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime, assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path +def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime): + mock_agent_runtime["config"].agents.defaults.memory_window = 100 + + result = runner.invoke(app, ["agent", "-m", "hello"]) + + assert result.exit_code == 0 + assert "memoryWindow" in result.stdout + assert "contextWindowTokens" in result.stdout + + def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) @@ -327,6 +337,29 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) assert seen["workspace"] == override assert config.workspace_path == override + +def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.memory_window = 100 + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr( + "nanobot.cli.commands._make_provider", + lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + ) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + + assert isinstance(result.exception, _StopGateway) + assert "memoryWindow" in result.stdout + assert "contextWindowTokens" in result.stdout + def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py new file mode 100644 index 0000000..62e601e --- /dev/null +++ b/tests/test_config_migration.py @@ -0,0 +1,88 @@ +import json + +from typer.testing import CliRunner + +from nanobot.cli.commands import app +from nanobot.config.loader import load_config, save_config + +runner = CliRunner() + + +def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "agents": { + "defaults": { + "maxTokens": 1234, + "memoryWindow": 42, + } + } + } + ), + encoding="utf-8", + ) + + config = load_config(config_path) + + assert config.agents.defaults.max_tokens == 1234 + assert config.agents.defaults.context_window_tokens == 65_536 + assert config.agents.defaults.should_warn_deprecated_memory_window is True + + +def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "agents": { + "defaults": { + "maxTokens": 2222, + "memoryWindow": 30, + } + } + } + ), + encoding="utf-8", + ) + + config = load_config(config_path) + save_config(config, config_path) + saved = json.loads(config_path.read_text(encoding="utf-8")) + defaults = saved["agents"]["defaults"] + + assert defaults["maxTokens"] == 2222 + assert defaults["contextWindowTokens"] == 65_536 + assert "memoryWindow" not in defaults + + +def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None: + config_path = tmp_path / "config.json" + workspace = tmp_path / "workspace" + config_path.write_text( + json.dumps( + { + "agents": { + "defaults": { + "maxTokens": 3333, + "memoryWindow": 50, + } + } + } + ), + encoding="utf-8", + ) + + monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) + monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace) + + result = runner.invoke(app, ["onboard"], input="n\n") + + assert result.exit_code == 0 + assert "contextWindowTokens" in result.stdout + saved = json.loads(config_path.read_text(encoding="utf-8")) + defaults = saved["agents"]["defaults"] + assert defaults["maxTokens"] == 3333 + assert defaults["contextWindowTokens"] == 65_536 + assert "memoryWindow" not in defaults diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index a3213dd..7d12338 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions: assert_messages_content(old_messages, 10, 34) -class TestConsolidationDeduplicationGuard: - """Test that consolidation tasks are deduplicated and serialized.""" +class TestNewCommandArchival: + """Test /new archival behavior with the simplified consolidation flow.""" - @pytest.mark.asyncio - async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None: - """Concurrent messages above memory_window spawn only one consolidation task.""" + @staticmethod + def _make_loop(tmp_path: Path): from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" + provider.estimate_prompt_tokens.return_value = (10_000, "test") loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 + bus=bus, + provider=provider, + workspace=tmp_path, + model="test-model", + context_window_tokens=1, ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - consolidation_calls = 0 - - async def _fake_consolidate(_session, archive_all: bool = False) -> None: - nonlocal consolidation_calls - consolidation_calls += 1 - await asyncio.sleep(0.05) - - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - await loop._process_message(msg) - await asyncio.sleep(0.1) - - assert consolidation_calls == 1, ( - f"Expected exactly 1 consolidation, got {consolidation_calls}" - ) - - @pytest.mark.asyncio - async def test_new_command_guard_prevents_concurrent_consolidation( - self, tmp_path: Path - ) -> None: - """/new command does not run consolidation concurrently with in-flight consolidation.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - consolidation_calls = 0 - active = 0 - max_active = 0 - - async def _fake_consolidate(_session, archive_all: bool = False) -> None: - nonlocal consolidation_calls, active, max_active - consolidation_calls += 1 - active += 1 - max_active = max(max_active, active) - await asyncio.sleep(0.05) - active -= 1 - - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - - new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") - await loop._process_message(new_msg) - await asyncio.sleep(0.1) - - assert consolidation_calls == 2, ( - f"Expected normal + /new consolidations, got {consolidation_calls}" - ) - assert max_active == 1, ( - f"Expected serialized consolidation, observed concurrency={max_active}" - ) - - @pytest.mark.asyncio - async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None: - """create_task results are tracked in _consolidation_tasks while in flight.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - started = asyncio.Event() - - async def _slow_consolidate(_session, archive_all: bool = False) -> None: - started.set() - await asyncio.sleep(0.1) - - loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - - await started.wait() - assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight" - - await asyncio.sleep(0.15) - assert len(loop._consolidation_tasks) == 0, ( - "Task reference must be removed after completion" - ) - - @pytest.mark.asyncio - async def test_new_waits_for_inflight_consolidation_and_preserves_messages( - self, tmp_path: Path - ) -> None: - """/new waits for in-flight consolidation and archives before clear.""" - from nanobot.agent.loop import AgentLoop - from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) - - session = loop.sessions.get_or_create("cli:test") - for i in range(15): - session.add_message("user", f"msg{i}") - session.add_message("assistant", f"resp{i}") - loop.sessions.save(session) - - started = asyncio.Event() - release = asyncio.Event() - archived_count = 0 - - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: - nonlocal archived_count - if archive_all: - archived_count = len(sess.messages) - return True - started.set() - await release.wait() - return True - - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - await started.wait() - - new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") - pending_new = asyncio.create_task(loop._process_message(new_msg)) - - await asyncio.sleep(0.02) - assert not pending_new.done(), "/new should wait while consolidation is in-flight" - - release.set() - response = await pending_new - assert response is not None - assert "new session started" in response.content.lower() - assert archived_count > 0, "Expected /new archival to process a non-empty snapshot" - - session_after = loop.sessions.get_or_create("cli:test") - assert session_after.messages == [], "Session should be cleared after successful archival" + return loop @pytest.mark.asyncio async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None: - """/new must keep session data if archive step reports failure.""" - from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) + loop = self._make_loop(tmp_path) session = loop.sessions.get_or_create("cli:test") for i in range(5): session.add_message("user", f"msg{i}") @@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard: loop.sessions.save(session) before_count = len(session.messages) - async def _failing_consolidate(sess, archive_all: bool = False) -> bool: - if archive_all: - return False - return True + async def _failing_consolidate(_messages) -> bool: + return False - loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign] + loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") response = await loop._process_message(new_msg) assert response is not None assert "failed" in response.content.lower() - session_after = loop.sessions.get_or_create("cli:test") - assert len(session_after.messages) == before_count, ( - "Session must remain intact when /new archival fails" - ) + assert len(loop.sessions.get_or_create("cli:test").messages) == before_count @pytest.mark.asyncio - async def test_new_archives_only_unconsolidated_messages_after_inflight_task( - self, tmp_path: Path - ) -> None: - """/new should archive only messages not yet consolidated by prior task.""" - from nanobot.agent.loop import AgentLoop + async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None: from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) + loop = self._make_loop(tmp_path) session = loop.sessions.get_or_create("cli:test") for i in range(15): session.add_message("user", f"msg{i}") session.add_message("assistant", f"resp{i}") + session.last_consolidated = len(session.messages) - 3 loop.sessions.save(session) - started = asyncio.Event() - release = asyncio.Event() archived_count = -1 - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: + async def _fake_consolidate(messages) -> bool: nonlocal archived_count - if archive_all: - archived_count = len(sess.messages) - return True - - started.set() - await release.wait() - sess.last_consolidated = len(sess.messages) - 3 + archived_count = len(messages) return True - loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign] - - msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello") - await loop._process_message(msg) - await started.wait() + loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") - pending_new = asyncio.create_task(loop._process_message(new_msg)) - await asyncio.sleep(0.02) - assert not pending_new.done() - - release.set() - response = await pending_new + response = await loop._process_message(new_msg) assert response is not None assert "new session started" in response.content.lower() - assert archived_count == 3, ( - f"Expected only unconsolidated tail to archive, got {archived_count}" - ) + assert archived_count == 3 @pytest.mark.asyncio async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None: - """/new clears session and returns confirmation.""" - from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage - from nanobot.bus.queue import MessageBus - from nanobot.providers.base import LLMResponse - - bus = MessageBus() - provider = MagicMock() - provider.get_default_model.return_value = "test-model" - loop = AgentLoop( - bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10 - ) - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) - loop.tools.get_definitions = MagicMock(return_value=[]) + loop = self._make_loop(tmp_path) session = loop.sessions.get_or_create("cli:test") for i in range(3): session.add_message("user", f"msg{i}") session.add_message("assistant", f"resp{i}") loop.sessions.save(session) - async def _ok_consolidate(sess, archive_all: bool = False) -> bool: + async def _ok_consolidate(_messages) -> bool: return True - loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign] + loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") response = await loop._process_message(new_msg) diff --git a/tests/test_loop_consolidation_tokens.py b/tests/test_loop_consolidation_tokens.py new file mode 100644 index 0000000..b0f3dda --- /dev/null +++ b/tests/test_loop_consolidation_tokens.py @@ -0,0 +1,190 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.loop import AgentLoop +import nanobot.agent.memory as memory_module +from nanobot.bus.queue import MessageBus +from nanobot.providers.base import LLMResponse + + +def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop: + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") + provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[])) + + loop = AgentLoop( + bus=MessageBus(), + provider=provider, + workspace=tmp_path, + model="test-model", + context_window_tokens=context_window_tokens, + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + return loop + + +@pytest.mark.asyncio +async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None: + loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + await loop.process_direct("hello", session_key="cli:test") + + loop.memory_consolidator.consolidate_messages.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None: + loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + ] + loop.sessions.save(session) + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500) + + await loop.process_direct("hello", session_key="cli:test") + + assert loop.memory_consolidator.consolidate_messages.await_count >= 1 + + +@pytest.mark.asyncio +async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None: + loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, + {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, + ] + loop.sessions.save(session) + + token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120} + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]]) + + await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + + archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0] + assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"] + assert session.last_consolidated == 4 + + +@pytest.mark.asyncio +async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None: + """Verify maybe_consolidate_by_tokens keeps looping until under threshold.""" + loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, + {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, + {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"}, + {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"}, + ] + loop.sessions.save(session) + + call_count = [0] + def mock_estimate(_session): + call_count[0] += 1 + if call_count[0] == 1: + return (500, "test") + if call_count[0] == 2: + return (300, "test") + return (80, "test") + + loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) + + await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + + assert loop.memory_consolidator.consolidate_messages.await_count == 2 + assert session.last_consolidated == 6 + + +@pytest.mark.asyncio +async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None: + """Once triggered, consolidation should continue until it drops below half threshold.""" + loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) + loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + {"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"}, + {"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"}, + {"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"}, + {"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"}, + ] + loop.sessions.save(session) + + call_count = [0] + + def mock_estimate(_session): + call_count[0] += 1 + if call_count[0] == 1: + return (500, "test") + if call_count[0] == 2: + return (150, "test") + return (80, "test") + + loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) + + await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + + assert loop.memory_consolidator.consolidate_messages.await_count == 2 + assert session.last_consolidated == 6 + + +@pytest.mark.asyncio +async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None: + """Verify preflight consolidation runs before the LLM call in process_direct.""" + order: list[str] = [] + + loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) + + async def track_consolidate(messages): + order.append("consolidate") + return True + loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign] + + async def track_llm(*args, **kwargs): + order.append("llm") + return LLMResponse(content="ok", tool_calls=[]) + loop.provider.chat_with_retry = track_llm + + session = loop.sessions.get_or_create("cli:test") + session.messages = [ + {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, + {"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"}, + {"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"}, + ] + loop.sessions.save(session) + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500) + + call_count = [0] + def mock_estimate(_session): + call_count[0] += 1 + return (1000 if call_count[0] <= 1 else 80, "test") + loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + + await loop.process_direct("hello", session_key="cli:test") + + assert "consolidate" in order + assert "llm" in order + assert order.index("consolidate") < order.index("llm") diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py index 2605bf7..0263f01 100644 --- a/tests/test_memory_consolidation_types.py +++ b/tests/test_memory_consolidation_types.py @@ -7,7 +7,7 @@ tool call response, it should serialize them to JSON instead of raising TypeErro import json from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock import pytest @@ -15,15 +15,12 @@ from nanobot.agent.memory import MemoryStore from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -def _make_session(message_count: int = 30, memory_window: int = 50): - """Create a mock session with messages.""" - session = MagicMock() - session.messages = [ +def _make_messages(message_count: int = 30): + """Create a list of mock messages.""" + return [ {"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"} for i in range(message_count) ] - session.last_consolidated = 0 - return session def _make_tool_response(history_entry, memory_update): @@ -74,9 +71,9 @@ class TestMemoryConsolidationTypeHandling: ) ) provider.chat_with_retry = provider.chat - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert store.history_file.exists() @@ -95,9 +92,9 @@ class TestMemoryConsolidationTypeHandling: ) ) provider.chat_with_retry = provider.chat - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert store.history_file.exists() @@ -131,9 +128,9 @@ class TestMemoryConsolidationTypeHandling: ) provider.chat = AsyncMock(return_value=response) provider.chat_with_retry = provider.chat - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert "User discussed testing." in store.history_file.read_text() @@ -147,22 +144,22 @@ class TestMemoryConsolidationTypeHandling: return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[]) ) provider.chat_with_retry = provider.chat - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is False assert not store.history_file.exists() @pytest.mark.asyncio - async def test_skips_when_few_messages(self, tmp_path: Path) -> None: - """Consolidation should be a no-op when messages < keep_count.""" + async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None: + """Consolidation should be a no-op when the selected chunk is empty.""" store = MemoryStore(tmp_path) provider = AsyncMock() provider.chat_with_retry = provider.chat - session = _make_session(message_count=10) + messages: list[dict] = [] - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True provider.chat.assert_not_called() @@ -189,9 +186,9 @@ class TestMemoryConsolidationTypeHandling: ) provider.chat = AsyncMock(return_value=response) provider.chat_with_retry = provider.chat - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert "User discussed testing." in store.history_file.read_text() @@ -215,9 +212,9 @@ class TestMemoryConsolidationTypeHandling: ) provider.chat = AsyncMock(return_value=response) provider.chat_with_retry = provider.chat - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is False @@ -239,9 +236,9 @@ class TestMemoryConsolidationTypeHandling: ) provider.chat = AsyncMock(return_value=response) provider.chat_with_retry = provider.chat - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is False @@ -255,7 +252,7 @@ class TestMemoryConsolidationTypeHandling: memory_update="# Memory\nUser likes testing.", ), ]) - session = _make_session(message_count=60) + messages = _make_messages(message_count=60) delays: list[int] = [] async def _fake_sleep(delay: int) -> None: @@ -263,7 +260,7 @@ class TestMemoryConsolidationTypeHandling: monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) - result = await store.consolidate(session, provider, "test-model", memory_window=50) + result = await store.consolidate(messages, provider, "test-model") assert result is True assert provider.calls == 2 diff --git a/tests/test_message_tool_suppress.py b/tests/test_message_tool_suppress.py index 63b0fd1..1091de4 100644 --- a/tests/test_message_tool_suppress.py +++ b/tests/test_message_tool_suppress.py @@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop: bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" - return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10) + return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") class TestMessageToolSuppressLogic: @@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic: LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="Done", tool_calls=[]), ]) - loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) sent: list[OutboundMessage] = [] @@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic: LLMResponse(content="", tool_calls=[tool_call]), LLMResponse(content="I've sent the email.", tool_calls=[]), ]) - loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) sent: list[OutboundMessage] = [] @@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic: @pytest.mark.asyncio async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None: loop = _make_loop(tmp_path) - loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[])) loop.tools.get_definitions = MagicMock(return_value=[]) msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi") @@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic: ), LLMResponse(content="Done", tool_calls=[]), ]) - loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.execute = AsyncMock(return_value="ok")