Merge PR #1618: support Azure OpenAI
This commit is contained in:
@@ -675,6 +675,7 @@ Config file: `~/.nanobot/config.json`
|
|||||||
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
||||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||||
|
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ def onboard():
|
|||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config."""
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
provider_name = config.get_provider_name(model)
|
provider_name = config.get_provider_name(model)
|
||||||
@@ -231,6 +232,20 @@ def _make_provider(config: Config):
|
|||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||||
|
if provider_name == "azure_openai":
|
||||||
|
if not p or not p.api_key or not p.api_base:
|
||||||
|
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||||
|
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||||
|
console.print("Use the model field to specify the deployment name.")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
return AzureOpenAIProvider(
|
||||||
|
api_key=p.api_key,
|
||||||
|
api_base=p.api_base,
|
||||||
|
default_model=model,
|
||||||
|
)
|
||||||
|
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
spec = find_by_name(provider_name)
|
spec = find_by_name(provider_name)
|
||||||
|
|||||||
@@ -251,6 +251,7 @@ class ProvidersConfig(Base):
|
|||||||
"""Configuration for LLM providers."""
|
"""Configuration for LLM providers."""
|
||||||
|
|
||||||
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||||
|
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
|
||||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
|
|||||||
@@ -3,5 +3,6 @@
|
|||||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider"]
|
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
|
||||||
|
|||||||
210
nanobot/providers/azure_openai_provider.py
Normal file
210
nanobot/providers/azure_openai_provider.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json_repair
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIProvider(LLMProvider):
|
||||||
|
"""
|
||||||
|
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Hardcoded API version 2024-10-21
|
||||||
|
- Uses model field as Azure deployment name in URL path
|
||||||
|
- Uses api-key header instead of Authorization Bearer
|
||||||
|
- Uses max_completion_tokens instead of max_tokens
|
||||||
|
- Direct HTTP calls, bypasses LiteLLM
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str = "",
|
||||||
|
api_base: str = "",
|
||||||
|
default_model: str = "gpt-5.2-chat",
|
||||||
|
):
|
||||||
|
super().__init__(api_key, api_base)
|
||||||
|
self.default_model = default_model
|
||||||
|
self.api_version = "2024-10-21"
|
||||||
|
|
||||||
|
# Validate required parameters
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("Azure OpenAI api_key is required")
|
||||||
|
if not api_base:
|
||||||
|
raise ValueError("Azure OpenAI api_base is required")
|
||||||
|
|
||||||
|
# Ensure api_base ends with /
|
||||||
|
if not api_base.endswith('/'):
|
||||||
|
api_base += '/'
|
||||||
|
self.api_base = api_base
|
||||||
|
|
||||||
|
def _build_chat_url(self, deployment_name: str) -> str:
|
||||||
|
"""Build the Azure OpenAI chat completions URL."""
|
||||||
|
# Azure OpenAI URL format:
|
||||||
|
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||||
|
base_url = self.api_base
|
||||||
|
if not base_url.endswith('/'):
|
||||||
|
base_url += '/'
|
||||||
|
|
||||||
|
url = urljoin(
|
||||||
|
base_url,
|
||||||
|
f"openai/deployments/{deployment_name}/chat/completions"
|
||||||
|
)
|
||||||
|
return f"{url}?api-version={self.api_version}"
|
||||||
|
|
||||||
|
def _build_headers(self) -> dict[str, str]:
|
||||||
|
"""Build headers for Azure OpenAI API with api-key header."""
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||||
|
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _supports_temperature(
|
||||||
|
deployment_name: str,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True when temperature is likely supported for this deployment."""
|
||||||
|
if reasoning_effort:
|
||||||
|
return False
|
||||||
|
name = deployment_name.lower()
|
||||||
|
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||||
|
|
||||||
|
def _prepare_request_payload(
|
||||||
|
self,
|
||||||
|
deployment_name: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"messages": self._sanitize_request_messages(
|
||||||
|
self._sanitize_empty_content(messages),
|
||||||
|
_AZURE_MSG_KEYS,
|
||||||
|
),
|
||||||
|
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||||
|
payload["temperature"] = temperature
|
||||||
|
|
||||||
|
if reasoning_effort:
|
||||||
|
payload["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = tools
|
||||||
|
payload["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""
|
||||||
|
Send a chat completion request to Azure OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'.
|
||||||
|
tools: Optional list of tool definitions in OpenAI format.
|
||||||
|
model: Model identifier (used as deployment name).
|
||||||
|
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
reasoning_effort: Optional reasoning effort parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and/or tool calls.
|
||||||
|
"""
|
||||||
|
deployment_name = model or self.default_model
|
||||||
|
url = self._build_chat_url(deployment_name)
|
||||||
|
headers = self._build_headers()
|
||||||
|
payload = self._prepare_request_payload(
|
||||||
|
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||||
|
response = await client.post(url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
return self._parse_response(response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||||
|
"""Parse Azure OpenAI response into our standard format."""
|
||||||
|
try:
|
||||||
|
choice = response["choices"][0]
|
||||||
|
message = choice["message"]
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if message.get("tool_calls"):
|
||||||
|
for tc in message["tool_calls"]:
|
||||||
|
# Parse arguments from JSON string if needed
|
||||||
|
args = tc["function"]["arguments"]
|
||||||
|
if isinstance(args, str):
|
||||||
|
args = json_repair.loads(args)
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCallRequest(
|
||||||
|
id=tc["id"],
|
||||||
|
name=tc["function"]["name"],
|
||||||
|
arguments=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = {}
|
||||||
|
if response.get("usage"):
|
||||||
|
usage_data = response["usage"]
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||||
|
"total_tokens": usage_data.get("total_tokens", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
reasoning_content = message.get("reasoning_content") or None
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=message.get("content"),
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=choice.get("finish_reason", "stop"),
|
||||||
|
usage=usage,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (KeyError, IndexError) as e:
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
"""Get the default model (also used as default deployment name)."""
|
||||||
|
return self.default_model
|
||||||
@@ -87,6 +87,20 @@ class LLMProvider(ABC):
|
|||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_request_messages(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
allowed_keys: frozenset[str],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Keep only provider-safe message keys and normalize assistant content."""
|
||||||
|
sanitized = []
|
||||||
|
for msg in messages:
|
||||||
|
clean = {k: v for k, v in msg.items() if k in allowed_keys}
|
||||||
|
if clean.get("role") == "assistant" and "content" not in clean:
|
||||||
|
clean["content"] = None
|
||||||
|
sanitized.append(clean)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
||||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||||
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
||||||
sanitized = []
|
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
|
||||||
id_map: dict[str, str] = {}
|
id_map: dict[str, str] = {}
|
||||||
|
|
||||||
def map_id(value: Any) -> Any:
|
def map_id(value: Any) -> Any:
|
||||||
@@ -188,12 +188,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
return value
|
return value
|
||||||
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
||||||
|
|
||||||
for msg in messages:
|
for clean in sanitized:
|
||||||
clean = {k: v for k, v in msg.items() if k in allowed}
|
|
||||||
# Strict providers require "content" even when assistant only has tool_calls
|
|
||||||
if clean.get("role") == "assistant" and "content" not in clean:
|
|
||||||
clean["content"] = None
|
|
||||||
|
|
||||||
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
||||||
# shortening, otherwise strict providers reject the broken linkage.
|
# shortening, otherwise strict providers reject the broken linkage.
|
||||||
if isinstance(clean.get("tool_calls"), list):
|
if isinstance(clean.get("tool_calls"), list):
|
||||||
@@ -209,7 +204,6 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
|
|
||||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
sanitized.append(clean)
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
|
|||||||
@@ -79,6 +79,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
litellm_prefix="",
|
litellm_prefix="",
|
||||||
is_direct=True,
|
is_direct=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
|
||||||
|
ProviderSpec(
|
||||||
|
name="azure_openai",
|
||||||
|
keywords=("azure", "azure-openai"),
|
||||||
|
env_key="",
|
||||||
|
display_name="Azure OpenAI",
|
||||||
|
litellm_prefix="",
|
||||||
|
is_direct=True,
|
||||||
|
),
|
||||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||||
# Gateways can route any model, so they win in fallback.
|
# Gateways can route any model, so they win in fallback.
|
||||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||||
|
|||||||
399
tests/test_azure_openai_provider.py
Normal file
399
tests/test_azure_openai_provider.py
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_openai_provider_init():
|
||||||
|
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.api_key == "test-key"
|
||||||
|
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||||
|
assert provider.default_model == "gpt-4o-deployment"
|
||||||
|
assert provider.api_version == "2024-10-21"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_openai_provider_init_validation():
|
||||||
|
"""Test AzureOpenAIProvider initialization validation."""
|
||||||
|
# Missing api_key
|
||||||
|
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||||
|
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||||
|
|
||||||
|
# Missing api_base
|
||||||
|
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||||
|
AzureOpenAIProvider(api_key="test", api_base="")
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_chat_url():
|
||||||
|
"""Test Azure OpenAI URL building with different deployment names."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test various deployment names
|
||||||
|
test_cases = [
|
||||||
|
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||||
|
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||||
|
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for deployment_name, expected_url in test_cases:
|
||||||
|
url = provider._build_chat_url(deployment_name)
|
||||||
|
assert url == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_chat_url_api_base_without_slash():
|
||||||
|
"""Test URL building when api_base doesn't end with slash."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
url = provider._build_chat_url("test-deployment")
|
||||||
|
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert url == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_headers():
|
||||||
|
"""Test Azure OpenAI header building with api-key authentication."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-api-key-123",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = provider._build_headers()
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||||
|
assert "x-session-affinity" in headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_request_payload():
|
||||||
|
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||||
|
|
||||||
|
assert payload["messages"] == messages
|
||||||
|
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
|
assert payload["temperature"] == 0.8
|
||||||
|
assert "tools" not in payload
|
||||||
|
|
||||||
|
# Test with tools
|
||||||
|
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||||
|
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||||
|
assert payload_with_tools["tools"] == tools
|
||||||
|
assert payload_with_tools["tool_choice"] == "auto"
|
||||||
|
|
||||||
|
# Test with reasoning_effort
|
||||||
|
payload_with_reasoning = provider._prepare_request_payload(
|
||||||
|
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||||
|
)
|
||||||
|
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||||
|
assert "temperature" not in payload_with_reasoning
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_request_payload_sanitizes_messages():
|
||||||
|
"""Test Azure payload strips non-standard message keys before sending."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||||
|
"reasoning_content": "hidden chain-of-thought",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "x",
|
||||||
|
"content": "ok",
|
||||||
|
"extra_field": "should be removed",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||||
|
|
||||||
|
assert payload["messages"] == [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "x",
|
||||||
|
"content": "ok",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_success():
|
||||||
|
"""Test successful chat request using model as deployment name."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock response data
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": "Hello! How can I help you today?",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 12,
|
||||||
|
"completion_tokens": 18,
|
||||||
|
"total_tokens": 30
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
# Test with specific model (deployment name)
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await provider.chat(messages, model="custom-deployment")
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert result.content == "Hello! How can I help you today?"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.usage["prompt_tokens"] == 12
|
||||||
|
assert result.usage["completion_tokens"] == 18
|
||||||
|
assert result.usage["total_tokens"] == 30
|
||||||
|
|
||||||
|
# Verify URL was built with the provided model as deployment name
|
||||||
|
call_args = mock_context.post.call_args
|
||||||
|
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert call_args[0][0] == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_uses_default_model_when_no_model_provided():
|
||||||
|
"""Test that chat uses default_model when no model is specified."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="default-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {"content": "Response", "role": "assistant"},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Test"}]
|
||||||
|
await provider.chat(messages) # No model specified
|
||||||
|
|
||||||
|
# Verify URL was built with default model as deployment name
|
||||||
|
call_args = mock_context.post.call_args
|
||||||
|
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert call_args[0][0] == expected_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_tool_calls():
|
||||||
|
"""Test chat request with tool calls in response."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock response with tool calls
|
||||||
|
mock_response_data = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": None,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_12345",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"location": "San Francisco"}'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 20,
|
||||||
|
"completion_tokens": 15,
|
||||||
|
"total_tokens": 35
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||||
|
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||||
|
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert result.content is None
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_api_error():
|
||||||
|
"""Test chat request API error handling."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
mock_response.text = "Invalid authentication credentials"
|
||||||
|
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await provider.chat(messages)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert "Azure OpenAI API Error 401" in result.content
|
||||||
|
assert "Invalid authentication credentials" in result.content
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_connection_error():
|
||||||
|
"""Test chat request connection error handling."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_context = AsyncMock()
|
||||||
|
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_context
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
result = await provider.chat(messages)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_malformed():
|
||||||
|
"""Test response parsing with malformed data."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with missing choices
|
||||||
|
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||||
|
result = provider._parse_response(malformed_response)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert "Error parsing Azure OpenAI response" in result.content
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_model():
|
||||||
|
"""Test get_default_model method."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="my-custom-deployment",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.get_default_model() == "my-custom-deployment"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run basic tests
|
||||||
|
print("Running basic Azure OpenAI provider tests...")
|
||||||
|
|
||||||
|
# Test initialization
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o-deployment",
|
||||||
|
)
|
||||||
|
print("✅ Provider initialization successful")
|
||||||
|
|
||||||
|
# Test URL building
|
||||||
|
url = provider._build_chat_url("my-deployment")
|
||||||
|
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||||
|
assert url == expected
|
||||||
|
print("✅ URL building works correctly")
|
||||||
|
|
||||||
|
# Test headers
|
||||||
|
headers = provider._build_headers()
|
||||||
|
assert headers["api-key"] == "test-key"
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
print("✅ Header building works correctly")
|
||||||
|
|
||||||
|
# Test payload preparation
|
||||||
|
messages = [{"role": "user", "content": "Test"}]
|
||||||
|
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||||
|
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||||
|
print("✅ Payload preparation works correctly")
|
||||||
|
|
||||||
|
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||||
Reference in New Issue
Block a user