diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index dfb4a25..efea399 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -300,16 +300,22 @@ def onboard( console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)") else: config = _apply_workspace_override(Config()) - save_config(config, config_path) - console.print(f"[green]✓[/green] Created config at {config_path}") + # In interactive mode, don't save yet - the wizard will handle saving if should_save=True + if not interactive: + save_config(config, config_path) + console.print(f"[green]✓[/green] Created config at {config_path}") # Run interactive wizard if enabled if interactive: from nanobot.cli.onboard_wizard import run_onboard try: - # Pass the config with workspace override applied as initial config - config = run_onboard(initial_config=config) + result = run_onboard(initial_config=config) + if not result.should_save: + console.print("[yellow]Configuration discarded. No changes were saved.[/yellow]") + return + + config = result.config save_config(config, config_path) console.print(f"[green]✓[/green] Config saved at {config_path}") except Exception as e: diff --git a/nanobot/cli/onboard_wizard.py b/nanobot/cli/onboard_wizard.py index a4c06f3..ea41bc8 100644 --- a/nanobot/cli/onboard_wizard.py +++ b/nanobot/cli/onboard_wizard.py @@ -2,7 +2,8 @@ import json import types -from typing import Any, Callable, get_args, get_origin +from dataclasses import dataclass +from typing import Any, get_args, get_origin import questionary from loguru import logger @@ -21,6 +22,14 @@ from nanobot.config.schema import Config console = Console() + +@dataclass +class OnboardResult: + """Result of an onboarding session.""" + + config: Config + should_save: bool + # --- Field Hints for Select Fields --- # Maps field names to (choices, hint_text) # To add a new select field with hints, add an entry: @@ -458,83 +467,88 @@ def _configure_pydantic_model( display_name: str, *, skip_fields: set[str] | None = None, - finalize_hook: Callable | None = None, -) -> None: - """Configure a Pydantic model interactively.""" +) -> BaseModel | None: + """Configure a Pydantic model interactively. + + Returns the updated model only when the user explicitly selects "Done". + Back and cancel actions discard the section draft. + """ skip_fields = skip_fields or set() + working_model = model.model_copy(deep=True) fields = [] - for field_name, field_info in type(model).model_fields.items(): + for field_name, field_info in type(working_model).model_fields.items(): if field_name in skip_fields: continue fields.append((field_name, field_info)) if not fields: console.print(f"[dim]{display_name}: No configurable fields[/dim]") - return + return working_model def get_choices() -> list[str]: choices = [] for field_name, field_info in fields: - value = getattr(model, field_name, None) + value = getattr(working_model, field_name, None) display = _get_field_display_name(field_name, field_info) formatted = _format_value(value, rich=False) choices.append(f"{display}: {formatted}") return choices + ["✓ Done"] while True: - _show_config_panel(display_name, model, fields) + _show_config_panel(display_name, working_model, fields) choices = get_choices() answer = _select_with_back("Select field to configure:", choices) - if answer is _BACK_PRESSED: - # User pressed Escape or Left arrow - go back - if finalize_hook: - finalize_hook(model) - break + if answer is _BACK_PRESSED or answer is None: + return None - if answer == "✓ Done" or answer is None: - if finalize_hook: - finalize_hook(model) - break + if answer == "✓ Done": + return working_model field_idx = next((i for i, c in enumerate(choices) if c == answer), -1) if field_idx < 0 or field_idx >= len(fields): - break + return None field_name, field_info = fields[field_idx] - current_value = getattr(model, field_name, None) + current_value = getattr(working_model, field_name, None) field_type, _ = _get_field_type_info(field_info) field_display = _get_field_display_name(field_name, field_info) if field_type == "model": nested_model = current_value + created_nested_model = nested_model is None if nested_model is None: _, nested_cls = _get_field_type_info(field_info) if nested_cls: nested_model = nested_cls() - setattr(model, field_name, nested_model) if nested_model and isinstance(nested_model, BaseModel): - _configure_pydantic_model(nested_model, field_display) + updated_nested_model = _configure_pydantic_model(nested_model, field_display) + if updated_nested_model is not None: + setattr(working_model, field_name, updated_nested_model) + elif created_nested_model: + setattr(working_model, field_name, None) continue # Special handling for model field (autocomplete) if field_name == "model": - provider = _get_current_provider(model) + provider = _get_current_provider(working_model) new_value = _input_model_with_autocomplete(field_display, current_value, provider) if new_value is not None and new_value != current_value: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) # Auto-fill context_window_tokens if it's at default value - _try_auto_fill_context_window(model, new_value) + _try_auto_fill_context_window(working_model, new_value) continue # Special handling for context_window_tokens field if field_name == "context_window_tokens": - new_value = _input_context_window_with_recommendation(field_display, current_value, model) + new_value = _input_context_window_with_recommendation( + field_display, current_value, working_model + ) if new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) continue # Special handling for select fields with hints (e.g., reasoning_effort) @@ -542,23 +556,25 @@ def _configure_pydantic_model( choices_list, hint = _SELECT_FIELD_HINTS[field_name] select_choices = choices_list + ["(clear/unset)"] console.print(f"[dim] Hint: {hint}[/dim]") - new_value = _select_with_back(field_display, select_choices, default=current_value or select_choices[0]) + new_value = _select_with_back( + field_display, select_choices, default=current_value or select_choices[0] + ) if new_value is _BACK_PRESSED: continue if new_value == "(clear/unset)": - setattr(model, field_name, None) + setattr(working_model, field_name, None) elif new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) continue if field_type == "bool": new_value = _input_bool(field_display, current_value) if new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) else: new_value = _input_with_existing(field_display, current_value, field_type) if new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None: @@ -637,10 +653,12 @@ def _configure_provider(config: Config, provider_name: str) -> None: if default_api_base and not provider_config.api_base: provider_config.api_base = default_api_base - _configure_pydantic_model( + updated_provider = _configure_pydantic_model( provider_config, display_name, ) + if updated_provider is not None: + setattr(config.providers, provider_name, updated_provider) def _configure_providers(config: Config) -> None: @@ -747,15 +765,13 @@ def _configure_channel(config: Config, channel_name: str) -> None: model = config_cls.model_validate(channel_dict) if channel_dict else config_cls() - def finalize(model: BaseModel): - new_dict = model.model_dump(by_alias=True, exclude_none=True) - setattr(config.channels, channel_name, new_dict) - - _configure_pydantic_model( + updated_channel = _configure_pydantic_model( model, display_name, - finalize_hook=finalize, ) + if updated_channel is not None: + new_dict = updated_channel.model_dump(by_alias=True, exclude_none=True) + setattr(config.channels, channel_name, new_dict) def _configure_channels(config: Config) -> None: @@ -798,13 +814,25 @@ def _configure_general_settings(config: Config, section: str) -> None: model, display_name = section_map[section] if section == "Tools": - _configure_pydantic_model( + updated_model = _configure_pydantic_model( model, display_name, skip_fields={"mcp_servers"}, ) else: - _configure_pydantic_model(model, display_name) + updated_model = _configure_pydantic_model(model, display_name) + + if updated_model is None: + return + + if section == "Agent Settings": + config.agents.defaults = updated_model + elif section == "Gateway": + config.gateway = updated_model + elif section == "Tools": + config.tools = updated_model + elif section == "Channel Common": + config.channels = updated_model def _configure_agents(config: Config) -> None: @@ -938,7 +966,35 @@ def _show_summary(config: Config) -> None: # --- Main Entry Point --- -def run_onboard(initial_config: Config | None = None) -> Config: +def _has_unsaved_changes(original: Config, current: Config) -> bool: + """Return True when the onboarding session has committed changes.""" + return original.model_dump(by_alias=True) != current.model_dump(by_alias=True) + + +def _prompt_main_menu_exit(has_unsaved_changes: bool) -> str: + """Resolve how to leave the main menu.""" + if not has_unsaved_changes: + return "discard" + + answer = questionary.select( + "You have unsaved changes. What would you like to do?", + choices=[ + "💾 Save and Exit", + "🗑️ Exit Without Saving", + "↩ Resume Editing", + ], + default="↩ Resume Editing", + qmark="→", + ).ask() + + if answer == "💾 Save and Exit": + return "save" + if answer == "🗑️ Exit Without Saving": + return "discard" + return "resume" + + +def run_onboard(initial_config: Config | None = None) -> OnboardResult: """Run the interactive onboarding questionnaire. Args: @@ -946,50 +1002,59 @@ def run_onboard(initial_config: Config | None = None) -> Config: If None, loads from config file or creates new default. """ if initial_config is not None: - config = initial_config + base_config = initial_config.model_copy(deep=True) else: config_path = get_config_path() if config_path.exists(): - config = load_config() + base_config = load_config() else: - config = Config() + base_config = Config() + + original_config = base_config.model_copy(deep=True) + config = base_config.model_copy(deep=True) while True: - try: - _show_main_menu_header() + _show_main_menu_header() + try: answer = questionary.select( "What would you like to configure?", choices=[ - "🔌 Configure LLM Provider", - "💬 Configure Chat Channel", - "🤖 Configure Agent Settings", - "🌐 Configure Gateway", - "🔧 Configure Tools", + "🔌 LLM Provider", + "💬 Chat Channel", + "🤖 Agent Settings", + "🌐 Gateway", + "🔧 Tools", "📋 View Configuration Summary", "💾 Save and Exit", + "🗑️ Exit Without Saving", ], qmark="→", ).ask() - - if answer == "🔌 Configure LLM Provider": - _configure_providers(config) - elif answer == "💬 Configure Chat Channel": - _configure_channels(config) - elif answer == "🤖 Configure Agent Settings": - _configure_agents(config) - elif answer == "🌐 Configure Gateway": - _configure_gateway(config) - elif answer == "🔧 Configure Tools": - _configure_tools(config) - elif answer == "📋 View Configuration Summary": - _show_summary(config) - elif answer == "💾 Save and Exit": - break except KeyboardInterrupt: - console.print( - "\n\n[yellow]Operation cancelled. Use 'Save and Exit' to save changes.[/yellow]" - ) - break + answer = None - return config + if answer is None: + action = _prompt_main_menu_exit(_has_unsaved_changes(original_config, config)) + if action == "save": + return OnboardResult(config=config, should_save=True) + if action == "discard": + return OnboardResult(config=original_config, should_save=False) + continue + + if answer == "🔌 LLM Provider": + _configure_providers(config) + elif answer == "💬 Chat Channel": + _configure_channels(config) + elif answer == "🤖 Agent Settings": + _configure_agents(config) + elif answer == "🌐 Gateway": + _configure_gateway(config) + elif answer == "🔧 Tools": + _configure_tools(config) + elif answer == "📋 View Configuration Summary": + _show_summary(config) + elif answer == "💾 Save and Exit": + return OnboardResult(config=config, should_save=True) + elif answer == "🗑️ Exit Without Saving": + return OnboardResult(config=original_config, should_save=False) diff --git a/tests/test_commands.py b/tests/test_commands.py index d374d0c..38af553 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -15,7 +15,7 @@ from nanobot.providers.registry import find_by_model runner = CliRunner() -class _StopGateway(RuntimeError): +class _StopGatewayError(RuntimeError): pass @@ -133,6 +133,24 @@ def test_onboard_help_shows_workspace_and_config_options(): assert "--dir" not in stripped_output +def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch): + config_file, workspace_dir, _ = mock_paths + + from nanobot.cli.onboard_wizard import OnboardResult + + monkeypatch.setattr( + "nanobot.cli.onboard_wizard.run_onboard", + lambda initial_config: OnboardResult(config=initial_config, should_save=False), + ) + + result = runner.invoke(app, ["onboard"]) + + assert result.exit_code == 0 + assert "No changes were saved" in result.stdout + assert not config_file.exists() + assert not workspace_dir.exists() + + def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch): config_path = tmp_path / "instance" / "config.json" workspace_path = tmp_path / "workspace" @@ -438,12 +456,12 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa ) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert seen["config_path"] == config_file.resolve() assert seen["workspace"] == Path(config.agents.defaults.workspace) @@ -466,7 +484,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) ) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke( @@ -474,7 +492,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) ["gateway", "--config", str(config_file), "--workspace", str(override)], ) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert seen["workspace"] == override assert config.workspace_path == override @@ -492,12 +510,12 @@ def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Pat monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert "memoryWindow" in result.stdout assert "contextWindowTokens" in result.stdout @@ -521,13 +539,13 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path - raise _StopGateway("stop") + raise _StopGatewayError("stop") monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json" @@ -544,12 +562,12 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_ monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert "port 18791" in result.stdout @@ -566,10 +584,10 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert "port 18792" in result.stdout diff --git a/tests/test_onboard_logic.py b/tests/test_onboard_logic.py index a7c8d96..5ac08a5 100644 --- a/tests/test_onboard_logic.py +++ b/tests/test_onboard_logic.py @@ -7,18 +7,24 @@ without testing the interactive UI components. import json from pathlib import Path from types import SimpleNamespace -from typing import Any +from typing import Any, cast import pytest from pydantic import BaseModel, Field +from nanobot.cli import onboard_wizard + # Import functions to test from nanobot.cli.commands import _merge_missing_defaults from nanobot.cli.onboard_wizard import ( + _BACK_PRESSED, + _configure_pydantic_model, _format_value, _get_field_display_name, _get_field_type_info, + run_onboard, ) +from nanobot.config.schema import Config from nanobot.utils.helpers import sync_workspace_templates @@ -371,3 +377,116 @@ class TestProviderChannelInfo: for provider_name, value in info.items(): assert isinstance(value, tuple) assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var) + + +class _SimpleDraftModel(BaseModel): + api_key: str = "" + + +class _NestedDraftModel(BaseModel): + api_key: str = "" + + +class _OuterDraftModel(BaseModel): + nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel) + + +class TestConfigurePydanticModelDrafts: + @staticmethod + def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"): + sequence = iter(tokens) + + def fake_select(_prompt, choices, default=None): + token = next(sequence) + if token == "first": + return choices[0] + if token == "done": + return "✓ Done" + if token == "back": + return _BACK_PRESSED + return token + + monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select) + monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value + ) + + def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch): + model = _SimpleDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "back"]) + + result = _configure_pydantic_model(model, "Simple") + + assert result is None + assert model.api_key == "" + + def test_completing_section_returns_updated_draft(self, monkeypatch): + model = _SimpleDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "done"]) + + result = _configure_pydantic_model(model, "Simple") + + assert result is not None + updated = cast(_SimpleDraftModel, result) + assert updated.api_key == "secret" + assert model.api_key == "" + + def test_nested_section_back_discards_nested_edits(self, monkeypatch): + model = _OuterDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"]) + + result = _configure_pydantic_model(model, "Outer") + + assert result is not None + updated = cast(_OuterDraftModel, result) + assert updated.nested.api_key == "" + assert model.nested.api_key == "" + + def test_nested_section_done_commits_nested_edits(self, monkeypatch): + model = _OuterDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"]) + + result = _configure_pydantic_model(model, "Outer") + + assert result is not None + updated = cast(_OuterDraftModel, result) + assert updated.nested.api_key == "secret" + assert model.nested.api_key == "" + + +class TestRunOnboardExitBehavior: + def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch): + initial_config = Config() + + responses = iter( + [ + "🤖 Configure Agent Settings", + KeyboardInterrupt(), + "🗑️ Exit Without Saving", + ] + ) + + class FakePrompt: + def __init__(self, response): + self.response = response + + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + def fake_configure_agents(config): + config.agents.defaults.model = "test/provider-model" + + monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None) + monkeypatch.setattr(onboard_wizard.questionary, "select", fake_select) + monkeypatch.setattr(onboard_wizard, "_configure_agents", fake_configure_agents) + + result = run_onboard(initial_config=initial_config) + + assert result.should_save is False + assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)