From c3a4b16e76df8e001b63d8142669725b765a8918 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 20 Mar 2026 07:53:18 +0000 Subject: [PATCH] refactor: optimize onboard wizard - mask secrets, remove emoji, reduce repetition - Mask sensitive fields (api_key/token/secret/password) in all display surfaces, showing only the last 4 characters - Replace all emoji with pure ASCII labels for consistent cross-platform terminal rendering - Extract _print_summary_panel helper, eliminating 5x duplicate table construction in _show_summary - Replace 3 one-line wrapper functions with declarative _SETTINGS_SECTIONS dispatch tables and _MENU_DISPATCH in run_onboard - Extract _handle_model_field / _handle_context_window_field into a _FIELD_HANDLERS registry, shrinking _configure_pydantic_model - Return FieldTypeInfo NamedTuple from _get_field_type_info for clarity - Replace global mutable _PROVIDER_INFO / _CHANNEL_INFO with @lru_cache - Use vars() instead of dir() in _get_channel_info for reliable config class discovery - Defer litellm import in model_info.py so non-wizard CLI paths stay fast - Clarify README Quick Start wording (Add -> Configure) --- README.md | 5 +- nanobot/cli/commands.py | 20 +- nanobot/cli/model_info.py | 13 +- nanobot/cli/onboard_wizard.py | 594 +++++++++++++++------------------ tests/test_commands.py | 38 ++- tests/test_config_migration.py | 4 +- tests/test_onboard_logic.py | 15 +- 7 files changed, 344 insertions(+), 345 deletions(-) diff --git a/README.md b/README.md index 9fbec37..8ac23a0 100644 --- a/README.md +++ b/README.md @@ -191,9 +191,11 @@ nanobot channels login nanobot onboard ``` +Use `nanobot onboard --wizard` if you want the interactive setup wizard. + **2. Configure** (`~/.nanobot/config.json`) -Add or merge these **two parts** into your config (other options have defaults). +Configure these **two parts** in your config (other options have defaults). *Set your API key* (e.g. OpenRouter, recommended for global users): ```json @@ -1288,6 +1290,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo | Command | Description | |---------|-------------| | `nanobot onboard` | Initialize config & workspace at `~/.nanobot/` | +| `nanobot onboard --wizard` | Launch the interactive onboarding wizard | | `nanobot onboard -c -w ` | Initialize or refresh a specific instance config and workspace | | `nanobot agent -m "..."` | Chat with the agent | | `nanobot agent -w ` | Chat against a specific workspace | diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index efea399..de49668 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -264,7 +264,7 @@ def main( def onboard( workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), - interactive: bool = typer.Option(True, "--interactive/--no-interactive", help="Use interactive wizard"), + wizard: bool = typer.Option(False, "--wizard", help="Use interactive wizard"), ): """Initialize nanobot configuration and workspace.""" from nanobot.config.loader import get_config_path, load_config, save_config, set_config_path @@ -284,7 +284,7 @@ def onboard( # Create or update config if config_path.exists(): - if interactive: + if wizard: config = _apply_workspace_override(load_config(config_path)) else: console.print(f"[yellow]Config already exists at {config_path}[/yellow]") @@ -300,13 +300,13 @@ def onboard( console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)") else: config = _apply_workspace_override(Config()) - # In interactive mode, don't save yet - the wizard will handle saving if should_save=True - if not interactive: + # In wizard mode, don't save yet - the wizard will handle saving if should_save=True + if not wizard: save_config(config, config_path) console.print(f"[green]✓[/green] Created config at {config_path}") # Run interactive wizard if enabled - if interactive: + if wizard: from nanobot.cli.onboard_wizard import run_onboard try: @@ -336,14 +336,16 @@ def onboard( sync_workspace_templates(workspace_path) agent_cmd = 'nanobot agent -m "Hello!"' - if config_path: + gateway_cmd = "nanobot gateway" + if config: agent_cmd += f" --config {config_path}" + gateway_cmd += f" --config {config_path}" console.print(f"\n{__logo__} nanobot is ready!") console.print("\nNext steps:") - if interactive: - console.print(" 1. Chat: [cyan]nanobot agent -m \"Hello!\"[/cyan]") - console.print(" 2. Start gateway: [cyan]nanobot gateway[/cyan]") + if wizard: + console.print(f" 1. Chat: [cyan]{agent_cmd}[/cyan]") + console.print(f" 2. Start gateway: [cyan]{gateway_cmd}[/cyan]") else: console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]") console.print(" Get one at: https://openrouter.ai/keys") diff --git a/nanobot/cli/model_info.py b/nanobot/cli/model_info.py index 2bcd4af..520370c 100644 --- a/nanobot/cli/model_info.py +++ b/nanobot/cli/model_info.py @@ -8,13 +8,18 @@ from __future__ import annotations from functools import lru_cache from typing import Any -import litellm + +def _litellm(): + """Lazy accessor for litellm (heavy import deferred until actually needed).""" + import litellm as _ll + + return _ll @lru_cache(maxsize=1) def _get_model_cost_map() -> dict[str, Any]: """Get litellm's model cost map (cached).""" - return getattr(litellm, "model_cost", {}) + return getattr(_litellm(), "model_cost", {}) @lru_cache(maxsize=1) @@ -30,7 +35,7 @@ def get_all_models() -> list[str]: models.add(k) # From models_by_provider (more complete provider coverage) - for provider_models in getattr(litellm, "models_by_provider", {}).values(): + for provider_models in getattr(_litellm(), "models_by_provider", {}).values(): if isinstance(provider_models, (set, list)): models.update(provider_models) @@ -126,7 +131,7 @@ def get_model_context_limit(model: str, provider: str = "auto") -> int | None: # Fall back to litellm's get_max_tokens (returns max_output_tokens typically) try: - result = litellm.get_max_tokens(model) + result = _litellm().get_max_tokens(model) if result and result > 0: return result except (KeyError, ValueError, AttributeError): diff --git a/nanobot/cli/onboard_wizard.py b/nanobot/cli/onboard_wizard.py index ea41bc8..f661375 100644 --- a/nanobot/cli/onboard_wizard.py +++ b/nanobot/cli/onboard_wizard.py @@ -3,9 +3,13 @@ import json import types from dataclasses import dataclass -from typing import Any, get_args, get_origin +from functools import lru_cache +from typing import Any, NamedTuple, get_args, get_origin -import questionary +try: + import questionary +except ModuleNotFoundError: # pragma: no cover - exercised in environments without wizard deps + questionary = None from loguru import logger from pydantic import BaseModel from rich.console import Console @@ -37,7 +41,7 @@ class OnboardResult: _SELECT_FIELD_HINTS: dict[str, tuple[list[str], str]] = { "reasoning_effort": ( ["low", "medium", "high"], - "low / medium / high — enables LLM thinking mode", + "low / medium / high - enables LLM thinking mode", ), } @@ -46,6 +50,16 @@ _SELECT_FIELD_HINTS: dict[str, tuple[list[str], str]] = { _BACK_PRESSED = object() # Sentinel value for back navigation +def _get_questionary(): + """Return questionary or raise a clear error when wizard deps are unavailable.""" + if questionary is None: + raise RuntimeError( + "Interactive onboarding requires the optional 'questionary' dependency. " + "Install project dependencies and rerun with --wizard." + ) + return questionary + + def _select_with_back( prompt: str, choices: list[str], default: str | None = None ) -> str | None | object: @@ -87,7 +101,7 @@ def _select_with_back( items = [] for i, choice in enumerate(choices): if i == selected_index: - items.append(("class:selected", f"→ {choice}\n")) + items.append(("class:selected", f"> {choice}\n")) else: items.append(("", f" {choice}\n")) return items @@ -96,7 +110,7 @@ def _select_with_back( menu_control = FormattedTextControl(get_menu_text) menu_window = Window(content=menu_control, height=len(choices)) - prompt_control = FormattedTextControl(lambda: [("class:question", f"→ {prompt}")]) + prompt_control = FormattedTextControl(lambda: [("class:question", f"> {prompt}")]) prompt_window = Window(content=prompt_control, height=1) layout = Layout(HSplit([prompt_window, menu_window])) @@ -154,21 +168,22 @@ def _select_with_back( # --- Type Introspection --- -def _get_field_type_info(field_info) -> tuple[str, Any]: - """Extract field type info from Pydantic field. +class FieldTypeInfo(NamedTuple): + """Result of field type introspection.""" - Returns: (type_name, inner_type) - - type_name: "str", "int", "float", "bool", "list", "dict", "model" - - inner_type: for list, the item type; for model, the model class - """ + type_name: str + inner_type: Any + + +def _get_field_type_info(field_info) -> FieldTypeInfo: + """Extract field type info from Pydantic field.""" annotation = field_info.annotation if annotation is None: - return "str", None + return FieldTypeInfo("str", None) origin = get_origin(annotation) args = get_args(annotation) - # Handle Optional[T] / T | None if origin is types.UnionType: non_none_args = [a for a in args if a is not type(None)] if len(non_none_args) == 1: @@ -176,33 +191,18 @@ def _get_field_type_info(field_info) -> tuple[str, Any]: origin = get_origin(annotation) args = get_args(annotation) - # Check for list + _SIMPLE_TYPES: dict[type, str] = {bool: "bool", int: "int", float: "float"} + if origin is list or (hasattr(origin, "__name__") and origin.__name__ == "List"): - if args: - return "list", args[0] - return "list", str - - # Check for dict + return FieldTypeInfo("list", args[0] if args else str) if origin is dict or (hasattr(origin, "__name__") and origin.__name__ == "Dict"): - return "dict", None - - # Check for bool - if annotation is bool or (hasattr(annotation, "__name__") and annotation.__name__ == "bool"): - return "bool", None - - # Check for int - if annotation is int or (hasattr(annotation, "__name__") and annotation.__name__ == "int"): - return "int", None - - # Check for float - if annotation is float or (hasattr(annotation, "__name__") and annotation.__name__ == "float"): - return "float", None - - # Check if it's a nested BaseModel + return FieldTypeInfo("dict", None) + for py_type, name in _SIMPLE_TYPES.items(): + if annotation is py_type: + return FieldTypeInfo(name, None) if isinstance(annotation, type) and issubclass(annotation, BaseModel): - return "model", annotation - - return "str", None + return FieldTypeInfo("model", annotation) + return FieldTypeInfo("str", None) def _get_field_display_name(field_key: str, field_info) -> str: @@ -226,13 +226,33 @@ def _get_field_display_name(field_key: str, field_info) -> str: return name.replace("_", " ").title() +# --- Sensitive Field Masking --- + +_SENSITIVE_KEYWORDS = frozenset({"api_key", "token", "secret", "password", "credentials"}) + + +def _is_sensitive_field(field_name: str) -> bool: + """Check if a field name indicates sensitive content.""" + return any(kw in field_name.lower() for kw in _SENSITIVE_KEYWORDS) + + +def _mask_value(value: str) -> str: + """Mask a sensitive value, showing only the last 4 characters.""" + if len(value) <= 4: + return "****" + return "*" * (len(value) - 4) + value[-4:] + + # --- Value Formatting --- -def _format_value(value: Any, rich: bool = True) -> str: - """Format a value for display.""" +def _format_value(value: Any, rich: bool = True, field_name: str = "") -> str: + """Format a value for display, masking sensitive fields.""" if value is None or value == "" or value == {} or value == []: return "[dim]not set[/dim]" if rich else "[not set]" + if field_name and _is_sensitive_field(field_name) and isinstance(value, str): + masked = _mask_value(value) + return f"[dim]{masked}[/dim]" if rich else masked if isinstance(value, list): return ", ".join(str(v) for v in value) if isinstance(value, dict): @@ -260,10 +280,10 @@ def _show_config_panel(display_name: str, model: BaseModel, fields: list) -> Non table.add_column("Field", style="cyan") table.add_column("Value") - for field_name, field_info in fields: - value = getattr(model, field_name, None) - display = _get_field_display_name(field_name, field_info) - formatted = _format_value(value, rich=True) + for fname, field_info in fields: + value = getattr(model, fname, None) + display = _get_field_display_name(fname, field_info) + formatted = _format_value(value, rich=True, field_name=fname) table.add_row(display, formatted) console.print(Panel(table, title=f"[bold]{display_name}[/bold]", border_style="blue")) @@ -299,7 +319,7 @@ def _show_section_header(title: str, subtitle: str = "") -> None: def _input_bool(display_name: str, current: bool | None) -> bool | None: """Get boolean input via confirm dialog.""" - return questionary.confirm( + return _get_questionary().confirm( display_name, default=bool(current) if current is not None else False, ).ask() @@ -309,7 +329,7 @@ def _input_text(display_name: str, current: Any, field_type: str) -> Any: """Get text input and parse based on field type.""" default = _format_value_for_input(current, field_type) - value = questionary.text(f"{display_name}:", default=default).ask() + value = _get_questionary().text(f"{display_name}:", default=default).ask() if value is None or value == "": return None @@ -318,13 +338,13 @@ def _input_text(display_name: str, current: Any, field_type: str) -> Any: try: return int(value) except ValueError: - console.print("[yellow]⚠ Invalid number format, value not saved[/yellow]") + console.print("[yellow]! Invalid number format, value not saved[/yellow]") return None elif field_type == "float": try: return float(value) except ValueError: - console.print("[yellow]⚠ Invalid number format, value not saved[/yellow]") + console.print("[yellow]! Invalid number format, value not saved[/yellow]") return None elif field_type == "list": return [v.strip() for v in value.split(",") if v.strip()] @@ -332,7 +352,7 @@ def _input_text(display_name: str, current: Any, field_type: str) -> Any: try: return json.loads(value) except json.JSONDecodeError: - console.print("[yellow]⚠ Invalid JSON format, value not saved[/yellow]") + console.print("[yellow]! Invalid JSON format, value not saved[/yellow]") return None return value @@ -345,7 +365,7 @@ def _input_with_existing( has_existing = current is not None and current != "" and current != {} and current != [] if has_existing and not isinstance(current, list): - choice = questionary.select( + choice = _get_questionary().select( display_name, choices=["Enter new value", "Keep existing value"], default="Keep existing value", @@ -395,12 +415,12 @@ def _input_model_with_autocomplete( display=model, ) - value = questionary.autocomplete( + value = _get_questionary().autocomplete( f"{display_name}:", choices=[""], # Placeholder, actual completions from completer completer=DynamicModelCompleter(provider), default=default, - qmark="→", + qmark=">", ).ask() return value if value else None @@ -415,9 +435,9 @@ def _input_context_window_with_recommendation( choices = ["Enter new value"] if current_val: choices.append("Keep existing value") - choices.append("🔍 Get recommended value") + choices.append("[?] Get recommended value") - choice = questionary.select( + choice = _get_questionary().select( display_name, choices=choices, default="Enter new value", @@ -429,25 +449,25 @@ def _input_context_window_with_recommendation( if choice == "Keep existing value": return None - if choice == "🔍 Get recommended value": + if choice == "[?] Get recommended value": # Get the model name from the model object model_name = getattr(model_obj, "model", None) if not model_name: - console.print("[yellow]⚠ Please configure the model field first[/yellow]") + console.print("[yellow]! Please configure the model field first[/yellow]") return None provider = _get_current_provider(model_obj) context_limit = get_model_context_limit(model_name, provider) if context_limit: - console.print(f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]") + console.print(f"[green]+ Recommended context window: {format_token_count(context_limit)} tokens[/green]") return context_limit else: - console.print("[yellow]⚠ Could not fetch model info, please enter manually[/yellow]") + console.print("[yellow]! Could not fetch model info, please enter manually[/yellow]") # Fall through to manual input # Manual input - value = questionary.text( + value = _get_questionary().text( f"{display_name}:", default=str(current_val) if current_val else "", ).ask() @@ -458,10 +478,38 @@ def _input_context_window_with_recommendation( try: return int(value) except ValueError: - console.print("[yellow]⚠ Invalid number format, value not saved[/yellow]") + console.print("[yellow]! Invalid number format, value not saved[/yellow]") return None +def _handle_model_field( + working_model: BaseModel, field_name: str, field_display: str, current_value: Any +) -> None: + """Handle the 'model' field with autocomplete and context-window auto-fill.""" + provider = _get_current_provider(working_model) + new_value = _input_model_with_autocomplete(field_display, current_value, provider) + if new_value is not None and new_value != current_value: + setattr(working_model, field_name, new_value) + _try_auto_fill_context_window(working_model, new_value) + + +def _handle_context_window_field( + working_model: BaseModel, field_name: str, field_display: str, current_value: Any +) -> None: + """Handle context_window_tokens with recommendation lookup.""" + new_value = _input_context_window_with_recommendation( + field_display, current_value, working_model + ) + if new_value is not None: + setattr(working_model, field_name, new_value) + + +_FIELD_HANDLERS: dict[str, Any] = { + "model": _handle_model_field, + "context_window_tokens": _handle_context_window_field, +} + + def _configure_pydantic_model( model: BaseModel, display_name: str, @@ -476,35 +524,32 @@ def _configure_pydantic_model( skip_fields = skip_fields or set() working_model = model.model_copy(deep=True) - fields = [] - for field_name, field_info in type(working_model).model_fields.items(): - if field_name in skip_fields: - continue - fields.append((field_name, field_info)) - + fields = [ + (name, info) + for name, info in type(working_model).model_fields.items() + if name not in skip_fields + ] if not fields: console.print(f"[dim]{display_name}: No configurable fields[/dim]") return working_model def get_choices() -> list[str]: - choices = [] - for field_name, field_info in fields: - value = getattr(working_model, field_name, None) - display = _get_field_display_name(field_name, field_info) - formatted = _format_value(value, rich=False) - choices.append(f"{display}: {formatted}") - return choices + ["✓ Done"] + items = [] + for fname, finfo in fields: + value = getattr(working_model, fname, None) + display = _get_field_display_name(fname, finfo) + formatted = _format_value(value, rich=False, field_name=fname) + items.append(f"{display}: {formatted}") + return items + ["[Done]"] while True: _show_config_panel(display_name, working_model, fields) choices = get_choices() - answer = _select_with_back("Select field to configure:", choices) if answer is _BACK_PRESSED or answer is None: return None - - if answer == "✓ Done": + if answer == "[Done]": return working_model field_idx = next((i for i, c in enumerate(choices) if c == answer), -1) @@ -513,45 +558,30 @@ def _configure_pydantic_model( field_name, field_info = fields[field_idx] current_value = getattr(working_model, field_name, None) - field_type, _ = _get_field_type_info(field_info) + ftype = _get_field_type_info(field_info) field_display = _get_field_display_name(field_name, field_info) - if field_type == "model": - nested_model = current_value - created_nested_model = nested_model is None - if nested_model is None: - _, nested_cls = _get_field_type_info(field_info) - if nested_cls: - nested_model = nested_cls() - - if nested_model and isinstance(nested_model, BaseModel): - updated_nested_model = _configure_pydantic_model(nested_model, field_display) - if updated_nested_model is not None: - setattr(working_model, field_name, updated_nested_model) - elif created_nested_model: + # Nested Pydantic model - recurse + if ftype.type_name == "model": + nested = current_value + created = nested is None + if nested is None and ftype.inner_type: + nested = ftype.inner_type() + if nested and isinstance(nested, BaseModel): + updated = _configure_pydantic_model(nested, field_display) + if updated is not None: + setattr(working_model, field_name, updated) + elif created: setattr(working_model, field_name, None) continue - # Special handling for model field (autocomplete) - if field_name == "model": - provider = _get_current_provider(working_model) - new_value = _input_model_with_autocomplete(field_display, current_value, provider) - if new_value is not None and new_value != current_value: - setattr(working_model, field_name, new_value) - # Auto-fill context_window_tokens if it's at default value - _try_auto_fill_context_window(working_model, new_value) + # Registered special-field handlers + handler = _FIELD_HANDLERS.get(field_name) + if handler: + handler(working_model, field_name, field_display, current_value) continue - # Special handling for context_window_tokens field - if field_name == "context_window_tokens": - new_value = _input_context_window_with_recommendation( - field_display, current_value, working_model - ) - if new_value is not None: - setattr(working_model, field_name, new_value) - continue - - # Special handling for select fields with hints (e.g., reasoning_effort) + # Select fields with hints (e.g. reasoning_effort) if field_name in _SELECT_FIELD_HINTS: choices_list, hint = _SELECT_FIELD_HINTS[field_name] select_choices = choices_list + ["(clear/unset)"] @@ -567,14 +597,13 @@ def _configure_pydantic_model( setattr(working_model, field_name, new_value) continue - if field_type == "bool": + # Generic field input + if ftype.type_name == "bool": new_value = _input_bool(field_display, current_value) - if new_value is not None: - setattr(working_model, field_name, new_value) else: - new_value = _input_with_existing(field_display, current_value, field_type) - if new_value is not None: - setattr(working_model, field_name, new_value) + new_value = _input_with_existing(field_display, current_value, ftype.type_name) + if new_value is not None: + setattr(working_model, field_name, new_value) def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None: @@ -605,32 +634,28 @@ def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None if context_limit: setattr(model, "context_window_tokens", context_limit) - console.print(f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]") + console.print(f"[green]+ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]") else: - console.print("[dim]ℹ Could not auto-fill context window (model not in database)[/dim]") + console.print("[dim](i) Could not auto-fill context window (model not in database)[/dim]") # --- Provider Configuration --- -_PROVIDER_INFO: dict[str, tuple[str, bool, bool, str]] | None = None - - +@lru_cache(maxsize=1) def _get_provider_info() -> dict[str, tuple[str, bool, bool, str]]: """Get provider info from registry (cached).""" - global _PROVIDER_INFO - if _PROVIDER_INFO is None: - from nanobot.providers.registry import PROVIDERS + from nanobot.providers.registry import PROVIDERS - _PROVIDER_INFO = {} - for spec in PROVIDERS: - _PROVIDER_INFO[spec.name] = ( - spec.display_name or spec.name, - spec.is_gateway, - spec.is_local, - spec.default_api_base, - ) - return _PROVIDER_INFO + return { + spec.name: ( + spec.display_name or spec.name, + spec.is_gateway, + spec.is_local, + spec.default_api_base, + ) + for spec in PROVIDERS + } def _get_provider_names() -> dict[str, str]: @@ -671,23 +696,23 @@ def _configure_providers(config: Config) -> None: for name, display in _get_provider_names().items(): provider = getattr(config.providers, name, None) if provider and provider.api_key: - choices.append(f"{display} ✓") + choices.append(f"{display} *") else: choices.append(display) - return choices + ["← Back"] + return choices + ["<- Back"] while True: try: choices = get_provider_choices() answer = _select_with_back("Select provider:", choices) - if answer is _BACK_PRESSED or answer is None or answer == "← Back": + if answer is _BACK_PRESSED or answer is None or answer == "<- Back": break # Type guard: answer is now guaranteed to be a string assert isinstance(answer, str) - # Extract provider name from choice (remove " ✓" suffix if present) - provider_name = answer.replace(" ✓", "") + # Extract provider name from choice (remove " *" suffix if present) + provider_name = answer.replace(" *", "") # Find the actual provider key from display names for name, display in _get_provider_names().items(): if display == provider_name: @@ -702,51 +727,45 @@ def _configure_providers(config: Config) -> None: # --- Channel Configuration --- +@lru_cache(maxsize=1) def _get_channel_info() -> dict[str, tuple[str, type[BaseModel]]]: """Get channel info (display name + config class) from channel modules.""" import importlib from nanobot.channels.registry import discover_all - result = {} + result: dict[str, tuple[str, type[BaseModel]]] = {} for name, channel_cls in discover_all().items(): try: mod = importlib.import_module(f"nanobot.channels.{name}") - config_cls = None - display_name = name.capitalize() - for attr_name in dir(mod): - attr = getattr(mod, attr_name) - if isinstance(attr, type) and issubclass(attr, BaseModel) and attr is not BaseModel: - if "Config" in attr_name: - config_cls = attr - if hasattr(channel_cls, "display_name"): - display_name = channel_cls.display_name - break - + config_cls = next( + ( + attr + for attr in vars(mod).values() + if isinstance(attr, type) + and issubclass(attr, BaseModel) + and attr is not BaseModel + and attr.__name__.endswith("Config") + ), + None, + ) if config_cls: + display_name = getattr(channel_cls, "display_name", name.capitalize()) result[name] = (display_name, config_cls) except Exception: logger.warning(f"Failed to load channel module: {name}") return result -_CHANNEL_INFO: dict[str, tuple[str, type[BaseModel]]] | None = None - - def _get_channel_names() -> dict[str, str]: """Get channel display names.""" - global _CHANNEL_INFO - if _CHANNEL_INFO is None: - _CHANNEL_INFO = _get_channel_info() - return {name: info[0] for name, info in _CHANNEL_INFO.items() if name} + return {name: info[0] for name, info in _get_channel_info().items()} def _get_channel_config_class(channel: str) -> type[BaseModel] | None: """Get channel config class.""" - global _CHANNEL_INFO - if _CHANNEL_INFO is None: - _CHANNEL_INFO = _get_channel_info() - return _CHANNEL_INFO.get(channel, (None, None))[1] + entry = _get_channel_info().get(channel) + return entry[1] if entry else None def _configure_channel(config: Config, channel_name: str) -> None: @@ -779,13 +798,13 @@ def _configure_channels(config: Config) -> None: _show_section_header("Chat Channels", "Select a channel to configure connection settings") channel_names = list(_get_channel_names().keys()) - choices = channel_names + ["← Back"] + choices = channel_names + ["<- Back"] while True: try: answer = _select_with_back("Select channel:", choices) - if answer is _BACK_PRESSED or answer is None or answer == "← Back": + if answer is _BACK_PRESSED or answer is None or answer == "<- Back": break # Type guard: answer is now guaranteed to be a string @@ -798,113 +817,87 @@ def _configure_channels(config: Config) -> None: # --- General Settings --- +_SETTINGS_SECTIONS: dict[str, tuple[str, str, set[str] | None]] = { + "Agent Settings": ("Agent Defaults", "Configure default model, temperature, and behavior", None), + "Gateway": ("Gateway Settings", "Configure server host, port, and heartbeat", None), + "Tools": ("Tools Settings", "Configure web search, shell exec, and other tools", {"mcp_servers"}), +} + +_SETTINGS_GETTER = { + "Agent Settings": lambda c: c.agents.defaults, + "Gateway": lambda c: c.gateway, + "Tools": lambda c: c.tools, +} + +_SETTINGS_SETTER = { + "Agent Settings": lambda c, v: setattr(c.agents, "defaults", v), + "Gateway": lambda c, v: setattr(c, "gateway", v), + "Tools": lambda c, v: setattr(c, "tools", v), +} + def _configure_general_settings(config: Config, section: str) -> None: - """Configure a general settings section.""" - section_map = { - "Agent Settings": (config.agents.defaults, "Agent Defaults"), - "Gateway": (config.gateway, "Gateway Settings"), - "Tools": (config.tools, "Tools Settings"), - "Channel Common": (config.channels, "Channel Common Settings"), - } - - if section not in section_map: + """Configure a general settings section (header + model edit + writeback).""" + meta = _SETTINGS_SECTIONS.get(section) + if not meta: return + display_name, subtitle, skip = meta + _show_section_header(section, subtitle) - model, display_name = section_map[section] - - if section == "Tools": - updated_model = _configure_pydantic_model( - model, - display_name, - skip_fields={"mcp_servers"}, - ) - else: - updated_model = _configure_pydantic_model(model, display_name) - - if updated_model is None: - return - - if section == "Agent Settings": - config.agents.defaults = updated_model - elif section == "Gateway": - config.gateway = updated_model - elif section == "Tools": - config.tools = updated_model - elif section == "Channel Common": - config.channels = updated_model - - -def _configure_agents(config: Config) -> None: - """Configure agent settings.""" - _show_section_header("Agent Settings", "Configure default model, temperature, and behavior") - _configure_general_settings(config, "Agent Settings") - - -def _configure_gateway(config: Config) -> None: - """Configure gateway settings.""" - _show_section_header("Gateway", "Configure server host, port, and heartbeat") - _configure_general_settings(config, "Gateway") - - -def _configure_tools(config: Config) -> None: - """Configure tools settings.""" - _show_section_header("Tools", "Configure web search, shell exec, and other tools") - _configure_general_settings(config, "Tools") + model = _SETTINGS_GETTER[section](config) + updated = _configure_pydantic_model(model, display_name, skip_fields=skip) + if updated is not None: + _SETTINGS_SETTER[section](config, updated) # --- Summary --- -def _summarize_model(obj: BaseModel, indent: int = 2) -> list[tuple[str, str]]: +def _summarize_model(obj: BaseModel) -> list[tuple[str, str]]: """Recursively summarize a Pydantic model. Returns list of (field, value) tuples.""" - items = [] - + items: list[tuple[str, str]] = [] for field_name, field_info in type(obj).model_fields.items(): value = getattr(obj, field_name, None) - field_type, _ = _get_field_type_info(field_info) - if value is None or value == "" or value == {} or value == []: continue - display = _get_field_display_name(field_name, field_info) - - if field_type == "model" and isinstance(value, BaseModel): - nested_items = _summarize_model(value, indent) - for nested_field, nested_value in nested_items: + ftype = _get_field_type_info(field_info) + if ftype.type_name == "model" and isinstance(value, BaseModel): + for nested_field, nested_value in _summarize_model(value): items.append((f"{display}.{nested_field}", nested_value)) continue - - formatted = _format_value(value, rich=False) + formatted = _format_value(value, rich=False, field_name=field_name) if formatted != "[not set]": items.append((display, formatted)) - return items +def _print_summary_panel(rows: list[tuple[str, str]], title: str) -> None: + """Build a two-column summary panel and print it.""" + if not rows: + return + table = Table(show_header=False, box=None, padding=(0, 2)) + table.add_column("Setting", style="cyan") + table.add_column("Value") + for field, value in rows: + table.add_row(field, value) + console.print(Panel(table, title=f"[bold]{title}[/bold]", border_style="blue")) + + def _show_summary(config: Config) -> None: """Display configuration summary using rich.""" console.print() - # Providers table - provider_table = Table(show_header=False, box=None, padding=(0, 2)) - provider_table.add_column("Provider", style="cyan") - provider_table.add_column("Status") - + # Providers + provider_rows = [] for name, display in _get_provider_names().items(): provider = getattr(config.providers, name, None) - if provider and provider.api_key: - provider_table.add_row(display, "[green]✓ configured[/green]") - else: - provider_table.add_row(display, "[dim]not configured[/dim]") - - console.print(Panel(provider_table, title="[bold]LLM Providers[/bold]", border_style="blue")) - - # Channels table - channel_table = Table(show_header=False, box=None, padding=(0, 2)) - channel_table.add_column("Channel", style="cyan") - channel_table.add_column("Status") + status = "[green]configured[/green]" if (provider and provider.api_key) else "[dim]not configured[/dim]" + provider_rows.append((display, status)) + _print_summary_panel(provider_rows, "LLM Providers") + # Channels + channel_rows = [] for name, display in _get_channel_names().items(): channel = getattr(config.channels, name, None) if channel: @@ -913,54 +906,20 @@ def _show_summary(config: Config) -> None: if isinstance(channel, dict) else getattr(channel, "enabled", False) ) - if enabled: - channel_table.add_row(display, "[green]✓ enabled[/green]") - else: - channel_table.add_row(display, "[dim]disabled[/dim]") + status = "[green]enabled[/green]" if enabled else "[dim]disabled[/dim]" else: - channel_table.add_row(display, "[dim]not configured[/dim]") + status = "[dim]not configured[/dim]" + channel_rows.append((display, status)) + _print_summary_panel(channel_rows, "Chat Channels") - console.print(Panel(channel_table, title="[bold]Chat Channels[/bold]", border_style="blue")) - - # Agent Settings - agent_items = _summarize_model(config.agents.defaults) - if agent_items: - agent_table = Table(show_header=False, box=None, padding=(0, 2)) - agent_table.add_column("Setting", style="cyan") - agent_table.add_column("Value") - for field, value in agent_items: - agent_table.add_row(field, value) - console.print(Panel(agent_table, title="[bold]Agent Settings[/bold]", border_style="blue")) - - # Gateway - gateway_items = _summarize_model(config.gateway) - if gateway_items: - gw_table = Table(show_header=False, box=None, padding=(0, 2)) - gw_table.add_column("Setting", style="cyan") - gw_table.add_column("Value") - for field, value in gateway_items: - gw_table.add_row(field, value) - console.print(Panel(gw_table, title="[bold]Gateway[/bold]", border_style="blue")) - - # Tools - tools_items = _summarize_model(config.tools) - if tools_items: - tools_table = Table(show_header=False, box=None, padding=(0, 2)) - tools_table.add_column("Setting", style="cyan") - tools_table.add_column("Value") - for field, value in tools_items: - tools_table.add_row(field, value) - console.print(Panel(tools_table, title="[bold]Tools[/bold]", border_style="blue")) - - # Channel Common - channel_common_items = _summarize_model(config.channels) - if channel_common_items: - cc_table = Table(show_header=False, box=None, padding=(0, 2)) - cc_table.add_column("Setting", style="cyan") - cc_table.add_column("Value") - for field, value in channel_common_items: - cc_table.add_row(field, value) - console.print(Panel(cc_table, title="[bold]Channel Common[/bold]", border_style="blue")) + # Settings sections + for title, model in [ + ("Agent Settings", config.agents.defaults), + ("Gateway", config.gateway), + ("Tools", config.tools), + ("Channel Common", config.channels), + ]: + _print_summary_panel(_summarize_model(model), title) # --- Main Entry Point --- @@ -976,20 +935,20 @@ def _prompt_main_menu_exit(has_unsaved_changes: bool) -> str: if not has_unsaved_changes: return "discard" - answer = questionary.select( + answer = _get_questionary().select( "You have unsaved changes. What would you like to do?", choices=[ - "💾 Save and Exit", - "🗑️ Exit Without Saving", - "↩ Resume Editing", + "[S] Save and Exit", + "[X] Exit Without Saving", + "[R] Resume Editing", ], - default="↩ Resume Editing", - qmark="→", + default="[R] Resume Editing", + qmark=">", ).ask() - if answer == "💾 Save and Exit": + if answer == "[S] Save and Exit": return "save" - if answer == "🗑️ Exit Without Saving": + if answer == "[X] Exit Without Saving": return "discard" return "resume" @@ -1001,6 +960,8 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult: initial_config: Optional pre-loaded config to use as starting point. If None, loads from config file or creates new default. """ + _get_questionary() + if initial_config is not None: base_config = initial_config.model_copy(deep=True) else: @@ -1017,19 +978,19 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult: _show_main_menu_header() try: - answer = questionary.select( + answer = _get_questionary().select( "What would you like to configure?", choices=[ - "🔌 LLM Provider", - "💬 Chat Channel", - "🤖 Agent Settings", - "🌐 Gateway", - "🔧 Tools", - "📋 View Configuration Summary", - "💾 Save and Exit", - "🗑️ Exit Without Saving", + "[P] LLM Provider", + "[C] Chat Channel", + "[A] Agent Settings", + "[G] Gateway", + "[T] Tools", + "[V] View Configuration Summary", + "[S] Save and Exit", + "[X] Exit Without Saving", ], - qmark="→", + qmark=">", ).ask() except KeyboardInterrupt: answer = None @@ -1042,19 +1003,20 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult: return OnboardResult(config=original_config, should_save=False) continue - if answer == "🔌 LLM Provider": - _configure_providers(config) - elif answer == "💬 Chat Channel": - _configure_channels(config) - elif answer == "🤖 Agent Settings": - _configure_agents(config) - elif answer == "🌐 Gateway": - _configure_gateway(config) - elif answer == "🔧 Tools": - _configure_tools(config) - elif answer == "📋 View Configuration Summary": - _show_summary(config) - elif answer == "💾 Save and Exit": + _MENU_DISPATCH = { + "[P] LLM Provider": lambda: _configure_providers(config), + "[C] Chat Channel": lambda: _configure_channels(config), + "[A] Agent Settings": lambda: _configure_general_settings(config, "Agent Settings"), + "[G] Gateway": lambda: _configure_general_settings(config, "Gateway"), + "[T] Tools": lambda: _configure_general_settings(config, "Tools"), + "[V] View Configuration Summary": lambda: _show_summary(config), + } + + if answer == "[S] Save and Exit": return OnboardResult(config=config, should_save=True) - elif answer == "🗑️ Exit Without Saving": + if answer == "[X] Exit Without Saving": return OnboardResult(config=original_config, should_save=False) + + action_fn = _MENU_DISPATCH.get(answer) + if action_fn: + action_fn() diff --git a/tests/test_commands.py b/tests/test_commands.py index 38af553..08ed59e 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -61,7 +61,7 @@ def test_onboard_fresh_install(mock_paths): """No existing config — should create from scratch.""" config_file, workspace_dir, mock_ws = mock_paths - result = runner.invoke(app, ["onboard", "--no-interactive"]) + result = runner.invoke(app, ["onboard"]) assert result.exit_code == 0 assert "Created config" in result.stdout @@ -79,7 +79,7 @@ def test_onboard_existing_config_refresh(mock_paths): config_file, workspace_dir, _ = mock_paths config_file.write_text('{"existing": true}') - result = runner.invoke(app, ["onboard", "--no-interactive"], input="n\n") + result = runner.invoke(app, ["onboard"], input="n\n") assert result.exit_code == 0 assert "Config already exists" in result.stdout @@ -93,7 +93,7 @@ def test_onboard_existing_config_overwrite(mock_paths): config_file, workspace_dir, _ = mock_paths config_file.write_text('{"existing": true}') - result = runner.invoke(app, ["onboard", "--no-interactive"], input="y\n") + result = runner.invoke(app, ["onboard"], input="y\n") assert result.exit_code == 0 assert "Config already exists" in result.stdout @@ -107,7 +107,7 @@ def test_onboard_existing_workspace_safe_create(mock_paths): workspace_dir.mkdir(parents=True) config_file.write_text("{}") - result = runner.invoke(app, ["onboard", "--no-interactive"], input="n\n") + result = runner.invoke(app, ["onboard"], input="n\n") assert result.exit_code == 0 assert "Created workspace" not in result.stdout @@ -130,6 +130,7 @@ def test_onboard_help_shows_workspace_and_config_options(): assert "-w" in stripped_output assert "--config" in stripped_output assert "-c" in stripped_output + assert "--wizard" in stripped_output assert "--dir" not in stripped_output @@ -143,7 +144,7 @@ def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_path lambda initial_config: OnboardResult(config=initial_config, should_save=False), ) - result = runner.invoke(app, ["onboard"]) + result = runner.invoke(app, ["onboard", "--wizard"]) assert result.exit_code == 0 assert "No changes were saved" in result.stdout @@ -159,7 +160,7 @@ def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch) result = runner.invoke( app, - ["onboard", "--config", str(config_path), "--workspace", str(workspace_path), "--no-interactive"], + ["onboard", "--config", str(config_path), "--workspace", str(workspace_path)], ) assert result.exit_code == 0 @@ -173,6 +174,31 @@ def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch) assert f"--config {resolved_config}" in compact_output +def test_onboard_wizard_preserves_explicit_config_in_next_steps(tmp_path, monkeypatch): + config_path = tmp_path / "instance" / "config.json" + workspace_path = tmp_path / "workspace" + + from nanobot.cli.onboard_wizard import OnboardResult + + monkeypatch.setattr( + "nanobot.cli.onboard_wizard.run_onboard", + lambda initial_config: OnboardResult(config=initial_config, should_save=True), + ) + monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {}) + + result = runner.invoke( + app, + ["onboard", "--wizard", "--config", str(config_path), "--workspace", str(workspace_path)], + ) + + assert result.exit_code == 0 + stripped_output = _strip_ansi(result.stdout) + compact_output = stripped_output.replace("\n", "") + resolved_config = str(config_path.resolve()) + assert f'nanobot agent -m "Hello!" --config {resolved_config}' in compact_output + assert f"nanobot gateway --config {resolved_config}" in compact_output + + def test_config_matches_github_copilot_codex_with_hyphen_prefix(): config = Config() config.agents.defaults.model = "github-copilot/gpt-5.3-codex" diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py index 28e0feb..7728c26 100644 --- a/tests/test_config_migration.py +++ b/tests/test_config_migration.py @@ -75,7 +75,7 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) from typer.testing import CliRunner from nanobot.cli.commands import app runner = CliRunner() - result = runner.invoke(app, ["onboard", "--no-interactive"], input="n\n") + result = runner.invoke(app, ["onboard"], input="n\n") assert result.exit_code == 0 assert "contextWindowTokens" in result.stdout @@ -127,7 +127,7 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) from typer.testing import CliRunner from nanobot.cli.commands import app runner = CliRunner() - result = runner.invoke(app, ["onboard", "--no-interactive"], input="n\n") + result = runner.invoke(app, ["onboard"], input="n\n") assert result.exit_code == 0 saved = json.loads(config_path.read_text(encoding="utf-8")) diff --git a/tests/test_onboard_logic.py b/tests/test_onboard_logic.py index 5ac08a5..fbcb4fb 100644 --- a/tests/test_onboard_logic.py +++ b/tests/test_onboard_logic.py @@ -401,7 +401,7 @@ class TestConfigurePydanticModelDrafts: if token == "first": return choices[0] if token == "done": - return "✓ Done" + return "[Done]" if token == "back": return _BACK_PRESSED return token @@ -461,9 +461,9 @@ class TestRunOnboardExitBehavior: responses = iter( [ - "🤖 Configure Agent Settings", + "[A] Agent Settings", KeyboardInterrupt(), - "🗑️ Exit Without Saving", + "[X] Exit Without Saving", ] ) @@ -479,12 +479,13 @@ class TestRunOnboardExitBehavior: def fake_select(*_args, **_kwargs): return FakePrompt(next(responses)) - def fake_configure_agents(config): - config.agents.defaults.model = "test/provider-model" + def fake_configure_general_settings(config, section): + if section == "Agent Settings": + config.agents.defaults.model = "test/provider-model" monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None) - monkeypatch.setattr(onboard_wizard.questionary, "select", fake_select) - monkeypatch.setattr(onboard_wizard, "_configure_agents", fake_configure_agents) + monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select)) + monkeypatch.setattr(onboard_wizard, "_configure_general_settings", fake_configure_general_settings) result = run_onboard(initial_config=initial_config)