refactor(memory): switch consolidation to token-based context windows

Move consolidation policy into MemoryConsolidator, keep backward compatibility for legacy config, and compress history by token budget instead of message count.
This commit is contained in:
Re-bin
2026-03-10 19:55:06 +00:00
parent 4784eb4128
commit 62ccda43b9
13 changed files with 709 additions and 911 deletions

View File

@@ -11,18 +11,12 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
try:
import tiktoken # type: ignore
except Exception: # pragma: no cover - optional dependency
tiktoken = None
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
@@ -60,11 +54,8 @@ class AgentLoop:
max_iterations: int = 40,
temperature: float = 0.1,
max_tokens: int = 4096,
memory_window: int | None = None, # backward-compat only (unused)
reasoning_effort: str | None = None,
max_tokens_input: int = 128_000,
compression_start_ratio: float = 0.7,
compression_target_ratio: float = 0.4,
context_window_tokens: int = 65_536,
brave_api_key: str | None = None,
web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None,
@@ -82,18 +73,9 @@ class AgentLoop:
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
self.temperature = temperature
# max_tokens: per-call output token cap (maxTokensOutput in config)
self.max_tokens = max_tokens
# Keep legacy attribute for older call sites/tests; compression no longer uses it.
self.memory_window = memory_window
self.reasoning_effort = reasoning_effort
# max_tokens_input: model native context window (maxTokensInput in config)
self.max_tokens_input = max_tokens_input
# Token-based compression watermarks (fractions of available input budget)
self.compression_start_ratio = compression_start_ratio
self.compression_target_ratio = compression_target_ratio
# Reserve tokens for safety margin
self._reserve_tokens = 1000
self.context_window_tokens = context_window_tokens
self.brave_api_key = brave_api_key
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@@ -123,382 +105,23 @@ class AgentLoop:
self._mcp_connected = False
self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
self._last_turn_prompt_tokens: int = 0
self._last_turn_prompt_source: str = "none"
self._processing_lock = asyncio.Lock()
self.memory_consolidator = MemoryConsolidator(
workspace=workspace,
provider=provider,
model=self.model,
sessions=self.sessions,
context_window_tokens=context_window_tokens,
build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions,
)
self._register_default_tools()
@staticmethod
def _estimate_prompt_tokens(
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> int:
"""Estimate prompt tokens with tiktoken (fallback only)."""
if tiktoken is None:
return 0
try:
enc = tiktoken.get_encoding("cl100k_base")
parts: list[str] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
txt = part.get("text", "")
if txt:
parts.append(txt)
if tools:
parts.append(json.dumps(tools, ensure_ascii=False))
return len(enc.encode("\n".join(parts)))
except Exception:
return 0
def _estimate_prompt_tokens_chain(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> tuple[int, str]:
"""Unified prompt-token estimation: provider counter -> tiktoken."""
provider_counter = getattr(self.provider, "estimate_prompt_tokens", None)
if callable(provider_counter):
try:
tokens, source = provider_counter(messages, tools, self.model)
if isinstance(tokens, (int, float)) and tokens > 0:
return int(tokens), str(source or "provider_counter")
except Exception:
logger.debug("Provider token counter failed; fallback to tiktoken")
estimated = self._estimate_prompt_tokens(messages, tools)
if estimated > 0:
return int(estimated), "tiktoken"
return 0, "none"
@staticmethod
def _estimate_completion_tokens(content: str) -> int:
"""Estimate completion tokens with tiktoken (fallback only)."""
if tiktoken is None:
return 0
try:
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode(content or ""))
except Exception:
return 0
def _get_compressed_until(self, session: Session) -> int:
"""Read/normalize compressed boundary and migrate old metadata format."""
raw = session.metadata.get("_compressed_until", 0)
try:
compressed_until = int(raw)
except (TypeError, ValueError):
compressed_until = 0
if compressed_until <= 0:
ranges = session.metadata.get("_compressed_ranges")
if isinstance(ranges, list):
inferred = 0
for item in ranges:
if not isinstance(item, (list, tuple)) or len(item) != 2:
continue
try:
inferred = max(inferred, int(item[1]))
except (TypeError, ValueError):
continue
compressed_until = inferred
compressed_until = max(0, min(compressed_until, len(session.messages)))
session.metadata["_compressed_until"] = compressed_until
# 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段
session.metadata.pop("_compressed_ranges", None)
# 注意:不要删除 _cumulative_tokens压缩逻辑需要它来跟踪累积 token 计数
return compressed_until
def _set_compressed_until(self, session: Session, idx: int) -> None:
"""Persist a contiguous compressed boundary."""
session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages)))
session.metadata.pop("_compressed_ranges", None)
# 注意:不要删除 _cumulative_tokens压缩逻辑需要它来跟踪累积 token 计数
@staticmethod
def _estimate_message_tokens(message: dict[str, Any]) -> int:
"""Rough token estimate for a single persisted message."""
content = message.get("content")
parts: list[str] = []
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
txt = part.get("text", "")
if txt:
parts.append(txt)
else:
parts.append(json.dumps(part, ensure_ascii=False))
elif content is not None:
parts.append(json.dumps(content, ensure_ascii=False))
for key in ("name", "tool_call_id"):
val = message.get(key)
if isinstance(val, str) and val:
parts.append(val)
if message.get("tool_calls"):
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
payload = "\n".join(parts)
if not payload:
return 1
if tiktoken is not None:
try:
enc = tiktoken.get_encoding("cl100k_base")
return max(1, len(enc.encode(payload)))
except Exception:
pass
return max(1, len(payload) // 4)
def _pick_compression_chunk_by_tokens(
self,
session: Session,
reduction_tokens: int,
*,
tail_keep: int = 12,
) -> tuple[int, int, int] | None:
"""
Pick one contiguous old chunk so its estimated size is roughly enough
to reduce `reduction_tokens`.
"""
messages = session.messages
start = self._get_compressed_until(session)
if len(messages) - start <= tail_keep + 2:
return None
end_limit = len(messages) - tail_keep
if end_limit - start < 2:
return None
target = max(1, reduction_tokens)
end = start
collected = 0
while end < end_limit and collected < target:
collected += self._estimate_message_tokens(messages[end])
end += 1
if end - start < 2:
end = min(end_limit, start + 2)
collected = sum(self._estimate_message_tokens(m) for m in messages[start:end])
if end - start < 2:
return None
return start, end, collected
def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
"""
Estimate current full prompt tokens for this session view
(system + compressed history view + runtime/user placeholder + tools).
"""
history = self._build_compressed_history_view(session)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
probe_messages = self.context.build_messages(
history=history,
current_message="[token-probe]",
channel=channel,
chat_id=chat_id,
)
return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions())
async def _maybe_compress_history(
self,
session: Session,
) -> None:
"""
End-of-turn policy:
- Estimate current prompt usage from persisted session view.
- If above start ratio, perform one best-effort compression chunk.
"""
if not session.messages:
self._set_compressed_until(session, 0)
return
budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens)
start_threshold = int(budget * self.compression_start_ratio)
target_threshold = int(budget * self.compression_target_ratio)
if target_threshold >= start_threshold:
target_threshold = max(0, start_threshold - 1)
# Prefer provider usage prompt tokens from the turn-ending call.
# If unavailable, fall back to estimator chain.
raw_prompt_tokens = session.metadata.get("_last_prompt_tokens")
if isinstance(raw_prompt_tokens, (int, float)) and raw_prompt_tokens > 0:
current_tokens = int(raw_prompt_tokens)
token_source = str(session.metadata.get("_last_prompt_source") or "usage_prompt")
else:
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
current_ratio = current_tokens / budget if budget else 0.0
if current_tokens <= 0:
logger.debug("Compression skip {}: token estimate unavailable", session.key)
return
if current_tokens < start_threshold:
logger.debug(
"Compression idle {}: {}/{} ({:.1%}) via {}",
session.key,
current_tokens,
budget,
current_ratio,
token_source,
)
return
logger.info(
"Compression trigger {}: {}/{} ({:.1%}) via {}",
session.key,
current_tokens,
budget,
current_ratio,
token_source,
)
reduction_by_target = max(0, current_tokens - target_threshold)
reduction_by_delta = max(1, start_threshold - target_threshold)
reduction_need = max(reduction_by_target, reduction_by_delta)
chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10)
if chunk_range is None:
logger.info("Compression skipped for {}: no compressible chunk", session.key)
return
start_idx, end_idx, estimated_chunk_tokens = chunk_range
chunk = session.messages[start_idx:end_idx]
if len(chunk) < 2:
return
logger.info(
"Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})",
session.key,
start_idx,
end_idx - 1,
len(chunk),
estimated_chunk_tokens,
reduction_need,
)
success, _ = await self.context.memory.consolidate_chunk(
chunk,
self.provider,
self.model,
)
if not success:
logger.warning("Compression aborted for {}: consolidation failed", session.key)
return
self._set_compressed_until(session, end_idx)
self.sessions.save(session)
after_tokens, after_source = self._estimate_session_prompt_tokens(session)
after_ratio = after_tokens / budget if budget else 0.0
reduced = max(0, current_tokens - after_tokens)
reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0
logger.info(
"Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})",
session.key,
after_tokens,
budget,
after_ratio,
after_source,
reduced,
reduced_ratio,
)
def _schedule_background_compression(self, session_key: str) -> None:
"""Schedule best-effort background compression for a session."""
existing = self._compression_tasks.get(session_key)
if existing is not None and not existing.done():
return
async def _runner() -> None:
session = self.sessions.get_or_create(session_key)
try:
await self._maybe_compress_history(session)
except Exception:
logger.exception("Background compression failed for {}", session_key)
task = asyncio.create_task(_runner())
self._compression_tasks[session_key] = task
def _cleanup(t: asyncio.Task) -> None:
cur = self._compression_tasks.get(session_key)
if cur is t:
self._compression_tasks.pop(session_key, None)
try:
t.result()
except BaseException:
pass
task.add_done_callback(_cleanup)
async def wait_for_background_compression(self, timeout_s: float | None = None) -> None:
"""Wait for currently scheduled compression tasks."""
pending = [t for t in self._compression_tasks.values() if not t.done()]
if not pending:
return
logger.info("Waiting for {} background compression task(s)", len(pending))
waiter = asyncio.gather(*pending, return_exceptions=True)
if timeout_s is None:
await waiter
return
try:
await asyncio.wait_for(waiter, timeout=timeout_s)
except asyncio.TimeoutError:
logger.warning(
"Background compression wait timed out after {}s ({} task(s) still running)",
timeout_s,
len([t for t in self._compression_tasks.values() if not t.done()]),
)
def _build_compressed_history_view(
self,
session: Session,
) -> list[dict]:
"""Build non-destructive history view using the compressed boundary."""
compressed_until = self._get_compressed_until(session)
if compressed_until <= 0:
return session.get_history(max_messages=0)
notice_msg: dict[str, Any] = {
"role": "assistant",
"content": (
"As your assistant, I have compressed earlier context. "
"If you need details, please check memory/HISTORY.md."
),
}
tail: list[dict[str, Any]] = []
for msg in session.messages[compressed_until:]:
entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")}
for k in ("tool_calls", "tool_call_id", "name"):
if k in msg:
entry[k] = msg[k]
tail.append(entry)
# Drop leading non-user entries from tail to avoid orphan tool blocks.
for i, m in enumerate(tail):
if m.get("role") == "user":
tail = tail[i:]
break
else:
tail = []
return [notice_msg, *tail]
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = self.workspace if self.restrict_to_workspace else None
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(ValidateDeployJSONTool())
self.tools.register(ValidateUsageYAMLTool())
self.tools.register(HuggingFaceModelSearchTool())
self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
@@ -563,24 +186,12 @@ class AgentLoop:
self,
initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None,
) -> tuple[str | None, list[str], list[dict], int, str]:
"""
Run the agent iteration loop.
Returns:
(final_content, tools_used, messages, total_tokens_this_turn, token_source)
total_tokens_this_turn: total tokens (prompt + completion) for this turn
token_source: provider_total / provider_sum / provider_prompt /
provider_counter+tiktoken_completion / tiktoken / none
"""
) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop."""
messages = initial_messages
iteration = 0
final_content = None
tools_used: list[str] = []
total_tokens_this_turn = 0
token_source = "none"
self._last_turn_prompt_tokens = 0
self._last_turn_prompt_source = "none"
while iteration < self.max_iterations:
iteration += 1
@@ -596,63 +207,6 @@ class AgentLoop:
reasoning_effort=self.reasoning_effort,
)
# Prefer provider usage from the turn-ending model call; fallback to tiktoken.
# Calculate total tokens (prompt + completion) for this turn.
usage = response.usage or {}
t_tokens = usage.get("total_tokens")
p_tokens = usage.get("prompt_tokens")
c_tokens = usage.get("completion_tokens")
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
total_tokens_this_turn = int(t_tokens)
token_source = "provider_total"
if isinstance(p_tokens, (int, float)) and p_tokens > 0:
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
elif isinstance(c_tokens, (int, float)):
prompt_derived = int(t_tokens) - int(c_tokens)
if prompt_derived > 0:
self._last_turn_prompt_tokens = prompt_derived
self._last_turn_prompt_source = "usage_total_minus_completion"
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
# If we have both prompt and completion tokens, sum them
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
token_source = "provider_sum"
if p_tokens > 0:
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
total_tokens_this_turn = int(p_tokens)
token_source = "provider_prompt"
self._last_turn_prompt_tokens = int(p_tokens)
self._last_turn_prompt_source = "usage_prompt"
else:
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
estimated_completion = self._estimate_completion_tokens(response.content or "")
total_tokens_this_turn = estimated_prompt + estimated_completion
if estimated_prompt > 0:
self._last_turn_prompt_tokens = int(estimated_prompt)
self._last_turn_prompt_source = str(prompt_source or "tiktoken")
if total_tokens_this_turn > 0:
token_source = (
"tiktoken"
if prompt_source == "tiktoken"
else f"{prompt_source}+tiktoken_completion"
)
if total_tokens_this_turn <= 0:
total_tokens_this_turn = 0
token_source = "none"
logger.debug(
"Turn token usage: source={}, total={}, prompt={}, completion={}",
token_source,
total_tokens_this_turn,
p_tokens if isinstance(p_tokens, (int, float)) else None,
c_tokens if isinstance(c_tokens, (int, float)) else None,
)
if response.has_tool_calls:
if on_progress:
thought = self._strip_think(response.content)
@@ -707,7 +261,7 @@ class AgentLoop:
"without completing the task. You can try breaking the task into smaller steps."
)
return final_content, tools_used, messages, total_tokens_this_turn, token_source
return final_content, tools_used, messages
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
@@ -732,9 +286,6 @@ class AgentLoop:
"""Cancel all active tasks and subagents for the session."""
tasks = self._active_tasks.pop(msg.session_key, [])
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
comp = self._compression_tasks.get(msg.session_key)
if comp is not None and not comp.done() and comp.cancel():
cancelled += 1
for t in tasks:
try:
await t
@@ -781,9 +332,6 @@ class AgentLoop:
def stop(self) -> None:
"""Stop the agent loop."""
self._running = False
for task in list(self._compression_tasks.values()):
if not task.done():
task.cancel()
logger.info("Agent loop stopping")
async def _process_message(
@@ -800,22 +348,17 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = self._build_compressed_history_view(session)
history = session.get_history(max_messages=0)
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
)
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
if self._last_turn_prompt_tokens > 0:
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
else:
session.metadata.pop("_last_prompt_tokens", None)
session.metadata.pop("_last_prompt_source", None)
final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
self._schedule_background_compression(session.key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.")
@@ -829,14 +372,7 @@ class AgentLoop:
cmd = msg.content.strip().lower()
if cmd == "/new":
try:
# 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中
if session.messages:
ok, _ = await self.context.memory.consolidate_chunk(
session.messages,
self.provider,
self.model,
)
if not ok:
if not await self.memory_consolidator.archive_unconsolidated(session):
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
@@ -859,23 +395,20 @@ class AgentLoop:
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
# 正常对话:使用压缩后的历史视图(压缩在回合结束后进行)
history = self._build_compressed_history_view(session)
history = session.get_history(max_messages=0)
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id,
)
# Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:")
if session_key and session_key.startswith("cron:"):
if initial_messages and initial_messages[0].get("role") == "system":
initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}"
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {})
@@ -885,23 +418,16 @@ class AgentLoop:
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
final_content, _, all_msgs, total_tokens_this_turn, token_source = await self._run_agent_loop(
final_content, _, all_msgs = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress,
)
if final_content is None:
final_content = "I've completed processing but have no response to give."
if self._last_turn_prompt_tokens > 0:
session.metadata["_last_prompt_tokens"] = self._last_turn_prompt_tokens
session.metadata["_last_prompt_source"] = self._last_turn_prompt_source
else:
session.metadata.pop("_last_prompt_tokens", None)
session.metadata.pop("_last_prompt_source", None)
self._save_turn(session, all_msgs, 1 + len(history), total_tokens_this_turn)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
self._schedule_background_compression(session.key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
@@ -913,7 +439,7 @@ class AgentLoop:
metadata=msg.metadata or {},
)
def _save_turn(self, session: Session, messages: list[dict], skip: int, total_tokens_this_turn: int = 0) -> None:
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
for m in messages[skip:]:
@@ -948,14 +474,6 @@ class AgentLoop:
session.messages.append(entry)
session.updated_at = datetime.now()
# Update cumulative token count for compression tracking
if total_tokens_this_turn > 0:
current_cumulative = session.metadata.get("_cumulative_tokens", 0)
if isinstance(current_cumulative, (int, float)):
session.metadata["_cumulative_tokens"] = int(current_cumulative) + total_tokens_this_turn
else:
session.metadata["_cumulative_tokens"] = total_tokens_this_turn
async def process_direct(
self,
content: str,

View File

@@ -2,17 +2,19 @@
from __future__ import annotations
import asyncio
import json
import weakref
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Callable
from loguru import logger
from nanobot.utils.helpers import ensure_dir
from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain
if TYPE_CHECKING:
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session
from nanobot.session.manager import Session, SessionManager
_SAVE_MEMORY_TOOL = [
@@ -26,7 +28,7 @@ _SAVE_MEMORY_TOOL = [
"properties": {
"history_entry": {
"type": "string",
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
"description": "A paragraph summarizing key events/decisions/topics. "
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
},
"memory_update": {
@@ -42,6 +44,20 @@ _SAVE_MEMORY_TOOL = [
]
def _ensure_text(value: Any) -> str:
"""Normalize tool-call payload values to text for file storage."""
return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False)
def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
"""Normalize provider tool-call arguments to the expected dict shape."""
if isinstance(args, str):
args = json.loads(args)
if isinstance(args, list):
return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None
class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
@@ -66,29 +82,27 @@ class MemoryStore:
long_term = self.read_long_term()
return f"## Long-term Memory\n{long_term}" if long_term else ""
async def consolidate_chunk(
@staticmethod
def _format_messages(messages: list[dict]) -> str:
lines = []
for message in messages:
if not message.get("content"):
continue
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
lines.append(
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
)
return "\n".join(lines)
async def consolidate(
self,
messages: list[dict],
provider: LLMProvider,
model: str,
) -> tuple[bool, str | None]:
"""Consolidate a chunk of messages into MEMORY.md + HISTORY.md via LLM tool call.
Returns (success, None).
- success: True on success (including no-op), False on failure.
- The second return value is reserved for future use (e.g. RAG-style summaries) and is
always None in the current implementation.
"""
) -> bool:
"""Consolidate the provided message chunk into MEMORY.md + HISTORY.md."""
if not messages:
return True, None
lines = []
for m in messages:
if not m.get("content"):
continue
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
return True
current_memory = self.read_long_term()
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
@@ -97,24 +111,12 @@ class MemoryStore:
{current_memory or "(empty)"}
## Conversation to Process
{chr(10).join(lines)}"""
{self._format_messages(messages)}"""
try:
response = await provider.chat_with_retry(
messages=[
{
"role": "system",
"content": (
"You are a memory consolidation agent.\n"
"Your job is to:\n"
"1) Append a concise but grep-friendly entry to HISTORY.md summarizing key events, decisions and topics.\n"
" - Write 1 paragraph of 25 sentences that starts with [YYYY-MM-DD HH:MM].\n"
" - Include concrete names, IDs and numbers so it is easy to search with grep.\n"
"2) Update long-term MEMORY.md with stable facts and user preferences as markdown, including all existing facts plus new ones.\n"
"3) Optionally return a short context_summary (13 sentences) that will replace the raw messages in future dialogue history.\n\n"
"Always call the save_memory tool with history_entry, memory_update and (optionally) context_summary."
),
},
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},
],
tools=_SAVE_MEMORY_TOOL,
@@ -123,35 +125,160 @@ class MemoryStore:
if not response.has_tool_calls:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
return False, None
return False
args = response.tool_calls[0].arguments
# Some providers return arguments as a JSON string instead of dict
if isinstance(args, str):
args = json.loads(args)
# Some providers return arguments as a list (handle edge case)
if isinstance(args, list):
if args and isinstance(args[0], dict):
args = args[0]
else:
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
return False, None
if not isinstance(args, dict):
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
return False, None
args = _normalize_save_memory_args(response.tool_calls[0].arguments)
if args is None:
logger.warning("Memory consolidation: unexpected save_memory arguments")
return False
if entry := args.get("history_entry"):
if not isinstance(entry, str):
entry = json.dumps(entry, ensure_ascii=False)
self.append_history(entry)
self.append_history(_ensure_text(entry))
if update := args.get("memory_update"):
if not isinstance(update, str):
update = json.dumps(update, ensure_ascii=False)
update = _ensure_text(update)
if update != current_memory:
self.write_long_term(update)
logger.info("Memory consolidation done for {} messages", len(messages))
return True, None
return True
except Exception:
logger.exception("Memory consolidation failed")
return False, None
return False
class MemoryConsolidator:
"""Owns consolidation policy, locking, and session offset updates."""
_MAX_CONSOLIDATION_ROUNDS = 5
def __init__(
self,
workspace: Path,
provider: LLMProvider,
model: str,
sessions: SessionManager,
context_window_tokens: int,
build_messages: Callable[..., list[dict[str, Any]]],
get_tool_definitions: Callable[[], list[dict[str, Any]]],
):
self.store = MemoryStore(workspace)
self.provider = provider
self.model = model
self.sessions = sessions
self.context_window_tokens = context_window_tokens
self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
def get_lock(self, session_key: str) -> asyncio.Lock:
"""Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock())
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
"""Archive a selected message chunk into persistent memory."""
return await self.store.consolidate(messages, self.provider, self.model)
def pick_consolidation_boundary(
self,
session: Session,
tokens_to_remove: int,
) -> tuple[int, int] | None:
"""Pick a user-turn boundary that removes enough old prompt tokens."""
start = session.last_consolidated
if start >= len(session.messages) or tokens_to_remove <= 0:
return None
removed_tokens = 0
last_boundary: tuple[int, int] | None = None
for idx in range(start, len(session.messages)):
message = session.messages[idx]
if idx > start and message.get("role") == "user":
last_boundary = (idx, removed_tokens)
if removed_tokens >= tokens_to_remove:
return last_boundary
removed_tokens += estimate_message_tokens(message)
return last_boundary
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
"""Estimate current prompt size for the normal session history view."""
history = session.get_history(max_messages=0)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
probe_messages = self._build_messages(
history=history,
current_message="[token-probe]",
channel=channel,
chat_id=chat_id,
)
return estimate_prompt_tokens_chain(
self.provider,
self.model,
probe_messages,
self._get_tool_definitions(),
)
async def archive_unconsolidated(self, session: Session) -> bool:
"""Archive the full unconsolidated tail for /new-style session rollover."""
lock = self.get_lock(session.key)
async with lock:
snapshot = session.messages[session.last_consolidated:]
if not snapshot:
return True
return await self.consolidate_messages(snapshot)
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Loop: archive old messages until prompt fits within half the context window."""
if not session.messages or self.context_window_tokens <= 0:
return
lock = self.get_lock(session.key)
async with lock:
target = self.context_window_tokens // 2
estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0:
return
if estimated < self.context_window_tokens:
logger.debug(
"Token consolidation idle {}: {}/{} via {}",
session.key,
estimated,
self.context_window_tokens,
source,
)
return
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
if estimated <= target:
return
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
if boundary is None:
logger.debug(
"Token consolidation: no safe boundary for {} (round {})",
session.key,
round_num,
)
return
end_idx = boundary[0]
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
return
logger.info(
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
round_num,
session.key,
estimated,
self.context_window_tokens,
source,
len(chunk),
)
if not await self.consolidate_messages(chunk):
return
session.last_consolidated = end_idx
self.sessions.save(session)
estimated, source = self.estimate_session_prompt_tokens(session)
if estimated <= 0:
return

View File

@@ -191,6 +191,8 @@ def onboard():
save_config(Config())
console.print(f"[green]✓[/green] Created config at {config_path}")
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
# Create workspace
workspace = get_workspace_path()
@@ -283,6 +285,16 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
return loaded
def _print_deprecated_memory_window_notice(config: Config) -> None:
"""Warn when running with old memoryWindow-only config."""
if config.agents.defaults.should_warn_deprecated_memory_window:
console.print(
"[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
"`contextWindowTokens`. `memoryWindow` is ignored; run "
"[cyan]nanobot onboard[/cyan] to refresh your config template."
)
# ============================================================================
# Gateway / Server
# ============================================================================
@@ -310,6 +322,7 @@ def gateway(
logging.basicConfig(level=logging.DEBUG)
config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
port = port if port is not None else config.gateway.port
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
@@ -329,12 +342,10 @@ def gateway(
workspace=config.workspace_path,
model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens_output,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
reasoning_effort=config.agents.defaults.reasoning_effort,
max_tokens_input=config.agents.defaults.max_tokens_input,
compression_start_ratio=config.agents.defaults.compression_start_ratio,
compression_target_ratio=config.agents.defaults.compression_target_ratio,
context_window_tokens=config.agents.defaults.context_window_tokens,
brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
@@ -496,6 +507,7 @@ def agent(
from nanobot.cron.service import CronService
config = _load_runtime_config(config, workspace)
_print_deprecated_memory_window_notice(config)
sync_workspace_templates(config.workspace_path)
bus = MessageBus()
@@ -516,12 +528,10 @@ def agent(
workspace=config.workspace_path,
model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens_output,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
reasoning_effort=config.agents.defaults.reasoning_effort,
max_tokens_input=config.agents.defaults.max_tokens_input,
compression_start_ratio=config.agents.defaults.compression_start_ratio,
compression_target_ratio=config.agents.defaults.compression_target_ratio,
context_window_tokens=config.agents.defaults.context_window_tokens,
brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,

View File

@@ -190,22 +190,11 @@ class SlackConfig(Base):
class QQConfig(Base):
"""QQ channel configuration.
Supports two implementations:
1. Official botpy SDK: requires app_id and secret
2. OneBot protocol: requires api_url (and optionally ws_reverse_url, bot_qq, access_token)
"""
"""QQ channel configuration using botpy SDK."""
enabled: bool = False
# Official botpy SDK fields
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
# OneBot protocol fields
api_url: str = "" # OneBot HTTP API URL (e.g. "http://localhost:5700")
ws_reverse_url: str = "" # OneBot WebSocket reverse URL (e.g. "ws://localhost:8080/ws/reverse")
bot_qq: int | None = None # Bot's QQ number (for filtering self messages)
access_token: str = "" # Optional access token for OneBot API
allow_from: list[str] = Field(
default_factory=list
) # Allowed user openids (empty = public access)
@@ -238,20 +227,19 @@ class AgentDefaults(Base):
provider: str = (
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
)
# 原生上下文最大窗口(通常对应模型的 max_input_tokens / max_context_tokens
# 默认按照主流大模型(如 GPT-4o、Claude 3.x 等)的 128k 上下文给一个宽松上限,实际应根据所选模型文档手动调整。
max_tokens_input: int = 128_000
# 默认单次回复的最大输出 token 上限(调用时可按需要再做截断或比例分配)
# 8192 足以覆盖大多数实际对话/工具使用场景,同样可按需手动调整。
max_tokens_output: int = 8192
# 会话历史压缩触发比例:当估算的输入 token 使用量 >= maxTokensInput * compressionStartRatio 时开始压缩。
compression_start_ratio: float = 0.7
# 会话历史压缩目标比例:每轮压缩后尽量把估算的输入 token 使用量压到 maxTokensInput * compressionTargetRatio 附近。
compression_target_ratio: float = 0.4
max_tokens: int = 8192
context_window_tokens: int = 65_536
temperature: float = 0.1
max_tool_iterations: int = 40
# Deprecated compatibility field: accepted from old configs but ignored at runtime.
memory_window: int | None = Field(default=None, exclude=True)
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
@property
def should_warn_deprecated_memory_window(self) -> bool:
"""Return True when old memoryWindow is present without contextWindowTokens."""
return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
class AgentsConfig(Base):
"""Agent configuration."""

View File

@@ -9,6 +9,7 @@ from typing import Any
from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ensure_dir, safe_filename
@@ -29,6 +30,7 @@ class Session:
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
metadata: dict[str, Any] = field(default_factory=dict)
last_consolidated: int = 0 # Number of messages already consolidated to files
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
"""Add a message to the session."""
@@ -42,13 +44,9 @@ class Session:
self.updated_at = datetime.now()
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""
Return messages for LLM input, aligned to a user turn.
- max_messages > 0 时只保留最近 max_messages 条;
- max_messages <= 0 时不做条数截断,返回全部消息。
"""
sliced = self.messages if max_messages <= 0 else self.messages[-max_messages:]
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
# Drop leading non-user messages to avoid orphaned tool_result blocks
for i, m in enumerate(sliced):
@@ -68,7 +66,7 @@ class Session:
def clear(self) -> None:
"""Clear all messages and reset session to initial state."""
self.messages = []
self.metadata = {}
self.last_consolidated = 0
self.updated_at = datetime.now()
@@ -82,7 +80,7 @@ class SessionManager:
def __init__(self, workspace: Path):
self.workspace = workspace
self.sessions_dir = ensure_dir(self.workspace / "sessions")
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
self.legacy_sessions_dir = get_legacy_sessions_dir()
self._cache: dict[str, Session] = {}
def _get_session_path(self, key: str) -> Path:
@@ -134,6 +132,7 @@ class SessionManager:
messages = []
metadata = {}
created_at = None
last_consolidated = 0
with open(path, encoding="utf-8") as f:
for line in f:
@@ -146,6 +145,7 @@ class SessionManager:
if data.get("_type") == "metadata":
metadata = data.get("metadata", {})
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
last_consolidated = data.get("last_consolidated", 0)
else:
messages.append(data)
@@ -154,6 +154,7 @@ class SessionManager:
messages=messages,
created_at=created_at or datetime.now(),
metadata=metadata,
last_consolidated=last_consolidated
)
except Exception as e:
logger.warning("Failed to load session {}: {}", key, e)
@@ -170,6 +171,7 @@ class SessionManager:
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"metadata": session.metadata,
"last_consolidated": session.last_consolidated
}
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
for msg in session.messages:

View File

@@ -1,8 +1,12 @@
"""Utility functions for nanobot."""
import json
import re
from datetime import datetime
from pathlib import Path
from typing import Any
import tiktoken
def detect_image_mime(data: bytes) -> str | None:
@@ -68,6 +72,87 @@ def split_message(content: str, max_len: int = 2000) -> list[str]:
return chunks
def estimate_prompt_tokens(
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> int:
"""Estimate prompt tokens with tiktoken."""
try:
enc = tiktoken.get_encoding("cl100k_base")
parts: list[str] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
txt = part.get("text", "")
if txt:
parts.append(txt)
if tools:
parts.append(json.dumps(tools, ensure_ascii=False))
return len(enc.encode("\n".join(parts)))
except Exception:
return 0
def estimate_message_tokens(message: dict[str, Any]) -> int:
"""Estimate prompt tokens contributed by one persisted message."""
content = message.get("content")
parts: list[str] = []
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
text = part.get("text", "")
if text:
parts.append(text)
else:
parts.append(json.dumps(part, ensure_ascii=False))
elif content is not None:
parts.append(json.dumps(content, ensure_ascii=False))
for key in ("name", "tool_call_id"):
value = message.get(key)
if isinstance(value, str) and value:
parts.append(value)
if message.get("tool_calls"):
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
payload = "\n".join(parts)
if not payload:
return 1
try:
enc = tiktoken.get_encoding("cl100k_base")
return max(1, len(enc.encode(payload)))
except Exception:
return max(1, len(payload) // 4)
def estimate_prompt_tokens_chain(
provider: Any,
model: str | None,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> tuple[int, str]:
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
if callable(provider_counter):
try:
tokens, source = provider_counter(messages, tools, model)
if isinstance(tokens, (int, float)) and tokens > 0:
return int(tokens), str(source or "provider_counter")
except Exception:
pass
estimated = estimate_prompt_tokens(messages, tools)
if estimated > 0:
return int(estimated), "tiktoken"
return 0, "none"
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
"""Sync bundled templates to workspace. Only creates missing files."""
from importlib.resources import files as pkg_files

View File

@@ -44,6 +44,7 @@ dependencies = [
"json-repair>=0.57.0,<1.0.0",
"chardet>=3.0.2,<6.0.0",
"openai>=2.8.0",
"tiktoken>=0.12.0,<1.0.0",
]
[project.optional-dependencies]

View File

@@ -267,6 +267,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
mock_agent_runtime["config"].agents.defaults.memory_window = 100
result = runner.invoke(app, ["agent", "-m", "hello"])
assert result.exit_code == 0
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
@@ -327,6 +337,29 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
assert seen["workspace"] == override
assert config.workspace_path == override
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.memory_window = 100
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)

View File

@@ -0,0 +1,88 @@
import json
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.config.loader import load_config, save_config
runner = CliRunner()
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 1234,
"memoryWindow": 42,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536
assert config.agents.defaults.should_warn_deprecated_memory_window is True
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 2222,
"memoryWindow": 30,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 2222
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 3333,
"memoryWindow": 50,
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
assert "contextWindowTokens" in result.stdout
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 3333
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults

View File

@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard:
"""Test that consolidation tasks are deduplicated and serialized."""
class TestNewCommandArchival:
"""Test /new archival behavior with the simplified consolidation flow."""
@pytest.mark.asyncio
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
"""Concurrent messages above memory_window spawn only one consolidation task."""
@staticmethod
def _make_loop(tmp_path: Path):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (10_000, "test")
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=1,
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
return loop
@pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
if archive_all:
async def _failing_consolidate(_messages) -> bool:
return False
return True
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
@pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
self, tmp_path: Path
) -> None:
"""/new should archive only messages not yet consolidated by prior task."""
from nanobot.agent.loop import AgentLoop
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
session.last_consolidated = len(session.messages) - 3
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
async def _fake_consolidate(messages) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
archived_count = len(messages)
return True
started.set()
await release.wait()
sess.last_consolidated = len(sess.messages) - 3
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done()
release.set()
response = await pending_new
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count == 3, (
f"Expected only unconsolidated tail to archive, got {archived_count}"
)
assert archived_count == 3
@pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
"""/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
async def _ok_consolidate(_messages) -> bool:
return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)

View File

@@ -0,0 +1,190 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
import nanobot.agent.memory as memory_module
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=context_window_tokens,
)
loop.tools.get_definitions = MagicMock(return_value=[])
return loop
@pytest.mark.asyncio
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
@pytest.mark.asyncio
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
await loop.process_direct("hello", session_key="cli:test")
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
@pytest.mark.asyncio
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
]
loop.sessions.save(session)
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
assert session.last_consolidated == 4
@pytest.mark.asyncio
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (300, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
"""Once triggered, consolidation should continue until it drops below half threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (150, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
"""Verify preflight consolidation runs before the LLM call in process_direct."""
order: list[str] = []
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages):
order.append("consolidate")
return True
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
async def track_llm(*args, **kwargs):
order.append("llm")
return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
return (1000 if call_count[0] <= 1 else 80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
assert "consolidate" in order
assert "llm" in order
assert order.index("consolidate") < order.index("llm")

View File

@@ -7,7 +7,7 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import pytest
@@ -15,15 +15,12 @@ from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50):
"""Create a mock session with messages."""
session = MagicMock()
session.messages = [
def _make_messages(message_count: int = 30):
"""Create a list of mock messages."""
return [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count)
]
session.last_consolidated = 0
return session
def _make_tool_response(history_entry, memory_update):
@@ -74,9 +71,9 @@ class TestMemoryConsolidationTypeHandling:
)
)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
@@ -95,9 +92,9 @@ class TestMemoryConsolidationTypeHandling:
)
)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
@@ -131,9 +128,9 @@ class TestMemoryConsolidationTypeHandling:
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@@ -147,22 +144,22 @@ class TestMemoryConsolidationTypeHandling:
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when messages < keep_count."""
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when the selected chunk is empty."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = provider.chat
session = _make_session(message_count=10)
messages: list[dict] = []
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat.assert_not_called()
@@ -189,9 +186,9 @@ class TestMemoryConsolidationTypeHandling:
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@@ -215,9 +212,9 @@ class TestMemoryConsolidationTypeHandling:
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@@ -239,9 +236,9 @@ class TestMemoryConsolidationTypeHandling:
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@@ -255,7 +252,7 @@ class TestMemoryConsolidationTypeHandling:
memory_update="# Memory\nUser likes testing.",
),
])
session = _make_session(message_count=60)
messages = _make_messages(message_count=60)
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
@@ -263,7 +260,7 @@ class TestMemoryConsolidationTypeHandling:
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert provider.calls == 2

View File

@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
class TestMessageToolSuppressLogic:
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
@pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")