refactor: implement token-based context compression mechanism
Major changes: - Replace message-count-based memory window with token-budget-based compression - Add max_tokens_input, compression_start_ratio, compression_target_ratio config - Implement _maybe_compress_history() that triggers based on prompt token usage - Use _build_compressed_history_view() to provide compressed history to LLM - Refactor MemoryStore.consolidate() -> consolidate_chunk() for chunk-based compression - Remove last_consolidated from Session, use _compressed_until metadata instead - Add background compression scheduling to avoid blocking message processing Key improvements: - Compression now based on actual token usage, not arbitrary message counts - Better handling of long conversations with large context windows - Non-destructive compression: old messages remain in session, but excluded from prompt - Automatic compression when history exceeds configured token thresholds
This commit is contained in:
@@ -5,19 +5,24 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import weakref
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tiktoken # type: ignore
|
||||||
|
except Exception: # pragma: no cover - optional dependency
|
||||||
|
tiktoken = None
|
||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.memory import MemoryStore
|
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
|
from nanobot.agent.tools.huggingface import HuggingFaceModelSearchTool
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.agent.tools.model_config import ValidateDeployJSONTool, ValidateUsageYAMLTool
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
from nanobot.agent.tools.spawn import SpawnTool
|
from nanobot.agent.tools.spawn import SpawnTool
|
||||||
@@ -55,8 +60,11 @@ class AgentLoop:
|
|||||||
max_iterations: int = 40,
|
max_iterations: int = 40,
|
||||||
temperature: float = 0.1,
|
temperature: float = 0.1,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
memory_window: int = 100,
|
memory_window: int | None = None, # backward-compat only (unused)
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
|
max_tokens_input: int = 128_000,
|
||||||
|
compression_start_ratio: float = 0.7,
|
||||||
|
compression_target_ratio: float = 0.4,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
@@ -74,9 +82,18 @@ class AgentLoop:
|
|||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
# max_tokens: per-call output token cap (maxTokensOutput in config)
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
# Keep legacy attribute for older call sites/tests; compression no longer uses it.
|
||||||
self.memory_window = memory_window
|
self.memory_window = memory_window
|
||||||
self.reasoning_effort = reasoning_effort
|
self.reasoning_effort = reasoning_effort
|
||||||
|
# max_tokens_input: model native context window (maxTokensInput in config)
|
||||||
|
self.max_tokens_input = max_tokens_input
|
||||||
|
# Token-based compression watermarks (fractions of available input budget)
|
||||||
|
self.compression_start_ratio = compression_start_ratio
|
||||||
|
self.compression_target_ratio = compression_target_ratio
|
||||||
|
# Reserve tokens for safety margin
|
||||||
|
self._reserve_tokens = 1000
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@@ -105,18 +122,373 @@ class AgentLoop:
|
|||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stack: AsyncExitStack | None = None
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
|
||||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
|
||||||
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
|
self._compression_tasks: dict[str, asyncio.Task] = {} # session_key -> task
|
||||||
self._processing_lock = asyncio.Lock()
|
self._processing_lock = asyncio.Lock()
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _estimate_prompt_tokens(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Estimate prompt tokens with tiktoken (fallback only)."""
|
||||||
|
if tiktoken is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
parts: list[str] = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
txt = part.get("text", "")
|
||||||
|
if txt:
|
||||||
|
parts.append(txt)
|
||||||
|
if tools:
|
||||||
|
parts.append(json.dumps(tools, ensure_ascii=False))
|
||||||
|
return len(enc.encode("\n".join(parts)))
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _estimate_prompt_tokens_chain(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> tuple[int, str]:
|
||||||
|
"""Unified prompt-token estimation: provider counter -> tiktoken."""
|
||||||
|
provider_counter = getattr(self.provider, "estimate_prompt_tokens", None)
|
||||||
|
if callable(provider_counter):
|
||||||
|
try:
|
||||||
|
tokens, source = provider_counter(messages, tools, self.model)
|
||||||
|
if isinstance(tokens, (int, float)) and tokens > 0:
|
||||||
|
return int(tokens), str(source or "provider_counter")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Provider token counter failed; fallback to tiktoken")
|
||||||
|
|
||||||
|
estimated = self._estimate_prompt_tokens(messages, tools)
|
||||||
|
if estimated > 0:
|
||||||
|
return int(estimated), "tiktoken"
|
||||||
|
return 0, "none"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _estimate_completion_tokens(content: str) -> int:
|
||||||
|
"""Estimate completion tokens with tiktoken (fallback only)."""
|
||||||
|
if tiktoken is None:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return len(enc.encode(content or ""))
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _get_compressed_until(self, session: Session) -> int:
|
||||||
|
"""Read/normalize compressed boundary and migrate old metadata format."""
|
||||||
|
raw = session.metadata.get("_compressed_until", 0)
|
||||||
|
try:
|
||||||
|
compressed_until = int(raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
compressed_until = 0
|
||||||
|
|
||||||
|
if compressed_until <= 0:
|
||||||
|
ranges = session.metadata.get("_compressed_ranges")
|
||||||
|
if isinstance(ranges, list):
|
||||||
|
inferred = 0
|
||||||
|
for item in ranges:
|
||||||
|
if not isinstance(item, (list, tuple)) or len(item) != 2:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
inferred = max(inferred, int(item[1]))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
compressed_until = inferred
|
||||||
|
|
||||||
|
compressed_until = max(0, min(compressed_until, len(session.messages)))
|
||||||
|
session.metadata["_compressed_until"] = compressed_until
|
||||||
|
# 兼容旧版本:一旦迁移出连续边界,就可以清理旧字段
|
||||||
|
session.metadata.pop("_compressed_ranges", None)
|
||||||
|
session.metadata.pop("_cumulative_tokens", None)
|
||||||
|
return compressed_until
|
||||||
|
|
||||||
|
def _set_compressed_until(self, session: Session, idx: int) -> None:
|
||||||
|
"""Persist a contiguous compressed boundary."""
|
||||||
|
session.metadata["_compressed_until"] = max(0, min(int(idx), len(session.messages)))
|
||||||
|
session.metadata.pop("_compressed_ranges", None)
|
||||||
|
session.metadata.pop("_cumulative_tokens", None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _estimate_message_tokens(message: dict[str, Any]) -> int:
|
||||||
|
"""Rough token estimate for a single persisted message."""
|
||||||
|
content = message.get("content")
|
||||||
|
parts: list[str] = []
|
||||||
|
if isinstance(content, str):
|
||||||
|
parts.append(content)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
txt = part.get("text", "")
|
||||||
|
if txt:
|
||||||
|
parts.append(txt)
|
||||||
|
else:
|
||||||
|
parts.append(json.dumps(part, ensure_ascii=False))
|
||||||
|
elif content is not None:
|
||||||
|
parts.append(json.dumps(content, ensure_ascii=False))
|
||||||
|
|
||||||
|
for key in ("name", "tool_call_id"):
|
||||||
|
val = message.get(key)
|
||||||
|
if isinstance(val, str) and val:
|
||||||
|
parts.append(val)
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
||||||
|
|
||||||
|
payload = "\n".join(parts)
|
||||||
|
if not payload:
|
||||||
|
return 1
|
||||||
|
if tiktoken is not None:
|
||||||
|
try:
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return max(1, len(enc.encode(payload)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return max(1, len(payload) // 4)
|
||||||
|
|
||||||
|
def _pick_compression_chunk_by_tokens(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
reduction_tokens: int,
|
||||||
|
*,
|
||||||
|
tail_keep: int = 12,
|
||||||
|
) -> tuple[int, int, int] | None:
|
||||||
|
"""
|
||||||
|
Pick one contiguous old chunk so its estimated size is roughly enough
|
||||||
|
to reduce `reduction_tokens`.
|
||||||
|
"""
|
||||||
|
messages = session.messages
|
||||||
|
start = self._get_compressed_until(session)
|
||||||
|
if len(messages) - start <= tail_keep + 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
end_limit = len(messages) - tail_keep
|
||||||
|
if end_limit - start < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target = max(1, reduction_tokens)
|
||||||
|
end = start
|
||||||
|
collected = 0
|
||||||
|
while end < end_limit and collected < target:
|
||||||
|
collected += self._estimate_message_tokens(messages[end])
|
||||||
|
end += 1
|
||||||
|
|
||||||
|
if end - start < 2:
|
||||||
|
end = min(end_limit, start + 2)
|
||||||
|
collected = sum(self._estimate_message_tokens(m) for m in messages[start:end])
|
||||||
|
if end - start < 2:
|
||||||
|
return None
|
||||||
|
return start, end, collected
|
||||||
|
|
||||||
|
def _estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||||
|
"""
|
||||||
|
Estimate current full prompt tokens for this session view
|
||||||
|
(system + compressed history view + runtime/user placeholder + tools).
|
||||||
|
"""
|
||||||
|
history = self._build_compressed_history_view(session)
|
||||||
|
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||||
|
probe_messages = self.context.build_messages(
|
||||||
|
history=history,
|
||||||
|
current_message="[token-probe]",
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
)
|
||||||
|
return self._estimate_prompt_tokens_chain(probe_messages, self.tools.get_definitions())
|
||||||
|
|
||||||
|
async def _maybe_compress_history(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
End-of-turn policy:
|
||||||
|
- Estimate current prompt usage from persisted session view.
|
||||||
|
- If above start ratio, perform one best-effort compression chunk.
|
||||||
|
"""
|
||||||
|
if not session.messages:
|
||||||
|
self._set_compressed_until(session, 0)
|
||||||
|
return
|
||||||
|
|
||||||
|
budget = max(1, self.max_tokens_input - self.max_tokens - self._reserve_tokens)
|
||||||
|
start_threshold = int(budget * self.compression_start_ratio)
|
||||||
|
target_threshold = int(budget * self.compression_target_ratio)
|
||||||
|
if target_threshold >= start_threshold:
|
||||||
|
target_threshold = max(0, start_threshold - 1)
|
||||||
|
|
||||||
|
current_tokens, token_source = self._estimate_session_prompt_tokens(session)
|
||||||
|
current_ratio = current_tokens / budget if budget else 0.0
|
||||||
|
if current_tokens <= 0:
|
||||||
|
logger.debug("Compression skip {}: token estimate unavailable", session.key)
|
||||||
|
return
|
||||||
|
if current_tokens < start_threshold:
|
||||||
|
logger.debug(
|
||||||
|
"Compression idle {}: {}/{} ({:.1%}) via {}",
|
||||||
|
session.key,
|
||||||
|
current_tokens,
|
||||||
|
budget,
|
||||||
|
current_ratio,
|
||||||
|
token_source,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"Compression trigger {}: {}/{} ({:.1%}) via {}",
|
||||||
|
session.key,
|
||||||
|
current_tokens,
|
||||||
|
budget,
|
||||||
|
current_ratio,
|
||||||
|
token_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
reduction_by_target = max(0, current_tokens - target_threshold)
|
||||||
|
reduction_by_delta = max(1, start_threshold - target_threshold)
|
||||||
|
reduction_need = max(reduction_by_target, reduction_by_delta)
|
||||||
|
|
||||||
|
chunk_range = self._pick_compression_chunk_by_tokens(session, reduction_need, tail_keep=10)
|
||||||
|
if chunk_range is None:
|
||||||
|
logger.info("Compression skipped for {}: no compressible chunk", session.key)
|
||||||
|
return
|
||||||
|
|
||||||
|
start_idx, end_idx, estimated_chunk_tokens = chunk_range
|
||||||
|
chunk = session.messages[start_idx:end_idx]
|
||||||
|
if len(chunk) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Compression chunk {}: msgs {}-{} (count={}, est~{}, need~{})",
|
||||||
|
session.key,
|
||||||
|
start_idx,
|
||||||
|
end_idx - 1,
|
||||||
|
len(chunk),
|
||||||
|
estimated_chunk_tokens,
|
||||||
|
reduction_need,
|
||||||
|
)
|
||||||
|
success, _ = await self.context.memory.consolidate_chunk(
|
||||||
|
chunk,
|
||||||
|
self.provider,
|
||||||
|
self.model,
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
logger.warning("Compression aborted for {}: consolidation failed", session.key)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._set_compressed_until(session, end_idx)
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
after_tokens, after_source = self._estimate_session_prompt_tokens(session)
|
||||||
|
after_ratio = after_tokens / budget if budget else 0.0
|
||||||
|
reduced = max(0, current_tokens - after_tokens)
|
||||||
|
reduced_ratio = (reduced / current_tokens) if current_tokens > 0 else 0.0
|
||||||
|
logger.info(
|
||||||
|
"Compression done {}: {}/{} ({:.1%}) via {}, reduced={} ({:.1%})",
|
||||||
|
session.key,
|
||||||
|
after_tokens,
|
||||||
|
budget,
|
||||||
|
after_ratio,
|
||||||
|
after_source,
|
||||||
|
reduced,
|
||||||
|
reduced_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _schedule_background_compression(self, session_key: str) -> None:
|
||||||
|
"""Schedule best-effort background compression for a session."""
|
||||||
|
existing = self._compression_tasks.get(session_key)
|
||||||
|
if existing is not None and not existing.done():
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _runner() -> None:
|
||||||
|
session = self.sessions.get_or_create(session_key)
|
||||||
|
try:
|
||||||
|
await self._maybe_compress_history(session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background compression failed for {}", session_key)
|
||||||
|
|
||||||
|
task = asyncio.create_task(_runner())
|
||||||
|
self._compression_tasks[session_key] = task
|
||||||
|
|
||||||
|
def _cleanup(t: asyncio.Task) -> None:
|
||||||
|
cur = self._compression_tasks.get(session_key)
|
||||||
|
if cur is t:
|
||||||
|
self._compression_tasks.pop(session_key, None)
|
||||||
|
try:
|
||||||
|
t.result()
|
||||||
|
except BaseException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
task.add_done_callback(_cleanup)
|
||||||
|
|
||||||
|
async def wait_for_background_compression(self, timeout_s: float | None = None) -> None:
|
||||||
|
"""Wait for currently scheduled compression tasks."""
|
||||||
|
pending = [t for t in self._compression_tasks.values() if not t.done()]
|
||||||
|
if not pending:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Waiting for {} background compression task(s)", len(pending))
|
||||||
|
waiter = asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
if timeout_s is None:
|
||||||
|
await waiter
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(waiter, timeout=timeout_s)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Background compression wait timed out after {}s ({} task(s) still running)",
|
||||||
|
timeout_s,
|
||||||
|
len([t for t in self._compression_tasks.values() if not t.done()]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_compressed_history_view(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Build non-destructive history view using the compressed boundary."""
|
||||||
|
compressed_until = self._get_compressed_until(session)
|
||||||
|
if compressed_until <= 0:
|
||||||
|
return session.get_history(max_messages=0)
|
||||||
|
|
||||||
|
notice_msg: dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": (
|
||||||
|
"As your assistant, I have compressed earlier context. "
|
||||||
|
"If you need details, please check memory/HISTORY.md."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
tail: list[dict[str, Any]] = []
|
||||||
|
for msg in session.messages[compressed_until:]:
|
||||||
|
entry: dict[str, Any] = {"role": msg["role"], "content": msg.get("content", "")}
|
||||||
|
for k in ("tool_calls", "tool_call_id", "name"):
|
||||||
|
if k in msg:
|
||||||
|
entry[k] = msg[k]
|
||||||
|
tail.append(entry)
|
||||||
|
|
||||||
|
# Drop leading non-user entries from tail to avoid orphan tool blocks.
|
||||||
|
for i, m in enumerate(tail):
|
||||||
|
if m.get("role") == "user":
|
||||||
|
tail = tail[i:]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tail = []
|
||||||
|
|
||||||
|
return [notice_msg, *tail]
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
"""Register the default set of tools."""
|
"""Register the default set of tools."""
|
||||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||||
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
|
self.tools.register(ValidateDeployJSONTool())
|
||||||
|
self.tools.register(ValidateUsageYAMLTool())
|
||||||
|
self.tools.register(HuggingFaceModelSearchTool())
|
||||||
self.tools.register(ExecTool(
|
self.tools.register(ExecTool(
|
||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
timeout=self.exec_config.timeout,
|
timeout=self.exec_config.timeout,
|
||||||
@@ -181,25 +553,78 @@ class AgentLoop:
|
|||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
) -> tuple[str | None, list[str], list[dict], int, str]:
|
||||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
"""
|
||||||
|
Run the agent iteration loop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(final_content, tools_used, messages, total_tokens_this_turn, token_source)
|
||||||
|
total_tokens_this_turn: total tokens (prompt + completion) for this turn
|
||||||
|
token_source: provider_total / provider_sum / provider_prompt /
|
||||||
|
provider_counter+tiktoken_completion / tiktoken / none
|
||||||
|
"""
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
iteration = 0
|
iteration = 0
|
||||||
final_content = None
|
final_content = None
|
||||||
tools_used: list[str] = []
|
tools_used: list[str] = []
|
||||||
|
total_tokens_this_turn = 0
|
||||||
|
token_source = "none"
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
while iteration < self.max_iterations:
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
|
tool_defs = self.tools.get_definitions()
|
||||||
|
|
||||||
response = await self.provider.chat(
|
response = await self.provider.chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.tools.get_definitions(),
|
tools=tool_defs,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
reasoning_effort=self.reasoning_effort,
|
reasoning_effort=self.reasoning_effort,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefer provider usage from the turn-ending model call; fallback to tiktoken.
|
||||||
|
# Calculate total tokens (prompt + completion) for this turn.
|
||||||
|
usage = response.usage or {}
|
||||||
|
t_tokens = usage.get("total_tokens")
|
||||||
|
p_tokens = usage.get("prompt_tokens")
|
||||||
|
c_tokens = usage.get("completion_tokens")
|
||||||
|
|
||||||
|
if isinstance(t_tokens, (int, float)) and t_tokens > 0:
|
||||||
|
total_tokens_this_turn = int(t_tokens)
|
||||||
|
token_source = "provider_total"
|
||||||
|
elif isinstance(p_tokens, (int, float)) and isinstance(c_tokens, (int, float)):
|
||||||
|
# If we have both prompt and completion tokens, sum them
|
||||||
|
total_tokens_this_turn = int(p_tokens) + int(c_tokens)
|
||||||
|
token_source = "provider_sum"
|
||||||
|
elif isinstance(p_tokens, (int, float)) and p_tokens > 0:
|
||||||
|
# Fallback: use prompt tokens only (completion might be 0 for tool calls)
|
||||||
|
total_tokens_this_turn = int(p_tokens)
|
||||||
|
token_source = "provider_prompt"
|
||||||
|
else:
|
||||||
|
# Estimate with unified chain (provider counter -> tiktoken), plus completion tiktoken.
|
||||||
|
estimated_prompt, prompt_source = self._estimate_prompt_tokens_chain(messages, tool_defs)
|
||||||
|
estimated_completion = self._estimate_completion_tokens(response.content or "")
|
||||||
|
total_tokens_this_turn = estimated_prompt + estimated_completion
|
||||||
|
if total_tokens_this_turn > 0:
|
||||||
|
token_source = (
|
||||||
|
"tiktoken"
|
||||||
|
if prompt_source == "tiktoken"
|
||||||
|
else f"{prompt_source}+tiktoken_completion"
|
||||||
|
)
|
||||||
|
if total_tokens_this_turn <= 0:
|
||||||
|
total_tokens_this_turn = 0
|
||||||
|
token_source = "none"
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Turn token usage: source={}, total={}, prompt={}, completion={}",
|
||||||
|
token_source,
|
||||||
|
total_tokens_this_turn,
|
||||||
|
p_tokens if isinstance(p_tokens, (int, float)) else None,
|
||||||
|
c_tokens if isinstance(c_tokens, (int, float)) else None,
|
||||||
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
if on_progress:
|
if on_progress:
|
||||||
thought = self._strip_think(response.content)
|
thought = self._strip_think(response.content)
|
||||||
@@ -254,7 +679,7 @@ class AgentLoop:
|
|||||||
"without completing the task. You can try breaking the task into smaller steps."
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
return final_content, tools_used, messages
|
return final_content, tools_used, messages, total_tokens_this_turn, token_source
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||||
@@ -279,6 +704,9 @@ class AgentLoop:
|
|||||||
"""Cancel all active tasks and subagents for the session."""
|
"""Cancel all active tasks and subagents for the session."""
|
||||||
tasks = self._active_tasks.pop(msg.session_key, [])
|
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||||
|
comp = self._compression_tasks.get(msg.session_key)
|
||||||
|
if comp is not None and not comp.done() and comp.cancel():
|
||||||
|
cancelled += 1
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
try:
|
try:
|
||||||
await t
|
await t
|
||||||
@@ -325,6 +753,9 @@ class AgentLoop:
|
|||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the agent loop."""
|
"""Stop the agent loop."""
|
||||||
self._running = False
|
self._running = False
|
||||||
|
for task in list(self._compression_tasks.values()):
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
logger.info("Agent loop stopping")
|
logger.info("Agent loop stopping")
|
||||||
|
|
||||||
async def _process_message(
|
async def _process_message(
|
||||||
@@ -342,14 +773,15 @@ class AgentLoop:
|
|||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=self.memory_window)
|
history = self._build_compressed_history_view(session)
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
final_content, _, all_msgs, _, _ = await self._run_agent_loop(messages)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
self._schedule_background_compression(session.key)
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
content=final_content or "Background task completed.")
|
content=final_content or "Background task completed.")
|
||||||
|
|
||||||
@@ -362,27 +794,27 @@ class AgentLoop:
|
|||||||
# Slash commands
|
# Slash commands
|
||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
|
||||||
self._consolidating.add(session.key)
|
|
||||||
try:
|
try:
|
||||||
async with lock:
|
# 在清空会话前,将当前完整对话做一次归档压缩到 MEMORY/HISTORY 中
|
||||||
snapshot = session.messages[session.last_consolidated:]
|
if session.messages:
|
||||||
if snapshot:
|
ok, _ = await self.context.memory.consolidate_chunk(
|
||||||
temp = Session(key=session.key)
|
session.messages,
|
||||||
temp.messages = list(snapshot)
|
self.provider,
|
||||||
if not await self._consolidate_memory(temp, archive_all=True):
|
self.model,
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
content="Memory archival failed, session not cleared. Please try again.",
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("/new archival failed for {}", session.key)
|
logger.exception("/new archival failed for {}", session.key)
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
content="Memory archival failed, session not cleared. Please try again.",
|
content="Memory archival failed, session not cleared. Please try again.",
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
self._consolidating.discard(session.key)
|
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
@@ -393,36 +825,23 @@ class AgentLoop:
|
|||||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||||
|
|
||||||
unconsolidated = len(session.messages) - session.last_consolidated
|
|
||||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
|
||||||
self._consolidating.add(session.key)
|
|
||||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
|
||||||
|
|
||||||
async def _consolidate_and_unlock():
|
|
||||||
try:
|
|
||||||
async with lock:
|
|
||||||
await self._consolidate_memory(session)
|
|
||||||
finally:
|
|
||||||
self._consolidating.discard(session.key)
|
|
||||||
_task = asyncio.current_task()
|
|
||||||
if _task is not None:
|
|
||||||
self._consolidation_tasks.discard(_task)
|
|
||||||
|
|
||||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
|
||||||
self._consolidation_tasks.add(_task)
|
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
if isinstance(message_tool, MessageTool):
|
if isinstance(message_tool, MessageTool):
|
||||||
message_tool.start_turn()
|
message_tool.start_turn()
|
||||||
|
|
||||||
history = session.get_history(max_messages=self.memory_window)
|
# 正常对话:使用压缩后的历史视图(压缩在回合结束后进行)
|
||||||
|
history = self._build_compressed_history_view(session)
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
)
|
)
|
||||||
|
# Add [CRON JOB] identifier for cron sessions (session_key starts with "cron:")
|
||||||
|
if session_key and session_key.startswith("cron:"):
|
||||||
|
if initial_messages and initial_messages[0].get("role") == "system":
|
||||||
|
initial_messages[0]["content"] = f"[CRON JOB] {initial_messages[0]['content']}"
|
||||||
|
|
||||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
@@ -432,7 +851,7 @@ class AgentLoop:
|
|||||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||||
))
|
))
|
||||||
|
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||||
initial_messages, on_progress=on_progress or _bus_progress,
|
initial_messages, on_progress=on_progress or _bus_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -441,6 +860,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
self._schedule_background_compression(session.key)
|
||||||
|
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
return None
|
return None
|
||||||
@@ -487,13 +907,6 @@ class AgentLoop:
|
|||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|
||||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
|
||||||
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
|
||||||
return await MemoryStore(self.workspace).consolidate(
|
|
||||||
session, self.provider, self.model,
|
|
||||||
archive_all=archive_all, memory_window=self.memory_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|||||||
@@ -66,36 +66,25 @@ class MemoryStore:
|
|||||||
long_term = self.read_long_term()
|
long_term = self.read_long_term()
|
||||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||||
|
|
||||||
async def consolidate(
|
async def consolidate_chunk(
|
||||||
self,
|
self,
|
||||||
session: Session,
|
messages: list[dict],
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
*,
|
) -> tuple[bool, str | None]:
|
||||||
archive_all: bool = False,
|
"""Consolidate a chunk of messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||||
memory_window: int = 50,
|
|
||||||
) -> bool:
|
|
||||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
|
||||||
|
|
||||||
Returns True on success (including no-op), False on failure.
|
Returns (success, None).
|
||||||
|
|
||||||
|
- success: True on success (including no-op), False on failure.
|
||||||
|
- The second return value is reserved for future use (e.g. RAG-style summaries) and is
|
||||||
|
always None in the current implementation.
|
||||||
"""
|
"""
|
||||||
if archive_all:
|
if not messages:
|
||||||
old_messages = session.messages
|
return True, None
|
||||||
keep_count = 0
|
|
||||||
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
|
||||||
else:
|
|
||||||
keep_count = memory_window // 2
|
|
||||||
if len(session.messages) <= keep_count:
|
|
||||||
return True
|
|
||||||
if len(session.messages) - session.last_consolidated <= 0:
|
|
||||||
return True
|
|
||||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
|
||||||
if not old_messages:
|
|
||||||
return True
|
|
||||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
for m in old_messages:
|
for m in messages:
|
||||||
if not m.get("content"):
|
if not m.get("content"):
|
||||||
continue
|
continue
|
||||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||||
@@ -113,7 +102,19 @@ class MemoryStore:
|
|||||||
try:
|
try:
|
||||||
response = await provider.chat(
|
response = await provider.chat(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You are a memory consolidation agent.\n"
|
||||||
|
"Your job is to:\n"
|
||||||
|
"1) Append a concise but grep-friendly entry to HISTORY.md summarizing key events, decisions and topics.\n"
|
||||||
|
" - Write 1 paragraph of 2–5 sentences that starts with [YYYY-MM-DD HH:MM].\n"
|
||||||
|
" - Include concrete names, IDs and numbers so it is easy to search with grep.\n"
|
||||||
|
"2) Update long-term MEMORY.md with stable facts and user preferences as markdown, including all existing facts plus new ones.\n"
|
||||||
|
"3) Optionally return a short context_summary (1–3 sentences) that will replace the raw messages in future dialogue history.\n\n"
|
||||||
|
"Always call the save_memory tool with history_entry, memory_update and (optionally) context_summary."
|
||||||
|
),
|
||||||
|
},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
],
|
],
|
||||||
tools=_SAVE_MEMORY_TOOL,
|
tools=_SAVE_MEMORY_TOOL,
|
||||||
@@ -122,7 +123,7 @@ class MemoryStore:
|
|||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||||
return False
|
return False, None
|
||||||
|
|
||||||
args = response.tool_calls[0].arguments
|
args = response.tool_calls[0].arguments
|
||||||
# Some providers return arguments as a JSON string instead of dict
|
# Some providers return arguments as a JSON string instead of dict
|
||||||
@@ -134,10 +135,10 @@ class MemoryStore:
|
|||||||
args = args[0]
|
args = args[0]
|
||||||
else:
|
else:
|
||||||
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
||||||
return False
|
return False, None
|
||||||
if not isinstance(args, dict):
|
if not isinstance(args, dict):
|
||||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||||
return False
|
return False, None
|
||||||
|
|
||||||
if entry := args.get("history_entry"):
|
if entry := args.get("history_entry"):
|
||||||
if not isinstance(entry, str):
|
if not isinstance(entry, str):
|
||||||
@@ -149,9 +150,8 @@ class MemoryStore:
|
|||||||
if update != current_memory:
|
if update != current_memory:
|
||||||
self.write_long_term(update)
|
self.write_long_term(update)
|
||||||
|
|
||||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
logger.info("Memory consolidation done for {} messages", len(messages))
|
||||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
return True, None
|
||||||
return True
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Memory consolidation failed")
|
logger.exception("Memory consolidation failed")
|
||||||
return False
|
return False, None
|
||||||
|
|||||||
@@ -189,11 +189,22 @@ class SlackConfig(Base):
|
|||||||
|
|
||||||
|
|
||||||
class QQConfig(Base):
|
class QQConfig(Base):
|
||||||
"""QQ channel configuration using botpy SDK."""
|
"""QQ channel configuration.
|
||||||
|
|
||||||
|
Supports two implementations:
|
||||||
|
1. Official botpy SDK: requires app_id and secret
|
||||||
|
2. OneBot protocol: requires api_url (and optionally ws_reverse_url, bot_qq, access_token)
|
||||||
|
"""
|
||||||
|
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
|
# Official botpy SDK fields
|
||||||
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
|
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
|
||||||
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
||||||
|
# OneBot protocol fields
|
||||||
|
api_url: str = "" # OneBot HTTP API URL (e.g. "http://localhost:5700")
|
||||||
|
ws_reverse_url: str = "" # OneBot WebSocket reverse URL (e.g. "ws://localhost:8080/ws/reverse")
|
||||||
|
bot_qq: int | None = None # Bot's QQ number (for filtering self messages)
|
||||||
|
access_token: str = "" # Optional access token for OneBot API
|
||||||
allow_from: list[str] = Field(
|
allow_from: list[str] = Field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
) # Allowed user openids (empty = public access)
|
) # Allowed user openids (empty = public access)
|
||||||
@@ -226,10 +237,18 @@ class AgentDefaults(Base):
|
|||||||
provider: str = (
|
provider: str = (
|
||||||
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||||
)
|
)
|
||||||
max_tokens: int = 8192
|
# 原生上下文最大窗口(通常对应模型的 max_input_tokens / max_context_tokens)
|
||||||
|
# 默认按照主流大模型(如 GPT-4o、Claude 3.x 等)的 128k 上下文给一个宽松上限,实际应根据所选模型文档手动调整。
|
||||||
|
max_tokens_input: int = 128_000
|
||||||
|
# 默认单次回复的最大输出 token 上限(调用时可按需要再做截断或比例分配)
|
||||||
|
# 8192 足以覆盖大多数实际对话/工具使用场景,同样可按需手动调整。
|
||||||
|
max_tokens_output: int = 8192
|
||||||
|
# 会话历史压缩触发比例:当估算的输入 token 使用量 >= maxTokensInput * compressionStartRatio 时开始压缩。
|
||||||
|
compression_start_ratio: float = 0.7
|
||||||
|
# 会话历史压缩目标比例:每轮压缩后尽量把估算的输入 token 使用量压到 maxTokensInput * compressionTargetRatio 附近。
|
||||||
|
compression_target_ratio: float = 0.4
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 40
|
max_tool_iterations: int = 40
|
||||||
memory_window: int = 100
|
|
||||||
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.config.paths import get_legacy_sessions_dir
|
|
||||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +29,6 @@ class Session:
|
|||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
updated_at: datetime = field(default_factory=datetime.now)
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
|
||||||
|
|
||||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||||
"""Add a message to the session."""
|
"""Add a message to the session."""
|
||||||
@@ -44,9 +42,13 @@ class Session:
|
|||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||||
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
"""
|
||||||
unconsolidated = self.messages[self.last_consolidated:]
|
Return messages for LLM input, aligned to a user turn.
|
||||||
sliced = unconsolidated[-max_messages:]
|
|
||||||
|
- max_messages > 0 时只保留最近 max_messages 条;
|
||||||
|
- max_messages <= 0 时不做条数截断,返回全部消息。
|
||||||
|
"""
|
||||||
|
sliced = self.messages if max_messages <= 0 else self.messages[-max_messages:]
|
||||||
|
|
||||||
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
||||||
for i, m in enumerate(sliced):
|
for i, m in enumerate(sliced):
|
||||||
@@ -66,7 +68,7 @@ class Session:
|
|||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all messages and reset session to initial state."""
|
"""Clear all messages and reset session to initial state."""
|
||||||
self.messages = []
|
self.messages = []
|
||||||
self.last_consolidated = 0
|
self.metadata = {}
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
@@ -80,7 +82,7 @@ class SessionManager:
|
|||||||
def __init__(self, workspace: Path):
|
def __init__(self, workspace: Path):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
||||||
self.legacy_sessions_dir = get_legacy_sessions_dir()
|
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
|
||||||
self._cache: dict[str, Session] = {}
|
self._cache: dict[str, Session] = {}
|
||||||
|
|
||||||
def _get_session_path(self, key: str) -> Path:
|
def _get_session_path(self, key: str) -> Path:
|
||||||
@@ -132,7 +134,6 @@ class SessionManager:
|
|||||||
messages = []
|
messages = []
|
||||||
metadata = {}
|
metadata = {}
|
||||||
created_at = None
|
created_at = None
|
||||||
last_consolidated = 0
|
|
||||||
|
|
||||||
with open(path, encoding="utf-8") as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
@@ -145,7 +146,6 @@ class SessionManager:
|
|||||||
if data.get("_type") == "metadata":
|
if data.get("_type") == "metadata":
|
||||||
metadata = data.get("metadata", {})
|
metadata = data.get("metadata", {})
|
||||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||||
last_consolidated = data.get("last_consolidated", 0)
|
|
||||||
else:
|
else:
|
||||||
messages.append(data)
|
messages.append(data)
|
||||||
|
|
||||||
@@ -154,7 +154,6 @@ class SessionManager:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
created_at=created_at or datetime.now(),
|
created_at=created_at or datetime.now(),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
last_consolidated=last_consolidated
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load session {}: {}", key, e)
|
logger.warning("Failed to load session {}: {}", key, e)
|
||||||
@@ -171,7 +170,6 @@ class SessionManager:
|
|||||||
"created_at": session.created_at.isoformat(),
|
"created_at": session.created_at.isoformat(),
|
||||||
"updated_at": session.updated_at.isoformat(),
|
"updated_at": session.updated_at.isoformat(),
|
||||||
"metadata": session.metadata,
|
"metadata": session.metadata,
|
||||||
"last_consolidated": session.last_consolidated
|
|
||||||
}
|
}
|
||||||
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
||||||
for msg in session.messages:
|
for msg in session.messages:
|
||||||
|
|||||||
Reference in New Issue
Block a user