fix(azure): sanitize messages and handle temperature
This commit is contained in:
@@ -11,6 +11,8 @@ import json_repair
|
|||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIProvider(LLMProvider):
|
class AzureOpenAIProvider(LLMProvider):
|
||||||
"""
|
"""
|
||||||
@@ -67,19 +69,38 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _supports_temperature(
|
||||||
|
deployment_name: str,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True when temperature is likely supported for this deployment."""
|
||||||
|
if reasoning_effort:
|
||||||
|
return False
|
||||||
|
name = deployment_name.lower()
|
||||||
|
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||||
|
|
||||||
def _prepare_request_payload(
|
def _prepare_request_payload(
|
||||||
self,
|
self,
|
||||||
|
deployment_name: str,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"messages": self._sanitize_empty_content(messages),
|
"messages": self._sanitize_request_messages(
|
||||||
|
self._sanitize_empty_content(messages),
|
||||||
|
_AZURE_MSG_KEYS,
|
||||||
|
),
|
||||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||||
|
payload["temperature"] = temperature
|
||||||
|
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
payload["reasoning_effort"] = reasoning_effort
|
payload["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
@@ -116,7 +137,7 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
url = self._build_chat_url(deployment_name)
|
url = self._build_chat_url(deployment_name)
|
||||||
headers = self._build_headers()
|
headers = self._build_headers()
|
||||||
payload = self._prepare_request_payload(
|
payload = self._prepare_request_payload(
|
||||||
messages, tools, max_tokens, reasoning_effort
|
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -87,6 +87,20 @@ class LLMProvider(ABC):
|
|||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_request_messages(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
allowed_keys: frozenset[str],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Keep only provider-safe message keys and normalize assistant content."""
|
||||||
|
sanitized = []
|
||||||
|
for msg in messages:
|
||||||
|
clean = {k: v for k, v in msg.items() if k in allowed_keys}
|
||||||
|
if clean.get("role") == "assistant" and "content" not in clean:
|
||||||
|
clean["content"] = None
|
||||||
|
sanitized.append(clean)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
||||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||||
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
||||||
sanitized = []
|
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
|
||||||
id_map: dict[str, str] = {}
|
id_map: dict[str, str] = {}
|
||||||
|
|
||||||
def map_id(value: Any) -> Any:
|
def map_id(value: Any) -> Any:
|
||||||
@@ -188,12 +188,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
return value
|
return value
|
||||||
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
||||||
|
|
||||||
for msg in messages:
|
for clean in sanitized:
|
||||||
clean = {k: v for k, v in msg.items() if k in allowed}
|
|
||||||
# Strict providers require "content" even when assistant only has tool_calls
|
|
||||||
if clean.get("role") == "assistant" and "content" not in clean:
|
|
||||||
clean["content"] = None
|
|
||||||
|
|
||||||
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
||||||
# shortening, otherwise strict providers reject the broken linkage.
|
# shortening, otherwise strict providers reject the broken linkage.
|
||||||
if isinstance(clean.get("tool_calls"), list):
|
if isinstance(clean.get("tool_calls"), list):
|
||||||
@@ -209,7 +204,6 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
|
|
||||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
sanitized.append(clean)
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
@@ -89,22 +89,65 @@ def test_prepare_request_payload():
|
|||||||
)
|
)
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
payload = provider._prepare_request_payload(messages, max_tokens=1500)
|
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||||
|
|
||||||
assert payload["messages"] == messages
|
assert payload["messages"] == messages
|
||||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
assert "temperature" not in payload # Temperature not included in payload
|
assert payload["temperature"] == 0.8
|
||||||
assert "tools" not in payload
|
assert "tools" not in payload
|
||||||
|
|
||||||
# Test with tools
|
# Test with tools
|
||||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||||
payload_with_tools = provider._prepare_request_payload(messages, tools=tools)
|
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||||
assert payload_with_tools["tools"] == tools
|
assert payload_with_tools["tools"] == tools
|
||||||
assert payload_with_tools["tool_choice"] == "auto"
|
assert payload_with_tools["tool_choice"] == "auto"
|
||||||
|
|
||||||
# Test with reasoning_effort
|
# Test with reasoning_effort
|
||||||
payload_with_reasoning = provider._prepare_request_payload(messages, reasoning_effort="medium")
|
payload_with_reasoning = provider._prepare_request_payload(
|
||||||
|
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||||
|
)
|
||||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||||
|
assert "temperature" not in payload_with_reasoning
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_request_payload_sanitizes_messages():
|
||||||
|
"""Test Azure payload strips non-standard message keys before sending."""
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key="test-key",
|
||||||
|
api_base="https://test-resource.openai.azure.com",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||||
|
"reasoning_content": "hidden chain-of-thought",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "x",
|
||||||
|
"content": "ok",
|
||||||
|
"extra_field": "should be removed",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||||
|
|
||||||
|
assert payload["messages"] == [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"name": "x",
|
||||||
|
"content": "ok",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -349,7 +392,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Test payload preparation
|
# Test payload preparation
|
||||||
messages = [{"role": "user", "content": "Test"}]
|
messages = [{"role": "user", "content": "Test"}]
|
||||||
payload = provider._prepare_request_payload(messages, max_tokens=1000)
|
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||||
print("✅ Payload preparation works correctly")
|
print("✅ Payload preparation works correctly")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user