feat(onboard): add model autocomplete and auto-fill context window

- Add model_info.py module with litellm-based model lookup
- Provide autocomplete suggestions for model names
- Auto-fill context_window_tokens when model changes (only at default)
- Add "Get recommended value" option for manual context lookup
- Dynamically load provider keywords from registry (no hardcoding)

Resolves #2018
This commit is contained in:
chengyongru
2026-03-15 19:14:17 +08:00
committed by Xubin Ren
parent f127af0481
commit 3369613727
2 changed files with 384 additions and 0 deletions

226
nanobot/cli/model_info.py Normal file
View File

@@ -0,0 +1,226 @@
"""Model information helpers for the onboard wizard.
Provides model context window lookup and autocomplete suggestions using litellm.
"""
from __future__ import annotations
from functools import lru_cache
from typing import Any
import litellm
@lru_cache(maxsize=1)
def _get_model_cost_map() -> dict[str, Any]:
"""Get litellm's model cost map (cached)."""
return getattr(litellm, "model_cost", {})
@lru_cache(maxsize=1)
def get_all_models() -> list[str]:
"""Get all known model names from litellm.
"""
models = set()
# From model_cost (has pricing info)
cost_map = _get_model_cost_map()
for k in cost_map.keys():
if k != "sample_spec":
models.add(k)
# From models_by_provider (more complete provider coverage)
for provider_models in getattr(litellm, "models_by_provider", {}).values():
if isinstance(provider_models, (set, list)):
models.update(provider_models)
return sorted(models)
def _normalize_model_name(model: str) -> str:
"""Normalize model name for comparison."""
return model.lower().replace("-", "_").replace(".", "")
def find_model_info(model_name: str) -> dict[str, Any] | None:
"""Find model info with fuzzy matching.
Args:
model_name: Model name in any common format
Returns:
Model info dict or None if not found
"""
cost_map = _get_model_cost_map()
if not cost_map:
return None
# Direct match
if model_name in cost_map:
return cost_map[model_name]
# Extract base name (without provider prefix)
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
base_normalized = _normalize_model_name(base_name)
candidates = []
for key, info in cost_map.items():
if key == "sample_spec":
continue
key_base = key.split("/")[-1] if "/" in key else key
key_base_normalized = _normalize_model_name(key_base)
# Score the match
score = 0
# Exact base name match (highest priority)
if base_normalized == key_base_normalized:
score = 100
# Base name contains model
elif base_normalized in key_base_normalized:
score = 80
# Model contains base name
elif key_base_normalized in base_normalized:
score = 70
# Partial match
elif base_normalized[:10] in key_base_normalized:
score = 50
if score > 0:
# Prefer models with max_input_tokens
if info.get("max_input_tokens"):
score += 10
candidates.append((score, key, info))
if not candidates:
return None
# Return the best match
candidates.sort(key=lambda x: (-x[0], x[1]))
return candidates[0][2]
def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
"""Get the maximum input context tokens for a model.
Args:
model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o")
provider: Provider name for informational purposes (not yet used for filtering)
Returns:
Maximum input tokens, or None if unknown
Note:
The provider parameter is currently informational only. Future versions may
use it to prefer provider-specific model variants in the lookup.
"""
# First try fuzzy search in model_cost (has more accurate max_input_tokens)
info = find_model_info(model)
if info:
# Prefer max_input_tokens (this is what we want for context window)
max_input = info.get("max_input_tokens")
if max_input and isinstance(max_input, int):
return max_input
# Fall back to litellm's get_max_tokens (returns max_output_tokens typically)
try:
result = litellm.get_max_tokens(model)
if result and result > 0:
return result
except (KeyError, ValueError, AttributeError):
# Model not found in litellm's database or invalid response
pass
# Last resort: use max_tokens from model_cost
if info:
max_tokens = info.get("max_tokens")
if max_tokens and isinstance(max_tokens, int):
return max_tokens
return None
@lru_cache(maxsize=1)
def _get_provider_keywords() -> dict[str, list[str]]:
"""Build provider keywords mapping from nanobot's provider registry.
Returns:
Dict mapping provider name to list of keywords for model filtering.
"""
try:
from nanobot.providers.registry import PROVIDERS
mapping = {}
for spec in PROVIDERS:
if spec.keywords:
mapping[spec.name] = list(spec.keywords)
return mapping
except ImportError:
return {}
def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
"""Get autocomplete suggestions for model names.
Args:
partial: Partial model name typed by user
provider: Provider name for filtering (e.g., "openrouter", "minimax")
limit: Maximum number of suggestions to return
Returns:
List of matching model names
"""
all_models = get_all_models()
if not all_models:
return []
partial_lower = partial.lower()
partial_normalized = _normalize_model_name(partial)
# Get provider keywords from registry
provider_keywords = _get_provider_keywords()
# Filter by provider if specified
allowed_keywords = None
if provider and provider != "auto":
allowed_keywords = provider_keywords.get(provider.lower())
matches = []
for model in all_models:
model_lower = model.lower()
# Apply provider filter
if allowed_keywords:
if not any(kw in model_lower for kw in allowed_keywords):
continue
# Match against partial input
if not partial:
matches.append(model)
continue
if partial_lower in model_lower:
# Score by position of match (earlier = better)
pos = model_lower.find(partial_lower)
score = 100 - pos
matches.append((score, model))
elif partial_normalized in _normalize_model_name(model):
score = 50
matches.append((score, model))
# Sort by score if we have scored matches
if matches and isinstance(matches[0], tuple):
matches.sort(key=lambda x: (-x[0], x[1]))
matches = [m[1] for m in matches]
else:
matches.sort()
return matches[:limit]
def format_token_count(tokens: int) -> str:
"""Format token count for display (e.g., 200000 -> '200,000')."""
return f"{tokens:,}"

View File

@@ -11,6 +11,11 @@ from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from nanobot.cli.model_info import (
format_token_count,
get_model_context_limit,
get_model_suggestions,
)
from nanobot.config.loader import get_config_path, load_config
from nanobot.config.schema import Config
@@ -224,6 +229,109 @@ def _input_with_existing(
# --- Pydantic Model Configuration ---
def _get_current_provider(model: BaseModel) -> str:
"""Get the current provider setting from a model (if available)."""
if hasattr(model, "provider"):
return getattr(model, "provider", "auto") or "auto"
return "auto"
def _input_model_with_autocomplete(
display_name: str, current: Any, provider: str
) -> str | None:
"""Get model input with autocomplete suggestions.
"""
from prompt_toolkit.completion import Completer, Completion
default = str(current) if current else ""
class DynamicModelCompleter(Completer):
"""Completer that dynamically fetches model suggestions."""
def __init__(self, provider_name: str):
self.provider = provider_name
def get_completions(self, document, complete_event):
text = document.text_before_cursor
suggestions = get_model_suggestions(text, provider=self.provider, limit=50)
for model in suggestions:
# Skip if model doesn't contain the typed text
if text.lower() not in model.lower():
continue
yield Completion(
model,
start_position=-len(text),
display=model,
)
value = questionary.autocomplete(
f"{display_name}:",
choices=[""], # Placeholder, actual completions from completer
completer=DynamicModelCompleter(provider),
default=default,
qmark="",
).ask()
return value if value else None
def _input_context_window_with_recommendation(
display_name: str, current: Any, model_obj: BaseModel
) -> int | None:
"""Get context window input with option to fetch recommended value."""
current_val = current if current else ""
choices = ["Enter new value"]
if current_val:
choices.append("Keep existing value")
choices.append("🔍 Get recommended value")
choice = questionary.select(
display_name,
choices=choices,
default="Enter new value",
).ask()
if choice is None:
return None
if choice == "Keep existing value":
return None
if choice == "🔍 Get recommended value":
# Get the model name from the model object
model_name = getattr(model_obj, "model", None)
if not model_name:
console.print("[yellow]⚠ Please configure the model field first[/yellow]")
return None
provider = _get_current_provider(model_obj)
context_limit = get_model_context_limit(model_name, provider)
if context_limit:
console.print(f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]")
return context_limit
else:
console.print("[yellow]⚠ Could not fetch model info, please enter manually[/yellow]")
# Fall through to manual input
# Manual input
value = questionary.text(
f"{display_name}:",
default=str(current_val) if current_val else "",
).ask()
if value is None or value == "":
return None
try:
return int(value)
except ValueError:
console.print("[yellow]⚠ Invalid number format, value not saved[/yellow]")
return None
def _configure_pydantic_model(
model: BaseModel,
display_name: str,
@@ -289,6 +397,23 @@ def _configure_pydantic_model(
_configure_pydantic_model(nested_model, field_display)
continue
# Special handling for model field (autocomplete)
if field_name == "model":
provider = _get_current_provider(model)
new_value = _input_model_with_autocomplete(field_display, current_value, provider)
if new_value is not None and new_value != current_value:
setattr(model, field_name, new_value)
# Auto-fill context_window_tokens if it's at default value
_try_auto_fill_context_window(model, new_value)
continue
# Special handling for context_window_tokens field
if field_name == "context_window_tokens":
new_value = _input_context_window_with_recommendation(field_display, current_value, model)
if new_value is not None:
setattr(model, field_name, new_value)
continue
if field_type == "bool":
new_value = _input_bool(field_display, current_value)
if new_value is not None:
@@ -299,6 +424,39 @@ def _configure_pydantic_model(
setattr(model, field_name, new_value)
def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None:
"""Try to auto-fill context_window_tokens if it's at default value.
Note:
This function imports AgentDefaults from nanobot.config.schema to get
the default context_window_tokens value. If the schema changes, this
coupling needs to be updated accordingly.
"""
# Check if context_window_tokens field exists
if not hasattr(model, "context_window_tokens"):
return
current_context = getattr(model, "context_window_tokens", None)
# Check if current value is the default (65536)
# We only auto-fill if the user hasn't changed it from default
from nanobot.config.schema import AgentDefaults
default_context = AgentDefaults.model_fields["context_window_tokens"].default
if current_context != default_context:
return # User has customized it, don't override
provider = _get_current_provider(model)
context_limit = get_model_context_limit(new_model_name, provider)
if context_limit:
setattr(model, "context_window_tokens", context_limit)
console.print(f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]")
else:
console.print("[dim] Could not auto-fill context window (model not in database)[/dim]")
# --- Provider Configuration ---