Merge main into pr-436
This commit is contained in:
399
tests/test_azure_openai_provider.py
Normal file
399
tests/test_azure_openai_provider.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Test Azure OpenAI provider implementation (updated for model-based deployment names)."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def test_azure_openai_provider_init():
|
||||
"""Test AzureOpenAIProvider initialization without deployment_name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
assert provider.api_key == "test-key"
|
||||
assert provider.api_base == "https://test-resource.openai.azure.com/"
|
||||
assert provider.default_model == "gpt-4o-deployment"
|
||||
assert provider.api_version == "2024-10-21"
|
||||
|
||||
|
||||
def test_azure_openai_provider_init_validation():
|
||||
"""Test AzureOpenAIProvider initialization validation."""
|
||||
# Missing api_key
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_key is required"):
|
||||
AzureOpenAIProvider(api_key="", api_base="https://test.com")
|
||||
|
||||
# Missing api_base
|
||||
with pytest.raises(ValueError, match="Azure OpenAI api_base is required"):
|
||||
AzureOpenAIProvider(api_key="test", api_base="")
|
||||
|
||||
|
||||
def test_build_chat_url():
|
||||
"""Test Azure OpenAI URL building with different deployment names."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test various deployment names
|
||||
test_cases = [
|
||||
("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"),
|
||||
("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"),
|
||||
("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"),
|
||||
]
|
||||
|
||||
for deployment_name, expected_url in test_cases:
|
||||
url = provider._build_chat_url(deployment_name)
|
||||
assert url == expected_url
|
||||
|
||||
|
||||
def test_build_chat_url_api_base_without_slash():
|
||||
"""Test URL building when api_base doesn't end with slash."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com", # No trailing slash
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
url = provider._build_chat_url("test-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
|
||||
|
||||
def test_build_headers():
|
||||
"""Test Azure OpenAI header building with api-key authentication."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-api-key-123",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
headers = provider._build_headers()
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header
|
||||
assert "x-session-affinity" in headers
|
||||
|
||||
|
||||
def test_prepare_request_payload():
|
||||
"""Test request payload preparation with Azure OpenAI 2024-10-21 compliance."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8)
|
||||
|
||||
assert payload["messages"] == messages
|
||||
assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens
|
||||
assert payload["temperature"] == 0.8
|
||||
assert "tools" not in payload
|
||||
|
||||
# Test with tools
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools)
|
||||
assert payload_with_tools["tools"] == tools
|
||||
assert payload_with_tools["tool_choice"] == "auto"
|
||||
|
||||
# Test with reasoning_effort
|
||||
payload_with_reasoning = provider._prepare_request_payload(
|
||||
"gpt-5-chat", messages, reasoning_effort="medium"
|
||||
)
|
||||
assert payload_with_reasoning["reasoning_effort"] == "medium"
|
||||
assert "temperature" not in payload_with_reasoning
|
||||
|
||||
|
||||
def test_prepare_request_payload_sanitizes_messages():
|
||||
"""Test Azure payload strips non-standard message keys before sending."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
"reasoning_content": "hidden chain-of-thought",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
"extra_field": "should be removed",
|
||||
},
|
||||
]
|
||||
|
||||
payload = provider._prepare_request_payload("gpt-4o", messages)
|
||||
|
||||
assert payload["messages"] == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"name": "x",
|
||||
"content": "ok",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_success():
|
||||
"""Test successful chat request using model as deployment name."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
|
||||
# Mock response data
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 18,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
# Test with specific model (deployment name)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages, model="custom-deployment")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.usage["prompt_tokens"] == 12
|
||||
assert result.usage["completion_tokens"] == 18
|
||||
assert result.usage["total_tokens"] == 30
|
||||
|
||||
# Verify URL was built with the provided model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_uses_default_model_when_no_model_provided():
|
||||
"""Test that chat uses default_model when no model is specified."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="default-deployment",
|
||||
)
|
||||
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {"content": "Response", "role": "assistant"},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
await provider.chat(messages) # No model specified
|
||||
|
||||
# Verify URL was built with default model as deployment name
|
||||
call_args = mock_context.post.call_args
|
||||
expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_tool_calls():
|
||||
"""Test chat request with tool calls in response."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Mock response with tool calls
|
||||
mock_response_data = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_12345",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "San Francisco"}'
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = Mock(return_value=mock_response_data)
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "What's the weather?"}]
|
||||
tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}]
|
||||
result = await provider.chat(messages, tools=tools, model="weather-model")
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert result.content is None
|
||||
assert result.finish_reason == "tool_calls"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_api_error():
|
||||
"""Test chat request API error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Invalid authentication credentials"
|
||||
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Azure OpenAI API Error 401" in result.content
|
||||
assert "Invalid authentication credentials" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_connection_error():
|
||||
"""Test chat request connection error handling."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_context = AsyncMock()
|
||||
mock_context.post = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.return_value.__aenter__.return_value = mock_context
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await provider.chat(messages)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_parse_response_malformed():
|
||||
"""Test response parsing with malformed data."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o",
|
||||
)
|
||||
|
||||
# Test with missing choices
|
||||
malformed_response = {"usage": {"prompt_tokens": 10}}
|
||||
result = provider._parse_response(malformed_response)
|
||||
|
||||
assert isinstance(result, LLMResponse)
|
||||
assert "Error parsing Azure OpenAI response" in result.content
|
||||
assert result.finish_reason == "error"
|
||||
|
||||
|
||||
def test_get_default_model():
|
||||
"""Test get_default_model method."""
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="my-custom-deployment",
|
||||
)
|
||||
|
||||
assert provider.get_default_model() == "my-custom-deployment"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Running basic Azure OpenAI provider tests...")
|
||||
|
||||
# Test initialization
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key="test-key",
|
||||
api_base="https://test-resource.openai.azure.com",
|
||||
default_model="gpt-4o-deployment",
|
||||
)
|
||||
print("✅ Provider initialization successful")
|
||||
|
||||
# Test URL building
|
||||
url = provider._build_chat_url("my-deployment")
|
||||
expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21"
|
||||
assert url == expected
|
||||
print("✅ URL building works correctly")
|
||||
|
||||
# Test headers
|
||||
headers = provider._build_headers()
|
||||
assert headers["api-key"] == "test-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
print("✅ Header building works correctly")
|
||||
|
||||
# Test payload preparation
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000)
|
||||
assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format
|
||||
print("✅ Payload preparation works correctly")
|
||||
|
||||
print("✅ All basic tests passed! Updated test file is working correctly.")
|
||||
59
tests/test_cli_input.py
Normal file
59
tests/test_cli_input.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
|
||||
from nanobot.cli import commands
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_session():
|
||||
"""Mock the global prompt session."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.prompt_async = AsyncMock()
|
||||
with patch("nanobot.cli.commands._PROMPT_SESSION", mock_session), \
|
||||
patch("nanobot.cli.commands.patch_stdout"):
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_returns_input(mock_prompt_session):
|
||||
"""Test that _read_interactive_input_async returns the user input from prompt_session."""
|
||||
mock_prompt_session.prompt_async.return_value = "hello world"
|
||||
|
||||
result = await commands._read_interactive_input_async()
|
||||
|
||||
assert result == "hello world"
|
||||
mock_prompt_session.prompt_async.assert_called_once()
|
||||
args, _ = mock_prompt_session.prompt_async.call_args
|
||||
assert isinstance(args[0], HTML) # Verify HTML prompt is used
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_interactive_input_async_handles_eof(mock_prompt_session):
|
||||
"""Test that EOFError converts to KeyboardInterrupt."""
|
||||
mock_prompt_session.prompt_async.side_effect = EOFError()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
await commands._read_interactive_input_async()
|
||||
|
||||
|
||||
def test_init_prompt_session_creates_session():
|
||||
"""Test that _init_prompt_session initializes the global session."""
|
||||
# Ensure global is None before test
|
||||
commands._PROMPT_SESSION = None
|
||||
|
||||
with patch("nanobot.cli.commands.PromptSession") as MockSession, \
|
||||
patch("nanobot.cli.commands.FileHistory") as MockHistory, \
|
||||
patch("pathlib.Path.home") as mock_home:
|
||||
|
||||
mock_home.return_value = MagicMock()
|
||||
|
||||
commands._init_prompt_session()
|
||||
|
||||
assert commands._PROMPT_SESSION is not None
|
||||
MockSession.assert_called_once()
|
||||
_, kwargs = MockSession.call_args
|
||||
assert kwargs["multiline"] is False
|
||||
assert kwargs["enable_open_in_editor"] is False
|
||||
130
tests/test_commands.py
Normal file
130
tests/test_commands.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||
from nanobot.providers.registry import find_by_model
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paths():
|
||||
"""Mock config/workspace paths for test isolation."""
|
||||
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
|
||||
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||
patch("nanobot.utils.helpers.get_workspace_path") as mock_ws:
|
||||
|
||||
base_dir = Path("./test_onboard_data")
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
base_dir.mkdir()
|
||||
|
||||
config_file = base_dir / "config.json"
|
||||
workspace_dir = base_dir / "workspace"
|
||||
|
||||
mock_cp.return_value = config_file
|
||||
mock_ws.return_value = workspace_dir
|
||||
mock_sc.side_effect = lambda config: config_file.write_text("{}")
|
||||
|
||||
yield config_file, workspace_dir
|
||||
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
|
||||
|
||||
def test_onboard_fresh_install(mock_paths):
|
||||
"""No existing config — should create from scratch."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
|
||||
result = runner.invoke(app, ["onboard"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created config" in result.stdout
|
||||
assert "Created workspace" in result.stdout
|
||||
assert "nanobot is ready" in result.stdout
|
||||
assert config_file.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
assert (workspace_dir / "memory" / "MEMORY.md").exists()
|
||||
|
||||
|
||||
def test_onboard_existing_config_refresh(mock_paths):
|
||||
"""Config exists, user declines overwrite — should refresh (load-merge-save)."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "existing values preserved" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def test_onboard_existing_config_overwrite(mock_paths):
|
||||
"""Config exists, user confirms overwrite — should reset to defaults."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
config_file.write_text('{"existing": true}')
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="y\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Config already exists" in result.stdout
|
||||
assert "Config reset to defaults" in result.stdout
|
||||
assert workspace_dir.exists()
|
||||
|
||||
|
||||
def test_onboard_existing_workspace_safe_create(mock_paths):
|
||||
"""Workspace exists — should not recreate, but still add missing templates."""
|
||||
config_file, workspace_dir = mock_paths
|
||||
workspace_dir.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Created workspace" not in result.stdout
|
||||
assert "Created AGENTS.md" in result.stdout
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
|
||||
|
||||
assert config.get_provider_name() == "github_copilot"
|
||||
|
||||
|
||||
def test_config_matches_openai_codex_with_hyphen_prefix():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "openai-codex/gpt-5.1-codex"
|
||||
|
||||
assert config.get_provider_name() == "openai_codex"
|
||||
|
||||
|
||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||
|
||||
assert spec is not None
|
||||
assert spec.name == "github_copilot"
|
||||
|
||||
|
||||
def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
|
||||
provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex")
|
||||
|
||||
resolved = provider._resolve_model("github-copilot/gpt-5.3-codex")
|
||||
|
||||
assert resolved == "github_copilot/gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
821
tests/test_consolidate_offset.py
Normal file
821
tests/test_consolidate_offset.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""Test session management with cache-friendly message handling."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
# Test constants
|
||||
MEMORY_WINDOW = 50
|
||||
KEEP_COUNT = MEMORY_WINDOW // 2 # 25
|
||||
|
||||
|
||||
def create_session_with_messages(key: str, count: int, role: str = "user") -> Session:
|
||||
"""Create a session and add the specified number of messages.
|
||||
|
||||
Args:
|
||||
key: Session identifier
|
||||
count: Number of messages to add
|
||||
role: Message role (default: "user")
|
||||
|
||||
Returns:
|
||||
Session with the specified messages
|
||||
"""
|
||||
session = Session(key=key)
|
||||
for i in range(count):
|
||||
session.add_message(role, f"msg{i}")
|
||||
return session
|
||||
|
||||
|
||||
def assert_messages_content(messages: list, start_index: int, end_index: int) -> None:
|
||||
"""Assert that messages contain expected content from start to end index.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
start_index: Expected first message index
|
||||
end_index: Expected last message index
|
||||
"""
|
||||
assert len(messages) > 0
|
||||
assert messages[0]["content"] == f"msg{start_index}"
|
||||
assert messages[-1]["content"] == f"msg{end_index}"
|
||||
|
||||
|
||||
def get_old_messages(session: Session, last_consolidated: int, keep_count: int) -> list:
|
||||
"""Extract messages that would be consolidated using the standard slice logic.
|
||||
|
||||
Args:
|
||||
session: The session containing messages
|
||||
last_consolidated: Index of last consolidated message
|
||||
keep_count: Number of recent messages to keep
|
||||
|
||||
Returns:
|
||||
List of messages that would be consolidated
|
||||
"""
|
||||
return session.messages[last_consolidated:-keep_count]
|
||||
|
||||
|
||||
class TestSessionLastConsolidated:
|
||||
"""Test last_consolidated tracking to avoid duplicate processing."""
|
||||
|
||||
def test_initial_last_consolidated_zero(self) -> None:
|
||||
"""Test that new session starts with last_consolidated=0."""
|
||||
session = Session(key="test:initial")
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_last_consolidated_persistence(self, tmp_path) -> None:
|
||||
"""Test that last_consolidated persists across save/load."""
|
||||
manager = SessionManager(Path(tmp_path))
|
||||
session1 = create_session_with_messages("test:persist", 20)
|
||||
session1.last_consolidated = 15
|
||||
manager.save(session1)
|
||||
|
||||
session2 = manager.get_or_create("test:persist")
|
||||
assert session2.last_consolidated == 15
|
||||
assert len(session2.messages) == 20
|
||||
|
||||
def test_clear_resets_last_consolidated(self) -> None:
|
||||
"""Test that clear() resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
session.last_consolidated = 5
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
|
||||
class TestSessionImmutableHistory:
|
||||
"""Test Session message immutability for cache efficiency."""
|
||||
|
||||
def test_initial_state(self) -> None:
|
||||
"""Test that new session has empty messages list."""
|
||||
session = Session(key="test:initial")
|
||||
assert len(session.messages) == 0
|
||||
|
||||
def test_add_messages_appends_only(self) -> None:
|
||||
"""Test that adding messages only appends, never modifies."""
|
||||
session = Session(key="test:preserve")
|
||||
session.add_message("user", "msg1")
|
||||
session.add_message("assistant", "resp1")
|
||||
session.add_message("user", "msg2")
|
||||
assert len(session.messages) == 3
|
||||
assert session.messages[0]["content"] == "msg1"
|
||||
|
||||
def test_get_history_returns_most_recent(self) -> None:
|
||||
"""Test get_history returns the most recent messages."""
|
||||
session = Session(key="test:history")
|
||||
for i in range(10):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
|
||||
history = session.get_history(max_messages=6)
|
||||
assert len(history) == 6
|
||||
assert history[0]["content"] == "msg7"
|
||||
assert history[-1]["content"] == "resp9"
|
||||
|
||||
def test_get_history_with_all_messages(self) -> None:
|
||||
"""Test get_history with max_messages larger than actual."""
|
||||
session = create_session_with_messages("test:all", 5)
|
||||
history = session.get_history(max_messages=100)
|
||||
assert len(history) == 5
|
||||
assert history[0]["content"] == "msg0"
|
||||
|
||||
def test_get_history_stable_for_same_session(self) -> None:
|
||||
"""Test that get_history returns same content for same max_messages."""
|
||||
session = create_session_with_messages("test:stable", 20)
|
||||
history1 = session.get_history(max_messages=10)
|
||||
history2 = session.get_history(max_messages=10)
|
||||
assert history1 == history2
|
||||
|
||||
def test_messages_list_never_modified(self) -> None:
|
||||
"""Test that messages list is never modified after creation."""
|
||||
session = create_session_with_messages("test:immutable", 5)
|
||||
original_len = len(session.messages)
|
||||
|
||||
session.get_history(max_messages=2)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
for _ in range(10):
|
||||
session.get_history(max_messages=3)
|
||||
assert len(session.messages) == original_len
|
||||
|
||||
|
||||
class TestSessionPersistence:
|
||||
"""Test Session persistence and reload."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_manager(self, tmp_path):
|
||||
return SessionManager(Path(tmp_path))
|
||||
|
||||
def test_persistence_roundtrip(self, temp_manager):
|
||||
"""Test that messages persist across save/load."""
|
||||
session1 = create_session_with_messages("test:persistence", 20)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:persistence")
|
||||
assert len(session2.messages) == 20
|
||||
assert session2.messages[0]["content"] == "msg0"
|
||||
assert session2.messages[-1]["content"] == "msg19"
|
||||
|
||||
def test_get_history_after_reload(self, temp_manager):
|
||||
"""Test that get_history works correctly after reload."""
|
||||
session1 = create_session_with_messages("test:reload", 30)
|
||||
temp_manager.save(session1)
|
||||
|
||||
session2 = temp_manager.get_or_create("test:reload")
|
||||
history = session2.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
assert history[0]["content"] == "msg20"
|
||||
assert history[-1]["content"] == "msg29"
|
||||
|
||||
def test_clear_resets_session(self, temp_manager):
|
||||
"""Test that clear() properly resets session."""
|
||||
session = create_session_with_messages("test:clear", 10)
|
||||
assert len(session.messages) == 10
|
||||
|
||||
session.clear()
|
||||
assert len(session.messages) == 0
|
||||
|
||||
|
||||
class TestConsolidationTriggerConditions:
|
||||
"""Test consolidation trigger conditions and logic."""
|
||||
|
||||
def test_consolidation_needed_when_messages_exceed_window(self):
|
||||
"""Test consolidation logic: should trigger when messages > memory_window."""
|
||||
session = create_session_with_messages("test:trigger", 60)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
|
||||
assert total_messages > MEMORY_WINDOW
|
||||
assert messages_to_process > 0
|
||||
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT
|
||||
assert expected_consolidate_count == 35
|
||||
|
||||
def test_consolidation_skipped_when_within_keep_count(self):
|
||||
"""Test consolidation skipped when total messages <= keep_count."""
|
||||
session = create_session_with_messages("test:skip", 20)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
assert total_messages <= KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_consolidation_skipped_when_no_new_messages(self):
|
||||
"""Test consolidation skipped when messages_to_process <= 0."""
|
||||
session = create_session_with_messages("test:already_consolidated", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add a few more messages
|
||||
for i in range(40, 42):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process > 0
|
||||
|
||||
# Simulate last_consolidated catching up
|
||||
session.last_consolidated = total_messages - KEEP_COUNT
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestLastConsolidatedEdgeCases:
|
||||
"""Test last_consolidated edge cases and data corruption scenarios."""
|
||||
|
||||
def test_last_consolidated_exceeds_message_count(self):
|
||||
"""Test behavior when last_consolidated > len(messages) (data corruption)."""
|
||||
session = create_session_with_messages("test:corruption", 10)
|
||||
session.last_consolidated = 20
|
||||
|
||||
total_messages = len(session.messages)
|
||||
messages_to_process = total_messages - session.last_consolidated
|
||||
assert messages_to_process <= 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 5)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_last_consolidated_negative_value(self):
|
||||
"""Test behavior with negative last_consolidated (invalid state)."""
|
||||
session = create_session_with_messages("test:negative", 10)
|
||||
session.last_consolidated = -5
|
||||
|
||||
keep_count = 3
|
||||
old_messages = get_old_messages(session, session.last_consolidated, keep_count)
|
||||
|
||||
# messages[-5:-3] with 10 messages gives indices 5,6
|
||||
assert len(old_messages) == 2
|
||||
assert old_messages[0]["content"] == "msg5"
|
||||
assert old_messages[-1]["content"] == "msg6"
|
||||
|
||||
def test_messages_added_after_consolidation(self):
|
||||
"""Test correct behavior when new messages arrive after consolidation."""
|
||||
session = create_session_with_messages("test:new_messages", 40)
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT # 15
|
||||
|
||||
# Add new messages after consolidation
|
||||
for i in range(40, 50):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
total_messages = len(session.messages)
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
expected_consolidate_count = total_messages - KEEP_COUNT - session.last_consolidated
|
||||
|
||||
assert len(old_messages) == expected_consolidate_count
|
||||
assert_messages_content(old_messages, 15, 24)
|
||||
|
||||
def test_slice_behavior_when_indices_overlap(self):
|
||||
"""Test slice behavior when last_consolidated >= total - keep_count."""
|
||||
session = create_session_with_messages("test:overlap", 30)
|
||||
session.last_consolidated = 12
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, 20)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestArchiveAllMode:
|
||||
"""Test archive_all mode (used by /new command)."""
|
||||
|
||||
def test_archive_all_consolidates_everything(self):
|
||||
"""Test archive_all=True consolidates all messages."""
|
||||
session = create_session_with_messages("test:archive_all", 50)
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
assert len(old_messages) == 50
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
def test_archive_all_resets_last_consolidated(self):
|
||||
"""Test that archive_all mode resets last_consolidated to 0."""
|
||||
session = create_session_with_messages("test:reset", 40)
|
||||
session.last_consolidated = 15
|
||||
|
||||
archive_all = True
|
||||
if archive_all:
|
||||
session.last_consolidated = 0
|
||||
|
||||
assert session.last_consolidated == 0
|
||||
assert len(session.messages) == 40
|
||||
|
||||
def test_archive_all_vs_normal_consolidation(self):
|
||||
"""Test difference between archive_all and normal consolidation."""
|
||||
# Normal consolidation
|
||||
session1 = create_session_with_messages("test:normal", 60)
|
||||
session1.last_consolidated = len(session1.messages) - KEEP_COUNT
|
||||
|
||||
# archive_all mode
|
||||
session2 = create_session_with_messages("test:all", 60)
|
||||
session2.last_consolidated = 0
|
||||
|
||||
assert session1.last_consolidated == 35
|
||||
assert len(session1.messages) == 60
|
||||
assert session2.last_consolidated == 0
|
||||
assert len(session2.messages) == 60
|
||||
|
||||
|
||||
class TestCacheImmutability:
|
||||
"""Test that consolidation doesn't modify session.messages (cache safety)."""
|
||||
|
||||
def test_consolidation_does_not_modify_messages_list(self):
|
||||
"""Test that consolidation leaves messages list unchanged."""
|
||||
session = create_session_with_messages("test:immutable", 50)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_len = len(session.messages)
|
||||
session.last_consolidated = original_len - KEEP_COUNT
|
||||
|
||||
assert len(session.messages) == original_len
|
||||
assert session.messages == original_messages
|
||||
|
||||
def test_get_history_does_not_modify_messages(self):
|
||||
"""Test that get_history doesn't modify messages list."""
|
||||
session = create_session_with_messages("test:history_immutable", 40)
|
||||
original_messages = [m.copy() for m in session.messages]
|
||||
|
||||
for _ in range(5):
|
||||
history = session.get_history(max_messages=10)
|
||||
assert len(history) == 10
|
||||
|
||||
assert len(session.messages) == 40
|
||||
for i, msg in enumerate(session.messages):
|
||||
assert msg["content"] == original_messages[i]["content"]
|
||||
|
||||
def test_consolidation_only_updates_last_consolidated(self):
|
||||
"""Test that consolidation only updates last_consolidated field."""
|
||||
session = create_session_with_messages("test:field_only", 60)
|
||||
|
||||
original_messages = session.messages.copy()
|
||||
original_key = session.key
|
||||
original_metadata = session.metadata.copy()
|
||||
|
||||
session.last_consolidated = len(session.messages) - KEEP_COUNT
|
||||
|
||||
assert session.messages == original_messages
|
||||
assert session.key == original_key
|
||||
assert session.metadata == original_metadata
|
||||
assert session.last_consolidated == 35
|
||||
|
||||
|
||||
class TestSliceLogic:
|
||||
"""Test the slice logic: messages[last_consolidated:-keep_count]."""
|
||||
|
||||
def test_slice_extracts_correct_range(self):
|
||||
"""Test that slice extracts the correct message range."""
|
||||
session = create_session_with_messages("test:slice", 60)
|
||||
|
||||
old_messages = get_old_messages(session, 0, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 35
|
||||
assert_messages_content(old_messages, 0, 34)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 35, 59)
|
||||
|
||||
def test_slice_with_partial_consolidation(self):
|
||||
"""Test slice when some messages already consolidated."""
|
||||
session = create_session_with_messages("test:partial", 70)
|
||||
|
||||
last_consolidated = 30
|
||||
old_messages = get_old_messages(session, last_consolidated, KEEP_COUNT)
|
||||
|
||||
assert len(old_messages) == 15
|
||||
assert_messages_content(old_messages, 30, 44)
|
||||
|
||||
def test_slice_with_various_keep_counts(self):
|
||||
"""Test slice behavior with different keep_count values."""
|
||||
session = create_session_with_messages("test:keep_counts", 50)
|
||||
|
||||
test_cases = [(10, 40), (20, 30), (30, 20), (40, 10)]
|
||||
|
||||
for keep_count, expected_count in test_cases:
|
||||
old_messages = session.messages[0:-keep_count]
|
||||
assert len(old_messages) == expected_count
|
||||
|
||||
def test_slice_when_keep_count_exceeds_messages(self):
|
||||
"""Test slice when keep_count > len(messages)."""
|
||||
session = create_session_with_messages("test:exceed", 10)
|
||||
|
||||
old_messages = session.messages[0:-20]
|
||||
assert len(old_messages) == 0
|
||||
|
||||
|
||||
class TestEmptyAndBoundarySessions:
|
||||
"""Test empty sessions and boundary conditions."""
|
||||
|
||||
def test_empty_session_consolidation(self):
|
||||
"""Test consolidation behavior with empty session."""
|
||||
session = Session(key="test:empty")
|
||||
|
||||
assert len(session.messages) == 0
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
messages_to_process = len(session.messages) - session.last_consolidated
|
||||
assert messages_to_process == 0
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_single_message_session(self):
|
||||
"""Test consolidation with single message."""
|
||||
session = Session(key="test:single")
|
||||
session.add_message("user", "only message")
|
||||
|
||||
assert len(session.messages) == 1
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_exactly_keep_count_messages(self):
|
||||
"""Test session with exactly keep_count messages."""
|
||||
session = create_session_with_messages("test:exact", KEEP_COUNT)
|
||||
|
||||
assert len(session.messages) == KEEP_COUNT
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 0
|
||||
|
||||
def test_just_over_keep_count(self):
|
||||
"""Test session with one message over keep_count."""
|
||||
session = create_session_with_messages("test:over", KEEP_COUNT + 1)
|
||||
|
||||
assert len(session.messages) == 26
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 1
|
||||
assert old_messages[0]["content"] == "msg0"
|
||||
|
||||
def test_very_large_session(self):
|
||||
"""Test consolidation with very large message count."""
|
||||
session = create_session_with_messages("test:large", 1000)
|
||||
|
||||
assert len(session.messages) == 1000
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
assert len(old_messages) == 975
|
||||
assert_messages_content(old_messages, 0, 974)
|
||||
|
||||
remaining = session.messages[-KEEP_COUNT:]
|
||||
assert len(remaining) == 25
|
||||
assert_messages_content(remaining, 975, 999)
|
||||
|
||||
def test_session_with_gaps_in_consolidation(self):
|
||||
"""Test session with potential gaps in consolidation history."""
|
||||
session = create_session_with_messages("test:gaps", 50)
|
||||
session.last_consolidated = 10
|
||||
|
||||
# Add more messages
|
||||
for i in range(50, 60):
|
||||
session.add_message("user", f"msg{i}")
|
||||
|
||||
old_messages = get_old_messages(session, session.last_consolidated, KEEP_COUNT)
|
||||
|
||||
expected_count = 60 - KEEP_COUNT - 10
|
||||
assert len(old_messages) == expected_count
|
||||
assert_messages_content(old_messages, 10, 34)
|
||||
|
||||
|
||||
class TestConsolidationDeduplicationGuard:
|
||||
"""Test that consolidation tasks are deduplicated and serialized."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
|
||||
"""Concurrent messages above memory_window spawn only one consolidation task."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls
|
||||
consolidation_calls += 1
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await loop._process_message(msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 1, (
|
||||
f"Expected exactly 1 consolidation, got {consolidation_calls}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_command_guard_prevents_concurrent_consolidation(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new command does not run consolidation concurrently with in-flight consolidation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
consolidation_calls = 0
|
||||
active = 0
|
||||
max_active = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
nonlocal consolidation_calls, active, max_active
|
||||
consolidation_calls += 1
|
||||
active += 1
|
||||
max_active = max(max_active, active)
|
||||
await asyncio.sleep(0.05)
|
||||
active -= 1
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
await loop._process_message(new_msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert consolidation_calls == 2, (
|
||||
f"Expected normal + /new consolidations, got {consolidation_calls}"
|
||||
)
|
||||
assert max_active == 1, (
|
||||
f"Expected serialized consolidation, observed concurrency={max_active}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
|
||||
"""create_task results are tracked in _consolidation_tasks while in flight."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||
started.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
|
||||
await started.wait()
|
||||
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
assert len(loop._consolidation_tasks) == 0, (
|
||||
"Task reference must be removed after completion"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new waits for in-flight consolidation and archives before clear."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = 0
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
started.set()
|
||||
await release.wait()
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert session_after.messages == [], "Session should be cleared after successful archival"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
|
||||
"""/new must keep session data if archive step reports failure."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
if archive_all:
|
||||
return False
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "failed" in response.content.lower()
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == before_count, (
|
||||
"Session must remain intact when /new archival fails"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""/new should archive only messages not yet consolidated by prior task."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(15):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
return True
|
||||
|
||||
started.set()
|
||||
await release.wait()
|
||||
sess.last_consolidated = len(sess.messages) - 3
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
|
||||
await loop._process_message(msg)
|
||||
await started.wait()
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
pending_new = asyncio.create_task(loop._process_message(new_msg))
|
||||
await asyncio.sleep(0.02)
|
||||
assert not pending_new.done()
|
||||
|
||||
release.set()
|
||||
response = await pending_new
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert archived_count == 3, (
|
||||
f"Expected only unconsolidated tail to archive, got {archived_count}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||
"""/new clears session and returns confirmation."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||
)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||
65
tests/test_context_prompt_cache.py
Normal file
65
tests/test_context_prompt_cache.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for cache-friendly prompt construction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime as real_datetime
|
||||
from pathlib import Path
|
||||
import datetime as datetime_module
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
|
||||
class _FakeDatetime(real_datetime):
|
||||
current = real_datetime(2026, 2, 24, 13, 59)
|
||||
|
||||
@classmethod
|
||||
def now(cls, tz=None): # type: ignore[override]
|
||||
return cls.current
|
||||
|
||||
|
||||
def _make_workspace(tmp_path: Path) -> Path:
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
return workspace
|
||||
|
||||
|
||||
def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) -> None:
|
||||
"""System prompt should not change just because wall clock minute changes."""
|
||||
monkeypatch.setattr(datetime_module, "datetime", _FakeDatetime)
|
||||
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 13, 59)
|
||||
prompt1 = builder.build_system_prompt()
|
||||
|
||||
_FakeDatetime.current = real_datetime(2026, 2, 24, 14, 0)
|
||||
prompt2 = builder.build_system_prompt()
|
||||
|
||||
assert prompt1 == prompt2
|
||||
|
||||
|
||||
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
"""Runtime metadata should be merged with the user message."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
messages = builder.build_messages(
|
||||
history=[],
|
||||
current_message="Return exactly: OK",
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
)
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "## Current Session" not in messages[0]["content"]
|
||||
|
||||
# Runtime context is now merged with user message into a single message
|
||||
assert messages[-1]["role"] == "user"
|
||||
user_content = messages[-1]["content"]
|
||||
assert isinstance(user_content, str)
|
||||
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
|
||||
assert "Current Time:" in user_content
|
||||
assert "Channel: cli" in user_content
|
||||
assert "Chat ID: direct" in user_content
|
||||
assert "Return exactly: OK" in user_content
|
||||
61
tests/test_cron_service.py
Normal file
61
tests/test_cron_service.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
def test_add_job_rejects_unknown_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
with pytest.raises(ValueError, match="unknown timezone 'America/Vancovuer'"):
|
||||
service.add_job(
|
||||
name="tz typo",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancovuer"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert service.list_jobs(include_disabled=True) == []
|
||||
|
||||
|
||||
def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
|
||||
job = service.add_job(
|
||||
name="tz ok",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="America/Vancouver"),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
assert job.schedule.tz == "America/Vancouver"
|
||||
assert job.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
called: list[str] = []
|
||||
|
||||
async def on_job(job) -> None:
|
||||
called.append(job.id)
|
||||
|
||||
service = CronService(store_path, on_job=on_job)
|
||||
job = service.add_job(
|
||||
name="external-disable",
|
||||
schedule=CronSchedule(kind="every", every_ms=200),
|
||||
message="hello",
|
||||
)
|
||||
await service.start()
|
||||
try:
|
||||
# Wait slightly to ensure file mtime is definitively different
|
||||
await asyncio.sleep(0.05)
|
||||
external = CronService(store_path)
|
||||
updated = external.enable_job(job.id, enabled=False)
|
||||
assert updated is not None
|
||||
assert updated.enabled is False
|
||||
|
||||
await asyncio.sleep(0.35)
|
||||
assert called == []
|
||||
finally:
|
||||
service.stop()
|
||||
@@ -169,7 +169,8 @@ async def test_send_uses_smtp_and_reply_subject(monkeypatch) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
async def test_send_skips_reply_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""When auto_reply_enabled=False, replies should be skipped but proactive sends allowed."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
@@ -201,6 +202,11 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# Mark alice as someone who sent us an email (making this a "reply")
|
||||
channel._last_subject_by_chat["alice@example.com"] = "Previous email"
|
||||
|
||||
# Reply should be skipped (auto_reply_enabled=False)
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
@@ -210,6 +216,7 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
)
|
||||
assert fake_instances == []
|
||||
|
||||
# Reply with force_send=True should be sent
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
@@ -222,6 +229,56 @@ async def test_send_skips_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_email_when_auto_reply_disabled(monkeypatch) -> None:
|
||||
"""Proactive emails (not replies) should be sent even when auto_reply_enabled=False."""
|
||||
class FakeSMTP:
|
||||
def __init__(self, _host: str, _port: int, timeout: int = 30) -> None:
|
||||
self.sent_messages: list[EmailMessage] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def starttls(self, context=None):
|
||||
return None
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return None
|
||||
|
||||
def send_message(self, msg: EmailMessage):
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
fake_instances: list[FakeSMTP] = []
|
||||
|
||||
def _smtp_factory(host: str, port: int, timeout: int = 30):
|
||||
instance = FakeSMTP(host, port, timeout=timeout)
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.smtplib.SMTP", _smtp_factory)
|
||||
|
||||
cfg = _make_config()
|
||||
cfg.auto_reply_enabled = False
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
|
||||
# bob@example.com has never sent us an email (proactive send)
|
||||
# This should be sent even with auto_reply_enabled=False
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="email",
|
||||
chat_id="bob@example.com",
|
||||
content="Hello, this is a proactive email.",
|
||||
)
|
||||
)
|
||||
assert len(fake_instances) == 1
|
||||
assert len(fake_instances[0].sent_messages) == 1
|
||||
sent = fake_instances[0].sent_messages[0]
|
||||
assert sent["To"] == "bob@example.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_consent_not_granted(monkeypatch) -> None:
|
||||
class FakeSMTP:
|
||||
|
||||
65
tests/test_feishu_post_content.py
Normal file
65
tests/test_feishu_post_content.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from nanobot.channels.feishu import FeishuChannel, _extract_post_content
|
||||
|
||||
|
||||
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||
payload = {
|
||||
"post": {
|
||||
"zh_cn": {
|
||||
"title": "日报",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "完成"},
|
||||
{"tag": "img", "image_key": "img_1"},
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "日报 完成"
|
||||
assert image_keys == ["img_1"]
|
||||
|
||||
|
||||
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
||||
payload = {
|
||||
"title": "Daily",
|
||||
"content": [
|
||||
[
|
||||
{"tag": "text", "text": "report"},
|
||||
{"tag": "img", "image_key": "img_a"},
|
||||
{"tag": "img", "image_key": "img_b"},
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
text, image_keys = _extract_post_content(payload)
|
||||
|
||||
assert text == "Daily report"
|
||||
assert image_keys == ["img_a", "img_b"]
|
||||
|
||||
|
||||
def test_register_optional_event_keeps_builder_when_method_missing() -> None:
|
||||
class Builder:
|
||||
pass
|
||||
|
||||
builder = Builder()
|
||||
same = FeishuChannel._register_optional_event(builder, "missing", object())
|
||||
assert same is builder
|
||||
|
||||
|
||||
def test_register_optional_event_calls_supported_method() -> None:
|
||||
called = []
|
||||
|
||||
class Builder:
|
||||
def register_event(self, handler):
|
||||
called.append(handler)
|
||||
return self
|
||||
|
||||
builder = Builder()
|
||||
handler = object()
|
||||
same = FeishuChannel._register_optional_event(builder, "register_event", handler)
|
||||
|
||||
assert same is builder
|
||||
assert called == [handler]
|
||||
104
tests/test_feishu_table_split.py
Normal file
104
tests/test_feishu_table_split.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for FeishuChannel._split_elements_by_table_limit.
|
||||
|
||||
Feishu cards reject messages that contain more than one table element
|
||||
(API error 11310: card table number over limit). The helper splits a flat
|
||||
list of card elements into groups so that each group contains at most one
|
||||
table, allowing nanobot to send multiple cards instead of failing.
|
||||
"""
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def _md(text: str) -> dict:
|
||||
return {"tag": "markdown", "content": text}
|
||||
|
||||
|
||||
def _table() -> dict:
|
||||
return {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "v"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
|
||||
|
||||
split = FeishuChannel._split_elements_by_table_limit
|
||||
|
||||
|
||||
def test_empty_list_returns_single_empty_group() -> None:
|
||||
assert split([]) == [[]]
|
||||
|
||||
|
||||
def test_no_tables_returns_single_group() -> None:
|
||||
els = [_md("hello"), _md("world")]
|
||||
result = split(els)
|
||||
assert result == [els]
|
||||
|
||||
|
||||
def test_single_table_stays_in_one_group() -> None:
|
||||
els = [_md("intro"), _table(), _md("outro")]
|
||||
result = split(els)
|
||||
assert len(result) == 1
|
||||
assert result[0] == els
|
||||
|
||||
|
||||
def test_two_tables_split_into_two_groups() -> None:
|
||||
# Use different row values so the two tables are not equal
|
||||
t1 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||
"rows": [{"c0": "table-one"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
t2 = {
|
||||
"tag": "table",
|
||||
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
|
||||
"rows": [{"c0": "table-two"}],
|
||||
"page_size": 2,
|
||||
}
|
||||
els = [_md("before"), t1, _md("between"), t2, _md("after")]
|
||||
result = split(els)
|
||||
assert len(result) == 2
|
||||
# First group: text before table-1 + table-1
|
||||
assert t1 in result[0]
|
||||
assert t2 not in result[0]
|
||||
# Second group: text between tables + table-2 + text after
|
||||
assert t2 in result[1]
|
||||
assert t1 not in result[1]
|
||||
|
||||
|
||||
def test_three_tables_split_into_three_groups() -> None:
|
||||
tables = [
|
||||
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
|
||||
for i in range(3)
|
||||
]
|
||||
els = tables[:]
|
||||
result = split(els)
|
||||
assert len(result) == 3
|
||||
for i, group in enumerate(result):
|
||||
assert tables[i] in group
|
||||
|
||||
|
||||
def test_leading_markdown_stays_with_first_table() -> None:
|
||||
intro = _md("intro")
|
||||
t = _table()
|
||||
result = split([intro, t])
|
||||
assert len(result) == 1
|
||||
assert result[0] == [intro, t]
|
||||
|
||||
|
||||
def test_trailing_markdown_after_second_table() -> None:
|
||||
t1, t2 = _table(), _table()
|
||||
tail = _md("end")
|
||||
result = split([t1, t2, tail])
|
||||
assert len(result) == 2
|
||||
assert result[1] == [t2, tail]
|
||||
|
||||
|
||||
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
|
||||
head = _md("head")
|
||||
t1, t2 = _table(), _table()
|
||||
result = split([head, t1, t2])
|
||||
# head + t1 in group 0; t2 in group 1
|
||||
assert result[0] == [head, t1]
|
||||
assert result[1] == [t2]
|
||||
117
tests/test_heartbeat_service.py
Normal file
117
tests/test_heartbeat_service.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, responses: list[LLMResponse]):
|
||||
self._responses = list(responses)
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(tmp_path) -> None:
|
||||
provider = DummyProvider([])
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
interval_s=9999,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
await service.start()
|
||||
first_task = service._task
|
||||
await service.start()
|
||||
|
||||
assert service._task is first_task
|
||||
|
||||
service.stop()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_returns_skip_when_no_tool_call(tmp_path) -> None:
|
||||
provider = DummyProvider([LLMResponse(content="no tool call", tool_calls=[])])
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
)
|
||||
|
||||
action, tasks = await service._decide("heartbeat content")
|
||||
assert action == "skip"
|
||||
assert tasks == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_now_executes_when_decision_is_run(tmp_path) -> None:
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check open tasks"},
|
||||
)
|
||||
],
|
||||
)
|
||||
])
|
||||
|
||||
called_with: list[str] = []
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
called_with.append(tasks)
|
||||
return "done"
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
)
|
||||
|
||||
result = await service.trigger_now()
|
||||
assert result == "done"
|
||||
assert called_with == ["check open tasks"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] do thing", encoding="utf-8")
|
||||
|
||||
provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "skip"},
|
||||
)
|
||||
],
|
||||
)
|
||||
])
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
return tasks
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
provider=provider,
|
||||
model="openai/gpt-4o-mini",
|
||||
on_execute=_on_execute,
|
||||
)
|
||||
|
||||
assert await service.trigger_now() is None
|
||||
41
tests/test_loop_save_turn.py
Normal file
41
tests/test_loop_save_turn.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._TOOL_RESULT_MAX_CHARS = 500
|
||||
return loop
|
||||
|
||||
|
||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:runtime-only")
|
||||
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages == []
|
||||
|
||||
|
||||
def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:image")
|
||||
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||
|
||||
loop._save_turn(
|
||||
session,
|
||||
[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": runtime},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
}],
|
||||
skip=0,
|
||||
)
|
||||
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
||||
1318
tests/test_matrix_channel.py
Normal file
1318
tests/test_matrix_channel.py
Normal file
File diff suppressed because it is too large
Load Diff
222
tests/test_memory_consolidation_types.py
Normal file
222
tests/test_memory_consolidation_types.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Test MemoryStore.consolidate() handles non-string tool call arguments.
|
||||
|
||||
Regression test for https://github.com/HKUDS/nanobot/issues/1042
|
||||
When memory consolidation receives dict values instead of strings from the LLM
|
||||
tool call response, it should serialize them to JSON instead of raising TypeError.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_session(message_count: int = 30, memory_window: int = 50):
|
||||
"""Create a mock session with messages."""
|
||||
session = MagicMock()
|
||||
session.messages = [
|
||||
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
|
||||
for i in range(message_count)
|
||||
]
|
||||
session.last_consolidated = 0
|
||||
return session
|
||||
|
||||
|
||||
def _make_tool_response(history_entry, memory_update):
|
||||
"""Create an LLMResponse with a save_memory tool call."""
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments={
|
||||
"history_entry": history_entry,
|
||||
"memory_update": memory_update,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestMemoryConsolidationTypeHandling:
|
||||
"""Test that consolidation handles various argument types correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_work(self, tmp_path: Path) -> None:
|
||||
"""Normal case: LLM returns string arguments."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry="[2026-01-01] User discussed testing.",
|
||||
memory_update="# Memory\nUser likes testing.",
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
assert "[2026-01-01] User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None:
|
||||
"""Issue #1042: LLM returns dict instead of string — must not raise TypeError."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=_make_tool_response(
|
||||
history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."},
|
||||
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
|
||||
)
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert store.history_file.exists()
|
||||
history_content = store.history_file.read_text()
|
||||
parsed = json.loads(history_content.strip())
|
||||
assert parsed["summary"] == "User discussed testing."
|
||||
|
||||
memory_content = store.memory_file.read_text()
|
||||
parsed_mem = json.loads(memory_content)
|
||||
assert "User likes testing" in parsed_mem["facts"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a JSON string instead of parsed dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a JSON string (not yet parsed)
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=json.dumps({
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}),
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None:
|
||||
"""When LLM doesn't use the save_memory tool, return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
provider.chat = AsyncMock(
|
||||
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
|
||||
)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
assert not store.history_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
|
||||
"""Consolidation should be a no-op when messages < keep_count."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
session = _make_session(message_count=10)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
provider.chat.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
|
||||
"""Some providers return arguments as a list - extract first element if it's a dict."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
# Simulate arguments being a list containing a dict
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[{
|
||||
"history_entry": "[2026-01-01] User discussed testing.",
|
||||
"memory_update": "# Memory\nUser likes testing.",
|
||||
}],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is True
|
||||
assert "User discussed testing." in store.history_file.read_text()
|
||||
assert "User likes testing." in store.memory_file.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
|
||||
"""Empty list arguments should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
|
||||
"""List with non-dict content should return False."""
|
||||
store = MemoryStore(tmp_path)
|
||||
provider = AsyncMock()
|
||||
|
||||
response = LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_1",
|
||||
name="save_memory",
|
||||
arguments=["string", "content"],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider.chat = AsyncMock(return_value=response)
|
||||
session = _make_session(message_count=60)
|
||||
|
||||
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||
|
||||
assert result is False
|
||||
10
tests/test_message_tool.py
Normal file
10
tests/test_message_tool.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_returns_error_when_no_target_context() -> None:
|
||||
tool = MessageTool()
|
||||
result = await tool.execute(content="test")
|
||||
assert result == "Error: No target channel/chat specified"
|
||||
132
tests/test_message_tool_suppress.py
Normal file
132
tests/test_message_tool_suppress.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
||||
|
||||
|
||||
class TestMessageToolSuppressLogic:
|
||||
"""Final reply suppressed only when message tool sends to the same target."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert result is None # suppressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].channel == "email"
|
||||
assert result is not None # not suppressed
|
||||
assert result.channel == "feishu"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
calls = iter([
|
||||
LLMResponse(
|
||||
content="Visible<think>hidden</think>",
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content="secret reasoning",
|
||||
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
progress: list[tuple[str, bool]] = []
|
||||
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False),
|
||||
('read_file("foo.txt")', True),
|
||||
]
|
||||
|
||||
|
||||
class TestMessageToolTurnTracking:
|
||||
|
||||
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool.set_context("feishu", "chat1")
|
||||
assert not tool._sent_in_turn
|
||||
tool._sent_in_turn = True
|
||||
assert tool._sent_in_turn
|
||||
|
||||
def test_start_turn_resets(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool._sent_in_turn = True
|
||||
tool.start_turn()
|
||||
assert not tool._sent_in_turn
|
||||
167
tests/test_task_cancel.py
Normal file
167
tests/test_task_cancel.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for /stop task cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_loop():
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
workspace = MagicMock()
|
||||
workspace.__truediv__ = MagicMock(return_value=MagicMock())
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
return loop, bus
|
||||
|
||||
|
||||
class TestHandleStop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_no_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "No active task" in out.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_active_task(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow_task():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow_task())
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = [task]
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
|
||||
assert cancelled.is_set()
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "stopped" in out.content.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_multiple_tasks(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
events = [asyncio.Event(), asyncio.Event()]
|
||||
|
||||
async def slow(idx):
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
events[idx].set()
|
||||
raise
|
||||
|
||||
tasks = [asyncio.create_task(slow(i)) for i in range(2)]
|
||||
await asyncio.sleep(0)
|
||||
loop._active_tasks["test:c1"] = tasks
|
||||
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="/stop")
|
||||
await loop._handle_stop(msg)
|
||||
|
||||
assert all(e.is_set() for e in events)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert "2 task" in out.content
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_processes_and_publishes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="hello")
|
||||
loop._process_message = AsyncMock(
|
||||
return_value=OutboundMessage(channel="test", chat_id="c1", content="hi")
|
||||
)
|
||||
await loop._dispatch(msg)
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert out.content == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_lock_serializes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
order = []
|
||||
|
||||
async def mock_process(m, **kwargs):
|
||||
order.append(f"start-{m.content}")
|
||||
await asyncio.sleep(0.05)
|
||||
order.append(f"end-{m.content}")
|
||||
return OutboundMessage(channel="test", chat_id="c1", content=m.content)
|
||||
|
||||
loop._process_message = mock_process
|
||||
msg1 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="a")
|
||||
msg2 = InboundMessage(channel="test", sender_id="u1", chat_id="c1", content="b")
|
||||
|
||||
t1 = asyncio.create_task(loop._dispatch(msg1))
|
||||
t2 = asyncio.create_task(loop._dispatch(msg2))
|
||||
await asyncio.gather(t1, t2)
|
||||
assert order == ["start-a", "end-a", "start-b", "end-b"]
|
||||
|
||||
|
||||
class TestSubagentCancellation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def slow():
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
task = asyncio.create_task(slow())
|
||||
await asyncio.sleep(0)
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
assert count == 1
|
||||
assert cancelled.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session_no_tasks(self):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
169
tests/test_telegram_channel.py
Normal file
169
tests/test_telegram_channel.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
|
||||
class _FakeHTTPXRequest:
|
||||
instances: list["_FakeHTTPXRequest"] = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.kwargs = kwargs
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
|
||||
class _FakeUpdater:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self._on_start_polling = on_start_polling
|
||||
|
||||
async def start_polling(self, **kwargs) -> None:
|
||||
self._on_start_polling()
|
||||
|
||||
|
||||
class _FakeBot:
|
||||
def __init__(self) -> None:
|
||||
self.sent_messages: list[dict] = []
|
||||
|
||||
async def get_me(self):
|
||||
return SimpleNamespace(username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
self.commands = commands
|
||||
|
||||
async def send_message(self, **kwargs) -> None:
|
||||
self.sent_messages.append(kwargs)
|
||||
|
||||
|
||||
class _FakeApp:
|
||||
def __init__(self, on_start_polling) -> None:
|
||||
self.bot = _FakeBot()
|
||||
self.updater = _FakeUpdater(on_start_polling)
|
||||
self.handlers = []
|
||||
self.error_handlers = []
|
||||
|
||||
def add_error_handler(self, handler) -> None:
|
||||
self.error_handlers.append(handler)
|
||||
|
||||
def add_handler(self, handler) -> None:
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeBuilder:
|
||||
def __init__(self, app: _FakeApp) -> None:
|
||||
self.app = app
|
||||
self.token_value = None
|
||||
self.request_value = None
|
||||
self.get_updates_request_value = None
|
||||
|
||||
def token(self, token: str):
|
||||
self.token_value = token
|
||||
return self
|
||||
|
||||
def request(self, request):
|
||||
self.request_value = request
|
||||
return self
|
||||
|
||||
def get_updates_request(self, request):
|
||||
self.get_updates_request_value = request
|
||||
return self
|
||||
|
||||
def proxy(self, _proxy):
|
||||
raise AssertionError("builder.proxy should not be called when request is set")
|
||||
|
||||
def get_updates_proxy(self, _proxy):
|
||||
raise AssertionError("builder.get_updates_proxy should not be called when request is set")
|
||||
|
||||
def build(self):
|
||||
return self.app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
|
||||
config = TelegramConfig(
|
||||
enabled=True,
|
||||
token="123:abc",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(config, bus)
|
||||
app = _FakeApp(lambda: setattr(channel, "_running", False))
|
||||
builder = _FakeBuilder(app)
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.telegram.HTTPXRequest", _FakeHTTPXRequest)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.telegram.Application",
|
||||
SimpleNamespace(builder=lambda: builder),
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert len(_FakeHTTPXRequest.instances) == 1
|
||||
assert _FakeHTTPXRequest.instances[0].kwargs["proxy"] == config.proxy
|
||||
assert builder.request_value is _FakeHTTPXRequest.instances[0]
|
||||
assert builder.get_updates_request_value is _FakeHTTPXRequest.instances[0]
|
||||
|
||||
|
||||
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="supergroup"),
|
||||
chat_id=-100123,
|
||||
message_thread_id=42,
|
||||
)
|
||||
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||
|
||||
|
||||
def test_get_extension_falls_back_to_original_filename() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(), MessageBus())
|
||||
|
||||
assert channel._get_extension("file", None, "report.pdf") == ".pdf"
|
||||
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_keeps_message_in_topic() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"_progress": True, "message_thread_id": 42},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reply_infers_topic_from_message_id_cache() -> None:
|
||||
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], reply_to_message=True)
|
||||
channel = TelegramChannel(config, MessageBus())
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._message_threads[("123", 10)] = 42
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123",
|
||||
content="hello",
|
||||
metadata={"message_id": 10},
|
||||
)
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
class SampleTool(Tool):
|
||||
@@ -86,3 +87,253 @@ async def test_registry_returns_validation_error() -> None:
|
||||
reg.register(SampleTool())
|
||||
result = await reg.execute("sample", {"query": "hi"})
|
||||
assert "Invalid parameters" in result
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
|
||||
cmd = r"type C:\user\workspace\txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert paths == [r"C:\user\workspace\txt"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||
cmd = ".venv/bin/python script.py"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/bin/python" not in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
class CastTestTool(Tool):
|
||||
"""Minimal tool for testing cast_params."""
|
||||
|
||||
def __init__(self, schema: dict[str, Any]) -> None:
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cast_test"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "test tool for casting"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_cast_params_string_to_int() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "42"})
|
||||
assert result["count"] == 42
|
||||
assert isinstance(result["count"], int)
|
||||
|
||||
|
||||
def test_cast_params_string_to_number() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "3.14"})
|
||||
assert result["rate"] == 3.14
|
||||
assert isinstance(result["rate"], float)
|
||||
|
||||
|
||||
def test_cast_params_string_to_bool() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"enabled": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||
|
||||
|
||||
def test_cast_params_array_items() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||
assert result["nums"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_cast_params_nested_object() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"port": {"type": "integer"},
|
||||
"debug": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||
assert result["config"]["port"] == 8080
|
||||
assert result["config"]["debug"] is True
|
||||
|
||||
|
||||
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||
"""Booleans should not be silently cast to integers."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": True})
|
||||
assert result["count"] is True
|
||||
errors = tool.validate_params(result)
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_preserves_empty_string() -> None:
|
||||
"""Empty strings should be preserved for string type."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": ""})
|
||||
assert result["name"] == ""
|
||||
|
||||
|
||||
def test_cast_params_bool_string_false() -> None:
|
||||
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||
|
||||
|
||||
def test_cast_params_bool_string_invalid() -> None:
|
||||
"""Invalid boolean strings should not be cast."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
# Invalid strings should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"flag": "random"})
|
||||
assert result["flag"] == "random"
|
||||
result = tool.cast_params({"flag": "maybe"})
|
||||
assert result["flag"] == "maybe"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_int() -> None:
|
||||
"""Invalid strings should not be cast to integer."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "abc"})
|
||||
assert result["count"] == "abc" # Original value preserved
|
||||
result = tool.cast_params({"count": "12.5.7"})
|
||||
assert result["count"] == "12.5.7"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_number() -> None:
|
||||
"""Invalid strings should not be cast to number."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "not_a_number"})
|
||||
assert result["rate"] == "not_a_number"
|
||||
|
||||
|
||||
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||
"""Booleans should not pass number validation."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"rate": False})
|
||||
assert any("rate should be number" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_none_values() -> None:
|
||||
"""Test None handling for different types."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
"items": {"type": "array"},
|
||||
"config": {"type": "object"},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params(
|
||||
{
|
||||
"name": None,
|
||||
"count": None,
|
||||
"items": None,
|
||||
"config": None,
|
||||
}
|
||||
)
|
||||
# None should be preserved for all types
|
||||
assert result["name"] is None
|
||||
assert result["count"] is None
|
||||
assert result["items"] is None
|
||||
assert result["config"] is None
|
||||
|
||||
|
||||
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||
"""Single values should NOT be automatically wrapped into arrays."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
}
|
||||
)
|
||||
# Non-array values should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"items": 5})
|
||||
assert result["items"] == 5 # Not wrapped to [5]
|
||||
result = tool.cast_params({"items": "text"})
|
||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||
|
||||
Reference in New Issue
Block a user