diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index 3f325aa..bd79b00 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -11,6 +11,8 @@ import json_repair from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"}) + class AzureOpenAIProvider(LLMProvider): """ @@ -67,19 +69,38 @@ class AzureOpenAIProvider(LLMProvider): "x-session-affinity": uuid.uuid4().hex, # For cache locality } + @staticmethod + def _supports_temperature( + deployment_name: str, + reasoning_effort: str | None = None, + ) -> bool: + """Return True when temperature is likely supported for this deployment.""" + if reasoning_effort: + return False + name = deployment_name.lower() + return not any(token in name for token in ("gpt-5", "o1", "o3", "o4")) + def _prepare_request_payload( self, + deployment_name: str, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, max_tokens: int = 4096, + temperature: float = 0.7, reasoning_effort: str | None = None, ) -> dict[str, Any]: """Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" payload: dict[str, Any] = { - "messages": self._sanitize_empty_content(messages), + "messages": self._sanitize_request_messages( + self._sanitize_empty_content(messages), + _AZURE_MSG_KEYS, + ), "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens } + if self._supports_temperature(deployment_name, reasoning_effort): + payload["temperature"] = temperature + if reasoning_effort: payload["reasoning_effort"] = reasoning_effort @@ -116,7 +137,7 @@ class AzureOpenAIProvider(LLMProvider): url = self._build_chat_url(deployment_name) headers = self._build_headers() payload = self._prepare_request_payload( - messages, tools, max_tokens, reasoning_effort + deployment_name, messages, tools, max_tokens, temperature, reasoning_effort ) try: diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 55bd805..0f73544 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -87,6 +87,20 @@ class LLMProvider(ABC): result.append(msg) return result + @staticmethod + def _sanitize_request_messages( + messages: list[dict[str, Any]], + allowed_keys: frozenset[str], + ) -> list[dict[str, Any]]: + """Keep only provider-safe message keys and normalize assistant content.""" + sanitized = [] + for msg in messages: + clean = {k: v for k, v in msg.items() if k in allowed_keys} + if clean.get("role") == "assistant" and "content" not in clean: + clean["content"] = None + sanitized.append(clean) + return sanitized + @abstractmethod async def chat( self, diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 2fd6c18..cb67635 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -180,7 +180,7 @@ class LiteLLMProvider(LLMProvider): def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: """Strip non-standard keys and ensure assistant messages have a content key.""" allowed = _ALLOWED_MSG_KEYS | extra_keys - sanitized = [] + sanitized = LLMProvider._sanitize_request_messages(messages, allowed) id_map: dict[str, str] = {} def map_id(value: Any) -> Any: @@ -188,12 +188,7 @@ class LiteLLMProvider(LLMProvider): return value return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) - for msg in messages: - clean = {k: v for k, v in msg.items() if k in allowed} - # Strict providers require "content" even when assistant only has tool_calls - if clean.get("role") == "assistant" and "content" not in clean: - clean["content"] = None - + for clean in sanitized: # Keep assistant tool_calls[].id and tool tool_call_id in sync after # shortening, otherwise strict providers reject the broken linkage. if isinstance(clean.get("tool_calls"), list): @@ -209,7 +204,6 @@ class LiteLLMProvider(LLMProvider): if "tool_call_id" in clean and clean["tool_call_id"]: clean["tool_call_id"] = map_id(clean["tool_call_id"]) - sanitized.append(clean) return sanitized async def chat( diff --git a/tests/test_azure_openai_provider.py b/tests/test_azure_openai_provider.py index 680ddf4..77f36d4 100644 --- a/tests/test_azure_openai_provider.py +++ b/tests/test_azure_openai_provider.py @@ -1,9 +1,9 @@ """Test Azure OpenAI provider implementation (updated for model-based deployment names).""" -import asyncio -import pytest from unittest.mock import AsyncMock, Mock, patch +import pytest + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.base import LLMResponse @@ -89,22 +89,65 @@ def test_prepare_request_payload(): ) messages = [{"role": "user", "content": "Hello"}] - payload = provider._prepare_request_payload(messages, max_tokens=1500) + payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8) assert payload["messages"] == messages assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens - assert "temperature" not in payload # Temperature not included in payload + assert payload["temperature"] == 0.8 assert "tools" not in payload # Test with tools tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] - payload_with_tools = provider._prepare_request_payload(messages, tools=tools) + payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools) assert payload_with_tools["tools"] == tools assert payload_with_tools["tool_choice"] == "auto" # Test with reasoning_effort - payload_with_reasoning = provider._prepare_request_payload(messages, reasoning_effort="medium") + payload_with_reasoning = provider._prepare_request_payload( + "gpt-5-chat", messages, reasoning_effort="medium" + ) assert payload_with_reasoning["reasoning_effort"] == "medium" + assert "temperature" not in payload_with_reasoning + + +def test_prepare_request_payload_sanitizes_messages(): + """Test Azure payload strips non-standard message keys before sending.""" + provider = AzureOpenAIProvider( + api_key="test-key", + api_base="https://test-resource.openai.azure.com", + default_model="gpt-4o", + ) + + messages = [ + { + "role": "assistant", + "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], + "reasoning_content": "hidden chain-of-thought", + }, + { + "role": "tool", + "tool_call_id": "call_123", + "name": "x", + "content": "ok", + "extra_field": "should be removed", + }, + ] + + payload = provider._prepare_request_payload("gpt-4o", messages) + + assert payload["messages"] == [ + { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], + }, + { + "role": "tool", + "tool_call_id": "call_123", + "name": "x", + "content": "ok", + }, + ] @pytest.mark.asyncio @@ -349,7 +392,7 @@ if __name__ == "__main__": # Test payload preparation messages = [{"role": "user", "content": "Test"}] - payload = provider._prepare_request_payload(messages, max_tokens=1000) + payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000) assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format print("✅ Payload preparation works correctly")