Move consolidation policy into MemoryConsolidator, keep backward compatibility for legacy config, and compress history by token budget instead of message count.
187 lines
5.8 KiB
Python
187 lines
5.8 KiB
Python
"""Utility functions for nanobot."""
|
|
|
|
import json
|
|
import re
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import tiktoken
|
|
|
|
|
|
def detect_image_mime(data: bytes) -> str | None:
|
|
"""Detect image MIME type from magic bytes, ignoring file extension."""
|
|
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
|
return "image/png"
|
|
if data[:3] == b"\xff\xd8\xff":
|
|
return "image/jpeg"
|
|
if data[:6] in (b"GIF87a", b"GIF89a"):
|
|
return "image/gif"
|
|
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
|
return "image/webp"
|
|
return None
|
|
|
|
|
|
def ensure_dir(path: Path) -> Path:
|
|
"""Ensure directory exists, return it."""
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
return path
|
|
|
|
|
|
def timestamp() -> str:
|
|
"""Current ISO timestamp."""
|
|
return datetime.now().isoformat()
|
|
|
|
|
|
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
|
|
|
def safe_filename(name: str) -> str:
|
|
"""Replace unsafe path characters with underscores."""
|
|
return _UNSAFE_CHARS.sub("_", name).strip()
|
|
|
|
|
|
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
|
"""
|
|
Split content into chunks within max_len, preferring line breaks.
|
|
|
|
Args:
|
|
content: The text content to split.
|
|
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
|
|
|
|
Returns:
|
|
List of message chunks, each within max_len.
|
|
"""
|
|
if not content:
|
|
return []
|
|
if len(content) <= max_len:
|
|
return [content]
|
|
chunks: list[str] = []
|
|
while content:
|
|
if len(content) <= max_len:
|
|
chunks.append(content)
|
|
break
|
|
cut = content[:max_len]
|
|
# Try to break at newline first, then space, then hard break
|
|
pos = cut.rfind('\n')
|
|
if pos <= 0:
|
|
pos = cut.rfind(' ')
|
|
if pos <= 0:
|
|
pos = max_len
|
|
chunks.append(content[:pos])
|
|
content = content[pos:].lstrip()
|
|
return chunks
|
|
|
|
|
|
def estimate_prompt_tokens(
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
) -> int:
|
|
"""Estimate prompt tokens with tiktoken."""
|
|
try:
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
parts: list[str] = []
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
parts.append(content)
|
|
elif isinstance(content, list):
|
|
for part in content:
|
|
if isinstance(part, dict) and part.get("type") == "text":
|
|
txt = part.get("text", "")
|
|
if txt:
|
|
parts.append(txt)
|
|
if tools:
|
|
parts.append(json.dumps(tools, ensure_ascii=False))
|
|
return len(enc.encode("\n".join(parts)))
|
|
except Exception:
|
|
return 0
|
|
|
|
|
|
def estimate_message_tokens(message: dict[str, Any]) -> int:
|
|
"""Estimate prompt tokens contributed by one persisted message."""
|
|
content = message.get("content")
|
|
parts: list[str] = []
|
|
if isinstance(content, str):
|
|
parts.append(content)
|
|
elif isinstance(content, list):
|
|
for part in content:
|
|
if isinstance(part, dict) and part.get("type") == "text":
|
|
text = part.get("text", "")
|
|
if text:
|
|
parts.append(text)
|
|
else:
|
|
parts.append(json.dumps(part, ensure_ascii=False))
|
|
elif content is not None:
|
|
parts.append(json.dumps(content, ensure_ascii=False))
|
|
|
|
for key in ("name", "tool_call_id"):
|
|
value = message.get(key)
|
|
if isinstance(value, str) and value:
|
|
parts.append(value)
|
|
if message.get("tool_calls"):
|
|
parts.append(json.dumps(message["tool_calls"], ensure_ascii=False))
|
|
|
|
payload = "\n".join(parts)
|
|
if not payload:
|
|
return 1
|
|
try:
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
return max(1, len(enc.encode(payload)))
|
|
except Exception:
|
|
return max(1, len(payload) // 4)
|
|
|
|
|
|
def estimate_prompt_tokens_chain(
|
|
provider: Any,
|
|
model: str | None,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
) -> tuple[int, str]:
|
|
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
|
|
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
|
|
if callable(provider_counter):
|
|
try:
|
|
tokens, source = provider_counter(messages, tools, model)
|
|
if isinstance(tokens, (int, float)) and tokens > 0:
|
|
return int(tokens), str(source or "provider_counter")
|
|
except Exception:
|
|
pass
|
|
|
|
estimated = estimate_prompt_tokens(messages, tools)
|
|
if estimated > 0:
|
|
return int(estimated), "tiktoken"
|
|
return 0, "none"
|
|
|
|
|
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
|
"""Sync bundled templates to workspace. Only creates missing files."""
|
|
from importlib.resources import files as pkg_files
|
|
try:
|
|
tpl = pkg_files("nanobot") / "templates"
|
|
except Exception:
|
|
return []
|
|
if not tpl.is_dir():
|
|
return []
|
|
|
|
added: list[str] = []
|
|
|
|
def _write(src, dest: Path):
|
|
if dest.exists():
|
|
return
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
|
|
added.append(str(dest.relative_to(workspace)))
|
|
|
|
for item in tpl.iterdir():
|
|
if item.name.endswith(".md"):
|
|
_write(item, workspace / item.name)
|
|
_write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md")
|
|
_write(None, workspace / "memory" / "HISTORY.md")
|
|
(workspace / "skills").mkdir(exist_ok=True)
|
|
|
|
if added and not silent:
|
|
from rich.console import Console
|
|
for name in added:
|
|
Console().print(f" [dim]Created {name}[/dim]")
|
|
return added
|