306 lines
11 KiB
Python
306 lines
11 KiB
Python
"""Session management for conversation history."""
|
|
|
|
import json
|
|
import shutil
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from nanobot.config.paths import get_legacy_sessions_dir
|
|
from nanobot.utils.helpers import ensure_dir, safe_filename
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
"""
|
|
A conversation session.
|
|
|
|
Stores messages in JSONL format for easy reading and persistence.
|
|
|
|
Important: Messages are append-only for LLM cache efficiency.
|
|
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
|
but does NOT modify the messages list or get_history() output.
|
|
"""
|
|
|
|
key: str # channel:chat_id
|
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
|
created_at: datetime = field(default_factory=datetime.now)
|
|
updated_at: datetime = field(default_factory=datetime.now)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
last_consolidated: int = 0 # Number of messages already consolidated to files
|
|
_persisted_message_count: int = field(default=0, init=False, repr=False)
|
|
_persisted_metadata_state: str = field(default="", init=False, repr=False)
|
|
_requires_full_save: bool = field(default=False, init=False, repr=False)
|
|
|
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
|
"""Add a message to the session."""
|
|
msg = {
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": datetime.now().isoformat(),
|
|
**kwargs
|
|
}
|
|
self.messages.append(msg)
|
|
self.updated_at = datetime.now()
|
|
|
|
@staticmethod
|
|
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
|
|
"""Find first index where every tool result has a matching assistant tool_call."""
|
|
declared: set[str] = set()
|
|
start = 0
|
|
for i, msg in enumerate(messages):
|
|
role = msg.get("role")
|
|
if role == "assistant":
|
|
for tc in msg.get("tool_calls") or []:
|
|
if isinstance(tc, dict) and tc.get("id"):
|
|
declared.add(str(tc["id"]))
|
|
elif role == "tool":
|
|
tid = msg.get("tool_call_id")
|
|
if tid and str(tid) not in declared:
|
|
start = i + 1
|
|
declared.clear()
|
|
for prev in messages[start:i + 1]:
|
|
if prev.get("role") == "assistant":
|
|
for tc in prev.get("tool_calls") or []:
|
|
if isinstance(tc, dict) and tc.get("id"):
|
|
declared.add(str(tc["id"]))
|
|
return start
|
|
|
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
|
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
|
unconsolidated = self.messages[self.last_consolidated:]
|
|
sliced = unconsolidated[-max_messages:]
|
|
|
|
# Drop leading non-user messages to avoid starting mid-turn when possible.
|
|
for i, message in enumerate(sliced):
|
|
if message.get("role") == "user":
|
|
sliced = sliced[i:]
|
|
break
|
|
|
|
# Some providers reject orphan tool results if the matching assistant
|
|
# tool_calls message fell outside the fixed-size history window.
|
|
start = self._find_legal_start(sliced)
|
|
if start:
|
|
sliced = sliced[start:]
|
|
|
|
out: list[dict[str, Any]] = []
|
|
for message in sliced:
|
|
entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")}
|
|
for key in ("tool_calls", "tool_call_id", "name"):
|
|
if key in message:
|
|
entry[key] = message[key]
|
|
out.append(entry)
|
|
return out
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all messages and reset session to initial state."""
|
|
self.messages = []
|
|
self.last_consolidated = 0
|
|
self.updated_at = datetime.now()
|
|
self._requires_full_save = True
|
|
|
|
|
|
class SessionManager:
|
|
"""
|
|
Manages conversation sessions.
|
|
|
|
Sessions are stored as JSONL files in the sessions directory.
|
|
"""
|
|
|
|
def __init__(self, workspace: Path):
|
|
self.workspace = workspace
|
|
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
|
self.legacy_sessions_dir = get_legacy_sessions_dir()
|
|
self._cache: dict[str, Session] = {}
|
|
|
|
def _get_session_path(self, key: str) -> Path:
|
|
"""Get the file path for a session."""
|
|
safe_key = safe_filename(key.replace(":", "_"))
|
|
return self.sessions_dir / f"{safe_key}.jsonl"
|
|
|
|
def _get_legacy_session_path(self, key: str) -> Path:
|
|
"""Legacy global session path (~/.nanobot/sessions/)."""
|
|
safe_key = safe_filename(key.replace(":", "_"))
|
|
return self.legacy_sessions_dir / f"{safe_key}.jsonl"
|
|
|
|
def get_or_create(self, key: str) -> Session:
|
|
"""
|
|
Get an existing session or create a new one.
|
|
|
|
Args:
|
|
key: Session key (usually channel:chat_id).
|
|
|
|
Returns:
|
|
The session.
|
|
"""
|
|
if key in self._cache:
|
|
return self._cache[key]
|
|
|
|
session = self._load(key)
|
|
if session is None:
|
|
session = Session(key=key)
|
|
|
|
self._cache[key] = session
|
|
return session
|
|
|
|
def _load(self, key: str) -> Session | None:
|
|
"""Load a session from disk."""
|
|
path = self._get_session_path(key)
|
|
if not path.exists():
|
|
legacy_path = self._get_legacy_session_path(key)
|
|
if legacy_path.exists():
|
|
try:
|
|
shutil.move(str(legacy_path), str(path))
|
|
logger.info("Migrated session {} from legacy path", key)
|
|
except Exception:
|
|
logger.exception("Failed to migrate session {}", key)
|
|
|
|
if not path.exists():
|
|
return None
|
|
|
|
try:
|
|
messages = []
|
|
metadata = {}
|
|
created_at = None
|
|
last_consolidated = 0
|
|
|
|
with open(path, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
data = json.loads(line)
|
|
|
|
if data.get("_type") == "metadata":
|
|
metadata = data.get("metadata", {})
|
|
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
|
last_consolidated = data.get("last_consolidated", 0)
|
|
else:
|
|
messages.append(data)
|
|
|
|
session = Session(
|
|
key=key,
|
|
messages=messages,
|
|
created_at=created_at or datetime.now(),
|
|
updated_at=datetime.fromtimestamp(path.stat().st_mtime),
|
|
metadata=metadata,
|
|
last_consolidated=last_consolidated
|
|
)
|
|
self._mark_persisted(session)
|
|
return session
|
|
except Exception as e:
|
|
logger.warning("Failed to load session {}: {}", key, e)
|
|
return None
|
|
|
|
@staticmethod
|
|
def _metadata_state(session: Session) -> str:
|
|
"""Serialize metadata fields that require a checkpoint line."""
|
|
return json.dumps(
|
|
{
|
|
"key": session.key,
|
|
"created_at": session.created_at.isoformat(),
|
|
"metadata": session.metadata,
|
|
"last_consolidated": session.last_consolidated,
|
|
},
|
|
ensure_ascii=False,
|
|
sort_keys=True,
|
|
)
|
|
|
|
@staticmethod
|
|
def _metadata_line(session: Session) -> dict[str, Any]:
|
|
"""Build a metadata checkpoint record."""
|
|
return {
|
|
"_type": "metadata",
|
|
"key": session.key,
|
|
"created_at": session.created_at.isoformat(),
|
|
"updated_at": session.updated_at.isoformat(),
|
|
"metadata": session.metadata,
|
|
"last_consolidated": session.last_consolidated
|
|
}
|
|
|
|
@staticmethod
|
|
def _write_jsonl_line(handle: Any, payload: dict[str, Any]) -> None:
|
|
handle.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
|
|
def _mark_persisted(self, session: Session) -> None:
|
|
session._persisted_message_count = len(session.messages)
|
|
session._persisted_metadata_state = self._metadata_state(session)
|
|
session._requires_full_save = False
|
|
|
|
def _rewrite_session_file(self, path: Path, session: Session) -> None:
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
self._write_jsonl_line(f, self._metadata_line(session))
|
|
for msg in session.messages:
|
|
self._write_jsonl_line(f, msg)
|
|
self._mark_persisted(session)
|
|
|
|
def save(self, session: Session) -> None:
|
|
"""Save a session to disk."""
|
|
path = self._get_session_path(session.key)
|
|
metadata_state = self._metadata_state(session)
|
|
needs_full_rewrite = (
|
|
session._requires_full_save
|
|
or not path.exists()
|
|
or session._persisted_message_count > len(session.messages)
|
|
)
|
|
|
|
if needs_full_rewrite:
|
|
session.updated_at = datetime.now()
|
|
self._rewrite_session_file(path, session)
|
|
else:
|
|
new_messages = session.messages[session._persisted_message_count:]
|
|
metadata_changed = metadata_state != session._persisted_metadata_state
|
|
|
|
if new_messages or metadata_changed:
|
|
session.updated_at = datetime.now()
|
|
with open(path, "a", encoding="utf-8") as f:
|
|
for msg in new_messages:
|
|
self._write_jsonl_line(f, msg)
|
|
if metadata_changed:
|
|
self._write_jsonl_line(f, self._metadata_line(session))
|
|
self._mark_persisted(session)
|
|
|
|
self._cache[session.key] = session
|
|
|
|
def invalidate(self, key: str) -> None:
|
|
"""Remove a session from the in-memory cache."""
|
|
self._cache.pop(key, None)
|
|
|
|
def list_sessions(self) -> list[dict[str, Any]]:
|
|
"""
|
|
List all sessions.
|
|
|
|
Returns:
|
|
List of session info dicts.
|
|
"""
|
|
sessions = []
|
|
|
|
for path in self.sessions_dir.glob("*.jsonl"):
|
|
try:
|
|
created_at = None
|
|
key = path.stem.replace("_", ":", 1)
|
|
with open(path, encoding="utf-8") as f:
|
|
first_line = f.readline().strip()
|
|
if first_line:
|
|
data = json.loads(first_line)
|
|
if data.get("_type") == "metadata":
|
|
key = data.get("key") or key
|
|
created_at = data.get("created_at")
|
|
|
|
# Incremental saves append messages without rewriting the first metadata line,
|
|
# so use file mtime as the session's latest activity timestamp.
|
|
sessions.append({
|
|
"key": key,
|
|
"created_at": created_at,
|
|
"updated_at": datetime.fromtimestamp(path.stat().st_mtime).isoformat(),
|
|
"path": str(path)
|
|
})
|
|
except Exception:
|
|
continue
|
|
|
|
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|