Merge main into pr-1482
This commit is contained in:
399
tests/test_azure_openai_provider.py
Normal file
399
tests/test_azure_openai_provider.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
@@ -40,7 +40,7 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
|
||||
|
||||
|
||||
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
"""Runtime metadata should be a separate user message before the actual user message."""
|
||||
"""Runtime metadata should be merged with the user message."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
@@ -54,13 +54,12 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "## Current Session" not in messages[0]["content"]
|
||||
|
||||
assert messages[-2]["role"] == "user"
|
||||
runtime_content = messages[-2]["content"]
|
||||
assert isinstance(runtime_content, str)
|
||||
assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
|
||||
assert "Current Time:" in runtime_content
|
||||
assert "Channel: cli" in runtime_content
|
||||
assert "Chat ID: direct" in runtime_content
|
||||
|
||||
# Runtime context is now merged with user message into a single message
|
||||
assert messages[-1]["role"] == "user"
|
||||
assert messages[-1]["content"] == "Return exactly: OK"
|
||||
user_content = messages[-1]["content"]
|
||||
assert isinstance(user_content, str)
|
||||
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
|
||||
assert "Current Time:" in user_content
|
||||
assert "Channel: cli" in user_content
|
||||
assert "Chat ID: direct" in user_content
|
||||
assert "Return exactly: OK" in user_content
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
|
||||
monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"cron",
|
||||
"add",
|
||||
"--name",
|
||||
"demo",
|
||||
"--message",
|
||||
"hello",
|
||||
"--cron",
|
||||
"0 9 * * *",
|
||||
"--tz",
|
||||
"America/Vancovuer",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
|
||||
assert not (tmp_path / "cron" / "jobs.json").exists()
|
||||
@@ -48,6 +48,8 @@ async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||
)
|
||||
await service.start()
|
||||
try:
|
||||
# Wait slightly to ensure file mtime is definitively different
|
||||
await asyncio.sleep(0.05)
|
||||
external = CronService(store_path)
|
||||
updated = external.enable_job(job.id, enabled=False)
|
||||
assert updated is not None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from nanobot.channels.feishu import _extract_post_content
|
||||
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
|
||||
|
||||
|
||||
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||
@@ -38,3 +38,28 @@ def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
||||
|
||||
assert text == "Daily report"
|
||||
assert image_keys == ["img_a", "img_b"]
|
||||
|
||||
|
||||
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
|
||||
class Builder:
|
||||
pass
|
||||
|
||||
builder = Builder()
|
||||
same = FeishuChannel._register_optional_event(builder, "missing", object())
|
||||
assert same is builder
|
||||
|
||||
|
||||
def test_register_optional_event_calls_supported_method() -> None:
|
||||
called = []
|
||||
|
||||
class Builder:
|
||||
def register_event(self, handler):
|
||||
called.append(handler)
|
||||
return self
|
||||
|
||||
builder = Builder()
|
||||
handler = object()
|
||||
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
|
||||
|
||||
assert same is builder
|
||||
assert called == [handler]
|
||||
|
||||
104
tests/test_feishu_table_split.py
Normal file
104
tests/test_feishu_table_split.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for FeishuChannel._split_elements_by_table_limit.
|
||||
|
||||
Feishu cards reject messages that contain more than one table element
|
||||
(API error 11310: card table number over limit). The helper splits a flat
|
||||
list of card elements into groups so that each group contains at most one
|
||||
table, allowing nanobot to send multiple cards instead of failing.
|
||||
"""
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def _md(text: str) -> dict:
|
||||
return {"tag": "markdown", "content": text}
|
||||
|
||||
|
||||
def _table() -> dict:
|
||||
return {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "v"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
|
||||
|
||||
split = FeishuChannel._split_elements_by_table_limit
|
||||
|
||||
|
||||
def test_empty_list_returns_single_empty_group() -> None:
|
||||
assert split([]) == [[]]
|
||||
|
||||
|
||||
def test_no_tables_returns_single_group() -> None:
|
||||
els = [_md("hello"), _md("world")]
|
||||
result = split(els)
|
||||
assert result == [els]
|
||||
|
||||
|
||||
def test_single_table_stays_in_one_group() -> None:
|
||||
els = [_md("intro"), _table(), _md("outro")]
|
||||
result = split(els)
|
||||
assert len(result) == 1
|
||||
assert result[0] == els
|
||||
|
||||
|
||||
def test_two_tables_split_into_two_groups() -> None:
|
||||
# Use different row values so the two tables are not equal
|
||||
t1 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "table-one"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
t2 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
|
||||
"rows": [{"c0": "table-two"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
els = [_md("before"), t1, _md("between"), t2, _md("after")]
|
||||
result = split(els)
|
||||
assert len(result) == 2
|
||||
# First group: text before table-1 + table-1
|
||||
assert t1 in result[0]
|
||||
assert t2 not in result[0]
|
||||
# Second group: text between tables + table-2 + text after
|
||||
assert t2 in result[1]
|
||||
assert t1 not in result[1]
|
||||
|
||||
|
||||
def test_three_tables_split_into_three_groups() -> None:
|
||||
tables = [
|
||||
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
|
||||
for i in range(3)
|
||||
]
|
||||
els = tables[:]
|
||||
result = split(els)
|
||||
assert len(result) == 3
|
||||
for i, group in enumerate(result):
|
||||
assert tables[i] in group
|
||||
|
||||
|
||||
def test_leading_markdown_stays_with_first_table() -> None:
|
||||
intro = _md("intro")
|
||||
t = _table()
|
||||
result = split([intro, t])
|
||||
assert len(result) == 1
|
||||
assert result[0] == [intro, t]
|
||||
|
||||
|
||||
def test_trailing_markdown_after_second_table() -> None:
|
||||
t1, t2 = _table(), _table()
|
||||
tail = _md("end")
|
||||
result = split([t1, t2, tail])
|
||||
assert len(result) == 2
|
||||
assert result[1] == [t2, tail]
|
||||
|
||||
|
||||
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
|
||||
head = _md("head")
|
||||
t1, t2 = _table(), _table()
|
||||
result = split([head, t1, t2])
|
||||
# head + t1 in group 0; t2 in group 1
|
||||
assert result[0] == [head, t1]
|
||||
assert result[1] == [t2]
|
||||
@@ -159,6 +159,7 @@ class _FakeAsyncClient:
|
||||
|
||||
|
||||
def _make_config(**kwargs) -> MatrixConfig:
|
||||
kwargs.setdefault("allow_from", ["*"])
|
||||
return MatrixConfig(
|
||||
enabled=True,
|
||||
homeserver="https://matrix.org",
|
||||
@@ -274,7 +275,7 @@ async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_invite_joins_when_allow_list_is_empty() -> None:
|
||||
async def test_room_invite_ignores_when_allow_list_is_empty() -> None:
|
||||
channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
@@ -284,9 +285,22 @@ async def test_room_invite_joins_when_allow_list_is_empty() -> None:
|
||||
|
||||
await channel._on_room_invite(room, event)
|
||||
|
||||
assert client.join_calls == ["!room:matrix.org"]
|
||||
assert client.join_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_invite_joins_when_sender_allowed() -> None:
|
||||
channel = MatrixChannel(_make_config(allow_from=["@alice:matrix.org"]), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
room = SimpleNamespace(room_id="!room:matrix.org")
|
||||
event = SimpleNamespace(sender="@alice:matrix.org")
|
||||
|
||||
await channel._on_room_invite(room, event)
|
||||
|
||||
assert client.join_calls == ["!room:matrix.org"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_invite_respects_allow_list_when_configured() -> None:
|
||||
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
|
||||
@@ -1163,6 +1177,8 @@ async def test_send_progress_keeps_typing_keepalive_running() -> None:
|
||||
assert "!room:matrix.org" in channel._typing_tasks
|
||||
assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
|
||||
|
||||
await channel.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_clears_typing_when_send_fails() -> None:
|
||||
|
||||
@@ -145,3 +145,78 @@ class TestMemoryConsolidationTypeHandling:
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a list - extract first element if it's a dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a list containing a dict
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[{
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
|
||||
"""Empty list arguments should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
|
||||
"""List with non-dict content should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=["string", "content"],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -86,6 +86,35 @@ class TestMessageToolSuppressLogic:
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
calls = iter([
|
||||
LLMResponse(
|
||||
content="Visible<think>hidden</think>",
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content="secret reasoning",
|
||||
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
progress: list[tuple[str, bool]] = []
|
||||
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False),
|
||||
('read_file("foo.txt")', True),
|
||||
]
|
||||
|
||||
|
||||
class TestMessageToolTurnTracking:
|
||||
|
||||
|
||||
162
tests/test_telegram_channel.py
Normal file
162
tests/test_telegram_channel.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
|
||||
class _FakeHTTPXRequest:
|
||||
instances: list["_FakeHTTPXRequest"] = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
|
||||
class _FakeUpdater:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self._on_start_polling = on_start_polling
|
||||
|
||||
async def start_polling(self, **kwargs) -> None:
|
||||
self._on_start_polling()
|
||||
|
||||
|
||||
class _FakeBot:
|
||||
def __init__(self) -> None:
|
||||
self.sent_messages: list[dict] = []
|
||||
|
||||
async def get_me(self):
|
||||
return SimpleNamespace(username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
self.commands = commands
|
||||
|
||||
async def send_message(self, **kwargs) -> None:
|
||||
self.sent_messages.append(kwargs)
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self.bot = _FakeBot()
|
||||
self.updater = _FakeUpdater(on_start_polling)
|
||||
self.handlers = []
|
||||
self.error_handlers = []
|
||||
|
||||
def add_error_handler(self, handler) -> None:
|
||||
self.error_handlers.append(handler)
|
||||
|
||||
def add_handler(self, handler) -> None:
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeBuilder:
|
||||
def __init__(self, app: _FakeApp) -> None:
|
||||
self.app = app
|
||||
self.token_value = None
|
||||
self.request_value = None
|
||||
self.get_updates_request_value = None
|
||||
|
||||
def token(self, token: str):
|
||||
self.token_value = token
|
||||
return self
|
||||
|
||||
def request(self, request):
|
||||
self.request_value = request
|
||||
return self
|
||||
|
||||
def get_updates_request(self, request):
|
||||
self.get_updates_request_value = request
|
||||
return self
|
||||
|
||||
def proxy(self, _proxy):
|
||||
raise AssertionError("builder.proxy should not be called when request is set")
|
||||
|
||||
def get_updates_proxy(self, _proxy):
|
||||
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
|
||||
|
||||
def build(self):
|
||||
return self.app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert len(_FakeHTTPXRequest.instances) == 1
|
||||
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
|
||||
assert builder.request_value is _FakeHTTPXRequest.instances[0]
|
||||
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
|
||||
|
||||
|
||||
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="supergroup"),
|
||||
chat_id=-100123,
|
||||
message_thread_id=42,
|
||||
)
|
||||
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"_progress": True, "message_thread_id": 42},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._message_threads[("123", 10)] = 42
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"message_id": 10},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||
@@ -106,3 +106,234 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
class CastTestTool(Tool):
|
||||
"""Minimal tool for testing cast_params."""
|
||||
|
||||
def __init__(self, schema: dict[str, Any]) -> None:
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cast_test"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "test tool for casting"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_cast_params_string_to_int() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "42"})
|
||||
assert result["count"] == 42
|
||||
assert isinstance(result["count"], int)
|
||||
|
||||
|
||||
def test_cast_params_string_to_number() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "3.14"})
|
||||
assert result["rate"] == 3.14
|
||||
assert isinstance(result["rate"], float)
|
||||
|
||||
|
||||
def test_cast_params_string_to_bool() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"enabled": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||
|
||||
|
||||
def test_cast_params_array_items() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||
assert result["nums"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_cast_params_nested_object() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"port": {"type": "integer"},
|
||||
"debug": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||
assert result["config"]["port"] == 8080
|
||||
assert result["config"]["debug"] is True
|
||||
|
||||
|
||||
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||
"""Booleans should not be silently cast to integers."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": True})
|
||||
assert result["count"] is True
|
||||
errors = tool.validate_params(result)
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_preserves_empty_string() -> None:
|
||||
"""Empty strings should be preserved for string type."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": ""})
|
||||
assert result["name"] == ""
|
||||
|
||||
|
||||
def test_cast_params_bool_string_false() -> None:
|
||||
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||
|
||||
|
||||
def test_cast_params_bool_string_invalid() -> None:
|
||||
"""Invalid boolean strings should not be cast."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
# Invalid strings should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"flag": "random"})
|
||||
assert result["flag"] == "random"
|
||||
result = tool.cast_params({"flag": "maybe"})
|
||||
assert result["flag"] == "maybe"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_int() -> None:
|
||||
"""Invalid strings should not be cast to integer."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "abc"})
|
||||
assert result["count"] == "abc" # Original value preserved
|
||||
result = tool.cast_params({"count": "12.5.7"})
|
||||
assert result["count"] == "12.5.7"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_number() -> None:
|
||||
"""Invalid strings should not be cast to number."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "not_a_number"})
|
||||
assert result["rate"] == "not_a_number"
|
||||
|
||||
|
||||
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||
"""Booleans should not pass number validation."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"rate": False})
|
||||
assert any("rate should be number" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_none_values() -> None:
|
||||
"""Test None handling for different types."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
"items": {"type": "array"},
|
||||
"config": {"type": "object"},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params(
|
||||
{
|
||||
"name": None,
|
||||
"count": None,
|
||||
"items": None,
|
||||
"config": None,
|
||||
}
|
||||
)
|
||||
# None should be preserved for all types
|
||||
assert result["name"] is None
|
||||
assert result["count"] is None
|
||||
assert result["items"] is None
|
||||
assert result["config"] is None
|
||||
|
||||
|
||||
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||
"""Single values should NOT be automatically wrapped into arrays."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
}
|
||||
)
|
||||
# Non-array values should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"items": 5})
|
||||
assert result["items"] == 5 # Not wrapped to [5]
|
||||
result = tool.cast_params({"items": "text"})
|
||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||
|
||||
Reference in New Issue
Block a user