fix(azure): sanitize messages and handle temperature

This commit is contained in:
Re-bin
2026-03-07 03:57:57 +00:00
parent 7c074e4684
commit 576ad12ef1
4 changed files with 89 additions and 17 deletions

View File

@@ -11,6 +11,8 @@ import json_repair
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class AzureOpenAIProvider(LLMProvider):
"""
@@ -67,19 +69,38 @@ class AzureOpenAIProvider(LLMProvider):
"x-session-affinity": uuid.uuid4().hex, # For cache locality
}
@staticmethod
def _supports_temperature(
deployment_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when temperature is likely supported for this deployment."""
if reasoning_effort:
return False
name = deployment_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _prepare_request_payload(
self,
deployment_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> dict[str, Any]:
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
payload: dict[str, Any] = {
"messages": self._sanitize_empty_content(messages),
"messages": self._sanitize_request_messages(
self._sanitize_empty_content(messages),
_AZURE_MSG_KEYS,
),
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
}
if self._supports_temperature(deployment_name, reasoning_effort):
payload["temperature"] = temperature
if reasoning_effort:
payload["reasoning_effort"] = reasoning_effort
@@ -116,7 +137,7 @@ class AzureOpenAIProvider(LLMProvider):
url = self._build_chat_url(deployment_name)
headers = self._build_headers()
payload = self._prepare_request_payload(
messages, tools, max_tokens, reasoning_effort
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
)
try:

View File

@@ -87,6 +87,20 @@ class LLMProvider(ABC):
result.append(msg)
return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod
async def chat(
self,

View File

@@ -180,7 +180,7 @@ class LiteLLMProvider(LLMProvider):
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
"""Strip non-standard keys and ensure assistant messages have a content key."""
allowed = _ALLOWED_MSG_KEYS | extra_keys
sanitized = []
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
id_map: dict[str, str] = {}
def map_id(value: Any) -> Any:
@@ -188,12 +188,7 @@ class LiteLLMProvider(LLMProvider):
return value
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed}
# Strict providers require "content" even when assistant only has tool_calls
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
for clean in sanitized:
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
# shortening, otherwise strict providers reject the broken linkage.
if isinstance(clean.get("tool_calls"), list):
@@ -209,7 +204,6 @@ class LiteLLMProvider(LLMProvider):
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
sanitized.append(clean)
return sanitized
async def chat(