fix(azure): sanitize messages and handle temperature

This commit is contained in:
Re-bin
2026-03-07 03:57:57 +00:00
parent 7c074e4684
commit 576ad12ef1
4 changed files with 89 additions and 17 deletions

View File

@@ -11,6 +11,8 @@ import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider): class AzureOpenAIProvider(LLMProvider):
""" """
@@ -67,19 +69,38 @@ class AzureOpenAIProvider(LLMProvider):
"x-session-affinity": uuid.uuid4().hex, # For cache locality "x-session-affinity": uuid.uuid4().hex, # For cache locality
} }
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload( def _prepare_request_payload(
self, self,
deployment_name: str,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096, max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None, reasoning_effort: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" """Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = { payload: dict[str, Any] = {
"messages": self._sanitize_empty_content(messages), "messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
} }
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort: if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort payload["reasoning_effort"] = reasoning_effort
@@ -116,7 +137,7 @@ class AzureOpenAIProvider(LLMProvider):
url = self._build_chat_url(deployment_name) url = self._build_chat_url(deployment_name)
headers = self._build_headers() headers = self._build_headers()
payload = self._prepare_request_payload( payload = self._prepare_request_payload(
messages, tools, max_tokens, reasoning_effort deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
) )
try: try:

View File

@@ -87,6 +87,20 @@ class LLMProvider(ABC):
result.append(msg) result.append(msg)
return result return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,

View File

@@ -180,7 +180,7 @@ class LiteLLMProvider(LLMProvider):
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key.""" """Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = [] sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
id_map: dict[str, str] = {} id_map: dict[str, str] = {}
def map_id(value: Any) -> Any: def map_id(value: Any) -> Any:
@@ -188,12 +188,7 @@ class LiteLLMProvider(LLMProvider):
return value return value
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for msg in messages: for clean in sanitized:
clean = {k: v for k, v in msg.items() if k in allowed}
# Strict providers require "content" even when assistant only has tool_calls
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
# Keep assistant tool_calls[].id and tool tool_call_id in sync after # Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage. # shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list): if isinstance(clean.get("tool_calls"), list):
@@ -209,7 +204,6 @@ class LiteLLMProvider(LLMProvider):
if "tool_call_id" in clean and clean["tool_call_id"]: if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"]) clean["tool_call_id"] = map_id(clean["tool_call_id"])
sanitized.append(clean)
return sanitized return sanitized
async def chat( async def chat(

View File

@@ -1,9 +1,9 @@
"""Test Azure OpenAI provider implementation (updated for model-based deployment names).""" """Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
import asyncio
import pytest
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
@@ -89,22 +89,65 @@ def test_prepare_request_payload():
) )
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
payload = provider._prepare_request_payload(messages, max_tokens=1500) payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
assert payload["messages"] == messages assert payload["messages"] == messages
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
assert "temperature" not in payload # Temperature not included in payload assert payload["temperature"] == 0.8
assert "tools" not in payload assert "tools" not in payload
# Test with tools # Test with tools
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
payload_with_tools = provider._prepare_request_payload(messages, tools=tools) payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
assert payload_with_tools["tools"] == tools assert payload_with_tools["tools"] == tools
assert payload_with_tools["tool_choice"] == "auto" assert payload_with_tools["tool_choice"] == "auto"
# Test with reasoning_effort # Test with reasoning_effort
payload_with_reasoning = provider._prepare_request_payload(messages, reasoning_effort="medium") payload_with_reasoning = provider._prepare_request_payload(
"gpt-5-chat", messages, reasoning_effort="medium"
)
assert payload_with_reasoning["reasoning_effort"] == "medium" assert payload_with_reasoning["reasoning_effort"] == "medium"
assert "temperature" not in payload_with_reasoning
def test_prepare_request_payload_sanitizes_messages():
"""Test Azure payload strips non-standard message keys before sending."""
provider = AzureOpenAIProvider(
api_key="test-key",
api_base="https://test-resource.openai.azure.com",
default_model="gpt-4o",
)
messages = [
{
"role": "assistant",
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
"reasoning_content": "hidden chain-of-thought",
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
"extra_field": "should be removed",
},
]
payload = provider._prepare_request_payload("gpt-4o", messages)
assert payload["messages"] == [
{
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
},
{
"role": "tool",
"tool_call_id": "call_123",
"name": "x",
"content": "ok",
},
]
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -349,7 +392,7 @@ if __name__ == "__main__":
# Test payload preparation # Test payload preparation
messages = [{"role": "user", "content": "Test"}] messages = [{"role": "user", "content": "Test"}]
payload = provider._prepare_request_payload(messages, max_tokens=1000) payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
print("✅ Payload preparation works correctly") print("✅ Payload preparation works correctly")