diff --git a/Dockerfile b/Dockerfile
index 8132747..3682fb1 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
# Install Node.js 20 for the WhatsApp bridge
RUN apt-get update && \
- apt-get install -y --no-install-recommends curl ca-certificates gnupg git && \
+ apt-get install -y --no-install-recommends curl ca-certificates gnupg git openssh-client && \
mkdir -p /etc/apt/keyrings && \
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
@@ -26,6 +26,8 @@ COPY bridge/ bridge/
RUN uv pip install --system --no-cache .
# Build the WhatsApp bridge
+RUN git config --global url."https://github.com/".insteadOf "ssh://git@github.com/"
+
WORKDIR /app/bridge
RUN npm install && npm run build
WORKDIR /app
diff --git a/README.md b/README.md
index 9fbec37..64ae157 100644
--- a/README.md
+++ b/README.md
@@ -191,9 +191,11 @@ nanobot channels login
nanobot onboard
```
+Use `nanobot onboard --wizard` if you want the interactive setup wizard.
+
**2. Configure** (`~/.nanobot/config.json`)
-Add or merge these **two parts** into your config (other options have defaults).
+Configure these **two parts** in your config (other options have defaults).
*Set your API key* (e.g. OpenRouter, recommended for global users):
```json
@@ -809,6 +811,7 @@ Config file: `~/.nanobot/config.json`
OpenAI Codex (OAuth)
Codex uses OAuth instead of API keys. Requires a ChatGPT Plus or Pro account.
+No `providers.openaiCodex` block is needed in `config.json`; `nanobot provider login` stores the OAuth session outside config.
**1. Login:**
```bash
@@ -841,6 +844,44 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
+
+
+GitHub Copilot (OAuth)
+
+GitHub Copilot uses OAuth instead of API keys. Requires a [GitHub account with a plan](https://github.com/features/copilot/plans) configured.
+No `providers.githubCopilot` block is needed in `config.json`; `nanobot provider login` stores the OAuth session outside config.
+
+**1. Login:**
+```bash
+nanobot provider login github-copilot
+```
+
+**2. Set model** (merge into `~/.nanobot/config.json`):
+```json
+{
+ "agents": {
+ "defaults": {
+ "model": "github-copilot/gpt-4.1"
+ }
+ }
+}
+```
+
+**3. Chat:**
+```bash
+nanobot agent -m "Hello!"
+
+# Target a specific workspace/config locally
+nanobot agent -c ~/.nanobot-telegram/config.json -m "Hello!"
+
+# One-off workspace override on top of that config
+nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -m "Hello!"
+```
+
+> Docker users: use `docker run -it` for interactive OAuth login.
+
+
+
Custom Provider (Any OpenAI-compatible API)
@@ -1161,6 +1202,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
| Option | Default | Description |
|--------|---------|-------------|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
+| `tools.exec.enable` | `true` | When `false`, the shell `exec` tool is not registered at all. Use this to completely disable shell command execution. |
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
| `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. |
@@ -1288,6 +1330,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo
| Command | Description |
|---------|-------------|
| `nanobot onboard` | Initialize config & workspace at `~/.nanobot/` |
+| `nanobot onboard --wizard` | Launch the interactive onboarding wizard |
| `nanobot onboard -c -w ` | Initialize or refresh a specific instance config and workspace |
| `nanobot agent -m "..."` | Chat with the agent |
| `nanobot agent -w ` | Chat against a specific workspace |
diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py
index 10e2813..152b58d 100644
--- a/nanobot/agent/loop.py
+++ b/nanobot/agent/loop.py
@@ -120,12 +120,13 @@ class AgentLoop:
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
for cls in (WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
- self.tools.register(ExecTool(
- working_dir=str(self.workspace),
- timeout=self.exec_config.timeout,
- restrict_to_workspace=self.restrict_to_workspace,
- path_append=self.exec_config.path_append,
- ))
+ if self.exec_config.enable:
+ self.tools.register(ExecTool(
+ working_dir=str(self.workspace),
+ timeout=self.exec_config.timeout,
+ restrict_to_workspace=self.restrict_to_workspace,
+ path_append=self.exec_config.path_append,
+ ))
self.tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
self.tools.register(WebFetchTool(proxy=self.web_proxy))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
diff --git a/nanobot/agent/tools/spawn.py b/nanobot/agent/tools/spawn.py
index fc62bf8..2050eed 100644
--- a/nanobot/agent/tools/spawn.py
+++ b/nanobot/agent/tools/spawn.py
@@ -32,7 +32,9 @@ class SpawnTool(Tool):
return (
"Spawn a subagent to handle a task in the background. "
"Use this for complex or time-consuming tasks that can run independently. "
- "The subagent will complete the task and report back when done."
+ "The subagent will complete the task and report back when done. "
+ "For deliverables or existing projects, inspect the workspace first "
+ "and use a dedicated subdirectory when helpful."
)
@property
diff --git a/nanobot/channels/email.py b/nanobot/channels/email.py
index 618e640..be3cb3e 100644
--- a/nanobot/channels/email.py
+++ b/nanobot/channels/email.py
@@ -80,6 +80,21 @@ class EmailChannel(BaseChannel):
"Nov",
"Dec",
)
+ _IMAP_RECONNECT_MARKERS = (
+ "disconnected for inactivity",
+ "eof occurred in violation of protocol",
+ "socket error",
+ "connection reset",
+ "broken pipe",
+ "bye",
+ )
+ _IMAP_MISSING_MAILBOX_MARKERS = (
+ "mailbox doesn't exist",
+ "select failed",
+ "no such mailbox",
+ "can't open mailbox",
+ "does not exist",
+ )
@classmethod
def default_config(cls) -> dict[str, Any]:
@@ -267,8 +282,37 @@ class EmailChannel(BaseChannel):
dedupe: bool,
limit: int,
) -> list[dict[str, Any]]:
- """Fetch messages by arbitrary IMAP search criteria."""
messages: list[dict[str, Any]] = []
+ cycle_uids: set[str] = set()
+
+ for attempt in range(2):
+ try:
+ self._fetch_messages_once(
+ search_criteria,
+ mark_seen,
+ dedupe,
+ limit,
+ messages,
+ cycle_uids,
+ )
+ return messages
+ except Exception as exc:
+ if attempt == 1 or not self._is_stale_imap_error(exc):
+ raise
+ logger.warning("Email IMAP connection went stale, retrying once: {}", exc)
+
+ return messages
+
+ def _fetch_messages_once(
+ self,
+ search_criteria: tuple[str, ...],
+ mark_seen: bool,
+ dedupe: bool,
+ limit: int,
+ messages: list[dict[str, Any]],
+ cycle_uids: set[str],
+ ) -> None:
+ """Fetch messages by arbitrary IMAP search criteria."""
mailbox = self.config.imap_mailbox or "INBOX"
if self.config.imap_use_ssl:
@@ -278,8 +322,15 @@ class EmailChannel(BaseChannel):
try:
client.login(self.config.imap_username, self.config.imap_password)
- status, _ = client.select(mailbox)
+ try:
+ status, _ = client.select(mailbox)
+ except Exception as exc:
+ if self._is_missing_mailbox_error(exc):
+ logger.warning("Email mailbox unavailable, skipping poll for {}: {}", mailbox, exc)
+ return messages
+ raise
if status != "OK":
+ logger.warning("Email mailbox select returned {}, skipping poll for {}", status, mailbox)
return messages
status, data = client.search(None, *search_criteria)
@@ -299,6 +350,8 @@ class EmailChannel(BaseChannel):
continue
uid = self._extract_uid(fetched)
+ if uid and uid in cycle_uids:
+ continue
if dedupe and uid and uid in self._processed_uids:
continue
@@ -341,6 +394,8 @@ class EmailChannel(BaseChannel):
}
)
+ if uid:
+ cycle_uids.add(uid)
if dedupe and uid:
self._processed_uids.add(uid)
# mark_seen is the primary dedup; this set is a safety net
@@ -356,7 +411,15 @@ class EmailChannel(BaseChannel):
except Exception:
pass
- return messages
+ @classmethod
+ def _is_stale_imap_error(cls, exc: Exception) -> bool:
+ message = str(exc).lower()
+ return any(marker in message for marker in cls._IMAP_RECONNECT_MARKERS)
+
+ @classmethod
+ def _is_missing_mailbox_error(cls, exc: Exception) -> bool:
+ message = str(exc).lower()
+ return any(marker in message for marker in cls._IMAP_MISSING_MAILBOX_MARKERS)
@classmethod
def _format_imap_date(cls, value: date) -> str:
diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py
index 695689e..5e3d126 100644
--- a/nanobot/channels/feishu.py
+++ b/nanobot/channels/feishu.py
@@ -191,6 +191,10 @@ def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
texts.append(el.get("text", ""))
elif tag == "at":
texts.append(f"@{el.get('user_name', 'user')}")
+ elif tag == "code_block":
+ lang = el.get("language", "")
+ code_text = el.get("text", "")
+ texts.append(f"\n```{lang}\n{code_text}\n```\n")
elif tag == "img" and (key := el.get("image_key")):
images.append(key)
return (" ".join(texts).strip() or None), images
@@ -1039,7 +1043,7 @@ class FeishuChannel(BaseChannel):
event = data.event
message = event.message
sender = event.sender
-
+
# Deduplication check
message_id = message.message_id
if message_id in self._processed_message_ids:
diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py
index 49858da..c2b9199 100644
--- a/nanobot/channels/telegram.py
+++ b/nanobot/channels/telegram.py
@@ -11,6 +11,7 @@ from typing import Any, Literal
from loguru import logger
from pydantic import Field
from telegram import BotCommand, ReplyParameters, Update
+from telegram.error import TimedOut
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest
@@ -151,6 +152,10 @@ def _markdown_to_telegram_html(text: str) -> str:
return text
+_SEND_MAX_RETRIES = 3
+_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
+
+
class TelegramConfig(Base):
"""Telegram channel configuration."""
@@ -160,6 +165,8 @@ class TelegramConfig(Base):
proxy: str | None = None
reply_to_message: bool = False
group_policy: Literal["open", "mention"] = "mention"
+ connection_pool_size: int = 32
+ pool_timeout: float = 5.0
class TelegramChannel(BaseChannel):
@@ -226,15 +233,29 @@ class TelegramChannel(BaseChannel):
self._running = True
- # Build the application with larger connection pool to avoid pool-timeout on long runs
- req = HTTPXRequest(
- connection_pool_size=16,
- pool_timeout=5.0,
+ proxy = self.config.proxy or None
+
+ # Separate pools so long-polling (getUpdates) never starves outbound sends.
+ api_request = HTTPXRequest(
+ connection_pool_size=self.config.connection_pool_size,
+ pool_timeout=self.config.pool_timeout,
connect_timeout=30.0,
read_timeout=30.0,
- proxy=self.config.proxy if self.config.proxy else None,
+ proxy=proxy,
+ )
+ poll_request = HTTPXRequest(
+ connection_pool_size=4,
+ pool_timeout=self.config.pool_timeout,
+ connect_timeout=30.0,
+ read_timeout=30.0,
+ proxy=proxy,
+ )
+ builder = (
+ Application.builder()
+ .token(self.config.token)
+ .request(api_request)
+ .get_updates_request(poll_request)
)
- builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
self._app = builder.build()
self._app.add_error_handler(self._on_error)
@@ -365,7 +386,8 @@ class TelegramChannel(BaseChannel):
ok, error = validate_url_target(media_path)
if not ok:
raise ValueError(f"unsafe media URL: {error}")
- await sender(
+ await self._call_with_retry(
+ sender,
chat_id=chat_id,
**{param: media_path},
reply_parameters=reply_params,
@@ -401,6 +423,21 @@ class TelegramChannel(BaseChannel):
else:
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
+ async def _call_with_retry(self, fn, *args, **kwargs):
+ """Call an async Telegram API function with retry on pool/network timeout."""
+ for attempt in range(1, _SEND_MAX_RETRIES + 1):
+ try:
+ return await fn(*args, **kwargs)
+ except TimedOut:
+ if attempt == _SEND_MAX_RETRIES:
+ raise
+ delay = _SEND_RETRY_BASE_DELAY * (2 ** (attempt - 1))
+ logger.warning(
+ "Telegram timeout (attempt {}/{}), retrying in {:.1f}s",
+ attempt, _SEND_MAX_RETRIES, delay,
+ )
+ await asyncio.sleep(delay)
+
async def _send_text(
self,
chat_id: int,
@@ -411,7 +448,8 @@ class TelegramChannel(BaseChannel):
"""Send a plain text message with HTML fallback."""
try:
html = _markdown_to_telegram_html(text)
- await self._app.bot.send_message(
+ await self._call_with_retry(
+ self._app.bot.send_message,
chat_id=chat_id, text=html, parse_mode="HTML",
reply_parameters=reply_params,
**(thread_kwargs or {}),
@@ -419,7 +457,8 @@ class TelegramChannel(BaseChannel):
except Exception as e:
logger.warning("HTML parse failed, falling back to plain text: {}", e)
try:
- await self._app.bot.send_message(
+ await self._call_with_retry(
+ self._app.bot.send_message,
chat_id=chat_id,
text=text,
reply_parameters=reply_params,
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py
index 0d4bb3d..8172ad6 100644
--- a/nanobot/cli/commands.py
+++ b/nanobot/cli/commands.py
@@ -21,12 +21,11 @@ if sys.platform == "win32":
pass
import typer
-from prompt_toolkit import print_formatted_text
-from prompt_toolkit import PromptSession
+from prompt_toolkit import PromptSession, print_formatted_text
+from prompt_toolkit.application import run_in_terminal
from prompt_toolkit.formatted_text import ANSI, HTML
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
-from prompt_toolkit.application import run_in_terminal
from rich.console import Console
from rich.markdown import Markdown
from rich.table import Table
@@ -39,6 +38,7 @@ from nanobot.utils.helpers import sync_workspace_templates
app = typer.Typer(
name="nanobot",
+ context_settings={"help_option_names": ["-h", "--help"]},
help=f"{__logo__} nanobot - Personal AI Assistant",
no_args_is_help=True,
)
@@ -265,6 +265,7 @@ def main(
def onboard(
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
+ wizard: bool = typer.Option(False, "--wizard", help="Use interactive wizard"),
):
"""Initialize nanobot configuration and workspace."""
from nanobot.config.loader import get_config_path, load_config, save_config, set_config_path
@@ -284,42 +285,69 @@ def onboard(
# Create or update config
if config_path.exists():
- console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
- console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
- console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
- if typer.confirm("Overwrite?"):
- config = _apply_workspace_override(Config())
- save_config(config, config_path)
- console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
- else:
+ if wizard:
config = _apply_workspace_override(load_config(config_path))
- save_config(config, config_path)
- console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
+ else:
+ console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
+ console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
+ console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
+ if typer.confirm("Overwrite?"):
+ config = _apply_workspace_override(Config())
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
+ else:
+ config = _apply_workspace_override(load_config(config_path))
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
else:
config = _apply_workspace_override(Config())
- save_config(config, config_path)
- console.print(f"[green]✓[/green] Created config at {config_path}")
- console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
+ # In wizard mode, don't save yet - the wizard will handle saving if should_save=True
+ if not wizard:
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Created config at {config_path}")
+ # Run interactive wizard if enabled
+ if wizard:
+ from nanobot.cli.onboard_wizard import run_onboard
+
+ try:
+ result = run_onboard(initial_config=config)
+ if not result.should_save:
+ console.print("[yellow]Configuration discarded. No changes were saved.[/yellow]")
+ return
+
+ config = result.config
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Config saved at {config_path}")
+ except Exception as e:
+ console.print(f"[red]✗[/red] Error during configuration: {e}")
+ console.print("[yellow]Please run 'nanobot onboard' again to complete setup.[/yellow]")
+ raise typer.Exit(1)
_onboard_plugins(config_path)
# Create workspace, preferring the configured workspace path.
- workspace = get_workspace_path(config.workspace_path)
- if not workspace.exists():
- workspace.mkdir(parents=True, exist_ok=True)
- console.print(f"[green]✓[/green] Created workspace at {workspace}")
+ workspace_path = get_workspace_path(config.workspace_path)
+ if not workspace_path.exists():
+ workspace_path.mkdir(parents=True, exist_ok=True)
+ console.print(f"[green]✓[/green] Created workspace at {workspace_path}")
- sync_workspace_templates(workspace)
+ sync_workspace_templates(workspace_path)
agent_cmd = 'nanobot agent -m "Hello!"'
+ gateway_cmd = "nanobot gateway"
if config:
agent_cmd += f" --config {config_path}"
+ gateway_cmd += f" --config {config_path}"
console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:")
- console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]")
- console.print(" Get one at: https://openrouter.ai/keys")
- console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]")
+ if wizard:
+ console.print(f" 1. Chat: [cyan]{agent_cmd}[/cyan]")
+ console.print(f" 2. Start gateway: [cyan]{gateway_cmd}[/cyan]")
+ else:
+ console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]")
+ console.print(" Get one at: https://openrouter.ai/keys")
+ console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]")
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
@@ -363,9 +391,9 @@ def _onboard_plugins(config_path: Path) -> None:
def _make_provider(config: Config):
"""Create the appropriate LLM provider from config."""
+ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
- from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
@@ -434,21 +462,30 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
console.print(f"[dim]Using config: {config_path}[/dim]")
loaded = load_config(config_path)
+ _warn_deprecated_config_keys(config_path)
if workspace:
loaded.agents.defaults.workspace = workspace
return loaded
-def _print_deprecated_memory_window_notice(config: Config) -> None:
- """Warn when running with old memoryWindow-only config."""
- if config.agents.defaults.should_warn_deprecated_memory_window:
+def _warn_deprecated_config_keys(config_path: Path | None) -> None:
+ """Hint users to remove obsolete keys from their config file."""
+ import json
+ from nanobot.config.loader import get_config_path
+
+ path = config_path or get_config_path()
+ try:
+ raw = json.loads(path.read_text(encoding="utf-8"))
+ except Exception:
+ return
+ if "memoryWindow" in raw.get("agents", {}).get("defaults", {}):
console.print(
- "[yellow]Hint:[/yellow] Detected deprecated `memoryWindow` without "
- "`contextWindowTokens`. `memoryWindow` is ignored; run "
- "[cyan]nanobot onboard[/cyan] to refresh your config template."
+ "[dim]Hint: `memoryWindow` in your config is no longer used "
+ "and can be safely removed.[/dim]"
)
+
# ============================================================================
# Gateway / Server
# ============================================================================
@@ -476,7 +513,6 @@ def gateway(
logging.basicConfig(level=logging.DEBUG)
config = _load_runtime_config(config, workspace)
- _print_deprecated_memory_window_notice(config)
port = port if port is not None else config.gateway.port
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
@@ -667,7 +703,6 @@ def agent(
from nanobot.cron.service import CronService
config = _load_runtime_config(config, workspace)
- _print_deprecated_memory_window_notice(config)
sync_workspace_templates(config.workspace_path)
bus = MessageBus()
diff --git a/nanobot/cli/model_info.py b/nanobot/cli/model_info.py
new file mode 100644
index 0000000..520370c
--- /dev/null
+++ b/nanobot/cli/model_info.py
@@ -0,0 +1,231 @@
+"""Model information helpers for the onboard wizard.
+
+Provides model context window lookup and autocomplete suggestions using litellm.
+"""
+
+from __future__ import annotations
+
+from functools import lru_cache
+from typing import Any
+
+
+def _litellm():
+ """Lazy accessor for litellm (heavy import deferred until actually needed)."""
+ import litellm as _ll
+
+ return _ll
+
+
+@lru_cache(maxsize=1)
+def _get_model_cost_map() -> dict[str, Any]:
+ """Get litellm's model cost map (cached)."""
+ return getattr(_litellm(), "model_cost", {})
+
+
+@lru_cache(maxsize=1)
+def get_all_models() -> list[str]:
+ """Get all known model names from litellm.
+ """
+ models = set()
+
+ # From model_cost (has pricing info)
+ cost_map = _get_model_cost_map()
+ for k in cost_map.keys():
+ if k != "sample_spec":
+ models.add(k)
+
+ # From models_by_provider (more complete provider coverage)
+ for provider_models in getattr(_litellm(), "models_by_provider", {}).values():
+ if isinstance(provider_models, (set, list)):
+ models.update(provider_models)
+
+ return sorted(models)
+
+
+def _normalize_model_name(model: str) -> str:
+ """Normalize model name for comparison."""
+ return model.lower().replace("-", "_").replace(".", "")
+
+
+def find_model_info(model_name: str) -> dict[str, Any] | None:
+ """Find model info with fuzzy matching.
+
+ Args:
+ model_name: Model name in any common format
+
+ Returns:
+ Model info dict or None if not found
+ """
+ cost_map = _get_model_cost_map()
+ if not cost_map:
+ return None
+
+ # Direct match
+ if model_name in cost_map:
+ return cost_map[model_name]
+
+ # Extract base name (without provider prefix)
+ base_name = model_name.split("/")[-1] if "/" in model_name else model_name
+ base_normalized = _normalize_model_name(base_name)
+
+ candidates = []
+
+ for key, info in cost_map.items():
+ if key == "sample_spec":
+ continue
+
+ key_base = key.split("/")[-1] if "/" in key else key
+ key_base_normalized = _normalize_model_name(key_base)
+
+ # Score the match
+ score = 0
+
+ # Exact base name match (highest priority)
+ if base_normalized == key_base_normalized:
+ score = 100
+ # Base name contains model
+ elif base_normalized in key_base_normalized:
+ score = 80
+ # Model contains base name
+ elif key_base_normalized in base_normalized:
+ score = 70
+ # Partial match
+ elif base_normalized[:10] in key_base_normalized:
+ score = 50
+
+ if score > 0:
+ # Prefer models with max_input_tokens
+ if info.get("max_input_tokens"):
+ score += 10
+ candidates.append((score, key, info))
+
+ if not candidates:
+ return None
+
+ # Return the best match
+ candidates.sort(key=lambda x: (-x[0], x[1]))
+ return candidates[0][2]
+
+
+def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
+ """Get the maximum input context tokens for a model.
+
+ Args:
+ model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o")
+ provider: Provider name for informational purposes (not yet used for filtering)
+
+ Returns:
+ Maximum input tokens, or None if unknown
+
+ Note:
+ The provider parameter is currently informational only. Future versions may
+ use it to prefer provider-specific model variants in the lookup.
+ """
+ # First try fuzzy search in model_cost (has more accurate max_input_tokens)
+ info = find_model_info(model)
+ if info:
+ # Prefer max_input_tokens (this is what we want for context window)
+ max_input = info.get("max_input_tokens")
+ if max_input and isinstance(max_input, int):
+ return max_input
+
+ # Fall back to litellm's get_max_tokens (returns max_output_tokens typically)
+ try:
+ result = _litellm().get_max_tokens(model)
+ if result and result > 0:
+ return result
+ except (KeyError, ValueError, AttributeError):
+ # Model not found in litellm's database or invalid response
+ pass
+
+ # Last resort: use max_tokens from model_cost
+ if info:
+ max_tokens = info.get("max_tokens")
+ if max_tokens and isinstance(max_tokens, int):
+ return max_tokens
+
+ return None
+
+
+@lru_cache(maxsize=1)
+def _get_provider_keywords() -> dict[str, list[str]]:
+ """Build provider keywords mapping from nanobot's provider registry.
+
+ Returns:
+ Dict mapping provider name to list of keywords for model filtering.
+ """
+ try:
+ from nanobot.providers.registry import PROVIDERS
+
+ mapping = {}
+ for spec in PROVIDERS:
+ if spec.keywords:
+ mapping[spec.name] = list(spec.keywords)
+ return mapping
+ except ImportError:
+ return {}
+
+
+def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
+ """Get autocomplete suggestions for model names.
+
+ Args:
+ partial: Partial model name typed by user
+ provider: Provider name for filtering (e.g., "openrouter", "minimax")
+ limit: Maximum number of suggestions to return
+
+ Returns:
+ List of matching model names
+ """
+ all_models = get_all_models()
+ if not all_models:
+ return []
+
+ partial_lower = partial.lower()
+ partial_normalized = _normalize_model_name(partial)
+
+ # Get provider keywords from registry
+ provider_keywords = _get_provider_keywords()
+
+ # Filter by provider if specified
+ allowed_keywords = None
+ if provider and provider != "auto":
+ allowed_keywords = provider_keywords.get(provider.lower())
+
+ matches = []
+
+ for model in all_models:
+ model_lower = model.lower()
+
+ # Apply provider filter
+ if allowed_keywords:
+ if not any(kw in model_lower for kw in allowed_keywords):
+ continue
+
+ # Match against partial input
+ if not partial:
+ matches.append(model)
+ continue
+
+ if partial_lower in model_lower:
+ # Score by position of match (earlier = better)
+ pos = model_lower.find(partial_lower)
+ score = 100 - pos
+ matches.append((score, model))
+ elif partial_normalized in _normalize_model_name(model):
+ score = 50
+ matches.append((score, model))
+
+ # Sort by score if we have scored matches
+ if matches and isinstance(matches[0], tuple):
+ matches.sort(key=lambda x: (-x[0], x[1]))
+ matches = [m[1] for m in matches]
+ else:
+ matches.sort()
+
+ return matches[:limit]
+
+
+def format_token_count(tokens: int) -> str:
+ """Format token count for display (e.g., 200000 -> '200,000')."""
+ return f"{tokens:,}"
diff --git a/nanobot/cli/onboard_wizard.py b/nanobot/cli/onboard_wizard.py
new file mode 100644
index 0000000..eca86bf
--- /dev/null
+++ b/nanobot/cli/onboard_wizard.py
@@ -0,0 +1,1023 @@
+"""Interactive onboarding questionnaire for nanobot."""
+
+import json
+import types
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import Any, NamedTuple, get_args, get_origin
+
+try:
+ import questionary
+except ModuleNotFoundError: # pragma: no cover - exercised in environments without wizard deps
+ questionary = None
+from loguru import logger
+from pydantic import BaseModel
+from rich.console import Console
+from rich.panel import Panel
+from rich.table import Table
+
+from nanobot.cli.model_info import (
+ format_token_count,
+ get_model_context_limit,
+ get_model_suggestions,
+)
+from nanobot.config.loader import get_config_path, load_config
+from nanobot.config.schema import Config
+
+console = Console()
+
+
+@dataclass
+class OnboardResult:
+ """Result of an onboarding session."""
+
+ config: Config
+ should_save: bool
+
+# --- Field Hints for Select Fields ---
+# Maps field names to (choices, hint_text)
+# To add a new select field with hints, add an entry:
+# "field_name": (["choice1", "choice2", ...], "hint text for the field")
+_SELECT_FIELD_HINTS: dict[str, tuple[list[str], str]] = {
+ "reasoning_effort": (
+ ["low", "medium", "high"],
+ "low / medium / high - enables LLM thinking mode",
+ ),
+}
+
+# --- Key Bindings for Navigation ---
+
+_BACK_PRESSED = object() # Sentinel value for back navigation
+
+
+def _get_questionary():
+ """Return questionary or raise a clear error when wizard deps are unavailable."""
+ if questionary is None:
+ raise RuntimeError(
+ "Interactive onboarding requires the optional 'questionary' dependency. "
+ "Install project dependencies and rerun with --wizard."
+ )
+ return questionary
+
+
+def _select_with_back(
+ prompt: str, choices: list[str], default: str | None = None
+) -> str | None | object:
+ """Select with Escape/Left arrow support for going back.
+
+ Args:
+ prompt: The prompt text to display.
+ choices: List of choices to select from. Must not be empty.
+ default: The default choice to pre-select. If not in choices, first item is used.
+
+ Returns:
+ _BACK_PRESSED sentinel if user pressed Escape or Left arrow
+ The selected choice string if user confirmed
+ None if user cancelled (Ctrl+C)
+ """
+ from prompt_toolkit.application import Application
+ from prompt_toolkit.key_binding import KeyBindings
+ from prompt_toolkit.keys import Keys
+ from prompt_toolkit.layout import Layout
+ from prompt_toolkit.layout.containers import HSplit, Window
+ from prompt_toolkit.layout.controls import FormattedTextControl
+ from prompt_toolkit.styles import Style
+
+ # Validate choices
+ if not choices:
+ logger.warning("Empty choices list provided to _select_with_back")
+ return None
+
+ # Find default index
+ selected_index = 0
+ if default and default in choices:
+ selected_index = choices.index(default)
+
+ # State holder for the result
+ state: dict[str, str | None | object] = {"result": None}
+
+ # Build menu items (uses closure over selected_index)
+ def get_menu_text():
+ items = []
+ for i, choice in enumerate(choices):
+ if i == selected_index:
+ items.append(("class:selected", f"> {choice}\n"))
+ else:
+ items.append(("", f" {choice}\n"))
+ return items
+
+ # Create layout
+ menu_control = FormattedTextControl(get_menu_text)
+ menu_window = Window(content=menu_control, height=len(choices))
+
+ prompt_control = FormattedTextControl(lambda: [("class:question", f"> {prompt}")])
+ prompt_window = Window(content=prompt_control, height=1)
+
+ layout = Layout(HSplit([prompt_window, menu_window]))
+
+ # Key bindings
+ bindings = KeyBindings()
+
+ @bindings.add(Keys.Up)
+ def _up(event):
+ nonlocal selected_index
+ selected_index = (selected_index - 1) % len(choices)
+ event.app.invalidate()
+
+ @bindings.add(Keys.Down)
+ def _down(event):
+ nonlocal selected_index
+ selected_index = (selected_index + 1) % len(choices)
+ event.app.invalidate()
+
+ @bindings.add(Keys.Enter)
+ def _enter(event):
+ state["result"] = choices[selected_index]
+ event.app.exit()
+
+ @bindings.add("escape")
+ def _escape(event):
+ state["result"] = _BACK_PRESSED
+ event.app.exit()
+
+ @bindings.add(Keys.Left)
+ def _left(event):
+ state["result"] = _BACK_PRESSED
+ event.app.exit()
+
+ @bindings.add(Keys.ControlC)
+ def _ctrl_c(event):
+ state["result"] = None
+ event.app.exit()
+
+ # Style
+ style = Style.from_dict({
+ "selected": "fg:green bold",
+ "question": "fg:cyan",
+ })
+
+ app = Application(layout=layout, key_bindings=bindings, style=style)
+ try:
+ app.run()
+ except Exception:
+ logger.exception("Error in select prompt")
+ return None
+
+ return state["result"]
+
+# --- Type Introspection ---
+
+
+class FieldTypeInfo(NamedTuple):
+ """Result of field type introspection."""
+
+ type_name: str
+ inner_type: Any
+
+
+def _get_field_type_info(field_info) -> FieldTypeInfo:
+ """Extract field type info from Pydantic field."""
+ annotation = field_info.annotation
+ if annotation is None:
+ return FieldTypeInfo("str", None)
+
+ origin = get_origin(annotation)
+ args = get_args(annotation)
+
+ if origin is types.UnionType:
+ non_none_args = [a for a in args if a is not type(None)]
+ if len(non_none_args) == 1:
+ annotation = non_none_args[0]
+ origin = get_origin(annotation)
+ args = get_args(annotation)
+
+ _SIMPLE_TYPES: dict[type, str] = {bool: "bool", int: "int", float: "float"}
+
+ if origin is list or (hasattr(origin, "__name__") and origin.__name__ == "List"):
+ return FieldTypeInfo("list", args[0] if args else str)
+ if origin is dict or (hasattr(origin, "__name__") and origin.__name__ == "Dict"):
+ return FieldTypeInfo("dict", None)
+ for py_type, name in _SIMPLE_TYPES.items():
+ if annotation is py_type:
+ return FieldTypeInfo(name, None)
+ if isinstance(annotation, type) and issubclass(annotation, BaseModel):
+ return FieldTypeInfo("model", annotation)
+ return FieldTypeInfo("str", None)
+
+
+def _get_field_display_name(field_key: str, field_info) -> str:
+ """Get display name for a field."""
+ if field_info and field_info.description:
+ return field_info.description
+ name = field_key
+ suffix_map = {
+ "_s": " (seconds)",
+ "_ms": " (ms)",
+ "_url": " URL",
+ "_path": " Path",
+ "_id": " ID",
+ "_key": " Key",
+ "_token": " Token",
+ }
+ for suffix, replacement in suffix_map.items():
+ if name.endswith(suffix):
+ name = name[: -len(suffix)] + replacement
+ break
+ return name.replace("_", " ").title()
+
+
+# --- Sensitive Field Masking ---
+
+_SENSITIVE_KEYWORDS = frozenset({"api_key", "token", "secret", "password", "credentials"})
+
+
+def _is_sensitive_field(field_name: str) -> bool:
+ """Check if a field name indicates sensitive content."""
+ return any(kw in field_name.lower() for kw in _SENSITIVE_KEYWORDS)
+
+
+def _mask_value(value: str) -> str:
+ """Mask a sensitive value, showing only the last 4 characters."""
+ if len(value) <= 4:
+ return "****"
+ return "*" * (len(value) - 4) + value[-4:]
+
+
+# --- Value Formatting ---
+
+
+def _format_value(value: Any, rich: bool = True, field_name: str = "") -> str:
+ """Single recursive entry point for safe value display. Handles any depth."""
+ if value is None or value == "" or value == {} or value == []:
+ return "[dim]not set[/dim]" if rich else "[not set]"
+ if _is_sensitive_field(field_name) and isinstance(value, str):
+ masked = _mask_value(value)
+ return f"[dim]{masked}[/dim]" if rich else masked
+ if isinstance(value, BaseModel):
+ parts = []
+ for fname, _finfo in type(value).model_fields.items():
+ fval = getattr(value, fname, None)
+ formatted = _format_value(fval, rich=False, field_name=fname)
+ if formatted != "[not set]":
+ parts.append(f"{fname}={formatted}")
+ return ", ".join(parts) if parts else ("[dim]not set[/dim]" if rich else "[not set]")
+ if isinstance(value, list):
+ return ", ".join(str(v) for v in value)
+ if isinstance(value, dict):
+ return json.dumps(value)
+ return str(value)
+
+
+def _format_value_for_input(value: Any, field_type: str) -> str:
+ """Format a value for use as input default."""
+ if value is None or value == "":
+ return ""
+ if field_type == "list" and isinstance(value, list):
+ return ",".join(str(v) for v in value)
+ if field_type == "dict" and isinstance(value, dict):
+ return json.dumps(value)
+ return str(value)
+
+
+# --- Rich UI Components ---
+
+
+def _show_config_panel(display_name: str, model: BaseModel, fields: list) -> None:
+ """Display current configuration as a rich table."""
+ table = Table(show_header=False, box=None, padding=(0, 2))
+ table.add_column("Field", style="cyan")
+ table.add_column("Value")
+
+ for fname, field_info in fields:
+ value = getattr(model, fname, None)
+ display = _get_field_display_name(fname, field_info)
+ formatted = _format_value(value, rich=True, field_name=fname)
+ table.add_row(display, formatted)
+
+ console.print(Panel(table, title=f"[bold]{display_name}[/bold]", border_style="blue"))
+
+
+def _show_main_menu_header() -> None:
+ """Display the main menu header."""
+ from nanobot import __logo__, __version__
+
+ console.print()
+ # Use Align.CENTER for the single line of text
+ from rich.align import Align
+
+ console.print(
+ Align.center(f"{__logo__} [bold cyan]nanobot[{__version__}][/bold cyan]")
+ )
+ console.print()
+
+
+def _show_section_header(title: str, subtitle: str = "") -> None:
+ """Display a section header."""
+ console.print()
+ if subtitle:
+ console.print(
+ Panel(f"[dim]{subtitle}[/dim]", title=f"[bold]{title}[/bold]", border_style="blue")
+ )
+ else:
+ console.print(Panel("", title=f"[bold]{title}[/bold]", border_style="blue"))
+
+
+# --- Input Handlers ---
+
+
+def _input_bool(display_name: str, current: bool | None) -> bool | None:
+ """Get boolean input via confirm dialog."""
+ return _get_questionary().confirm(
+ display_name,
+ default=bool(current) if current is not None else False,
+ ).ask()
+
+
+def _input_text(display_name: str, current: Any, field_type: str) -> Any:
+ """Get text input and parse based on field type."""
+ default = _format_value_for_input(current, field_type)
+
+ value = _get_questionary().text(f"{display_name}:", default=default).ask()
+
+ if value is None or value == "":
+ return None
+
+ if field_type == "int":
+ try:
+ return int(value)
+ except ValueError:
+ console.print("[yellow]! Invalid number format, value not saved[/yellow]")
+ return None
+ elif field_type == "float":
+ try:
+ return float(value)
+ except ValueError:
+ console.print("[yellow]! Invalid number format, value not saved[/yellow]")
+ return None
+ elif field_type == "list":
+ return [v.strip() for v in value.split(",") if v.strip()]
+ elif field_type == "dict":
+ try:
+ return json.loads(value)
+ except json.JSONDecodeError:
+ console.print("[yellow]! Invalid JSON format, value not saved[/yellow]")
+ return None
+
+ return value
+
+
+def _input_with_existing(
+ display_name: str, current: Any, field_type: str
+) -> Any:
+ """Handle input with 'keep existing' option for non-empty values."""
+ has_existing = current is not None and current != "" and current != {} and current != []
+
+ if has_existing and not isinstance(current, list):
+ choice = _get_questionary().select(
+ display_name,
+ choices=["Enter new value", "Keep existing value"],
+ default="Keep existing value",
+ ).ask()
+ if choice == "Keep existing value" or choice is None:
+ return None
+
+ return _input_text(display_name, current, field_type)
+
+
+# --- Pydantic Model Configuration ---
+
+
+def _get_current_provider(model: BaseModel) -> str:
+ """Get the current provider setting from a model (if available)."""
+ if hasattr(model, "provider"):
+ return getattr(model, "provider", "auto") or "auto"
+ return "auto"
+
+
+def _input_model_with_autocomplete(
+ display_name: str, current: Any, provider: str
+) -> str | None:
+ """Get model input with autocomplete suggestions.
+
+ """
+ from prompt_toolkit.completion import Completer, Completion
+
+ default = str(current) if current else ""
+
+ class DynamicModelCompleter(Completer):
+ """Completer that dynamically fetches model suggestions."""
+
+ def __init__(self, provider_name: str):
+ self.provider = provider_name
+
+ def get_completions(self, document, complete_event):
+ text = document.text_before_cursor
+ suggestions = get_model_suggestions(text, provider=self.provider, limit=50)
+ for model in suggestions:
+ # Skip if model doesn't contain the typed text
+ if text.lower() not in model.lower():
+ continue
+ yield Completion(
+ model,
+ start_position=-len(text),
+ display=model,
+ )
+
+ value = _get_questionary().autocomplete(
+ f"{display_name}:",
+ choices=[""], # Placeholder, actual completions from completer
+ completer=DynamicModelCompleter(provider),
+ default=default,
+ qmark=">",
+ ).ask()
+
+ return value if value else None
+
+
+def _input_context_window_with_recommendation(
+ display_name: str, current: Any, model_obj: BaseModel
+) -> int | None:
+ """Get context window input with option to fetch recommended value."""
+ current_val = current if current else ""
+
+ choices = ["Enter new value"]
+ if current_val:
+ choices.append("Keep existing value")
+ choices.append("[?] Get recommended value")
+
+ choice = _get_questionary().select(
+ display_name,
+ choices=choices,
+ default="Enter new value",
+ ).ask()
+
+ if choice is None:
+ return None
+
+ if choice == "Keep existing value":
+ return None
+
+ if choice == "[?] Get recommended value":
+ # Get the model name from the model object
+ model_name = getattr(model_obj, "model", None)
+ if not model_name:
+ console.print("[yellow]! Please configure the model field first[/yellow]")
+ return None
+
+ provider = _get_current_provider(model_obj)
+ context_limit = get_model_context_limit(model_name, provider)
+
+ if context_limit:
+ console.print(f"[green]+ Recommended context window: {format_token_count(context_limit)} tokens[/green]")
+ return context_limit
+ else:
+ console.print("[yellow]! Could not fetch model info, please enter manually[/yellow]")
+ # Fall through to manual input
+
+ # Manual input
+ value = _get_questionary().text(
+ f"{display_name}:",
+ default=str(current_val) if current_val else "",
+ ).ask()
+
+ if value is None or value == "":
+ return None
+
+ try:
+ return int(value)
+ except ValueError:
+ console.print("[yellow]! Invalid number format, value not saved[/yellow]")
+ return None
+
+
+def _handle_model_field(
+ working_model: BaseModel, field_name: str, field_display: str, current_value: Any
+) -> None:
+ """Handle the 'model' field with autocomplete and context-window auto-fill."""
+ provider = _get_current_provider(working_model)
+ new_value = _input_model_with_autocomplete(field_display, current_value, provider)
+ if new_value is not None and new_value != current_value:
+ setattr(working_model, field_name, new_value)
+ _try_auto_fill_context_window(working_model, new_value)
+
+
+def _handle_context_window_field(
+ working_model: BaseModel, field_name: str, field_display: str, current_value: Any
+) -> None:
+ """Handle context_window_tokens with recommendation lookup."""
+ new_value = _input_context_window_with_recommendation(
+ field_display, current_value, working_model
+ )
+ if new_value is not None:
+ setattr(working_model, field_name, new_value)
+
+
+_FIELD_HANDLERS: dict[str, Any] = {
+ "model": _handle_model_field,
+ "context_window_tokens": _handle_context_window_field,
+}
+
+
+def _configure_pydantic_model(
+ model: BaseModel,
+ display_name: str,
+ *,
+ skip_fields: set[str] | None = None,
+) -> BaseModel | None:
+ """Configure a Pydantic model interactively.
+
+ Returns the updated model only when the user explicitly selects "Done".
+ Back and cancel actions discard the section draft.
+ """
+ skip_fields = skip_fields or set()
+ working_model = model.model_copy(deep=True)
+
+ fields = [
+ (name, info)
+ for name, info in type(working_model).model_fields.items()
+ if name not in skip_fields
+ ]
+ if not fields:
+ console.print(f"[dim]{display_name}: No configurable fields[/dim]")
+ return working_model
+
+ def get_choices() -> list[str]:
+ items = []
+ for fname, finfo in fields:
+ value = getattr(working_model, fname, None)
+ display = _get_field_display_name(fname, finfo)
+ formatted = _format_value(value, rich=False, field_name=fname)
+ items.append(f"{display}: {formatted}")
+ return items + ["[Done]"]
+
+ while True:
+ console.clear()
+ _show_config_panel(display_name, working_model, fields)
+ choices = get_choices()
+ answer = _select_with_back("Select field to configure:", choices)
+
+ if answer is _BACK_PRESSED or answer is None:
+ return None
+ if answer == "[Done]":
+ return working_model
+
+ field_idx = next((i for i, c in enumerate(choices) if c == answer), -1)
+ if field_idx < 0 or field_idx >= len(fields):
+ return None
+
+ field_name, field_info = fields[field_idx]
+ current_value = getattr(working_model, field_name, None)
+ ftype = _get_field_type_info(field_info)
+ field_display = _get_field_display_name(field_name, field_info)
+
+ # Nested Pydantic model - recurse
+ if ftype.type_name == "model":
+ nested = current_value
+ created = nested is None
+ if nested is None and ftype.inner_type:
+ nested = ftype.inner_type()
+ if nested and isinstance(nested, BaseModel):
+ updated = _configure_pydantic_model(nested, field_display)
+ if updated is not None:
+ setattr(working_model, field_name, updated)
+ elif created:
+ setattr(working_model, field_name, None)
+ continue
+
+ # Registered special-field handlers
+ handler = _FIELD_HANDLERS.get(field_name)
+ if handler:
+ handler(working_model, field_name, field_display, current_value)
+ continue
+
+ # Select fields with hints (e.g. reasoning_effort)
+ if field_name in _SELECT_FIELD_HINTS:
+ choices_list, hint = _SELECT_FIELD_HINTS[field_name]
+ select_choices = choices_list + ["(clear/unset)"]
+ console.print(f"[dim] Hint: {hint}[/dim]")
+ new_value = _select_with_back(
+ field_display, select_choices, default=current_value or select_choices[0]
+ )
+ if new_value is _BACK_PRESSED:
+ continue
+ if new_value == "(clear/unset)":
+ setattr(working_model, field_name, None)
+ elif new_value is not None:
+ setattr(working_model, field_name, new_value)
+ continue
+
+ # Generic field input
+ if ftype.type_name == "bool":
+ new_value = _input_bool(field_display, current_value)
+ else:
+ new_value = _input_with_existing(field_display, current_value, ftype.type_name)
+ if new_value is not None:
+ setattr(working_model, field_name, new_value)
+
+
+def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None:
+ """Try to auto-fill context_window_tokens if it's at default value.
+
+ Note:
+ This function imports AgentDefaults from nanobot.config.schema to get
+ the default context_window_tokens value. If the schema changes, this
+ coupling needs to be updated accordingly.
+ """
+ # Check if context_window_tokens field exists
+ if not hasattr(model, "context_window_tokens"):
+ return
+
+ current_context = getattr(model, "context_window_tokens", None)
+
+ # Check if current value is the default (65536)
+ # We only auto-fill if the user hasn't changed it from default
+ from nanobot.config.schema import AgentDefaults
+
+ default_context = AgentDefaults.model_fields["context_window_tokens"].default
+
+ if current_context != default_context:
+ return # User has customized it, don't override
+
+ provider = _get_current_provider(model)
+ context_limit = get_model_context_limit(new_model_name, provider)
+
+ if context_limit:
+ setattr(model, "context_window_tokens", context_limit)
+ console.print(f"[green]+ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]")
+ else:
+ console.print("[dim](i) Could not auto-fill context window (model not in database)[/dim]")
+
+
+# --- Provider Configuration ---
+
+
+@lru_cache(maxsize=1)
+def _get_provider_info() -> dict[str, tuple[str, bool, bool, str]]:
+ """Get provider info from registry (cached)."""
+ from nanobot.providers.registry import PROVIDERS
+
+ return {
+ spec.name: (
+ spec.display_name or spec.name,
+ spec.is_gateway,
+ spec.is_local,
+ spec.default_api_base,
+ )
+ for spec in PROVIDERS
+ if not spec.is_oauth
+ }
+
+
+def _get_provider_names() -> dict[str, str]:
+ """Get provider display names."""
+ info = _get_provider_info()
+ return {name: data[0] for name, data in info.items() if name}
+
+
+def _configure_provider(config: Config, provider_name: str) -> None:
+ """Configure a single LLM provider."""
+ provider_config = getattr(config.providers, provider_name, None)
+ if provider_config is None:
+ console.print(f"[red]Unknown provider: {provider_name}[/red]")
+ return
+
+ display_name = _get_provider_names().get(provider_name, provider_name)
+ info = _get_provider_info()
+ default_api_base = info.get(provider_name, (None, None, None, None))[3]
+
+ if default_api_base and not provider_config.api_base:
+ provider_config.api_base = default_api_base
+
+ updated_provider = _configure_pydantic_model(
+ provider_config,
+ display_name,
+ )
+ if updated_provider is not None:
+ setattr(config.providers, provider_name, updated_provider)
+
+
+def _configure_providers(config: Config) -> None:
+ """Configure LLM providers."""
+
+ def get_provider_choices() -> list[str]:
+ """Build provider choices with config status indicators."""
+ choices = []
+ for name, display in _get_provider_names().items():
+ provider = getattr(config.providers, name, None)
+ if provider and provider.api_key:
+ choices.append(f"{display} *")
+ else:
+ choices.append(display)
+ return choices + ["<- Back"]
+
+ while True:
+ try:
+ console.clear()
+ _show_section_header("LLM Providers", "Select a provider to configure API key and endpoint")
+ choices = get_provider_choices()
+ answer = _select_with_back("Select provider:", choices)
+
+ if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
+ break
+
+ # Type guard: answer is now guaranteed to be a string
+ assert isinstance(answer, str)
+ # Extract provider name from choice (remove " *" suffix if present)
+ provider_name = answer.replace(" *", "")
+ # Find the actual provider key from display names
+ for name, display in _get_provider_names().items():
+ if display == provider_name:
+ _configure_provider(config, name)
+ break
+
+ except KeyboardInterrupt:
+ console.print("\n[dim]Returning to main menu...[/dim]")
+ break
+
+
+# --- Channel Configuration ---
+
+
+@lru_cache(maxsize=1)
+def _get_channel_info() -> dict[str, tuple[str, type[BaseModel]]]:
+ """Get channel info (display name + config class) from channel modules."""
+ import importlib
+
+ from nanobot.channels.registry import discover_all
+
+ result: dict[str, tuple[str, type[BaseModel]]] = {}
+ for name, channel_cls in discover_all().items():
+ try:
+ mod = importlib.import_module(f"nanobot.channels.{name}")
+ config_name = channel_cls.__name__.replace("Channel", "Config")
+ config_cls = getattr(mod, config_name, None)
+ if config_cls and isinstance(config_cls, type) and issubclass(config_cls, BaseModel):
+ display_name = getattr(channel_cls, "display_name", name.capitalize())
+ result[name] = (display_name, config_cls)
+ except Exception:
+ logger.warning(f"Failed to load channel module: {name}")
+ return result
+
+
+def _get_channel_names() -> dict[str, str]:
+ """Get channel display names."""
+ return {name: info[0] for name, info in _get_channel_info().items()}
+
+
+def _get_channel_config_class(channel: str) -> type[BaseModel] | None:
+ """Get channel config class."""
+ entry = _get_channel_info().get(channel)
+ return entry[1] if entry else None
+
+
+def _configure_channel(config: Config, channel_name: str) -> None:
+ """Configure a single channel."""
+ channel_dict = getattr(config.channels, channel_name, None)
+ if channel_dict is None:
+ channel_dict = {}
+ setattr(config.channels, channel_name, channel_dict)
+
+ display_name = _get_channel_names().get(channel_name, channel_name)
+ config_cls = _get_channel_config_class(channel_name)
+
+ if config_cls is None:
+ console.print(f"[red]No configuration class found for {display_name}[/red]")
+ return
+
+ model = config_cls.model_validate(channel_dict) if channel_dict else config_cls()
+
+ updated_channel = _configure_pydantic_model(
+ model,
+ display_name,
+ )
+ if updated_channel is not None:
+ new_dict = updated_channel.model_dump(by_alias=True, exclude_none=True)
+ setattr(config.channels, channel_name, new_dict)
+
+
+def _configure_channels(config: Config) -> None:
+ """Configure chat channels."""
+ channel_names = list(_get_channel_names().keys())
+ choices = channel_names + ["<- Back"]
+
+ while True:
+ try:
+ console.clear()
+ _show_section_header("Chat Channels", "Select a channel to configure connection settings")
+ answer = _select_with_back("Select channel:", choices)
+
+ if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
+ break
+
+ # Type guard: answer is now guaranteed to be a string
+ assert isinstance(answer, str)
+ _configure_channel(config, answer)
+ except KeyboardInterrupt:
+ console.print("\n[dim]Returning to main menu...[/dim]")
+ break
+
+
+# --- General Settings ---
+
+_SETTINGS_SECTIONS: dict[str, tuple[str, str, set[str] | None]] = {
+ "Agent Settings": ("Agent Defaults", "Configure default model, temperature, and behavior", None),
+ "Gateway": ("Gateway Settings", "Configure server host, port, and heartbeat", None),
+ "Tools": ("Tools Settings", "Configure web search, shell exec, and other tools", {"mcp_servers"}),
+}
+
+_SETTINGS_GETTER = {
+ "Agent Settings": lambda c: c.agents.defaults,
+ "Gateway": lambda c: c.gateway,
+ "Tools": lambda c: c.tools,
+}
+
+_SETTINGS_SETTER = {
+ "Agent Settings": lambda c, v: setattr(c.agents, "defaults", v),
+ "Gateway": lambda c, v: setattr(c, "gateway", v),
+ "Tools": lambda c, v: setattr(c, "tools", v),
+}
+
+
+def _configure_general_settings(config: Config, section: str) -> None:
+ """Configure a general settings section (header + model edit + writeback)."""
+ meta = _SETTINGS_SECTIONS.get(section)
+ if not meta:
+ return
+ display_name, subtitle, skip = meta
+ model = _SETTINGS_GETTER[section](config)
+ updated = _configure_pydantic_model(model, display_name, skip_fields=skip)
+ if updated is not None:
+ _SETTINGS_SETTER[section](config, updated)
+
+
+# --- Summary ---
+
+
+def _summarize_model(obj: BaseModel) -> list[tuple[str, str]]:
+ """Recursively summarize a Pydantic model. Returns list of (field, value) tuples."""
+ items: list[tuple[str, str]] = []
+ for field_name, field_info in type(obj).model_fields.items():
+ value = getattr(obj, field_name, None)
+ if value is None or value == "" or value == {} or value == []:
+ continue
+ display = _get_field_display_name(field_name, field_info)
+ ftype = _get_field_type_info(field_info)
+ if ftype.type_name == "model" and isinstance(value, BaseModel):
+ for nested_field, nested_value in _summarize_model(value):
+ items.append((f"{display}.{nested_field}", nested_value))
+ continue
+ formatted = _format_value(value, rich=False, field_name=field_name)
+ if formatted != "[not set]":
+ items.append((display, formatted))
+ return items
+
+
+def _print_summary_panel(rows: list[tuple[str, str]], title: str) -> None:
+ """Build a two-column summary panel and print it."""
+ if not rows:
+ return
+ table = Table(show_header=False, box=None, padding=(0, 2))
+ table.add_column("Setting", style="cyan")
+ table.add_column("Value")
+ for field, value in rows:
+ table.add_row(field, value)
+ console.print(Panel(table, title=f"[bold]{title}[/bold]", border_style="blue"))
+
+
+def _show_summary(config: Config) -> None:
+ """Display configuration summary using rich."""
+ console.print()
+
+ # Providers
+ provider_rows = []
+ for name, display in _get_provider_names().items():
+ provider = getattr(config.providers, name, None)
+ status = "[green]configured[/green]" if (provider and provider.api_key) else "[dim]not configured[/dim]"
+ provider_rows.append((display, status))
+ _print_summary_panel(provider_rows, "LLM Providers")
+
+ # Channels
+ channel_rows = []
+ for name, display in _get_channel_names().items():
+ channel = getattr(config.channels, name, None)
+ if channel:
+ enabled = (
+ channel.get("enabled", False)
+ if isinstance(channel, dict)
+ else getattr(channel, "enabled", False)
+ )
+ status = "[green]enabled[/green]" if enabled else "[dim]disabled[/dim]"
+ else:
+ status = "[dim]not configured[/dim]"
+ channel_rows.append((display, status))
+ _print_summary_panel(channel_rows, "Chat Channels")
+
+ # Settings sections
+ for title, model in [
+ ("Agent Settings", config.agents.defaults),
+ ("Gateway", config.gateway),
+ ("Tools", config.tools),
+ ("Channel Common", config.channels),
+ ]:
+ _print_summary_panel(_summarize_model(model), title)
+
+
+# --- Main Entry Point ---
+
+
+def _has_unsaved_changes(original: Config, current: Config) -> bool:
+ """Return True when the onboarding session has committed changes."""
+ return original.model_dump(by_alias=True) != current.model_dump(by_alias=True)
+
+
+def _prompt_main_menu_exit(has_unsaved_changes: bool) -> str:
+ """Resolve how to leave the main menu."""
+ if not has_unsaved_changes:
+ return "discard"
+
+ answer = _get_questionary().select(
+ "You have unsaved changes. What would you like to do?",
+ choices=[
+ "[S] Save and Exit",
+ "[X] Exit Without Saving",
+ "[R] Resume Editing",
+ ],
+ default="[R] Resume Editing",
+ qmark=">",
+ ).ask()
+
+ if answer == "[S] Save and Exit":
+ return "save"
+ if answer == "[X] Exit Without Saving":
+ return "discard"
+ return "resume"
+
+
+def run_onboard(initial_config: Config | None = None) -> OnboardResult:
+ """Run the interactive onboarding questionnaire.
+
+ Args:
+ initial_config: Optional pre-loaded config to use as starting point.
+ If None, loads from config file or creates new default.
+ """
+ _get_questionary()
+
+ if initial_config is not None:
+ base_config = initial_config.model_copy(deep=True)
+ else:
+ config_path = get_config_path()
+ if config_path.exists():
+ base_config = load_config()
+ else:
+ base_config = Config()
+
+ original_config = base_config.model_copy(deep=True)
+ config = base_config.model_copy(deep=True)
+
+ while True:
+ console.clear()
+ _show_main_menu_header()
+
+ try:
+ answer = _get_questionary().select(
+ "What would you like to configure?",
+ choices=[
+ "[P] LLM Provider",
+ "[C] Chat Channel",
+ "[A] Agent Settings",
+ "[G] Gateway",
+ "[T] Tools",
+ "[V] View Configuration Summary",
+ "[S] Save and Exit",
+ "[X] Exit Without Saving",
+ ],
+ qmark=">",
+ ).ask()
+ except KeyboardInterrupt:
+ answer = None
+
+ if answer is None:
+ action = _prompt_main_menu_exit(_has_unsaved_changes(original_config, config))
+ if action == "save":
+ return OnboardResult(config=config, should_save=True)
+ if action == "discard":
+ return OnboardResult(config=original_config, should_save=False)
+ continue
+
+ _MENU_DISPATCH = {
+ "[P] LLM Provider": lambda: _configure_providers(config),
+ "[C] Chat Channel": lambda: _configure_channels(config),
+ "[A] Agent Settings": lambda: _configure_general_settings(config, "Agent Settings"),
+ "[G] Gateway": lambda: _configure_general_settings(config, "Gateway"),
+ "[T] Tools": lambda: _configure_general_settings(config, "Tools"),
+ "[V] View Configuration Summary": lambda: _show_summary(config),
+ }
+
+ if answer == "[S] Save and Exit":
+ return OnboardResult(config=config, should_save=True)
+ if answer == "[X] Exit Without Saving":
+ return OnboardResult(config=original_config, should_save=False)
+
+ action_fn = _MENU_DISPATCH.get(answer)
+ if action_fn:
+ action_fn()
diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py
index 7d309e5..7095646 100644
--- a/nanobot/config/loader.py
+++ b/nanobot/config/loader.py
@@ -3,8 +3,10 @@
import json
from pathlib import Path
-from nanobot.config.schema import Config
+import pydantic
+from loguru import logger
+from nanobot.config.schema import Config
# Global variable to store current config path (for multi-instance support)
_current_config_path: Path | None = None
@@ -41,9 +43,9 @@ def load_config(config_path: Path | None = None) -> Config:
data = json.load(f)
data = _migrate_config(data)
return Config.model_validate(data)
- except (json.JSONDecodeError, ValueError) as e:
- print(f"Warning: Failed to load config from {path}: {e}")
- print("Using default configuration.")
+ except (json.JSONDecodeError, ValueError, pydantic.ValidationError) as e:
+ logger.warning(f"Failed to load config from {path}: {e}")
+ logger.warning("Using default configuration.")
return Config()
@@ -59,7 +61,7 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
path = config_path or get_config_path()
path.parent.mkdir(parents=True, exist_ok=True)
- data = config.model_dump(by_alias=True)
+ data = config.model_dump(mode="json", by_alias=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py
index c067231..c884433 100644
--- a/nanobot/config/schema.py
+++ b/nanobot/config/schema.py
@@ -38,14 +38,7 @@ class AgentDefaults(Base):
context_window_tokens: int = 65_536
temperature: float = 0.1
max_tool_iterations: int = 40
- # Deprecated compatibility field: accepted from old configs but ignored at runtime.
- memory_window: int | None = Field(default=None, exclude=True)
- reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
-
- @property
- def should_warn_deprecated_memory_window(self) -> bool:
- """Return True when old memoryWindow is present without contextWindowTokens."""
- return self.memory_window is not None and "context_window_tokens" not in self.model_fields_set
+ reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
class AgentsConfig(Base):
@@ -85,8 +78,8 @@ class ProvidersConfig(Base):
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
- openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
- github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
+ openai_codex: ProviderConfig = Field(default_factory=ProviderConfig, exclude=True) # OpenAI Codex (OAuth)
+ github_copilot: ProviderConfig = Field(default_factory=ProviderConfig, exclude=True) # Github Copilot (OAuth)
class HeartbeatConfig(Base):
@@ -125,10 +118,10 @@ class WebToolsConfig(Base):
class ExecToolConfig(Base):
"""Shell exec tool configuration."""
+ enable: bool = True
timeout: int = 60
path_append: str = ""
-
class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP)."""
diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py
index 1ed71f0..c956b89 100644
--- a/nanobot/cron/service.py
+++ b/nanobot/cron/service.py
@@ -10,7 +10,7 @@ from typing import Any, Callable, Coroutine
from loguru import logger
-from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule, CronStore
+from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
def _now_ms() -> int:
@@ -63,10 +63,12 @@ def _validate_schedule_for_add(schedule: CronSchedule) -> None:
class CronService:
"""Service for managing and executing scheduled jobs."""
+ _MAX_RUN_HISTORY = 20
+
def __init__(
self,
store_path: Path,
- on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
+ on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
):
self.store_path = store_path
self.on_job = on_job
@@ -113,6 +115,15 @@ class CronService:
last_run_at_ms=j.get("state", {}).get("lastRunAtMs"),
last_status=j.get("state", {}).get("lastStatus"),
last_error=j.get("state", {}).get("lastError"),
+ run_history=[
+ CronRunRecord(
+ run_at_ms=r["runAtMs"],
+ status=r["status"],
+ duration_ms=r.get("durationMs", 0),
+ error=r.get("error"),
+ )
+ for r in j.get("state", {}).get("runHistory", [])
+ ],
),
created_at_ms=j.get("createdAtMs", 0),
updated_at_ms=j.get("updatedAtMs", 0),
@@ -160,6 +171,15 @@ class CronService:
"lastRunAtMs": j.state.last_run_at_ms,
"lastStatus": j.state.last_status,
"lastError": j.state.last_error,
+ "runHistory": [
+ {
+ "runAtMs": r.run_at_ms,
+ "status": r.status,
+ "durationMs": r.duration_ms,
+ "error": r.error,
+ }
+ for r in j.state.run_history
+ ],
},
"createdAtMs": j.created_at_ms,
"updatedAtMs": j.updated_at_ms,
@@ -248,9 +268,8 @@ class CronService:
logger.info("Cron: executing job '{}' ({})", job.name, job.id)
try:
- response = None
if self.on_job:
- response = await self.on_job(job)
+ await self.on_job(job)
job.state.last_status = "ok"
job.state.last_error = None
@@ -261,8 +280,17 @@ class CronService:
job.state.last_error = str(e)
logger.error("Cron: job '{}' failed: {}", job.name, e)
+ end_ms = _now_ms()
job.state.last_run_at_ms = start_ms
- job.updated_at_ms = _now_ms()
+ job.updated_at_ms = end_ms
+
+ job.state.run_history.append(CronRunRecord(
+ run_at_ms=start_ms,
+ status=job.state.last_status,
+ duration_ms=end_ms - start_ms,
+ error=job.state.last_error,
+ ))
+ job.state.run_history = job.state.run_history[-self._MAX_RUN_HISTORY:]
# Handle one-shot jobs
if job.schedule.kind == "at":
@@ -366,6 +394,11 @@ class CronService:
return True
return False
+ def get_job(self, job_id: str) -> CronJob | None:
+ """Get a job by ID."""
+ store = self._load_store()
+ return next((j for j in store.jobs if j.id == job_id), None)
+
def status(self) -> dict:
"""Get service status."""
store = self._load_store()
diff --git a/nanobot/cron/types.py b/nanobot/cron/types.py
index 2b42060..e7b2c43 100644
--- a/nanobot/cron/types.py
+++ b/nanobot/cron/types.py
@@ -29,6 +29,15 @@ class CronPayload:
to: str | None = None # e.g. phone number
+@dataclass
+class CronRunRecord:
+ """A single execution record for a cron job."""
+ run_at_ms: int
+ status: Literal["ok", "error", "skipped"]
+ duration_ms: int = 0
+ error: str | None = None
+
+
@dataclass
class CronJobState:
"""Runtime state of a job."""
@@ -36,6 +45,7 @@ class CronJobState:
last_run_at_ms: int | None = None
last_status: Literal["ok", "error", "skipped"] | None = None
last_error: str | None = None
+ run_history: list[CronRunRecord] = field(default_factory=list)
@dataclass
diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py
index 4bdeb54..3daa0cc 100644
--- a/nanobot/providers/custom_provider.py
+++ b/nanobot/providers/custom_provider.py
@@ -51,6 +51,12 @@ class CustomProvider(LLMProvider):
try:
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
+ # JSONDecodeError.doc / APIError.response.text may carry the raw body
+ # (e.g. "unsupported model: xxx") which is far more useful than the
+ # generic "Expecting value …" message. Truncate to avoid huge HTML pages.
+ body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
+ if body and body.strip():
+ return LLMResponse(content=f"Error: {body.strip()[:500]}", finish_reason="error")
return LLMResponse(content=f"Error: {e}", finish_reason="error")
def _parse(self, response: Any) -> LLMResponse:
diff --git a/pyproject.toml b/pyproject.toml
index 25ef590..75e0893 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,6 +42,7 @@ dependencies = [
"qq-botpy>=1.2.0,<2.0.0",
"python-socks[asyncio]>=2.8.0,<3.0.0",
"prompt-toolkit>=3.0.50,<4.0.0",
+ "questionary>=2.0.0,<3.0.0",
"mcp>=1.26.0,<2.0.0",
"json-repair>=0.57.0,<1.0.0",
"chardet>=3.0.2,<6.0.0",
diff --git a/tests/test_commands.py b/tests/test_commands.py
index a820e77..124802e 100644
--- a/tests/test_commands.py
+++ b/tests/test_commands.py
@@ -1,6 +1,5 @@
import json
import re
-import shutil
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
@@ -13,19 +12,18 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model
-
-def _strip_ansi(text):
- """Remove ANSI escape codes from text."""
- ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
- return ansi_escape.sub('', text)
-
runner = CliRunner()
-class _StopGateway(RuntimeError):
+class _StopGatewayError(RuntimeError):
pass
+import shutil
+
+import pytest
+
+
@pytest.fixture
def mock_paths():
"""Mock config/workspace paths for test isolation."""
@@ -117,6 +115,12 @@ def test_onboard_existing_workspace_safe_create(mock_paths):
assert (workspace_dir / "AGENTS.md").exists()
+def _strip_ansi(text):
+ """Remove ANSI escape codes from text."""
+ ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
+ return ansi_escape.sub('', text)
+
+
def test_onboard_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["onboard", "--help"])
@@ -126,9 +130,28 @@ def test_onboard_help_shows_workspace_and_config_options():
assert "-w" in stripped_output
assert "--config" in stripped_output
assert "-c" in stripped_output
+ assert "--wizard" in stripped_output
assert "--dir" not in stripped_output
+def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch):
+ config_file, workspace_dir, _ = mock_paths
+
+ from nanobot.cli.onboard_wizard import OnboardResult
+
+ monkeypatch.setattr(
+ "nanobot.cli.onboard_wizard.run_onboard",
+ lambda initial_config: OnboardResult(config=initial_config, should_save=False),
+ )
+
+ result = runner.invoke(app, ["onboard", "--wizard"])
+
+ assert result.exit_code == 0
+ assert "No changes were saved" in result.stdout
+ assert not config_file.exists()
+ assert not workspace_dir.exists()
+
+
def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch):
config_path = tmp_path / "instance" / "config.json"
workspace_path = tmp_path / "workspace"
@@ -151,6 +174,31 @@ def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch)
assert f"--config {resolved_config}" in compact_output
+def test_onboard_wizard_preserves_explicit_config_in_next_steps(tmp_path, monkeypatch):
+ config_path = tmp_path / "instance" / "config.json"
+ workspace_path = tmp_path / "workspace"
+
+ from nanobot.cli.onboard_wizard import OnboardResult
+
+ monkeypatch.setattr(
+ "nanobot.cli.onboard_wizard.run_onboard",
+ lambda initial_config: OnboardResult(config=initial_config, should_save=True),
+ )
+ monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
+
+ result = runner.invoke(
+ app,
+ ["onboard", "--wizard", "--config", str(config_path), "--workspace", str(workspace_path)],
+ )
+
+ assert result.exit_code == 0
+ stripped_output = _strip_ansi(result.stdout)
+ compact_output = stripped_output.replace("\n", "")
+ resolved_config = str(config_path.resolve())
+ assert f'nanobot agent -m "Hello!" --config {resolved_config}' in compact_output
+ assert f"nanobot gateway --config {resolved_config}" in compact_output
+
+
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
config = Config()
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
@@ -165,6 +213,15 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex"
+def test_config_dump_excludes_oauth_provider_blocks():
+ config = Config()
+
+ providers = config.model_dump(by_alias=True)["providers"]
+
+ assert "openaiCodex" not in providers
+ assert "githubCopilot" not in providers
+
+
def test_config_matches_explicit_ollama_prefix_without_api_key():
config = Config()
config.agents.defaults.model = "ollama/llama3.2"
@@ -404,14 +461,15 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
-def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
- mock_agent_runtime["config"].agents.defaults.memory_window = 100
+def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
+ config_file = tmp_path / "config.json"
+ config_file.write_text(json.dumps({"agents": {"defaults": {"memoryWindow": 42}}}))
- result = runner.invoke(app, ["agent", "-m", "hello"])
+ result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
assert result.exit_code == 0
assert "memoryWindow" in result.stdout
- assert "contextWindowTokens" in result.stdout
+ assert "no longer used" in result.stdout
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
@@ -434,12 +492,12 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
- lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
- assert isinstance(result.exception, _StopGateway)
+ assert isinstance(result.exception, _StopGatewayError)
assert seen["config_path"] == config_file.resolve()
assert seen["workspace"] == Path(config.agents.defaults.workspace)
@@ -462,7 +520,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
- lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(
@@ -470,33 +528,11 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
["gateway", "--config", str(config_file), "--workspace", str(override)],
)
- assert isinstance(result.exception, _StopGateway)
+ assert isinstance(result.exception, _StopGatewayError)
assert seen["workspace"] == override
assert config.workspace_path == override
-def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
- config_file = tmp_path / "instance" / "config.json"
- config_file.parent.mkdir(parents=True)
- config_file.write_text("{}")
-
- config = Config()
- config.agents.defaults.memory_window = 100
-
- monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
- monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
- monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
- monkeypatch.setattr(
- "nanobot.cli.commands._make_provider",
- lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
- )
-
- result = runner.invoke(app, ["gateway", "--config", str(config_file)])
-
- assert isinstance(result.exception, _StopGateway)
- assert "memoryWindow" in result.stdout
- assert "contextWindowTokens" in result.stdout
-
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
@@ -517,13 +553,13 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
- raise _StopGateway("stop")
+ raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
- assert isinstance(result.exception, _StopGateway)
+ assert isinstance(result.exception, _StopGatewayError)
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
@@ -540,12 +576,12 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
- lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
- assert isinstance(result.exception, _StopGateway)
+ assert isinstance(result.exception, _StopGatewayError)
assert "port 18791" in result.stdout
@@ -562,10 +598,10 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
- lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
+ lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
- assert isinstance(result.exception, _StopGateway)
+ assert isinstance(result.exception, _StopGatewayError)
assert "port 18792" in result.stdout
diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py
index 2a446b7..c1c9510 100644
--- a/tests/test_config_migration.py
+++ b/tests/test_config_migration.py
@@ -1,15 +1,9 @@
import json
-from types import SimpleNamespace
-from typer.testing import CliRunner
-
-from nanobot.cli.commands import app
from nanobot.config.loader import load_config, save_config
-runner = CliRunner()
-
-def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
+def test_load_config_keeps_max_tokens_and_ignores_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
@@ -29,7 +23,7 @@ def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path
assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536
- assert config.agents.defaults.should_warn_deprecated_memory_window is True
+ assert not hasattr(config.agents.defaults, "memory_window")
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
@@ -58,7 +52,7 @@ def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path
assert "memoryWindow" not in defaults
-def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
+def test_onboard_does_not_crash_with_legacy_memory_window(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
@@ -78,18 +72,17 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
+ from typer.testing import CliRunner
+ from nanobot.cli.commands import app
+ runner = CliRunner()
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
- assert "contextWindowTokens" in result.stdout
- saved = json.loads(config_path.read_text(encoding="utf-8"))
- defaults = saved["agents"]["defaults"]
- assert defaults["maxTokens"] == 3333
- assert defaults["contextWindowTokens"] == 65_536
- assert "memoryWindow" not in defaults
def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
+ from types import SimpleNamespace
+
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
@@ -125,6 +118,9 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
},
)
+ from typer.testing import CliRunner
+ from nanobot.cli.commands import app
+ runner = CliRunner()
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py
index 21e1e78..4f2e8f1 100644
--- a/tests/test_consolidate_offset.py
+++ b/tests/test_consolidate_offset.py
@@ -182,7 +182,7 @@ class TestConsolidationTriggerConditions:
"""Test consolidation trigger conditions and logic."""
def test_consolidation_needed_when_messages_exceed_window(self):
- """Test consolidation logic: should trigger when messages > memory_window."""
+ """Test consolidation logic: should trigger when messages exceed the window."""
session = create_session_with_messages("test:trigger", 60)
total_messages = len(session.messages)
diff --git a/tests/test_cron_service.py b/tests/test_cron_service.py
index 9631da5..175c5eb 100644
--- a/tests/test_cron_service.py
+++ b/tests/test_cron_service.py
@@ -1,4 +1,5 @@
import asyncio
+import json
import pytest
@@ -32,6 +33,87 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
assert job.state.next_run_at_ms is not None
+@pytest.mark.asyncio
+async def test_execute_job_records_run_history(tmp_path) -> None:
+ store_path = tmp_path / "cron" / "jobs.json"
+ service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
+ job = service.add_job(
+ name="hist",
+ schedule=CronSchedule(kind="every", every_ms=60_000),
+ message="hello",
+ )
+ await service.run_job(job.id)
+
+ loaded = service.get_job(job.id)
+ assert loaded is not None
+ assert len(loaded.state.run_history) == 1
+ rec = loaded.state.run_history[0]
+ assert rec.status == "ok"
+ assert rec.duration_ms >= 0
+ assert rec.error is None
+
+
+@pytest.mark.asyncio
+async def test_run_history_records_errors(tmp_path) -> None:
+ store_path = tmp_path / "cron" / "jobs.json"
+
+ async def fail(_):
+ raise RuntimeError("boom")
+
+ service = CronService(store_path, on_job=fail)
+ job = service.add_job(
+ name="fail",
+ schedule=CronSchedule(kind="every", every_ms=60_000),
+ message="hello",
+ )
+ await service.run_job(job.id)
+
+ loaded = service.get_job(job.id)
+ assert len(loaded.state.run_history) == 1
+ assert loaded.state.run_history[0].status == "error"
+ assert loaded.state.run_history[0].error == "boom"
+
+
+@pytest.mark.asyncio
+async def test_run_history_trimmed_to_max(tmp_path) -> None:
+ store_path = tmp_path / "cron" / "jobs.json"
+ service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
+ job = service.add_job(
+ name="trim",
+ schedule=CronSchedule(kind="every", every_ms=60_000),
+ message="hello",
+ )
+ for _ in range(25):
+ await service.run_job(job.id)
+
+ loaded = service.get_job(job.id)
+ assert len(loaded.state.run_history) == CronService._MAX_RUN_HISTORY
+
+
+@pytest.mark.asyncio
+async def test_run_history_persisted_to_disk(tmp_path) -> None:
+ store_path = tmp_path / "cron" / "jobs.json"
+ service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
+ job = service.add_job(
+ name="persist",
+ schedule=CronSchedule(kind="every", every_ms=60_000),
+ message="hello",
+ )
+ await service.run_job(job.id)
+
+ raw = json.loads(store_path.read_text())
+ history = raw["jobs"][0]["state"]["runHistory"]
+ assert len(history) == 1
+ assert history[0]["status"] == "ok"
+ assert "runAtMs" in history[0]
+ assert "durationMs" in history[0]
+
+ fresh = CronService(store_path)
+ loaded = fresh.get_job(job.id)
+ assert len(loaded.state.run_history) == 1
+ assert loaded.state.run_history[0].status == "ok"
+
+
@pytest.mark.asyncio
async def test_running_service_honors_external_disable(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
diff --git a/tests/test_email_channel.py b/tests/test_email_channel.py
index c037ace..23d3ea7 100644
--- a/tests/test_email_channel.py
+++ b/tests/test_email_channel.py
@@ -1,5 +1,6 @@
from email.message import EmailMessage
from datetime import date
+import imaplib
import pytest
@@ -82,6 +83,120 @@ def test_fetch_new_messages_parses_unseen_and_marks_seen(monkeypatch) -> None:
assert items_again == []
+def test_fetch_new_messages_retries_once_when_imap_connection_goes_stale(monkeypatch) -> None:
+ raw = _make_raw_email(subject="Invoice", body="Please pay")
+ fail_once = {"pending": True}
+
+ class FlakyIMAP:
+ def __init__(self) -> None:
+ self.store_calls: list[tuple[bytes, str, str]] = []
+ self.search_calls = 0
+
+ def login(self, _user: str, _pw: str):
+ return "OK", [b"logged in"]
+
+ def select(self, _mailbox: str):
+ return "OK", [b"1"]
+
+ def search(self, *_args):
+ self.search_calls += 1
+ if fail_once["pending"]:
+ fail_once["pending"] = False
+ raise imaplib.IMAP4.abort("socket error")
+ return "OK", [b"1"]
+
+ def fetch(self, _imap_id: bytes, _parts: str):
+ return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
+
+ def store(self, imap_id: bytes, op: str, flags: str):
+ self.store_calls.append((imap_id, op, flags))
+ return "OK", [b""]
+
+ def logout(self):
+ return "BYE", [b""]
+
+ fake_instances: list[FlakyIMAP] = []
+
+ def _factory(_host: str, _port: int):
+ instance = FlakyIMAP()
+ fake_instances.append(instance)
+ return instance
+
+ monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", _factory)
+
+ channel = EmailChannel(_make_config(), MessageBus())
+ items = channel._fetch_new_messages()
+
+ assert len(items) == 1
+ assert len(fake_instances) == 2
+ assert fake_instances[0].search_calls == 1
+ assert fake_instances[1].search_calls == 1
+
+
+def test_fetch_new_messages_keeps_messages_collected_before_stale_retry(monkeypatch) -> None:
+ raw_first = _make_raw_email(subject="First", body="First body")
+ raw_second = _make_raw_email(subject="Second", body="Second body")
+ mailbox_state = {
+ b"1": {"uid": b"123", "raw": raw_first, "seen": False},
+ b"2": {"uid": b"124", "raw": raw_second, "seen": False},
+ }
+ fail_once = {"pending": True}
+
+ class FlakyIMAP:
+ def login(self, _user: str, _pw: str):
+ return "OK", [b"logged in"]
+
+ def select(self, _mailbox: str):
+ return "OK", [b"2"]
+
+ def search(self, *_args):
+ unseen_ids = [imap_id for imap_id, item in mailbox_state.items() if not item["seen"]]
+ return "OK", [b" ".join(unseen_ids)]
+
+ def fetch(self, imap_id: bytes, _parts: str):
+ if imap_id == b"2" and fail_once["pending"]:
+ fail_once["pending"] = False
+ raise imaplib.IMAP4.abort("socket error")
+ item = mailbox_state[imap_id]
+ header = b"%s (UID %s BODY[] {200})" % (imap_id, item["uid"])
+ return "OK", [(header, item["raw"]), b")"]
+
+ def store(self, imap_id: bytes, _op: str, _flags: str):
+ mailbox_state[imap_id]["seen"] = True
+ return "OK", [b""]
+
+ def logout(self):
+ return "BYE", [b""]
+
+ monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: FlakyIMAP())
+
+ channel = EmailChannel(_make_config(), MessageBus())
+ items = channel._fetch_new_messages()
+
+ assert [item["subject"] for item in items] == ["First", "Second"]
+
+
+def test_fetch_new_messages_skips_missing_mailbox(monkeypatch) -> None:
+ class MissingMailboxIMAP:
+ def login(self, _user: str, _pw: str):
+ return "OK", [b"logged in"]
+
+ def select(self, _mailbox: str):
+ raise imaplib.IMAP4.error("Mailbox doesn't exist")
+
+ def logout(self):
+ return "BYE", [b""]
+
+ monkeypatch.setattr(
+ "nanobot.channels.email.imaplib.IMAP4_SSL",
+ lambda _h, _p: MissingMailboxIMAP(),
+ )
+
+ channel = EmailChannel(_make_config(), MessageBus())
+
+ assert channel._fetch_new_messages() == []
+
+
def test_extract_text_body_falls_back_to_html() -> None:
msg = EmailMessage()
msg["From"] = "alice@example.com"
diff --git a/tests/test_onboard_logic.py b/tests/test_onboard_logic.py
new file mode 100644
index 0000000..9e0f6f7
--- /dev/null
+++ b/tests/test_onboard_logic.py
@@ -0,0 +1,495 @@
+"""Unit tests for onboard core logic functions.
+
+These tests focus on the business logic behind the onboard wizard,
+without testing the interactive UI components.
+"""
+
+import json
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any, cast
+
+import pytest
+from pydantic import BaseModel, Field
+
+from nanobot.cli import onboard_wizard
+
+# Import functions to test
+from nanobot.cli.commands import _merge_missing_defaults
+from nanobot.cli.onboard_wizard import (
+ _BACK_PRESSED,
+ _configure_pydantic_model,
+ _format_value,
+ _get_field_display_name,
+ _get_field_type_info,
+ run_onboard,
+)
+from nanobot.config.schema import Config
+from nanobot.utils.helpers import sync_workspace_templates
+
+
+class TestMergeMissingDefaults:
+ """Tests for _merge_missing_defaults recursive config merging."""
+
+ def test_adds_missing_top_level_keys(self):
+ existing = {"a": 1}
+ defaults = {"a": 1, "b": 2, "c": 3}
+
+ result = _merge_missing_defaults(existing, defaults)
+
+ assert result == {"a": 1, "b": 2, "c": 3}
+
+ def test_preserves_existing_values(self):
+ existing = {"a": "custom_value"}
+ defaults = {"a": "default_value"}
+
+ result = _merge_missing_defaults(existing, defaults)
+
+ assert result == {"a": "custom_value"}
+
+ def test_merges_nested_dicts_recursively(self):
+ existing = {
+ "level1": {
+ "level2": {
+ "existing": "kept",
+ }
+ }
+ }
+ defaults = {
+ "level1": {
+ "level2": {
+ "existing": "replaced",
+ "added": "new",
+ },
+ "level2b": "also_new",
+ }
+ }
+
+ result = _merge_missing_defaults(existing, defaults)
+
+ assert result == {
+ "level1": {
+ "level2": {
+ "existing": "kept",
+ "added": "new",
+ },
+ "level2b": "also_new",
+ }
+ }
+
+ def test_returns_existing_if_not_dict(self):
+ assert _merge_missing_defaults("string", {"a": 1}) == "string"
+ assert _merge_missing_defaults([1, 2, 3], {"a": 1}) == [1, 2, 3]
+ assert _merge_missing_defaults(None, {"a": 1}) is None
+ assert _merge_missing_defaults(42, {"a": 1}) == 42
+
+ def test_returns_existing_if_defaults_not_dict(self):
+ assert _merge_missing_defaults({"a": 1}, "string") == {"a": 1}
+ assert _merge_missing_defaults({"a": 1}, None) == {"a": 1}
+
+ def test_handles_empty_dicts(self):
+ assert _merge_missing_defaults({}, {"a": 1}) == {"a": 1}
+ assert _merge_missing_defaults({"a": 1}, {}) == {"a": 1}
+ assert _merge_missing_defaults({}, {}) == {}
+
+ def test_backfills_channel_config(self):
+ """Real-world scenario: backfill missing channel fields."""
+ existing_channel = {
+ "enabled": False,
+ "appId": "",
+ "secret": "",
+ }
+ default_channel = {
+ "enabled": False,
+ "appId": "",
+ "secret": "",
+ "msgFormat": "plain",
+ "allowFrom": [],
+ }
+
+ result = _merge_missing_defaults(existing_channel, default_channel)
+
+ assert result["msgFormat"] == "plain"
+ assert result["allowFrom"] == []
+
+
+class TestGetFieldTypeInfo:
+ """Tests for _get_field_type_info type extraction."""
+
+ def test_extracts_str_type(self):
+ class Model(BaseModel):
+ field: str
+
+ type_name, inner = _get_field_type_info(Model.model_fields["field"])
+ assert type_name == "str"
+ assert inner is None
+
+ def test_extracts_int_type(self):
+ class Model(BaseModel):
+ count: int
+
+ type_name, inner = _get_field_type_info(Model.model_fields["count"])
+ assert type_name == "int"
+ assert inner is None
+
+ def test_extracts_bool_type(self):
+ class Model(BaseModel):
+ enabled: bool
+
+ type_name, inner = _get_field_type_info(Model.model_fields["enabled"])
+ assert type_name == "bool"
+ assert inner is None
+
+ def test_extracts_float_type(self):
+ class Model(BaseModel):
+ ratio: float
+
+ type_name, inner = _get_field_type_info(Model.model_fields["ratio"])
+ assert type_name == "float"
+ assert inner is None
+
+ def test_extracts_list_type_with_item_type(self):
+ class Model(BaseModel):
+ items: list[str]
+
+ type_name, inner = _get_field_type_info(Model.model_fields["items"])
+ assert type_name == "list"
+ assert inner is str
+
+ def test_extracts_list_type_without_item_type(self):
+ # Plain list without type param falls back to str
+ class Model(BaseModel):
+ items: list # type: ignore
+
+ # Plain list annotation doesn't match list check, returns str
+ type_name, inner = _get_field_type_info(Model.model_fields["items"])
+ assert type_name == "str" # Falls back to str for untyped list
+ assert inner is None
+
+ def test_extracts_dict_type(self):
+ # Plain dict without type param falls back to str
+ class Model(BaseModel):
+ data: dict # type: ignore
+
+ # Plain dict annotation doesn't match dict check, returns str
+ type_name, inner = _get_field_type_info(Model.model_fields["data"])
+ assert type_name == "str" # Falls back to str for untyped dict
+ assert inner is None
+
+ def test_extracts_optional_type(self):
+ class Model(BaseModel):
+ optional: str | None = None
+
+ type_name, inner = _get_field_type_info(Model.model_fields["optional"])
+ # Should unwrap Optional and get str
+ assert type_name == "str"
+ assert inner is None
+
+ def test_extracts_nested_model_type(self):
+ class Inner(BaseModel):
+ x: int
+
+ class Outer(BaseModel):
+ nested: Inner
+
+ type_name, inner = _get_field_type_info(Outer.model_fields["nested"])
+ assert type_name == "model"
+ assert inner is Inner
+
+ def test_handles_none_annotation(self):
+ """Field with None annotation defaults to str."""
+ class Model(BaseModel):
+ field: Any = None
+
+ # Create a mock field_info with None annotation
+ field_info = SimpleNamespace(annotation=None)
+ type_name, inner = _get_field_type_info(field_info)
+ assert type_name == "str"
+ assert inner is None
+
+
+class TestGetFieldDisplayName:
+ """Tests for _get_field_display_name human-readable name generation."""
+
+ def test_uses_description_if_present(self):
+ class Model(BaseModel):
+ api_key: str = Field(description="API Key for authentication")
+
+ name = _get_field_display_name("api_key", Model.model_fields["api_key"])
+ assert name == "API Key for authentication"
+
+ def test_converts_snake_case_to_title(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("user_name", field_info)
+ assert name == "User Name"
+
+ def test_adds_url_suffix(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("api_url", field_info)
+ # Title case: "Api Url"
+ assert "Url" in name and "Api" in name
+
+ def test_adds_path_suffix(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("file_path", field_info)
+ assert "Path" in name and "File" in name
+
+ def test_adds_id_suffix(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("user_id", field_info)
+ # Title case: "User Id"
+ assert "Id" in name and "User" in name
+
+ def test_adds_key_suffix(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("api_key", field_info)
+ assert "Key" in name and "Api" in name
+
+ def test_adds_token_suffix(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("auth_token", field_info)
+ assert "Token" in name and "Auth" in name
+
+ def test_adds_seconds_suffix(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("timeout_s", field_info)
+ # Contains "(Seconds)" with title case
+ assert "(Seconds)" in name or "(seconds)" in name
+
+ def test_adds_ms_suffix(self):
+ field_info = SimpleNamespace(description=None)
+ name = _get_field_display_name("delay_ms", field_info)
+ # Contains "(Ms)" or "(ms)"
+ assert "(Ms)" in name or "(ms)" in name
+
+
+class TestFormatValue:
+ """Tests for _format_value display formatting."""
+
+ def test_formats_none_as_not_set(self):
+ assert "not set" in _format_value(None)
+
+ def test_formats_empty_string_as_not_set(self):
+ assert "not set" in _format_value("")
+
+ def test_formats_empty_dict_as_not_set(self):
+ assert "not set" in _format_value({})
+
+ def test_formats_empty_list_as_not_set(self):
+ assert "not set" in _format_value([])
+
+ def test_formats_string_value(self):
+ result = _format_value("hello")
+ assert "hello" in result
+
+ def test_formats_list_value(self):
+ result = _format_value(["a", "b"])
+ assert "a" in result or "b" in result
+
+ def test_formats_dict_value(self):
+ result = _format_value({"key": "value"})
+ assert "key" in result or "value" in result
+
+ def test_formats_int_value(self):
+ result = _format_value(42)
+ assert "42" in result
+
+ def test_formats_bool_true(self):
+ result = _format_value(True)
+ assert "true" in result.lower() or "✓" in result
+
+ def test_formats_bool_false(self):
+ result = _format_value(False)
+ assert "false" in result.lower() or "✗" in result
+
+
+class TestSyncWorkspaceTemplates:
+ """Tests for sync_workspace_templates file synchronization."""
+
+ def test_creates_missing_files(self, tmp_path):
+ """Should create template files that don't exist."""
+ workspace = tmp_path / "workspace"
+
+ added = sync_workspace_templates(workspace, silent=True)
+
+ # Check that some files were created
+ assert isinstance(added, list)
+ # The actual files depend on the templates directory
+
+ def test_does_not_overwrite_existing_files(self, tmp_path):
+ """Should not overwrite files that already exist."""
+ workspace = tmp_path / "workspace"
+ workspace.mkdir(parents=True)
+ (workspace / "AGENTS.md").write_text("existing content")
+
+ sync_workspace_templates(workspace, silent=True)
+
+ # Existing file should not be changed
+ content = (workspace / "AGENTS.md").read_text()
+ assert content == "existing content"
+
+ def test_creates_memory_directory(self, tmp_path):
+ """Should create memory directory structure."""
+ workspace = tmp_path / "workspace"
+
+ sync_workspace_templates(workspace, silent=True)
+
+ assert (workspace / "memory").exists() or (workspace / "skills").exists()
+
+ def test_returns_list_of_added_files(self, tmp_path):
+ """Should return list of relative paths for added files."""
+ workspace = tmp_path / "workspace"
+
+ added = sync_workspace_templates(workspace, silent=True)
+
+ assert isinstance(added, list)
+ # All paths should be relative to workspace
+ for path in added:
+ assert not Path(path).is_absolute()
+
+
+class TestProviderChannelInfo:
+ """Tests for provider and channel info retrieval."""
+
+ def test_get_provider_names_returns_dict(self):
+ from nanobot.cli.onboard_wizard import _get_provider_names
+
+ names = _get_provider_names()
+ assert isinstance(names, dict)
+ assert len(names) > 0
+ # Should include common providers
+ assert "openai" in names or "anthropic" in names
+ assert "openai_codex" not in names
+ assert "github_copilot" not in names
+
+ def test_get_channel_names_returns_dict(self):
+ from nanobot.cli.onboard_wizard import _get_channel_names
+
+ names = _get_channel_names()
+ assert isinstance(names, dict)
+ # Should include at least some channels
+ assert len(names) >= 0
+
+ def test_get_provider_info_returns_valid_structure(self):
+ from nanobot.cli.onboard_wizard import _get_provider_info
+
+ info = _get_provider_info()
+ assert isinstance(info, dict)
+ # Each value should be a tuple with expected structure
+ for provider_name, value in info.items():
+ assert isinstance(value, tuple)
+ assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var)
+
+
+class _SimpleDraftModel(BaseModel):
+ api_key: str = ""
+
+
+class _NestedDraftModel(BaseModel):
+ api_key: str = ""
+
+
+class _OuterDraftModel(BaseModel):
+ nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel)
+
+
+class TestConfigurePydanticModelDrafts:
+ @staticmethod
+ def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"):
+ sequence = iter(tokens)
+
+ def fake_select(_prompt, choices, default=None):
+ token = next(sequence)
+ if token == "first":
+ return choices[0]
+ if token == "done":
+ return "[Done]"
+ if token == "back":
+ return _BACK_PRESSED
+ return token
+
+ monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
+ monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None)
+ monkeypatch.setattr(
+ onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value
+ )
+
+ def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch):
+ model = _SimpleDraftModel()
+ self._patch_prompt_helpers(monkeypatch, ["first", "back"])
+
+ result = _configure_pydantic_model(model, "Simple")
+
+ assert result is None
+ assert model.api_key == ""
+
+ def test_completing_section_returns_updated_draft(self, monkeypatch):
+ model = _SimpleDraftModel()
+ self._patch_prompt_helpers(monkeypatch, ["first", "done"])
+
+ result = _configure_pydantic_model(model, "Simple")
+
+ assert result is not None
+ updated = cast(_SimpleDraftModel, result)
+ assert updated.api_key == "secret"
+ assert model.api_key == ""
+
+ def test_nested_section_back_discards_nested_edits(self, monkeypatch):
+ model = _OuterDraftModel()
+ self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"])
+
+ result = _configure_pydantic_model(model, "Outer")
+
+ assert result is not None
+ updated = cast(_OuterDraftModel, result)
+ assert updated.nested.api_key == ""
+ assert model.nested.api_key == ""
+
+ def test_nested_section_done_commits_nested_edits(self, monkeypatch):
+ model = _OuterDraftModel()
+ self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"])
+
+ result = _configure_pydantic_model(model, "Outer")
+
+ assert result is not None
+ updated = cast(_OuterDraftModel, result)
+ assert updated.nested.api_key == "secret"
+ assert model.nested.api_key == ""
+
+
+class TestRunOnboardExitBehavior:
+ def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch):
+ initial_config = Config()
+
+ responses = iter(
+ [
+ "[A] Agent Settings",
+ KeyboardInterrupt(),
+ "[X] Exit Without Saving",
+ ]
+ )
+
+ class FakePrompt:
+ def __init__(self, response):
+ self.response = response
+
+ def ask(self):
+ if isinstance(self.response, BaseException):
+ raise self.response
+ return self.response
+
+ def fake_select(*_args, **_kwargs):
+ return FakePrompt(next(responses))
+
+ def fake_configure_general_settings(config, section):
+ if section == "Agent Settings":
+ config.agents.defaults.model = "test/provider-model"
+
+ monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
+ monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
+ monkeypatch.setattr(onboard_wizard, "_configure_general_settings", fake_configure_general_settings)
+
+ result = run_onboard(initial_config=initial_config)
+
+ assert result.should_save is False
+ assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)
diff --git a/tests/test_restart_command.py b/tests/test_restart_command.py
index c495347..5cd8aa7 100644
--- a/tests/test_restart_command.py
+++ b/tests/test_restart_command.py
@@ -65,6 +65,18 @@ class TestRestartCommand:
mock_handle.assert_called_once()
+ @pytest.mark.asyncio
+ async def test_run_propagates_external_cancellation(self):
+ """External task cancellation should not be swallowed by the inbound wait loop."""
+ loop, _bus = _make_loop()
+
+ run_task = asyncio.create_task(loop.run())
+ await asyncio.sleep(0.1)
+ run_task.cancel()
+
+ with pytest.raises(asyncio.CancelledError):
+ await asyncio.wait_for(run_task, timeout=1.0)
+
@pytest.mark.asyncio
async def test_help_includes_restart(self):
loop, bus = _make_loop()
diff --git a/tests/test_task_cancel.py b/tests/test_task_cancel.py
index 62ab2cc..5bc2ea9 100644
--- a/tests/test_task_cancel.py
+++ b/tests/test_task_cancel.py
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
-def _make_loop():
+def _make_loop(*, exec_config=None):
"""Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
@@ -23,7 +23,7 @@ def _make_loop():
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
- loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
+ loop = AgentLoop(bus=bus, provider=provider, workspace=workspace, exec_config=exec_config)
return loop, bus
@@ -90,6 +90,13 @@ class TestHandleStop:
class TestDispatch:
+ def test_exec_tool_not_registered_when_disabled(self):
+ from nanobot.config.schema import ExecToolConfig
+
+ loop, _bus = _make_loop(exec_config=ExecToolConfig(enable=False))
+
+ assert loop.tools.get("exec") is None
+
@pytest.mark.asyncio
async def test_dispatch_processes_and_publishes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage
diff --git a/tests/test_telegram_channel.py b/tests/test_telegram_channel.py
index 414f9de..98b2644 100644
--- a/tests/test_telegram_channel.py
+++ b/tests/test_telegram_channel.py
@@ -18,6 +18,10 @@ class _FakeHTTPXRequest:
self.kwargs = kwargs
self.__class__.instances.append(self)
+ @classmethod
+ def clear(cls) -> None:
+ cls.instances.clear()
+
class _FakeUpdater:
def __init__(self, on_start_polling) -> None:
@@ -144,7 +148,8 @@ def _make_telegram_update(
@pytest.mark.asyncio
-async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
+async def test_start_creates_separate_pools_with_proxy(monkeypatch) -> None:
+ _FakeHTTPXRequest.clear()
config = TelegramConfig(
enabled=True,
token="123:abc",
@@ -164,10 +169,106 @@ async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> No
await channel.start()
- assert len(_FakeHTTPXRequest.instances) == 1
- assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
- assert builder.request_value is _FakeHTTPXRequest.instances[0]
- assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
+ assert len(_FakeHTTPXRequest.instances) == 2
+ api_req, poll_req = _FakeHTTPXRequest.instances
+ assert api_req.kwargs["proxy"] == config.proxy
+ assert poll_req.kwargs["proxy"] == config.proxy
+ assert api_req.kwargs["connection_pool_size"] == 32
+ assert poll_req.kwargs["connection_pool_size"] == 4
+ assert builder.request_value is api_req
+ assert builder.get_updates_request_value is poll_req
+
+
+@pytest.mark.asyncio
+async def test_start_respects_custom_pool_config(monkeypatch) -> None:
+ _FakeHTTPXRequest.clear()
+ config = TelegramConfig(
+ enabled=True,
+ token="123:abc",
+ allow_from=["*"],
+ connection_pool_size=32,
+ pool_timeout=10.0,
+ )
+ bus = MessageBus()
+ channel = TelegramChannel(config, bus)
+ app = _FakeApp(lambda: setattr(channel, "_running", False))
+ builder = _FakeBuilder(app)
+
+ monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
+ monkeypatch.setattr(
+ "nanobot.channels.telegram.Application",
+ SimpleNamespace(builder=lambda: builder),
+ )
+
+ await channel.start()
+
+ api_req = _FakeHTTPXRequest.instances[0]
+ poll_req = _FakeHTTPXRequest.instances[1]
+ assert api_req.kwargs["connection_pool_size"] == 32
+ assert api_req.kwargs["pool_timeout"] == 10.0
+ assert poll_req.kwargs["pool_timeout"] == 10.0
+
+
+@pytest.mark.asyncio
+async def test_send_text_retries_on_timeout() -> None:
+ """_send_text retries on TimedOut before succeeding."""
+ from telegram.error import TimedOut
+
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
+ MessageBus(),
+ )
+ channel._app = _FakeApp(lambda: None)
+
+ call_count = 0
+ original_send = channel._app.bot.send_message
+
+ async def flaky_send(**kwargs):
+ nonlocal call_count
+ call_count += 1
+ if call_count <= 2:
+ raise TimedOut()
+ return await original_send(**kwargs)
+
+ channel._app.bot.send_message = flaky_send
+
+ import nanobot.channels.telegram as tg_mod
+ orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
+ tg_mod._SEND_RETRY_BASE_DELAY = 0.01
+ try:
+ await channel._send_text(123, "hello", None, {})
+ finally:
+ tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
+
+ assert call_count == 3
+ assert len(channel._app.bot.sent_messages) == 1
+
+
+@pytest.mark.asyncio
+async def test_send_text_gives_up_after_max_retries() -> None:
+ """_send_text raises TimedOut after exhausting all retries."""
+ from telegram.error import TimedOut
+
+ channel = TelegramChannel(
+ TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
+ MessageBus(),
+ )
+ channel._app = _FakeApp(lambda: None)
+
+ async def always_timeout(**kwargs):
+ raise TimedOut()
+
+ channel._app.bot.send_message = always_timeout
+
+ import nanobot.channels.telegram as tg_mod
+ orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
+ tg_mod._SEND_RETRY_BASE_DELAY = 0.01
+ try:
+ await channel._send_text(123, "hello", None, {})
+ finally:
+ tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
+
+ assert channel._app.bot.sent_messages == []
def test_derive_topic_session_key_uses_thread_id() -> None:
diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py
index 1d822b3..e817f37 100644
--- a/tests/test_tool_validation.py
+++ b/tests/test_tool_validation.py
@@ -406,3 +406,64 @@ async def test_exec_timeout_capped_at_max() -> None:
# Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result
+
+
+# --- _resolve_type and nullable param tests ---
+
+
+def test_resolve_type_simple_string() -> None:
+ """Simple string type passes through unchanged."""
+ assert Tool._resolve_type("string") == "string"
+
+
+def test_resolve_type_union_with_null() -> None:
+ """Union type ['string', 'null'] resolves to 'string'."""
+ assert Tool._resolve_type(["string", "null"]) == "string"
+
+
+def test_resolve_type_only_null() -> None:
+ """Union type ['null'] resolves to None (no non-null type)."""
+ assert Tool._resolve_type(["null"]) is None
+
+
+def test_resolve_type_none_input() -> None:
+ """None input passes through as None."""
+ assert Tool._resolve_type(None) is None
+
+
+def test_validate_nullable_param_accepts_string() -> None:
+ """Nullable string param should accept a string value."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"name": {"type": ["string", "null"]}},
+ }
+ )
+ errors = tool.validate_params({"name": "hello"})
+ assert errors == []
+
+
+def test_validate_nullable_param_accepts_none() -> None:
+ """Nullable string param should accept None."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"name": {"type": ["string", "null"]}},
+ }
+ )
+ errors = tool.validate_params({"name": None})
+ assert errors == []
+
+
+def test_cast_nullable_param_no_crash() -> None:
+ """cast_params should not crash on nullable type (the original bug)."""
+ tool = CastTestTool(
+ {
+ "type": "object",
+ "properties": {"name": {"type": ["string", "null"]}},
+ }
+ )
+ result = tool.cast_params({"name": "hello"})
+ assert result["name"] == "hello"
+ result = tool.cast_params({"name": None})
+ assert result["name"] is None