feat(onboard): add model autocomplete and auto-fill context window
- Add model_info.py module with litellm-based model lookup - Provide autocomplete suggestions for model names - Auto-fill context_window_tokens when model changes (only at default) - Add "Get recommended value" option for manual context lookup - Dynamically load provider keywords from registry (no hardcoding) Resolves #2018
This commit is contained in:
226
nanobot/cli/model_info.py
Normal file
226
nanobot/cli/model_info.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Model information helpers for the onboard wizard.
|
||||
|
||||
Provides model context window lookup and autocomplete suggestions using litellm.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_model_cost_map() -> dict[str, Any]:
|
||||
"""Get litellm's model cost map (cached)."""
|
||||
return getattr(litellm, "model_cost", {})
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_all_models() -> list[str]:
|
||||
"""Get all known model names from litellm.
|
||||
"""
|
||||
models = set()
|
||||
|
||||
# From model_cost (has pricing info)
|
||||
cost_map = _get_model_cost_map()
|
||||
for k in cost_map.keys():
|
||||
if k != "sample_spec":
|
||||
models.add(k)
|
||||
|
||||
# From models_by_provider (more complete provider coverage)
|
||||
for provider_models in getattr(litellm, "models_by_provider", {}).values():
|
||||
if isinstance(provider_models, (set, list)):
|
||||
models.update(provider_models)
|
||||
|
||||
return sorted(models)
|
||||
|
||||
|
||||
def _normalize_model_name(model: str) -> str:
|
||||
"""Normalize model name for comparison."""
|
||||
return model.lower().replace("-", "_").replace(".", "")
|
||||
|
||||
|
||||
def find_model_info(model_name: str) -> dict[str, Any] | None:
|
||||
"""Find model info with fuzzy matching.
|
||||
|
||||
Args:
|
||||
model_name: Model name in any common format
|
||||
|
||||
Returns:
|
||||
Model info dict or None if not found
|
||||
"""
|
||||
cost_map = _get_model_cost_map()
|
||||
if not cost_map:
|
||||
return None
|
||||
|
||||
# Direct match
|
||||
if model_name in cost_map:
|
||||
return cost_map[model_name]
|
||||
|
||||
# Extract base name (without provider prefix)
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
base_normalized = _normalize_model_name(base_name)
|
||||
|
||||
candidates = []
|
||||
|
||||
for key, info in cost_map.items():
|
||||
if key == "sample_spec":
|
||||
continue
|
||||
|
||||
key_base = key.split("/")[-1] if "/" in key else key
|
||||
key_base_normalized = _normalize_model_name(key_base)
|
||||
|
||||
# Score the match
|
||||
score = 0
|
||||
|
||||
# Exact base name match (highest priority)
|
||||
if base_normalized == key_base_normalized:
|
||||
score = 100
|
||||
# Base name contains model
|
||||
elif base_normalized in key_base_normalized:
|
||||
score = 80
|
||||
# Model contains base name
|
||||
elif key_base_normalized in base_normalized:
|
||||
score = 70
|
||||
# Partial match
|
||||
elif base_normalized[:10] in key_base_normalized:
|
||||
score = 50
|
||||
|
||||
if score > 0:
|
||||
# Prefer models with max_input_tokens
|
||||
if info.get("max_input_tokens"):
|
||||
score += 10
|
||||
candidates.append((score, key, info))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Return the best match
|
||||
candidates.sort(key=lambda x: (-x[0], x[1]))
|
||||
return candidates[0][2]
|
||||
|
||||
|
||||
def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
|
||||
"""Get the maximum input context tokens for a model.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o")
|
||||
provider: Provider name for informational purposes (not yet used for filtering)
|
||||
|
||||
Returns:
|
||||
Maximum input tokens, or None if unknown
|
||||
|
||||
Note:
|
||||
The provider parameter is currently informational only. Future versions may
|
||||
use it to prefer provider-specific model variants in the lookup.
|
||||
"""
|
||||
# First try fuzzy search in model_cost (has more accurate max_input_tokens)
|
||||
info = find_model_info(model)
|
||||
if info:
|
||||
# Prefer max_input_tokens (this is what we want for context window)
|
||||
max_input = info.get("max_input_tokens")
|
||||
if max_input and isinstance(max_input, int):
|
||||
return max_input
|
||||
|
||||
# Fall back to litellm's get_max_tokens (returns max_output_tokens typically)
|
||||
try:
|
||||
result = litellm.get_max_tokens(model)
|
||||
if result and result > 0:
|
||||
return result
|
||||
except (KeyError, ValueError, AttributeError):
|
||||
# Model not found in litellm's database or invalid response
|
||||
pass
|
||||
|
||||
# Last resort: use max_tokens from model_cost
|
||||
if info:
|
||||
max_tokens = info.get("max_tokens")
|
||||
if max_tokens and isinstance(max_tokens, int):
|
||||
return max_tokens
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_provider_keywords() -> dict[str, list[str]]:
|
||||
"""Build provider keywords mapping from nanobot's provider registry.
|
||||
|
||||
Returns:
|
||||
Dict mapping provider name to list of keywords for model filtering.
|
||||
"""
|
||||
try:
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
mapping = {}
|
||||
for spec in PROVIDERS:
|
||||
if spec.keywords:
|
||||
mapping[spec.name] = list(spec.keywords)
|
||||
return mapping
|
||||
except ImportError:
|
||||
return {}
|
||||
|
||||
|
||||
def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
|
||||
"""Get autocomplete suggestions for model names.
|
||||
|
||||
Args:
|
||||
partial: Partial model name typed by user
|
||||
provider: Provider name for filtering (e.g., "openrouter", "minimax")
|
||||
limit: Maximum number of suggestions to return
|
||||
|
||||
Returns:
|
||||
List of matching model names
|
||||
"""
|
||||
all_models = get_all_models()
|
||||
if not all_models:
|
||||
return []
|
||||
|
||||
partial_lower = partial.lower()
|
||||
partial_normalized = _normalize_model_name(partial)
|
||||
|
||||
# Get provider keywords from registry
|
||||
provider_keywords = _get_provider_keywords()
|
||||
|
||||
# Filter by provider if specified
|
||||
allowed_keywords = None
|
||||
if provider and provider != "auto":
|
||||
allowed_keywords = provider_keywords.get(provider.lower())
|
||||
|
||||
matches = []
|
||||
|
||||
for model in all_models:
|
||||
model_lower = model.lower()
|
||||
|
||||
# Apply provider filter
|
||||
if allowed_keywords:
|
||||
if not any(kw in model_lower for kw in allowed_keywords):
|
||||
continue
|
||||
|
||||
# Match against partial input
|
||||
if not partial:
|
||||
matches.append(model)
|
||||
continue
|
||||
|
||||
if partial_lower in model_lower:
|
||||
# Score by position of match (earlier = better)
|
||||
pos = model_lower.find(partial_lower)
|
||||
score = 100 - pos
|
||||
matches.append((score, model))
|
||||
elif partial_normalized in _normalize_model_name(model):
|
||||
score = 50
|
||||
matches.append((score, model))
|
||||
|
||||
# Sort by score if we have scored matches
|
||||
if matches and isinstance(matches[0], tuple):
|
||||
matches.sort(key=lambda x: (-x[0], x[1]))
|
||||
matches = [m[1] for m in matches]
|
||||
else:
|
||||
matches.sort()
|
||||
|
||||
return matches[:limit]
|
||||
|
||||
|
||||
def format_token_count(tokens: int) -> str:
|
||||
"""Format token count for display (e.g., 200000 -> '200,000')."""
|
||||
return f"{tokens:,}"
|
||||
@@ -11,6 +11,11 @@ from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from nanobot.cli.model_info import (
|
||||
format_token_count,
|
||||
get_model_context_limit,
|
||||
get_model_suggestions,
|
||||
)
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
@@ -224,6 +229,109 @@ def _input_with_existing(
|
||||
# --- Pydantic Model Configuration ---
|
||||
|
||||
|
||||
def _get_current_provider(model: BaseModel) -> str:
|
||||
"""Get the current provider setting from a model (if available)."""
|
||||
if hasattr(model, "provider"):
|
||||
return getattr(model, "provider", "auto") or "auto"
|
||||
return "auto"
|
||||
|
||||
|
||||
def _input_model_with_autocomplete(
|
||||
display_name: str, current: Any, provider: str
|
||||
) -> str | None:
|
||||
"""Get model input with autocomplete suggestions.
|
||||
|
||||
"""
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
default = str(current) if current else ""
|
||||
|
||||
class DynamicModelCompleter(Completer):
|
||||
"""Completer that dynamically fetches model suggestions."""
|
||||
|
||||
def __init__(self, provider_name: str):
|
||||
self.provider = provider_name
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
suggestions = get_model_suggestions(text, provider=self.provider, limit=50)
|
||||
for model in suggestions:
|
||||
# Skip if model doesn't contain the typed text
|
||||
if text.lower() not in model.lower():
|
||||
continue
|
||||
yield Completion(
|
||||
model,
|
||||
start_position=-len(text),
|
||||
display=model,
|
||||
)
|
||||
|
||||
value = questionary.autocomplete(
|
||||
f"{display_name}:",
|
||||
choices=[""], # Placeholder, actual completions from completer
|
||||
completer=DynamicModelCompleter(provider),
|
||||
default=default,
|
||||
qmark="→",
|
||||
).ask()
|
||||
|
||||
return value if value else None
|
||||
|
||||
|
||||
def _input_context_window_with_recommendation(
|
||||
display_name: str, current: Any, model_obj: BaseModel
|
||||
) -> int | None:
|
||||
"""Get context window input with option to fetch recommended value."""
|
||||
current_val = current if current else ""
|
||||
|
||||
choices = ["Enter new value"]
|
||||
if current_val:
|
||||
choices.append("Keep existing value")
|
||||
choices.append("🔍 Get recommended value")
|
||||
|
||||
choice = questionary.select(
|
||||
display_name,
|
||||
choices=choices,
|
||||
default="Enter new value",
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
return None
|
||||
|
||||
if choice == "Keep existing value":
|
||||
return None
|
||||
|
||||
if choice == "🔍 Get recommended value":
|
||||
# Get the model name from the model object
|
||||
model_name = getattr(model_obj, "model", None)
|
||||
if not model_name:
|
||||
console.print("[yellow]⚠ Please configure the model field first[/yellow]")
|
||||
return None
|
||||
|
||||
provider = _get_current_provider(model_obj)
|
||||
context_limit = get_model_context_limit(model_name, provider)
|
||||
|
||||
if context_limit:
|
||||
console.print(f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]")
|
||||
return context_limit
|
||||
else:
|
||||
console.print("[yellow]⚠ Could not fetch model info, please enter manually[/yellow]")
|
||||
# Fall through to manual input
|
||||
|
||||
# Manual input
|
||||
value = questionary.text(
|
||||
f"{display_name}:",
|
||||
default=str(current_val) if current_val else "",
|
||||
).ask()
|
||||
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
console.print("[yellow]⚠ Invalid number format, value not saved[/yellow]")
|
||||
return None
|
||||
|
||||
|
||||
def _configure_pydantic_model(
|
||||
model: BaseModel,
|
||||
display_name: str,
|
||||
@@ -289,6 +397,23 @@ def _configure_pydantic_model(
|
||||
_configure_pydantic_model(nested_model, field_display)
|
||||
continue
|
||||
|
||||
# Special handling for model field (autocomplete)
|
||||
if field_name == "model":
|
||||
provider = _get_current_provider(model)
|
||||
new_value = _input_model_with_autocomplete(field_display, current_value, provider)
|
||||
if new_value is not None and new_value != current_value:
|
||||
setattr(model, field_name, new_value)
|
||||
# Auto-fill context_window_tokens if it's at default value
|
||||
_try_auto_fill_context_window(model, new_value)
|
||||
continue
|
||||
|
||||
# Special handling for context_window_tokens field
|
||||
if field_name == "context_window_tokens":
|
||||
new_value = _input_context_window_with_recommendation(field_display, current_value, model)
|
||||
if new_value is not None:
|
||||
setattr(model, field_name, new_value)
|
||||
continue
|
||||
|
||||
if field_type == "bool":
|
||||
new_value = _input_bool(field_display, current_value)
|
||||
if new_value is not None:
|
||||
@@ -299,6 +424,39 @@ def _configure_pydantic_model(
|
||||
setattr(model, field_name, new_value)
|
||||
|
||||
|
||||
def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None:
|
||||
"""Try to auto-fill context_window_tokens if it's at default value.
|
||||
|
||||
Note:
|
||||
This function imports AgentDefaults from nanobot.config.schema to get
|
||||
the default context_window_tokens value. If the schema changes, this
|
||||
coupling needs to be updated accordingly.
|
||||
"""
|
||||
# Check if context_window_tokens field exists
|
||||
if not hasattr(model, "context_window_tokens"):
|
||||
return
|
||||
|
||||
current_context = getattr(model, "context_window_tokens", None)
|
||||
|
||||
# Check if current value is the default (65536)
|
||||
# We only auto-fill if the user hasn't changed it from default
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
default_context = AgentDefaults.model_fields["context_window_tokens"].default
|
||||
|
||||
if current_context != default_context:
|
||||
return # User has customized it, don't override
|
||||
|
||||
provider = _get_current_provider(model)
|
||||
context_limit = get_model_context_limit(new_model_name, provider)
|
||||
|
||||
if context_limit:
|
||||
setattr(model, "context_window_tokens", context_limit)
|
||||
console.print(f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]")
|
||||
else:
|
||||
console.print("[dim]ℹ Could not auto-fill context window (model not in database)[/dim]")
|
||||
|
||||
|
||||
# --- Provider Configuration ---
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user