"""Session management for conversation history.""" import json import shutil from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any from loguru import logger from nanobot.config.paths import get_legacy_sessions_dir from nanobot.utils.helpers import ensure_dir, safe_filename @dataclass class Session: """ A conversation session. Stores messages in JSONL format for easy reading and persistence. Important: Messages are append-only for LLM cache efficiency. The consolidation process writes summaries to MEMORY.md/HISTORY.md but does NOT modify the messages list or get_history() output. """ key: str # channel:chat_id messages: list[dict[str, Any]] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) last_consolidated: int = 0 # Number of messages already consolidated to files _persisted_message_count: int = field(default=0, init=False, repr=False) _persisted_metadata_state: str = field(default="", init=False, repr=False) _requires_full_save: bool = field(default=False, init=False, repr=False) def add_message(self, role: str, content: str, **kwargs: Any) -> None: """Add a message to the session.""" msg = { "role": role, "content": content, "timestamp": datetime.now().isoformat(), **kwargs } self.messages.append(msg) self.updated_at = datetime.now() @staticmethod def _find_legal_start(messages: list[dict[str, Any]]) -> int: """Find first index where every tool result has a matching assistant tool_call.""" declared: set[str] = set() start = 0 for i, msg in enumerate(messages): role = msg.get("role") if role == "assistant": for tc in msg.get("tool_calls") or []: if isinstance(tc, dict) and tc.get("id"): declared.add(str(tc["id"])) elif role == "tool": tid = msg.get("tool_call_id") if tid and str(tid) not in declared: start = i + 1 declared.clear() for prev in messages[start:i + 1]: if prev.get("role") == "assistant": for tc in prev.get("tool_calls") or []: if isinstance(tc, dict) and tc.get("id"): declared.add(str(tc["id"])) return start def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" unconsolidated = self.messages[self.last_consolidated:] sliced = unconsolidated[-max_messages:] # Drop leading non-user messages to avoid starting mid-turn when possible. for i, message in enumerate(sliced): if message.get("role") == "user": sliced = sliced[i:] break # Some providers reject orphan tool results if the matching assistant # tool_calls message fell outside the fixed-size history window. start = self._find_legal_start(sliced) if start: sliced = sliced[start:] out: list[dict[str, Any]] = [] for message in sliced: entry: dict[str, Any] = {"role": message["role"], "content": message.get("content", "")} for key in ("tool_calls", "tool_call_id", "name"): if key in message: entry[key] = message[key] out.append(entry) return out def clear(self) -> None: """Clear all messages and reset session to initial state.""" self.messages = [] self.last_consolidated = 0 self.updated_at = datetime.now() self._requires_full_save = True class SessionManager: """ Manages conversation sessions. Sessions are stored as JSONL files in the sessions directory. """ def __init__(self, workspace: Path): self.workspace = workspace self.sessions_dir = ensure_dir(self.workspace / "sessions") self.legacy_sessions_dir = get_legacy_sessions_dir() self._cache: dict[str, Session] = {} def _get_session_path(self, key: str) -> Path: """Get the file path for a session.""" safe_key = safe_filename(key.replace(":", "_")) return self.sessions_dir / f"{safe_key}.jsonl" def _get_legacy_session_path(self, key: str) -> Path: """Legacy global session path (~/.nanobot/sessions/).""" safe_key = safe_filename(key.replace(":", "_")) return self.legacy_sessions_dir / f"{safe_key}.jsonl" def get_or_create(self, key: str) -> Session: """ Get an existing session or create a new one. Args: key: Session key (usually channel:chat_id). Returns: The session. """ if key in self._cache: return self._cache[key] session = self._load(key) if session is None: session = Session(key=key) self._cache[key] = session return session def _load(self, key: str) -> Session | None: """Load a session from disk.""" path = self._get_session_path(key) if not path.exists(): legacy_path = self._get_legacy_session_path(key) if legacy_path.exists(): try: shutil.move(str(legacy_path), str(path)) logger.info("Migrated session {} from legacy path", key) except Exception: logger.exception("Failed to migrate session {}", key) if not path.exists(): return None try: messages = [] metadata = {} created_at = None last_consolidated = 0 with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue data = json.loads(line) if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) session = Session( key=key, messages=messages, created_at=created_at or datetime.now(), updated_at=datetime.fromtimestamp(path.stat().st_mtime), metadata=metadata, last_consolidated=last_consolidated ) self._mark_persisted(session) return session except Exception as e: logger.warning("Failed to load session {}: {}", key, e) return None @staticmethod def _metadata_state(session: Session) -> str: """Serialize metadata fields that require a checkpoint line.""" return json.dumps( { "key": session.key, "created_at": session.created_at.isoformat(), "metadata": session.metadata, "last_consolidated": session.last_consolidated, }, ensure_ascii=False, sort_keys=True, ) @staticmethod def _metadata_line(session: Session) -> dict[str, Any]: """Build a metadata checkpoint record.""" return { "_type": "metadata", "key": session.key, "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), "metadata": session.metadata, "last_consolidated": session.last_consolidated } @staticmethod def _write_jsonl_line(handle: Any, payload: dict[str, Any]) -> None: handle.write(json.dumps(payload, ensure_ascii=False) + "\n") def _mark_persisted(self, session: Session) -> None: session._persisted_message_count = len(session.messages) session._persisted_metadata_state = self._metadata_state(session) session._requires_full_save = False def _rewrite_session_file(self, path: Path, session: Session) -> None: with open(path, "w", encoding="utf-8") as f: self._write_jsonl_line(f, self._metadata_line(session)) for msg in session.messages: self._write_jsonl_line(f, msg) self._mark_persisted(session) def save(self, session: Session) -> None: """Save a session to disk.""" path = self._get_session_path(session.key) metadata_state = self._metadata_state(session) needs_full_rewrite = ( session._requires_full_save or not path.exists() or session._persisted_message_count > len(session.messages) ) if needs_full_rewrite: session.updated_at = datetime.now() self._rewrite_session_file(path, session) else: new_messages = session.messages[session._persisted_message_count:] metadata_changed = metadata_state != session._persisted_metadata_state if new_messages or metadata_changed: session.updated_at = datetime.now() with open(path, "a", encoding="utf-8") as f: for msg in new_messages: self._write_jsonl_line(f, msg) if metadata_changed: self._write_jsonl_line(f, self._metadata_line(session)) self._mark_persisted(session) self._cache[session.key] = session def invalidate(self, key: str) -> None: """Remove a session from the in-memory cache.""" self._cache.pop(key, None) def list_sessions(self) -> list[dict[str, Any]]: """ List all sessions. Returns: List of session info dicts. """ sessions = [] for path in self.sessions_dir.glob("*.jsonl"): try: created_at = None key = path.stem.replace("_", ":", 1) with open(path, encoding="utf-8") as f: first_line = f.readline().strip() if first_line: data = json.loads(first_line) if data.get("_type") == "metadata": key = data.get("key") or key created_at = data.get("created_at") # Incremental saves append messages without rewriting the first metadata line, # so use file mtime as the session's latest activity timestamp. sessions.append({ "key": key, "created_at": created_at, "updated_at": datetime.fromtimestamp(path.stat().st_mtime).isoformat(), "path": str(path) }) except Exception: continue return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)