From 336961372793c8c73c5c7172b7cb13b1f29f8fe0 Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Sun, 15 Mar 2026 19:14:17 +0800 Subject: [PATCH] feat(onboard): add model autocomplete and auto-fill context window - Add model_info.py module with litellm-based model lookup - Provide autocomplete suggestions for model names - Auto-fill context_window_tokens when model changes (only at default) - Add "Get recommended value" option for manual context lookup - Dynamically load provider keywords from registry (no hardcoding) Resolves #2018 --- nanobot/cli/model_info.py | 226 ++++++++++++++++++++++++++++++++++ nanobot/cli/onboard_wizard.py | 158 ++++++++++++++++++++++++ 2 files changed, 384 insertions(+) create mode 100644 nanobot/cli/model_info.py diff --git a/nanobot/cli/model_info.py b/nanobot/cli/model_info.py new file mode 100644 index 0000000..2bcd4af --- /dev/null +++ b/nanobot/cli/model_info.py @@ -0,0 +1,226 @@ +"""Model information helpers for the onboard wizard. + +Provides model context window lookup and autocomplete suggestions using litellm. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Any + +import litellm + + +@lru_cache(maxsize=1) +def _get_model_cost_map() -> dict[str, Any]: + """Get litellm's model cost map (cached).""" + return getattr(litellm, "model_cost", {}) + + +@lru_cache(maxsize=1) +def get_all_models() -> list[str]: + """Get all known model names from litellm. + """ + models = set() + + # From model_cost (has pricing info) + cost_map = _get_model_cost_map() + for k in cost_map.keys(): + if k != "sample_spec": + models.add(k) + + # From models_by_provider (more complete provider coverage) + for provider_models in getattr(litellm, "models_by_provider", {}).values(): + if isinstance(provider_models, (set, list)): + models.update(provider_models) + + return sorted(models) + + +def _normalize_model_name(model: str) -> str: + """Normalize model name for comparison.""" + return model.lower().replace("-", "_").replace(".", "") + + +def find_model_info(model_name: str) -> dict[str, Any] | None: + """Find model info with fuzzy matching. + + Args: + model_name: Model name in any common format + + Returns: + Model info dict or None if not found + """ + cost_map = _get_model_cost_map() + if not cost_map: + return None + + # Direct match + if model_name in cost_map: + return cost_map[model_name] + + # Extract base name (without provider prefix) + base_name = model_name.split("/")[-1] if "/" in model_name else model_name + base_normalized = _normalize_model_name(base_name) + + candidates = [] + + for key, info in cost_map.items(): + if key == "sample_spec": + continue + + key_base = key.split("/")[-1] if "/" in key else key + key_base_normalized = _normalize_model_name(key_base) + + # Score the match + score = 0 + + # Exact base name match (highest priority) + if base_normalized == key_base_normalized: + score = 100 + # Base name contains model + elif base_normalized in key_base_normalized: + score = 80 + # Model contains base name + elif key_base_normalized in base_normalized: + score = 70 + # Partial match + elif base_normalized[:10] in key_base_normalized: + score = 50 + + if score > 0: + # Prefer models with max_input_tokens + if info.get("max_input_tokens"): + score += 10 + candidates.append((score, key, info)) + + if not candidates: + return None + + # Return the best match + candidates.sort(key=lambda x: (-x[0], x[1])) + return candidates[0][2] + + +def get_model_context_limit(model: str, provider: str = "auto") -> int | None: + """Get the maximum input context tokens for a model. + + Args: + model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o") + provider: Provider name for informational purposes (not yet used for filtering) + + Returns: + Maximum input tokens, or None if unknown + + Note: + The provider parameter is currently informational only. Future versions may + use it to prefer provider-specific model variants in the lookup. + """ + # First try fuzzy search in model_cost (has more accurate max_input_tokens) + info = find_model_info(model) + if info: + # Prefer max_input_tokens (this is what we want for context window) + max_input = info.get("max_input_tokens") + if max_input and isinstance(max_input, int): + return max_input + + # Fall back to litellm's get_max_tokens (returns max_output_tokens typically) + try: + result = litellm.get_max_tokens(model) + if result and result > 0: + return result + except (KeyError, ValueError, AttributeError): + # Model not found in litellm's database or invalid response + pass + + # Last resort: use max_tokens from model_cost + if info: + max_tokens = info.get("max_tokens") + if max_tokens and isinstance(max_tokens, int): + return max_tokens + + return None + + +@lru_cache(maxsize=1) +def _get_provider_keywords() -> dict[str, list[str]]: + """Build provider keywords mapping from nanobot's provider registry. + + Returns: + Dict mapping provider name to list of keywords for model filtering. + """ + try: + from nanobot.providers.registry import PROVIDERS + + mapping = {} + for spec in PROVIDERS: + if spec.keywords: + mapping[spec.name] = list(spec.keywords) + return mapping + except ImportError: + return {} + + +def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]: + """Get autocomplete suggestions for model names. + + Args: + partial: Partial model name typed by user + provider: Provider name for filtering (e.g., "openrouter", "minimax") + limit: Maximum number of suggestions to return + + Returns: + List of matching model names + """ + all_models = get_all_models() + if not all_models: + return [] + + partial_lower = partial.lower() + partial_normalized = _normalize_model_name(partial) + + # Get provider keywords from registry + provider_keywords = _get_provider_keywords() + + # Filter by provider if specified + allowed_keywords = None + if provider and provider != "auto": + allowed_keywords = provider_keywords.get(provider.lower()) + + matches = [] + + for model in all_models: + model_lower = model.lower() + + # Apply provider filter + if allowed_keywords: + if not any(kw in model_lower for kw in allowed_keywords): + continue + + # Match against partial input + if not partial: + matches.append(model) + continue + + if partial_lower in model_lower: + # Score by position of match (earlier = better) + pos = model_lower.find(partial_lower) + score = 100 - pos + matches.append((score, model)) + elif partial_normalized in _normalize_model_name(model): + score = 50 + matches.append((score, model)) + + # Sort by score if we have scored matches + if matches and isinstance(matches[0], tuple): + matches.sort(key=lambda x: (-x[0], x[1])) + matches = [m[1] for m in matches] + else: + matches.sort() + + return matches[:limit] + + +def format_token_count(tokens: int) -> str: + """Format token count for display (e.g., 200000 -> '200,000').""" + return f"{tokens:,}" diff --git a/nanobot/cli/onboard_wizard.py b/nanobot/cli/onboard_wizard.py index e755fa1..debd544 100644 --- a/nanobot/cli/onboard_wizard.py +++ b/nanobot/cli/onboard_wizard.py @@ -11,6 +11,11 @@ from rich.console import Console from rich.panel import Panel from rich.table import Table +from nanobot.cli.model_info import ( + format_token_count, + get_model_context_limit, + get_model_suggestions, +) from nanobot.config.loader import get_config_path, load_config from nanobot.config.schema import Config @@ -224,6 +229,109 @@ def _input_with_existing( # --- Pydantic Model Configuration --- +def _get_current_provider(model: BaseModel) -> str: + """Get the current provider setting from a model (if available).""" + if hasattr(model, "provider"): + return getattr(model, "provider", "auto") or "auto" + return "auto" + + +def _input_model_with_autocomplete( + display_name: str, current: Any, provider: str +) -> str | None: + """Get model input with autocomplete suggestions. + + """ + from prompt_toolkit.completion import Completer, Completion + + default = str(current) if current else "" + + class DynamicModelCompleter(Completer): + """Completer that dynamically fetches model suggestions.""" + + def __init__(self, provider_name: str): + self.provider = provider_name + + def get_completions(self, document, complete_event): + text = document.text_before_cursor + suggestions = get_model_suggestions(text, provider=self.provider, limit=50) + for model in suggestions: + # Skip if model doesn't contain the typed text + if text.lower() not in model.lower(): + continue + yield Completion( + model, + start_position=-len(text), + display=model, + ) + + value = questionary.autocomplete( + f"{display_name}:", + choices=[""], # Placeholder, actual completions from completer + completer=DynamicModelCompleter(provider), + default=default, + qmark="→", + ).ask() + + return value if value else None + + +def _input_context_window_with_recommendation( + display_name: str, current: Any, model_obj: BaseModel +) -> int | None: + """Get context window input with option to fetch recommended value.""" + current_val = current if current else "" + + choices = ["Enter new value"] + if current_val: + choices.append("Keep existing value") + choices.append("🔍 Get recommended value") + + choice = questionary.select( + display_name, + choices=choices, + default="Enter new value", + ).ask() + + if choice is None: + return None + + if choice == "Keep existing value": + return None + + if choice == "🔍 Get recommended value": + # Get the model name from the model object + model_name = getattr(model_obj, "model", None) + if not model_name: + console.print("[yellow]⚠ Please configure the model field first[/yellow]") + return None + + provider = _get_current_provider(model_obj) + context_limit = get_model_context_limit(model_name, provider) + + if context_limit: + console.print(f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]") + return context_limit + else: + console.print("[yellow]⚠ Could not fetch model info, please enter manually[/yellow]") + # Fall through to manual input + + # Manual input + value = questionary.text( + f"{display_name}:", + default=str(current_val) if current_val else "", + ).ask() + + if value is None or value == "": + return None + + try: + return int(value) + except ValueError: + console.print("[yellow]⚠ Invalid number format, value not saved[/yellow]") + return None + + def _configure_pydantic_model( model: BaseModel, display_name: str, @@ -289,6 +397,23 @@ def _configure_pydantic_model( _configure_pydantic_model(nested_model, field_display) continue + # Special handling for model field (autocomplete) + if field_name == "model": + provider = _get_current_provider(model) + new_value = _input_model_with_autocomplete(field_display, current_value, provider) + if new_value is not None and new_value != current_value: + setattr(model, field_name, new_value) + # Auto-fill context_window_tokens if it's at default value + _try_auto_fill_context_window(model, new_value) + continue + + # Special handling for context_window_tokens field + if field_name == "context_window_tokens": + new_value = _input_context_window_with_recommendation(field_display, current_value, model) + if new_value is not None: + setattr(model, field_name, new_value) + continue + if field_type == "bool": new_value = _input_bool(field_display, current_value) if new_value is not None: @@ -299,6 +424,39 @@ def _configure_pydantic_model( setattr(model, field_name, new_value) +def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None: + """Try to auto-fill context_window_tokens if it's at default value. + + Note: + This function imports AgentDefaults from nanobot.config.schema to get + the default context_window_tokens value. If the schema changes, this + coupling needs to be updated accordingly. + """ + # Check if context_window_tokens field exists + if not hasattr(model, "context_window_tokens"): + return + + current_context = getattr(model, "context_window_tokens", None) + + # Check if current value is the default (65536) + # We only auto-fill if the user hasn't changed it from default + from nanobot.config.schema import AgentDefaults + + default_context = AgentDefaults.model_fields["context_window_tokens"].default + + if current_context != default_context: + return # User has customized it, don't override + + provider = _get_current_provider(model) + context_limit = get_model_context_limit(new_model_name, provider) + + if context_limit: + setattr(model, "context_window_tokens", context_limit) + console.print(f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]") + else: + console.print("[dim]ℹ Could not auto-fill context window (model not in database)[/dim]") + + # --- Provider Configuration ---