Fix the temperature issue, remove temperature
This commit is contained in:
@@ -71,16 +71,13 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"messages": self._sanitize_empty_content(messages),
|
"messages": self._sanitize_empty_content(messages),
|
||||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
"temperature": 1,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
@@ -119,7 +116,7 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
url = self._build_chat_url(deployment_name)
|
url = self._build_chat_url(deployment_name)
|
||||||
headers = self._build_headers()
|
headers = self._build_headers()
|
||||||
payload = self._prepare_request_payload(
|
payload = self._prepare_request_payload(
|
||||||
messages, tools, model, max_tokens, temperature, reasoning_effort
|
messages, tools, max_tokens, reasoning_effort
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -131,7 +128,7 @@ class AzureOpenAIProvider(LLMProvider):
|
|||||||
finish_reason="error",
|
finish_reason="error",
|
||||||
)
|
)
|
||||||
|
|
||||||
response_data = await response.json()
|
response_data = response.json()
|
||||||
return self._parse_response(response_data)
|
return self._parse_response(response_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
@@ -89,11 +89,11 @@ def test_prepare_request_payload():
|
|||||||
)
|
)
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
payload = provider._prepare_request_payload(messages, max_tokens=1500, temperature=0.8)
|
payload = provider._prepare_request_payload(messages, max_tokens=1500)
|
||||||
|
|
||||||
assert payload["messages"] == messages
|
assert payload["messages"] == messages
|
||||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||||
assert payload["temperature"] == 0.8
|
assert "temperature" not in payload # Temperature not included in payload
|
||||||
assert "tools" not in payload
|
assert "tools" not in payload
|
||||||
|
|
||||||
# Test with tools
|
# Test with tools
|
||||||
@@ -135,7 +135,7 @@ async def test_chat_success():
|
|||||||
with patch("httpx.AsyncClient") as mock_client:
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
mock_response = AsyncMock()
|
mock_response = AsyncMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json = AsyncMock(return_value=mock_response_data)
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
mock_context = AsyncMock()
|
mock_context = AsyncMock()
|
||||||
mock_context.post = AsyncMock(return_value=mock_response)
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
@@ -178,7 +178,7 @@ async def test_chat_uses_default_model_when_no_model_provided():
|
|||||||
with patch("httpx.AsyncClient") as mock_client:
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
mock_response = AsyncMock()
|
mock_response = AsyncMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json = AsyncMock(return_value=mock_response_data)
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
mock_context = AsyncMock()
|
mock_context = AsyncMock()
|
||||||
mock_context.post = AsyncMock(return_value=mock_response)
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
@@ -228,7 +228,7 @@ async def test_chat_with_tool_calls():
|
|||||||
with patch("httpx.AsyncClient") as mock_client:
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
mock_response = AsyncMock()
|
mock_response = AsyncMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json = AsyncMock(return_value=mock_response_data)
|
mock_response.json = Mock(return_value=mock_response_data)
|
||||||
|
|
||||||
mock_context = AsyncMock()
|
mock_context = AsyncMock()
|
||||||
mock_context.post = AsyncMock(return_value=mock_response)
|
mock_context.post = AsyncMock(return_value=mock_response)
|
||||||
|
|||||||
Reference in New Issue
Block a user