Merge upstream/main: resolve conflicts with OAuth support
This commit is contained in:
@@ -20,6 +20,7 @@ class LLMResponse:
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
|
||||
@@ -26,18 +26,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
# Detect gateway / local deployment from api_key and api_base
|
||||
self._gateway = find_gateway(api_key, api_base)
|
||||
|
||||
# Backwards-compatible flags (used by tests and possibly external code)
|
||||
self.is_openrouter = bool(self._gateway and self._gateway.name == "openrouter")
|
||||
self.is_aihubmix = bool(self._gateway and self._gateway.name == "aihubmix")
|
||||
self.is_vllm = bool(self._gateway and self._gateway.is_local)
|
||||
# Detect gateway / local deployment.
|
||||
# provider_name (from config key) is the primary signal;
|
||||
# api_key / api_base are fallback for auto-detection.
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
|
||||
# Configure environment variables
|
||||
if api_key:
|
||||
@@ -48,26 +46,29 @@ class LiteLLMProvider(LLMProvider):
|
||||
|
||||
# Disable LiteLLM logging noise
|
||||
litellm.suppress_debug_info = True
|
||||
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
||||
litellm.drop_params = True
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||
"""Set environment variables based on detected provider."""
|
||||
if self._gateway:
|
||||
# Gateway / local: direct set (not setdefault)
|
||||
os.environ[self._gateway.env_key] = api_key
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if not spec:
|
||||
return
|
||||
|
||||
# Standard provider: match by model name
|
||||
spec = find_by_model(model)
|
||||
if spec:
|
||||
|
||||
# Gateway/local overrides existing env; standard provider doesn't
|
||||
if self._gateway:
|
||||
os.environ[spec.env_key] = api_key
|
||||
else:
|
||||
os.environ.setdefault(spec.env_key, api_key)
|
||||
# Resolve env_extras placeholders:
|
||||
# {api_key} → user's API key
|
||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_val in spec.env_extras:
|
||||
resolved = env_val.replace("{api_key}", api_key)
|
||||
resolved = resolved.replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
# Resolve env_extras placeholders:
|
||||
# {api_key} → user's API key
|
||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_val in spec.env_extras:
|
||||
resolved = env_val.replace("{api_key}", api_key)
|
||||
resolved = resolved.replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""Resolve model name by applying provider/gateway prefixes."""
|
||||
@@ -131,7 +132,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
# Pass api_base directly for custom endpoints (vLLM, etc.)
|
||||
# Pass api_base for custom endpoints
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
@@ -183,11 +184,14 @@ class LiteLLMProvider(LLMProvider):
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
|
||||
@@ -265,11 +265,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
),
|
||||
),
|
||||
|
||||
# === Local deployment (fallback: unknown api_base → assume local) ======
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
# If api_base is set but doesn't match a known gateway, we land here.
|
||||
# Placed before Groq so vLLM wins the fallback when both are configured.
|
||||
# Detected when config key is "vllm" (provider_name="vllm").
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
@@ -326,16 +325,34 @@ def find_by_model(model: str) -> ProviderSpec | None:
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(api_key: str | None, api_base: str | None) -> ProviderSpec | None:
|
||||
"""Detect gateway/local by api_key prefix or api_base substring.
|
||||
Fallback: unknown api_base → treat as local (vLLM)."""
|
||||
def find_gateway(
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""Detect gateway/local provider.
|
||||
|
||||
Priority:
|
||||
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
||||
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
||||
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
||||
|
||||
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
||||
will NOT be mistaken for vLLM — the old fallback is gone.
|
||||
"""
|
||||
# 1. Direct match by config key
|
||||
if provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
# 2. Auto-detect by api_key prefix / api_base keyword
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
if api_base:
|
||||
return next((s for s in PROVIDERS if s.is_local), None)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user