feat(onboard): align setup with config and workspace flags

This commit is contained in:
Xubin Ren
2026-03-17 05:42:49 +00:00
81 changed files with 9243 additions and 1383 deletions

View File

@@ -0,0 +1,228 @@
"""Tests for channel plugin discovery, merging, and config compatibility."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.channels.manager import ChannelManager
from nanobot.config.schema import ChannelsConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakePlugin(BaseChannel):
name = "fakeplugin"
display_name = "Fake Plugin"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
class _FakeTelegram(BaseChannel):
"""Plugin that tries to shadow built-in telegram."""
name = "telegram"
display_name = "Fake Telegram"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
def _make_entry_point(name: str, cls: type):
"""Create a mock entry point that returns *cls* on load()."""
ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls)
return ep
# ---------------------------------------------------------------------------
# ChannelsConfig extra="allow"
# ---------------------------------------------------------------------------
def test_channels_config_accepts_unknown_keys():
cfg = ChannelsConfig.model_validate({
"myplugin": {"enabled": True, "token": "abc"},
})
extra = cfg.model_extra
assert extra is not None
assert extra["myplugin"]["enabled"] is True
assert extra["myplugin"]["token"] == "abc"
def test_channels_config_getattr_returns_extra():
cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}})
section = getattr(cfg, "myplugin", None)
assert isinstance(section, dict)
assert section["enabled"] is True
def test_channels_config_builtin_fields_removed():
"""After decoupling, ChannelsConfig has no explicit channel fields."""
cfg = ChannelsConfig()
assert not hasattr(cfg, "telegram")
assert cfg.send_progress is True
assert cfg.send_tool_hints is False
# ---------------------------------------------------------------------------
# discover_plugins
# ---------------------------------------------------------------------------
_EP_TARGET = "importlib.metadata.entry_points"
def test_discover_plugins_loads_entry_points():
from nanobot.channels.registry import discover_plugins
ep = _make_entry_point("line", _FakePlugin)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_plugins()
assert "line" in result
assert result["line"] is _FakePlugin
def test_discover_plugins_handles_load_error():
from nanobot.channels.registry import discover_plugins
def _boom():
raise RuntimeError("broken")
ep = SimpleNamespace(name="broken", load=_boom)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_plugins()
assert "broken" not in result
# ---------------------------------------------------------------------------
# discover_all — merge & priority
# ---------------------------------------------------------------------------
def test_discover_all_includes_builtins():
from nanobot.channels.registry import discover_all, discover_channel_names
with patch(_EP_TARGET, return_value=[]):
result = discover_all()
# discover_all() only returns channels that are actually available (dependencies installed)
# discover_channel_names() returns all built-in channel names
# So we check that all actually loaded channels are in the result
for name in result:
assert name in discover_channel_names()
def test_discover_all_includes_external_plugin():
from nanobot.channels.registry import discover_all
ep = _make_entry_point("line", _FakePlugin)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_all()
assert "line" in result
assert result["line"] is _FakePlugin
def test_discover_all_builtin_shadows_plugin():
from nanobot.channels.registry import discover_all
ep = _make_entry_point("telegram", _FakeTelegram)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_all()
assert "telegram" in result
assert result["telegram"] is not _FakeTelegram
# ---------------------------------------------------------------------------
# Manager _init_channels with dict config (plugin scenario)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_manager_loads_plugin_from_dict_config():
"""ChannelManager should instantiate a plugin channel from a raw dict config."""
from nanobot.channels.manager import ChannelManager
fake_config = SimpleNamespace(
channels=ChannelsConfig.model_validate({
"fakeplugin": {"enabled": True, "allowFrom": ["*"]},
}),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
with patch(
"nanobot.channels.registry.discover_all",
return_value={"fakeplugin": _FakePlugin},
):
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {}
mgr._dispatch_task = None
mgr._init_channels()
assert "fakeplugin" in mgr.channels
assert isinstance(mgr.channels["fakeplugin"], _FakePlugin)
@pytest.mark.asyncio
async def test_manager_skips_disabled_plugin():
fake_config = SimpleNamespace(
channels=ChannelsConfig.model_validate({
"fakeplugin": {"enabled": False},
}),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
with patch(
"nanobot.channels.registry.discover_all",
return_value={"fakeplugin": _FakePlugin},
):
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {}
mgr._dispatch_task = None
mgr._init_channels()
assert "fakeplugin" not in mgr.channels
# ---------------------------------------------------------------------------
# Built-in channel default_config() and dict->Pydantic conversion
# ---------------------------------------------------------------------------
def test_builtin_channel_default_config():
"""Built-in channels expose default_config() returning a dict with 'enabled': False."""
from nanobot.channels.telegram import TelegramChannel
cfg = TelegramChannel.default_config()
assert isinstance(cfg, dict)
assert cfg["enabled"] is False
assert "token" in cfg
def test_builtin_channel_init_from_dict():
"""Built-in channels accept a raw dict and convert to Pydantic internally."""
from nanobot.channels.telegram import TelegramChannel
bus = MessageBus()
ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus)
assert ch.config.token == "test-tok"
assert ch.config.allow_from == ["*"]

View File

@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, call, patch
import pytest
from prompt_toolkit.formatted_text import HTML
@@ -57,3 +57,57 @@ def test_init_prompt_session_creates_session():
_, kwargs = MockSession.call_args
assert kwargs["multiline"] is False
assert kwargs["enable_open_in_editor"] is False
def test_thinking_spinner_pause_stops_and_restarts():
"""Pause should stop the active spinner and restart it afterward."""
spinner = MagicMock()
with patch.object(commands.console, "status", return_value=spinner):
thinking = commands._ThinkingSpinner(enabled=True)
with thinking:
with thinking.pause():
pass
assert spinner.method_calls == [
call.start(),
call.stop(),
call.start(),
call.stop(),
]
def test_print_cli_progress_line_pauses_spinner_before_printing():
"""CLI progress output should pause spinner to avoid garbled lines."""
order: list[str] = []
spinner = MagicMock()
spinner.start.side_effect = lambda: order.append("start")
spinner.stop.side_effect = lambda: order.append("stop")
with patch.object(commands.console, "status", return_value=spinner), \
patch.object(commands.console, "print", side_effect=lambda *_args, **_kwargs: order.append("print")):
thinking = commands._ThinkingSpinner(enabled=True)
with thinking:
commands._print_cli_progress_line("tool running", thinking)
assert order == ["start", "stop", "print", "start", "stop"]
@pytest.mark.asyncio
async def test_print_interactive_progress_line_pauses_spinner_before_printing():
"""Interactive progress output should also pause spinner cleanly."""
order: list[str] = []
spinner = MagicMock()
spinner.start.side_effect = lambda: order.append("start")
spinner.stop.side_effect = lambda: order.append("stop")
async def fake_print(_text: str) -> None:
order.append("print")
with patch.object(commands.console, "status", return_value=spinner), \
patch("nanobot.cli.commands._print_interactive_line", side_effect=fake_print):
thinking = commands._ThinkingSpinner(enabled=True)
with thinking:
await commands._print_interactive_progress_line("tool running", thinking)
assert order == ["start", "stop", "print", "start", "stop"]

View File

@@ -1,3 +1,5 @@
import json
import re
import shutil
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
@@ -5,12 +7,18 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.cli.commands import _make_provider, app
from nanobot.config.schema import Config
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model
def _strip_ansi(text):
"""Remove ANSI escape codes from text."""
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
return ansi_escape.sub('', text)
runner = CliRunner()
@@ -36,9 +44,16 @@ def mock_paths():
mock_cp.return_value = config_file
mock_ws.return_value = workspace_dir
mock_sc.side_effect = lambda config: config_file.write_text("{}")
mock_lc.side_effect = lambda _config_path=None: Config()
yield config_file, workspace_dir
def _save_config(config: Config, config_path: Path | None = None):
target = config_path or config_file
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(json.dumps(config.model_dump(by_alias=True)), encoding="utf-8")
mock_sc.side_effect = _save_config
yield config_file, workspace_dir, mock_ws
if base_dir.exists():
shutil.rmtree(base_dir)
@@ -46,7 +61,7 @@ def mock_paths():
def test_onboard_fresh_install(mock_paths):
"""No existing config — should create from scratch."""
config_file, workspace_dir = mock_paths
config_file, workspace_dir, mock_ws = mock_paths
result = runner.invoke(app, ["onboard"])
@@ -57,11 +72,13 @@ def test_onboard_fresh_install(mock_paths):
assert config_file.exists()
assert (workspace_dir / "AGENTS.md").exists()
assert (workspace_dir / "memory" / "MEMORY.md").exists()
expected_workspace = Config().workspace_path
assert mock_ws.call_args.args == (expected_workspace,)
def test_onboard_existing_config_refresh(mock_paths):
"""Config exists, user declines overwrite — should refresh (load-merge-save)."""
config_file, workspace_dir = mock_paths
config_file, workspace_dir, _ = mock_paths
config_file.write_text('{"existing": true}')
result = runner.invoke(app, ["onboard"], input="n\n")
@@ -75,7 +92,7 @@ def test_onboard_existing_config_refresh(mock_paths):
def test_onboard_existing_config_overwrite(mock_paths):
"""Config exists, user confirms overwrite — should reset to defaults."""
config_file, workspace_dir = mock_paths
config_file, workspace_dir, _ = mock_paths
config_file.write_text('{"existing": true}')
result = runner.invoke(app, ["onboard"], input="y\n")
@@ -88,7 +105,7 @@ def test_onboard_existing_config_overwrite(mock_paths):
def test_onboard_existing_workspace_safe_create(mock_paths):
"""Workspace exists — should not recreate, but still add missing templates."""
config_file, workspace_dir = mock_paths
config_file, workspace_dir, _ = mock_paths
workspace_dir.mkdir(parents=True)
config_file.write_text("{}")
@@ -100,6 +117,40 @@ def test_onboard_existing_workspace_safe_create(mock_paths):
assert (workspace_dir / "AGENTS.md").exists()
def test_onboard_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["onboard", "--help"])
assert result.exit_code == 0
stripped_output = _strip_ansi(result.stdout)
assert "--workspace" in stripped_output
assert "-w" in stripped_output
assert "--config" in stripped_output
assert "-c" in stripped_output
assert "--dir" not in stripped_output
def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch):
config_path = tmp_path / "instance" / "config.json"
workspace_path = tmp_path / "workspace"
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
result = runner.invoke(
app,
["onboard", "--config", str(config_path), "--workspace", str(workspace_path)],
)
assert result.exit_code == 0
saved = Config.model_validate(json.loads(config_path.read_text(encoding="utf-8")))
assert saved.workspace_path == workspace_path
assert (workspace_path / "AGENTS.md").exists()
stripped_output = _strip_ansi(result.stdout)
compact_output = stripped_output.replace("\n", "")
resolved_config = str(config_path.resolve())
assert resolved_config in compact_output
assert f"--config {resolved_config}" in compact_output
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
config = Config()
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
@@ -114,6 +165,64 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex"
def test_config_matches_explicit_ollama_prefix_without_api_key():
config = Config()
config.agents.defaults.model = "ollama/llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
config = Config()
config.agents.defaults.provider = "ollama"
config.agents.defaults.model = "llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_auto_detects_ollama_from_local_api_base():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
"ollama": {"apiBase": "http://localhost:11434"},
},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_falls_back_to_vllm_when_ollama_not_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
},
}
)
assert config.get_provider_name() == "vllm"
assert config.get_api_base() == "http://localhost:8000"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex")
@@ -134,6 +243,33 @@ def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
def test_make_provider_passes_extra_headers_to_custom_provider():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "custom", "model": "gpt-4o-mini"}},
"providers": {
"custom": {
"apiKey": "test-key",
"apiBase": "https://example.com/v1",
"extraHeaders": {
"APP-Code": "demo-app",
"x-session-affinity": "sticky-session",
},
}
},
}
)
with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai:
_make_provider(config)
kwargs = mock_async_openai.call_args.kwargs
assert kwargs["api_key"] == "test-key"
assert kwargs["base_url"] == "https://example.com/v1"
assert kwargs["default_headers"]["APP-Code"] == "demo-app"
assert kwargs["default_headers"]["x-session-affinity"] == "sticky-session"
@pytest.fixture
def mock_agent_runtime(tmp_path):
"""Mock agent command dependencies for focused CLI tests."""
@@ -170,10 +306,11 @@ def test_agent_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["agent", "--help"])
assert result.exit_code == 0
assert "--workspace" in result.stdout
assert "-w" in result.stdout
assert "--config" in result.stdout
assert "-c" in result.stdout
stripped_output = _strip_ansi(result.stdout)
assert "--workspace" in stripped_output
assert "-w" in stripped_output
assert "--config" in stripped_output
assert "-c" in stripped_output
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
@@ -267,6 +404,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
mock_agent_runtime["config"].agents.defaults.memory_window = 100
result = runner.invoke(app, ["agent", "-m", "hello"])
assert result.exit_code == 0
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
@@ -328,6 +475,28 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
assert config.workspace_path == override
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.memory_window = 100
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
@@ -356,3 +525,47 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat
assert isinstance(result.exception, _StopGateway)
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "port 18791" in result.stdout
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
assert isinstance(result.exception, _StopGateway)
assert "port 18792" in result.stdout

View File

@@ -0,0 +1,132 @@
import json
from types import SimpleNamespace
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.config.loader import load_config, save_config
runner = CliRunner()
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 1234,
"memoryWindow": 42,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536
assert config.agents.defaults.should_warn_deprecated_memory_window is True
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 2222,
"memoryWindow": 30,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 2222
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 3333,
"memoryWindow": 50,
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
assert "contextWindowTokens" in result.stdout
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 3333
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"channels": {
"qq": {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {
"qq": SimpleNamespace(
default_config=lambda: {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
"msgFormat": "plain",
}
)
},
)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["qq"]["msgFormat"] == "plain"

View File

@@ -480,338 +480,108 @@ class TestEmptyAndBoundarySessions:
assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard:
"""Test that consolidation tasks are deduplicated and serialized."""
class TestNewCommandArchival:
"""Test /new archival behavior with the simplified consolidation flow."""
@pytest.mark.asyncio
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
"""Concurrent messages above memory_window spawn only one consolidation task."""
@staticmethod
def _make_loop(tmp_path: Path):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (10_000, "test")
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=1,
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
return loop
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
async def test_new_clears_session_immediately_even_if_archive_fails(self, tmp_path: Path) -> None:
"""/new clears session immediately; archive_messages retries until raw dump."""
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
@pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
if archive_all:
return False
return True
call_count = 0
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
async def _failing_consolidate(_messages) -> bool:
nonlocal call_count
call_count += 1
return False
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "failed" in response.content.lower()
assert "new session started" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)
assert len(session_after.messages) == 0
await loop.close_mcp()
assert call_count == 3 # retried up to raw-archive threshold
@pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
self, tmp_path: Path
) -> None:
"""/new should archive only messages not yet consolidated by prior task."""
from nanobot.agent.loop import AgentLoop
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
session.last_consolidated = len(session.messages) - 3
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
async def _fake_consolidate(messages) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
sess.last_consolidated = len(sess.messages) - 3
archived_count = len(messages)
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done()
release.set()
response = await pending_new
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count == 3, (
f"Expected only unconsolidated tail to archive, got {archived_count}"
)
await loop.close_mcp()
assert archived_count == 3
@pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
"""/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
async def _ok_consolidate(_messages) -> bool:
return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
@@ -819,3 +589,31 @@ class TestConsolidationDeduplicationGuard:
assert response is not None
assert "new session started" in response.content.lower()
assert loop.sessions.get_or_create("cli:test").messages == []
@pytest.mark.asyncio
async def test_close_mcp_drains_background_tasks(self, tmp_path: Path) -> None:
"""close_mcp waits for background tasks to complete."""
from nanobot.bus.events import InboundMessage
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
archived = asyncio.Event()
async def _slow_consolidate(_messages) -> bool:
await asyncio.sleep(0.1)
archived.set()
return True
loop.memory_consolidator.consolidate_messages = _slow_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
assert not archived.is_set()
await loop.close_mcp()
assert archived.is_set()

View File

@@ -1,10 +1,12 @@
import asyncio
from types import SimpleNamespace
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.dingtalk import DingTalkChannel
from nanobot.config.schema import DingTalkConfig
import nanobot.channels.dingtalk as dingtalk_module
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
from nanobot.channels.dingtalk import DingTalkConfig
class _FakeResponse:
@@ -12,19 +14,31 @@ class _FakeResponse:
self.status_code = status_code
self._json_body = json_body or {}
self.text = "{}"
self.content = b""
self.headers = {"content-type": "application/json"}
def json(self) -> dict:
return self._json_body
class _FakeHttp:
def __init__(self) -> None:
def __init__(self, responses: list[_FakeResponse] | None = None) -> None:
self.calls: list[dict] = []
self._responses = list(responses) if responses else []
async def post(self, url: str, json=None, headers=None):
self.calls.append({"url": url, "json": json, "headers": headers})
def _next_response(self) -> _FakeResponse:
if self._responses:
return self._responses.pop(0)
return _FakeResponse()
async def post(self, url: str, json=None, headers=None, **kwargs):
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
return self._next_response()
async def get(self, url: str, **kwargs):
self.calls.append({"method": "GET", "url": url})
return self._next_response()
@pytest.mark.asyncio
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
@@ -64,3 +78,136 @@ async def test_group_send_uses_group_messages_api() -> None:
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
assert call["json"]["openConversationId"] == "conv123"
assert call["json"]["msgKey"] == "sampleMarkdown"
@pytest.mark.asyncio
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeChatbotMessage:
text = None
extensions = {"content": {"recognition": "voice transcript"}}
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "audio"
@staticmethod
def from_dict(_data):
return _FakeChatbotMessage()
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "2",
"conversationId": "conv123",
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert msg.content == "voice transcript"
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"
@pytest.mark.asyncio
async def test_handler_processes_file_message(monkeypatch) -> None:
"""Test that file messages are handled and forwarded with downloaded path."""
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeFileChatbotMessage:
text = None
extensions = {}
image_content = None
rich_text_content = None
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "file"
@staticmethod
def from_dict(_data):
return _FakeFileChatbotMessage()
async def fake_download(download_code, filename, sender_id):
return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}"
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download)
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "1",
"content": {"downloadCode": "abc123", "fileName": "report.xlsx"},
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert "[File]" in msg.content
assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content
@pytest.mark.asyncio
async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
"""Test the two-step file download flow (get URL then download content)."""
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
MessageBus(),
)
# Mock access token
async def fake_get_token():
return "test-token"
monkeypatch.setattr(channel, "_get_access_token", fake_get_token)
# Mock HTTP: first POST returns downloadUrl, then GET returns file bytes
file_content = b"fake file content"
channel._http = _FakeHttp(responses=[
_FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}),
_FakeResponse(200),
])
channel._http._responses[1].content = file_content
# Redirect media dir to tmp_path
monkeypatch.setattr(
"nanobot.config.paths.get_media_dir",
lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path,
)
result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1")
assert result is not None
assert result.endswith("test.xlsx")
assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content
# Verify API calls
assert channel._http.calls[0]["method"] == "POST"
assert "messageFiles/download" in channel._http.calls[0]["url"]
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
assert channel._http.calls[1]["method"] == "GET"

View File

@@ -6,7 +6,7 @@ import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.email import EmailChannel
from nanobot.config.schema import EmailConfig
from nanobot.channels.email import EmailConfig
def _make_config() -> EmailConfig:

63
tests/test_evaluator.py Normal file
View File

@@ -0,0 +1,63 @@
import pytest
from nanobot.utils.evaluator import evaluate_response
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
async def chat(self, *args, **kwargs) -> LLMResponse:
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
def _eval_tool_call(should_notify: bool, reason: str = "") -> LLMResponse:
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="eval_1",
name="evaluate_notification",
arguments={"should_notify": should_notify, "reason": reason},
)
],
)
@pytest.mark.asyncio
async def test_should_notify_true() -> None:
provider = DummyProvider([_eval_tool_call(True, "user asked to be reminded")])
result = await evaluate_response("Task completed with results", "check emails", provider, "m")
assert result is True
@pytest.mark.asyncio
async def test_should_notify_false() -> None:
provider = DummyProvider([_eval_tool_call(False, "routine check, nothing new")])
result = await evaluate_response("All clear, no updates", "check status", provider, "m")
assert result is False
@pytest.mark.asyncio
async def test_fallback_on_error() -> None:
class FailingProvider(DummyProvider):
async def chat(self, *args, **kwargs) -> LLMResponse:
raise RuntimeError("provider down")
provider = FailingProvider([])
result = await evaluate_response("some response", "some task", provider, "m")
assert result is True
@pytest.mark.asyncio
async def test_no_tool_call_fallback() -> None:
provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
result = await evaluate_response("some response", "some task", provider, "m")
assert result is True

View File

@@ -0,0 +1,69 @@
"""Tests for exec tool internal URL blocking."""
from __future__ import annotations
import socket
from unittest.mock import patch
import pytest
from nanobot.agent.tools.shell import ExecTool
def _fake_resolve_private(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
def _fake_resolve_localhost(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
def _fake_resolve_public(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
@pytest.mark.asyncio
async def test_exec_blocks_curl_metadata():
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(
command='curl -s -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/'
)
assert "Error" in result
assert "internal" in result.lower() or "private" in result.lower()
@pytest.mark.asyncio
async def test_exec_blocks_wget_localhost():
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost):
result = await tool.execute(command="wget http://localhost:8080/secret -O /tmp/out")
assert "Error" in result
@pytest.mark.asyncio
async def test_exec_allows_normal_commands():
tool = ExecTool(timeout=5)
result = await tool.execute(command="echo hello")
assert "hello" in result
assert "Error" not in result.split("\n")[0]
@pytest.mark.asyncio
async def test_exec_allows_curl_to_public_url():
"""Commands with public URLs should not be blocked by the internal URL check."""
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
guard_result = tool._guard_command("curl https://example.com/api", "/tmp")
assert guard_result is None
@pytest.mark.asyncio
async def test_exec_blocks_chained_internal_url():
"""Internal URLs buried in chained commands should still be caught."""
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
)
assert "Error" in result

392
tests/test_feishu_reply.py Normal file
View File

@@ -0,0 +1,392 @@
"""Tests for Feishu message reply (quote) feature."""
import asyncio
import json
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
config = FeishuConfig(
enabled=True,
app_id="cli_test",
app_secret="secret",
allow_from=["*"],
reply_to_message=reply_to_message,
)
channel = FeishuChannel(config, MessageBus())
channel._client = MagicMock()
# _loop is only used by the WebSocket thread bridge; not needed for unit tests
channel._loop = None
return channel
def _make_feishu_event(
*,
message_id: str = "om_001",
chat_id: str = "oc_abc",
chat_type: str = "p2p",
msg_type: str = "text",
content: str = '{"text": "hello"}',
sender_open_id: str = "ou_alice",
parent_id: str | None = None,
root_id: str | None = None,
):
message = SimpleNamespace(
message_id=message_id,
chat_id=chat_id,
chat_type=chat_type,
message_type=msg_type,
content=content,
parent_id=parent_id,
root_id=root_id,
mentions=[],
)
sender = SimpleNamespace(
sender_type="user",
sender_id=SimpleNamespace(open_id=sender_open_id),
)
return SimpleNamespace(event=SimpleNamespace(message=message, sender=sender))
def _make_get_message_response(text: str, msg_type: str = "text", success: bool = True):
"""Build a fake im.v1.message.get response object."""
body = SimpleNamespace(content=json.dumps({"text": text}))
item = SimpleNamespace(msg_type=msg_type, body=body)
data = SimpleNamespace(items=[item])
resp = MagicMock()
resp.success.return_value = success
resp.data = data
resp.code = 0
resp.msg = "ok"
return resp
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
def test_feishu_config_reply_to_message_defaults_false() -> None:
assert FeishuConfig().reply_to_message is False
def test_feishu_config_reply_to_message_can_be_enabled() -> None:
config = FeishuConfig(reply_to_message=True)
assert config.reply_to_message is True
# ---------------------------------------------------------------------------
# _get_message_content_sync tests
# ---------------------------------------------------------------------------
def test_get_message_content_sync_returns_reply_prefix() -> None:
channel = _make_feishu_channel()
channel._client.im.v1.message.get.return_value = _make_get_message_response("what time is it?")
result = channel._get_message_content_sync("om_parent")
assert result == "[Reply to: what time is it?]"
def test_get_message_content_sync_truncates_long_text() -> None:
channel = _make_feishu_channel()
long_text = "x" * (FeishuChannel._REPLY_CONTEXT_MAX_LEN + 50)
channel._client.im.v1.message.get.return_value = _make_get_message_response(long_text)
result = channel._get_message_content_sync("om_parent")
assert result is not None
assert result.endswith("...]")
inner = result[len("[Reply to: ") : -1]
assert len(inner) == FeishuChannel._REPLY_CONTEXT_MAX_LEN + len("...")
def test_get_message_content_sync_returns_none_on_api_failure() -> None:
channel = _make_feishu_channel()
resp = MagicMock()
resp.success.return_value = False
resp.code = 230002
resp.msg = "bot not in group"
channel._client.im.v1.message.get.return_value = resp
result = channel._get_message_content_sync("om_parent")
assert result is None
def test_get_message_content_sync_returns_none_for_non_text_type() -> None:
channel = _make_feishu_channel()
body = SimpleNamespace(content=json.dumps({"image_key": "img_1"}))
item = SimpleNamespace(msg_type="image", body=body)
data = SimpleNamespace(items=[item])
resp = MagicMock()
resp.success.return_value = True
resp.data = data
channel._client.im.v1.message.get.return_value = resp
result = channel._get_message_content_sync("om_parent")
assert result is None
def test_get_message_content_sync_returns_none_when_empty_text() -> None:
channel = _make_feishu_channel()
channel._client.im.v1.message.get.return_value = _make_get_message_response(" ")
result = channel._get_message_content_sync("om_parent")
assert result is None
# ---------------------------------------------------------------------------
# _reply_message_sync tests
# ---------------------------------------------------------------------------
def test_reply_message_sync_returns_true_on_success() -> None:
channel = _make_feishu_channel()
resp = MagicMock()
resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = resp
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
assert ok is True
channel._client.im.v1.message.reply.assert_called_once()
def test_reply_message_sync_returns_false_on_api_error() -> None:
channel = _make_feishu_channel()
resp = MagicMock()
resp.success.return_value = False
resp.code = 400
resp.msg = "bad request"
resp.get_log_id.return_value = "log_x"
channel._client.im.v1.message.reply.return_value = resp
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
assert ok is False
def test_reply_message_sync_returns_false_on_exception() -> None:
channel = _make_feishu_channel()
channel._client.im.v1.message.reply.side_effect = RuntimeError("network error")
ok = channel._reply_message_sync("om_parent", "text", '{"text":"hi"}')
assert ok is False
# ---------------------------------------------------------------------------
# send() — reply routing tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_uses_reply_api_when_configured() -> None:
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = reply_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
channel._client.im.v1.message.reply.assert_called_once()
channel._client.im.v1.message.create.assert_not_called()
@pytest.mark.asyncio
async def test_send_uses_create_api_when_reply_disabled() -> None:
channel = _make_feishu_channel(reply_to_message=False)
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
channel._client.im.v1.message.create.assert_called_once()
channel._client.im.v1.message.reply.assert_not_called()
@pytest.mark.asyncio
async def test_send_uses_create_api_when_no_message_id() -> None:
channel = _make_feishu_channel(reply_to_message=True)
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={},
))
channel._client.im.v1.message.create.assert_called_once()
channel._client.im.v1.message.reply.assert_not_called()
@pytest.mark.asyncio
async def test_send_skips_reply_for_progress_messages() -> None:
channel = _make_feishu_channel(reply_to_message=True)
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="thinking...",
metadata={"message_id": "om_001", "_progress": True},
))
channel._client.im.v1.message.create.assert_called_once()
channel._client.im.v1.message.reply.assert_not_called()
@pytest.mark.asyncio
async def test_send_fallback_to_create_when_reply_fails() -> None:
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = False
reply_resp.code = 400
reply_resp.msg = "error"
reply_resp.get_log_id.return_value = "log_x"
channel._client.im.v1.message.reply.return_value = reply_resp
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
# reply attempted first, then falls back to create
channel._client.im.v1.message.reply.assert_called_once()
channel._client.im.v1.message.create.assert_called_once()
# ---------------------------------------------------------------------------
# _on_message — parent_id / root_id metadata tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_on_message_captures_parent_and_root_id_in_metadata() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
channel._client.im.v1.message.react.return_value = MagicMock(success=lambda: True)
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(
_make_feishu_event(
parent_id="om_parent",
root_id="om_root",
)
)
assert len(captured) == 1
meta = captured[0]["metadata"]
assert meta["parent_id"] == "om_parent"
assert meta["root_id"] == "om_root"
assert meta["message_id"] == "om_001"
@pytest.mark.asyncio
async def test_on_message_parent_and_root_id_none_when_absent() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(_make_feishu_event())
assert len(captured) == 1
meta = captured[0]["metadata"]
assert meta["parent_id"] is None
assert meta["root_id"] is None
@pytest.mark.asyncio
async def test_on_message_prepends_reply_context_when_parent_id_present() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
channel._client.im.v1.message.get.return_value = _make_get_message_response("original question")
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(
_make_feishu_event(
content='{"text": "my answer"}',
parent_id="om_parent",
)
)
assert len(captured) == 1
content = captured[0]["content"]
assert content.startswith("[Reply to: original question]")
assert "my answer" in content
@pytest.mark.asyncio
async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
channel = _make_feishu_channel()
channel._processed_message_ids.clear()
captured = []
async def _capture(**kwargs):
captured.append(kwargs)
channel._handle_message = _capture
with patch.object(channel, "_add_reaction", return_value=None):
await channel._on_message(_make_feishu_event())
channel._client.im.v1.message.get.assert_not_called()
assert len(captured) == 1

View File

@@ -0,0 +1,138 @@
"""Tests for FeishuChannel tool hint code block formatting."""
import json
from unittest.mock import MagicMock, patch
import pytest
from pytest import mark
from nanobot.bus.events import OutboundMessage
from nanobot.channels.feishu import FeishuChannel
@pytest.fixture
def mock_feishu_channel():
"""Create a FeishuChannel with mocked client."""
config = MagicMock()
config.app_id = "test_app_id"
config.app_secret = "test_app_secret"
config.encrypt_key = None
config.verification_token = None
bus = MagicMock()
channel = FeishuChannel(config, bus)
channel._client = MagicMock() # Simulate initialized client
return channel
@mark.asyncio
async def test_tool_hint_sends_code_message(mock_feishu_channel):
"""Tool hint messages should be sent as interactive cards with code blocks."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='web_search("test query")',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Verify interactive message with card was sent
assert mock_send.call_count == 1
call_args = mock_send.call_args[0]
receive_id_type, receive_id, msg_type, content = call_args
assert receive_id_type == "chat_id"
assert receive_id == "oc_123456"
assert msg_type == "interactive"
# Parse content to verify card structure
card = json.loads(content)
assert card["config"]["wide_screen_mode"] is True
assert len(card["elements"]) == 1
assert card["elements"][0]["tag"] == "markdown"
# Check that code block is properly formatted with language hint
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
assert card["elements"][0]["content"] == expected_md
@mark.asyncio
async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
"""Empty tool hint messages should not be sent."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content=" ", # whitespace only
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Should not send any message
mock_send.assert_not_called()
@mark.asyncio
async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
"""Regular messages without _tool_hint should use normal formatting."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content="Hello, world!",
metadata={}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Should send as text message (detected format)
assert mock_send.call_count == 1
call_args = mock_send.call_args[0]
_, _, msg_type, content = call_args
assert msg_type == "text"
assert json.loads(content) == {"text": "Hello, world!"}
@mark.asyncio
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
"""Multiple tool calls should be displayed each on its own line in a code block."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='web_search("query"), read_file("/path/to/file")',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
call_args = mock_send.call_args[0]
msg_type = call_args[2]
content = json.loads(call_args[3])
assert msg_type == "interactive"
# Each tool call should be on its own line
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
assert content["elements"][0]["content"] == expected_md
@mark.asyncio
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
"""Commas inside a single tool argument must not be split onto a new line."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='web_search("foo, bar"), read_file("/path/to/file")',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
expected_md = (
"**Tool Calls**\n\n```text\n"
"web_search(\"foo, bar\"),\n"
"read_file(\"/path/to/file\")\n```"
)
assert content["elements"][0]["content"] == expected_md

View File

@@ -0,0 +1,364 @@
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
import pytest
from nanobot.agent.tools.filesystem import (
EditFileTool,
ListDirTool,
ReadFileTool,
_find_match,
)
# ---------------------------------------------------------------------------
# ReadFileTool
# ---------------------------------------------------------------------------
class TestReadFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def sample_file(self, tmp_path):
f = tmp_path / "sample.txt"
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
return f
@pytest.mark.asyncio
async def test_basic_read_has_line_numbers(self, tool, sample_file):
result = await tool.execute(path=str(sample_file))
assert "1| line 1" in result
assert "20| line 20" in result
@pytest.mark.asyncio
async def test_offset_and_limit(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
assert "5| line 5" in result
assert "7| line 7" in result
assert "8| line 8" not in result
assert "Use offset=8 to continue" in result
@pytest.mark.asyncio
async def test_offset_beyond_end(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=999)
assert "Error" in result
assert "beyond end" in result
@pytest.mark.asyncio
async def test_end_of_file_marker(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
assert "End of file" in result
@pytest.mark.asyncio
async def test_empty_file(self, tool, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("", encoding="utf-8")
result = await tool.execute(path=str(f))
assert "Empty file" in result
@pytest.mark.asyncio
async def test_file_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope.txt"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_char_budget_trims(self, tool, tmp_path):
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
f = tmp_path / "big.txt"
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
result = await tool.execute(path=str(f))
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
assert "Use offset=" in result
# ---------------------------------------------------------------------------
# _find_match (unit tests for the helper)
# ---------------------------------------------------------------------------
class TestFindMatch:
def test_exact_match(self):
match, count = _find_match("hello world", "world")
assert match == "world"
assert count == 1
def test_exact_no_match(self):
match, count = _find_match("hello world", "xyz")
assert match is None
assert count == 0
def test_crlf_normalisation(self):
# Caller normalises CRLF before calling _find_match, so test with
# pre-normalised content to verify exact match still works.
content = "line1\nline2\nline3"
old_text = "line1\nline2\nline3"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
def test_line_trim_fallback(self):
content = " def foo():\n pass\n"
old_text = "def foo():\n pass"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
# The returned match should be the *original* indented text
assert " def foo():" in match
def test_line_trim_multiple_candidates(self):
content = " a\n b\n a\n b\n"
old_text = "a\nb"
match, count = _find_match(content, old_text)
assert count == 2
def test_empty_old_text(self):
match, count = _find_match("hello", "")
# Empty string is always "in" any string via exact match
assert match == ""
# ---------------------------------------------------------------------------
# EditFileTool
# ---------------------------------------------------------------------------
class TestEditFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_exact_match(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
assert f.read_text() == "hello earth"
@pytest.mark.asyncio
async def test_crlf_normalisation(self, tool, tmp_path):
f = tmp_path / "crlf.py"
f.write_bytes(b"line1\r\nline2\r\nline3")
result = await tool.execute(
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
)
assert "Successfully" in result
raw = f.read_bytes()
assert b"LINE1" in raw
# CRLF line endings should be preserved throughout the file
assert b"\r\n" in raw
@pytest.mark.asyncio
async def test_trim_fallback(self, tool, tmp_path):
f = tmp_path / "indent.py"
f.write_text(" def foo():\n pass\n", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
)
assert "Successfully" in result
assert "bar" in f.read_text()
@pytest.mark.asyncio
async def test_ambiguous_match(self, tool, tmp_path):
f = tmp_path / "dup.py"
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
assert "appears" in result.lower() or "Warning" in result
@pytest.mark.asyncio
async def test_replace_all(self, tool, tmp_path):
f = tmp_path / "multi.py"
f.write_text("foo bar foo bar foo", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="foo", new_text="baz", replace_all=True,
)
assert "Successfully" in result
assert f.read_text() == "baz bar baz bar baz"
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
f = tmp_path / "nf.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
assert "Error" in result
assert "not found" in result
# ---------------------------------------------------------------------------
# ListDirTool
# ---------------------------------------------------------------------------
class TestListDirTool:
@pytest.fixture()
def tool(self, tmp_path):
return ListDirTool(workspace=tmp_path)
@pytest.fixture()
def populated_dir(self, tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("pass")
(tmp_path / "src" / "utils.py").write_text("pass")
(tmp_path / "README.md").write_text("hi")
(tmp_path / ".git").mkdir()
(tmp_path / ".git" / "config").write_text("x")
(tmp_path / "node_modules").mkdir()
(tmp_path / "node_modules" / "pkg").mkdir()
return tmp_path
@pytest.mark.asyncio
async def test_basic_list(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir))
assert "README.md" in result
assert "src" in result
# .git and node_modules should be ignored
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_recursive(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir), recursive=True)
# Normalize path separators for cross-platform compatibility
normalized = result.replace("\\", "/")
assert "src/main.py" in normalized
assert "src/utils.py" in normalized
assert "README.md" in result
# Ignored dirs should not appear
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_max_entries_truncation(self, tool, tmp_path):
for i in range(10):
(tmp_path / f"file_{i}.txt").write_text("x")
result = await tool.execute(path=str(tmp_path), max_entries=3)
assert "truncated" in result
assert "3 of 10" in result
@pytest.mark.asyncio
async def test_empty_dir(self, tool, tmp_path):
d = tmp_path / "empty"
d.mkdir()
result = await tool.execute(path=str(d))
assert "empty" in result.lower()
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope"))
assert "Error" in result
assert "not found" in result
# ---------------------------------------------------------------------------
# Workspace restriction + extra_allowed_dirs
# ---------------------------------------------------------------------------
class TestWorkspaceRestriction:
@pytest.mark.asyncio
async def test_read_blocked_outside_workspace(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
secret = outside / "secret.txt"
secret.write_text("top secret")
tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_allowed_with_extra_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "test_skill" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Test Skill\nDo something.")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(skill_file))
assert "Test Skill" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_extra_dirs_does_not_widen_write(self, tmp_path):
from nanobot.agent.tools.filesystem import WriteFileTool
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
unrelated = tmp_path / "other"
unrelated.mkdir()
secret = unrelated / "secret.txt"
secret.write_text("nope")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
"""Adding extra_allowed_dirs must not break normal workspace reads."""
workspace = tmp_path / "ws"
workspace.mkdir()
ws_file = workspace / "README.md"
ws_file.write_text("hello from workspace")
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(ws_file))
assert "hello from workspace" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_edit_blocked_in_extra_dir(self, tmp_path):
"""edit_file must not be able to modify files in extra_allowed_dirs."""
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "weather" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Weather\nOriginal content.")
tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(
path=str(skill_file),
old_text="Original content.",
new_text="Hacked content.",
)
assert "Error" in result
assert "outside" in result.lower()
assert skill_file.read_text() == "# Weather\nOriginal content."

View File

@@ -0,0 +1,53 @@
from types import SimpleNamespace
from nanobot.providers.base import ToolCallRequest
from nanobot.providers.litellm_provider import LiteLLMProvider
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="tool_calls",
message=SimpleNamespace(
content=None,
tool_calls=[
SimpleNamespace(
id="call_123",
function=SimpleNamespace(
name="read_file",
arguments='{"path":"todo.md"}',
provider_specific_fields={"inner": "value"},
),
provider_specific_fields={"thought_signature": "signed-token"},
)
],
),
)
],
usage=None,
)
parsed = provider._parse_response(response)
assert len(parsed.tool_calls) == 1
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
def test_tool_call_request_serializes_provider_fields() -> None:
tool_call = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"thought_signature": "signed-token"},
function_provider_specific_fields={"inner": "value"},
)
message = tool_call.to_openai_tool_call()
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
assert message["function"]["arguments"] == '{"path": "todo.md"}'

View File

@@ -3,18 +3,24 @@ import asyncio
import pytest
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider:
class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None:
@@ -115,3 +121,169 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
)
assert await service.trigger_now() is None
@pytest.mark.asyncio
async def test_tick_notifies_when_evaluator_says_yes(tmp_path, monkeypatch) -> None:
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=notify -> on_notify called."""
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check deployments", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check deployments"},
)
],
),
])
executed: list[str] = []
notified: list[str] = []
async def _on_execute(tasks: str) -> str:
executed.append(tasks)
return "deployment failed on staging"
async def _on_notify(response: str) -> None:
notified.append(response)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
on_notify=_on_notify,
)
async def _eval_notify(*a, **kw):
return True
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_notify)
await service._tick()
assert executed == ["check deployments"]
assert notified == ["deployment failed on staging"]
@pytest.mark.asyncio
async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) -> None:
"""Phase 1 run -> Phase 2 execute -> Phase 3 evaluate=silent -> on_notify NOT called."""
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check status", encoding="utf-8")
provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check status"},
)
],
),
])
executed: list[str] = []
notified: list[str] = []
async def _on_execute(tasks: str) -> str:
executed.append(tasks)
return "everything is fine, no issues"
async def _on_notify(response: str) -> None:
notified.append(response)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
on_execute=_on_execute,
on_notify=_on_notify,
)
async def _eval_silent(*a, **kw):
return False
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_silent)
await service._tick()
assert executed == ["check status"]
assert notified == []
@pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "run"
assert tasks == "check open tasks"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_decide_prompt_includes_current_time(tmp_path) -> None:
"""Phase 1 user prompt must contain current time so the LLM can judge task urgency."""
captured_messages: list[dict] = []
class CapturingProvider(LLMProvider):
async def chat(self, *, messages=None, **kwargs) -> LLMResponse:
if messages:
captured_messages.extend(messages)
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1", name="heartbeat",
arguments={"action": "skip"},
)
],
)
def get_default_model(self) -> str:
return "test-model"
service = HeartbeatService(
workspace=tmp_path,
provider=CapturingProvider(),
model="test-model",
)
await service._decide("- [ ] check servers at 10:00 UTC")
user_msg = captured_messages[1]
assert user_msg["role"] == "user"
assert "Current Time:" in user_msg["content"]

View File

@@ -0,0 +1,161 @@
"""Regression tests for PR #2026 — litellm_kwargs injection from ProviderSpec.
Validates that:
- OpenRouter uses litellm_prefix (NOT custom_llm_provider) to avoid LiteLLM double-prefixing.
- The litellm_kwargs mechanism works correctly for providers that declare it.
- Non-gateway providers are unaffected.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
def _fake_response(content: str = "ok") -> SimpleNamespace:
"""Build a minimal acompletion-shaped response object."""
message = SimpleNamespace(
content=content,
tool_calls=None,
reasoning_content=None,
thinking_blocks=None,
)
choice = SimpleNamespace(message=message, finish_reason="stop")
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
return SimpleNamespace(choices=[choice], usage=usage)
def test_openrouter_spec_uses_prefix_not_custom_llm_provider() -> None:
"""OpenRouter must rely on litellm_prefix, not custom_llm_provider kwarg.
LiteLLM internally adds a provider/ prefix when custom_llm_provider is set,
which double-prefixes models (openrouter/anthropic/model) and breaks the API.
"""
spec = find_by_name("openrouter")
assert spec is not None
assert spec.litellm_prefix == "openrouter"
assert "custom_llm_provider" not in spec.litellm_kwargs, (
"custom_llm_provider causes LiteLLM to double-prefix the model name"
)
@pytest.mark.asyncio
async def test_openrouter_prefixes_model_correctly() -> None:
"""OpenRouter should prefix model as openrouter/vendor/model for LiteLLM routing."""
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
provider = LiteLLMProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="anthropic/claude-sonnet-4-5",
provider_name="openrouter",
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="anthropic/claude-sonnet-4-5",
)
call_kwargs = mock_acompletion.call_args.kwargs
assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
"LiteLLM needs openrouter/ prefix to detect the provider and strip it before API call"
)
assert "custom_llm_provider" not in call_kwargs
@pytest.mark.asyncio
async def test_non_gateway_provider_no_extra_kwargs() -> None:
"""Standard (non-gateway) providers must NOT inject any litellm_kwargs."""
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
provider = LiteLLMProvider(
api_key="sk-ant-test-key",
default_model="claude-sonnet-4-5",
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="claude-sonnet-4-5",
)
call_kwargs = mock_acompletion.call_args.kwargs
assert "custom_llm_provider" not in call_kwargs, (
"Standard Anthropic provider should NOT inject custom_llm_provider"
)
@pytest.mark.asyncio
async def test_gateway_without_litellm_kwargs_injects_nothing_extra() -> None:
"""Gateways without litellm_kwargs (e.g. AiHubMix) must not add extra keys."""
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
provider = LiteLLMProvider(
api_key="sk-aihub-test-key",
api_base="https://aihubmix.com/v1",
default_model="claude-sonnet-4-5",
provider_name="aihubmix",
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="claude-sonnet-4-5",
)
call_kwargs = mock_acompletion.call_args.kwargs
assert "custom_llm_provider" not in call_kwargs
@pytest.mark.asyncio
async def test_openrouter_autodetect_by_key_prefix() -> None:
"""OpenRouter should be auto-detected by sk-or- key prefix even without explicit provider_name."""
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
provider = LiteLLMProvider(
api_key="sk-or-auto-detect-key",
default_model="anthropic/claude-sonnet-4-5",
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="anthropic/claude-sonnet-4-5",
)
call_kwargs = mock_acompletion.call_args.kwargs
assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
"Auto-detected OpenRouter should prefix model for LiteLLM routing"
)
@pytest.mark.asyncio
async def test_openrouter_native_model_id_gets_double_prefixed() -> None:
"""Models like openrouter/free must be double-prefixed so LiteLLM strips one layer.
openrouter/free is an actual OpenRouter model ID. LiteLLM strips the first
openrouter/ for routing, so we must send openrouter/openrouter/free to ensure
the API receives openrouter/free.
"""
mock_acompletion = AsyncMock(return_value=_fake_response())
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
provider = LiteLLMProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="openrouter/free",
provider_name="openrouter",
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="openrouter/free",
)
call_kwargs = mock_acompletion.call_args.kwargs
assert call_kwargs["model"] == "openrouter/openrouter/free", (
"openrouter/free must become openrouter/openrouter/free — "
"LiteLLM strips one layer so the API receives openrouter/free"
)

View File

@@ -0,0 +1,190 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
import nanobot.agent.memory as memory_module
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=context_window_tokens,
)
loop.tools.get_definitions = MagicMock(return_value=[])
return loop
@pytest.mark.asyncio
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
@pytest.mark.asyncio
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
await loop.process_direct("hello", session_key="cli:test")
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
@pytest.mark.asyncio
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
]
loop.sessions.save(session)
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
assert session.last_consolidated == 4
@pytest.mark.asyncio
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (300, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
"""Once triggered, consolidation should continue until it drops below half threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (150, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
"""Verify preflight consolidation runs before the LLM call in process_direct."""
order: list[str] = []
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages):
order.append("consolidate")
return True
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
async def track_llm(*args, **kwargs):
order.append("llm")
return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
return (1000 if call_count[0] <= 1 else 80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
assert "consolidate" in order
assert "llm" in order
assert order.index("consolidate") < order.index("llm")

View File

@@ -5,7 +5,7 @@ from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = 500
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
return loop
@@ -39,3 +39,17 @@ def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
def test_save_turn_keeps_tool_results_under_16k() -> None:
loop = _mk_loop()
session = Session(key="test:tool-result")
content = "x" * 12_000
loop._save_turn(
session,
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
skip=0,
)
assert session.messages[0]["content"] == content

View File

@@ -12,7 +12,7 @@ from nanobot.channels.matrix import (
TYPING_NOTICE_TIMEOUT_MS,
MatrixChannel,
)
from nanobot.config.schema import MatrixConfig
from nanobot.channels.matrix import MatrixConfig
_ROOM_SEND_UNSET = object()

View File

@@ -1,12 +1,15 @@
from __future__ import annotations
import asyncio
from contextlib import AsyncExitStack, asynccontextmanager
import sys
from types import ModuleType, SimpleNamespace
import pytest
from nanobot.agent.tools.mcp import MCPToolWrapper
from nanobot.agent.tools.mcp import MCPToolWrapper, connect_mcp_servers
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.config.schema import MCPServerConfig
class _FakeTextContent:
@@ -14,12 +17,63 @@ class _FakeTextContent:
self.text = text
@pytest.fixture
def fake_mcp_runtime() -> dict[str, object | None]:
return {"session": None}
@pytest.fixture(autouse=True)
def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None:
def _fake_mcp_module(
monkeypatch: pytest.MonkeyPatch, fake_mcp_runtime: dict[str, object | None]
) -> None:
mod = ModuleType("mcp")
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
class _FakeStdioServerParameters:
def __init__(self, command: str, args: list[str], env: dict | None = None) -> None:
self.command = command
self.args = args
self.env = env
class _FakeClientSession:
def __init__(self, _read: object, _write: object) -> None:
self._session = fake_mcp_runtime["session"]
async def __aenter__(self) -> object:
return self._session
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
@asynccontextmanager
async def _fake_stdio_client(_params: object):
yield object(), object()
@asynccontextmanager
async def _fake_sse_client(_url: str, httpx_client_factory=None):
yield object(), object()
@asynccontextmanager
async def _fake_streamable_http_client(_url: str, http_client=None):
yield object(), object(), object()
mod.ClientSession = _FakeClientSession
mod.StdioServerParameters = _FakeStdioServerParameters
monkeypatch.setitem(sys.modules, "mcp", mod)
client_mod = ModuleType("mcp.client")
stdio_mod = ModuleType("mcp.client.stdio")
stdio_mod.stdio_client = _fake_stdio_client
sse_mod = ModuleType("mcp.client.sse")
sse_mod.sse_client = _fake_sse_client
streamable_http_mod = ModuleType("mcp.client.streamable_http")
streamable_http_mod.streamable_http_client = _fake_streamable_http_client
monkeypatch.setitem(sys.modules, "mcp.client", client_mod)
monkeypatch.setitem(sys.modules, "mcp.client.stdio", stdio_mod)
monkeypatch.setitem(sys.modules, "mcp.client.sse", sse_mod)
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", streamable_http_mod)
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
tool_def = SimpleNamespace(
@@ -97,3 +151,132 @@ async def test_execute_handles_generic_exception() -> None:
result = await wrapper.execute()
assert result == "(MCP tool call failed: RuntimeError)"
def _make_tool_def(name: str) -> SimpleNamespace:
return SimpleNamespace(
name=name,
description=f"{name} tool",
inputSchema={"type": "object", "properties": {}},
)
def _make_fake_session(tool_names: list[str]) -> SimpleNamespace:
async def initialize() -> None:
return None
async def list_tools() -> SimpleNamespace:
return SimpleNamespace(tools=[_make_tool_def(name) for name in tool_names])
return SimpleNamespace(initialize=initialize, list_tools=list_tools)
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == []
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
fake_mcp_runtime: dict[str, object | None], monkeypatch: pytest.MonkeyPatch
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo"])
registry = ToolRegistry()
warnings: list[str] = []
def _warning(message: str, *args: object) -> None:
warnings.append(message.format(*args))
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
registry,
stack,
)
finally:
await stack.aclose()
assert registry.tool_names == []
assert warnings
assert "enabledTools entries not found: unknown" in warnings[-1]
assert "Available raw names: demo" in warnings[-1]
assert "Available wrapped names: mcp_test_demo" in warnings[-1]

View File

@@ -7,23 +7,20 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import pytest
from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50):
"""Create a mock session with messages."""
session = MagicMock()
session.messages = [
def _make_messages(message_count: int = 30):
"""Create a list of mock messages."""
return [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count)
]
session.last_consolidated = 0
return session
def _make_tool_response(history_entry, memory_update):
@@ -43,6 +40,22 @@ def _make_tool_response(history_entry, memory_update):
)
class ScriptedProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly."""
@@ -57,9 +70,10 @@ class TestMemoryConsolidationTypeHandling:
memory_update="# Memory\nUser likes testing.",
)
)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
@@ -77,9 +91,10 @@ class TestMemoryConsolidationTypeHandling:
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
)
)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
@@ -97,7 +112,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse(
content=None,
tool_calls=[
@@ -112,9 +126,10 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@@ -127,21 +142,23 @@ class TestMemoryConsolidationTypeHandling:
provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when messages < keep_count."""
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when the selected chunk is empty."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
session = _make_session(message_count=10)
provider.chat_with_retry = provider.chat
messages: list[dict] = []
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat.assert_not_called()
@@ -152,7 +169,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse(
content=None,
tool_calls=[
@@ -167,9 +183,10 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@@ -192,9 +209,10 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@@ -215,8 +233,246 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@pytest.mark.asyncio
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not persist partial results when required fields are missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"memory_update": "# Memory\nOnly memory update"},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not append history if memory_update is missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"history_entry": "[2026-01-01] Partial output."},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Null required fields should be rejected before persistence."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=None,
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Empty history entries should be rejected to avoid blank archival records."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=" ",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path)
provider = ScriptedProvider([
LLMResponse(content="503 server error", finish_reason="error"),
_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
),
])
messages = _make_messages(message_count=60)
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
"""Consolidation no longer passes generation params — the provider owns them."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat_with_retry.assert_awaited_once()
_, kwargs = provider.chat_with_retry.await_args
assert kwargs["model"] == "test-model"
assert "temperature" not in kwargs
assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs
@pytest.mark.asyncio
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error calling LLM: litellm.BadRequestError: "
"The tool_choice parameter does not support being set to required or object",
finish_reason="error",
tool_calls=[],
)
ok_resp = _make_tool_response(
history_entry="[2026-01-01] Fallback worked.",
memory_update="# Memory\nFallback OK.",
)
call_log: list[dict] = []
async def _tracking_chat(**kwargs):
call_log.append(kwargs)
return error_resp if len(call_log) == 1 else ok_resp
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert len(call_log) == 2
assert isinstance(call_log[0]["tool_choice"], dict)
assert call_log[1]["tool_choice"] == "auto"
assert "Fallback worked." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
"""Forced rejected, auto retry also produces no tool call -> return False."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error: tool_choice must be none or auto",
finish_reason="error",
tool_calls=[],
)
no_tool_resp = LLMResponse(
content="Here is a summary.",
finish_reason="stop",
tool_calls=[],
)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
"""After 3 consecutive failures, raw-archive messages and return True."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
messages = _make_messages(message_count=10)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is True
assert store.history_file.exists()
content = store.history_file.read_text()
assert "[RAW]" in content
assert "10 messages" in content
assert "msg0" in content
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
"""A successful consolidation resets the failure counter."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
ok_resp = _make_tool_response(
history_entry="[2026-01-01] OK.",
memory_update="# Memory\nOK.",
)
messages = _make_messages(message_count=10)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 2
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
assert await store.consolidate(messages, provider, "m") is True
assert store._consecutive_failures == 0
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 1

View File

@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
class TestMessageToolSuppressLogic:
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
@pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")

View File

@@ -0,0 +1,209 @@
import asyncio
import pytest
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider):
def __init__(self, responses):
super().__init__()
self._responses = list(responses)
self.calls = 0
self.last_kwargs: dict = {}
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
self.last_kwargs = kwargs
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "stop"
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "401 unauthorized"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit a", finish_reason="error"),
LLMResponse(content="429 rate limit b", finish_reason="error"),
LLMResponse(content="429 rate limit c", finish_reason="error"),
LLMResponse(content="503 final server error", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "503 final server error"
assert provider.calls == 4
assert delays == [1, 2, 4]
@pytest.mark.asyncio
async def test_chat_with_retry_preserves_cancelled_error() -> None:
provider = ScriptedProvider([asyncio.CancelledError()])
with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
@pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert provider.last_kwargs["temperature"] == 0.2
assert provider.last_kwargs["max_tokens"] == 321
assert provider.last_kwargs["reasoning_effort"] == "high"
@pytest.mark.asyncio
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
"""Explicit kwargs should override provider.generation defaults."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
temperature=0.9,
max_tokens=9999,
reasoning_effort="low",
)
assert provider.last_kwargs["temperature"] == 0.9
assert provider.last_kwargs["max_tokens"] == 9999
assert provider.last_kwargs["reasoning_effort"] == "low"
# ---------------------------------------------------------------------------
# Image-unsupported fallback tests
# ---------------------------------------------------------------------------
_IMAGE_MSG = [
{"role": "user", "content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]},
]
@pytest.mark.asyncio
async def test_image_unsupported_error_retries_without_images() -> None:
"""If the model rejects image_url, retry once with images stripped."""
provider = ScriptedProvider([
LLMResponse(
content="Invalid content type. image_url is only supported by certain models",
finish_reason="error",
),
LLMResponse(content="ok, no image"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert response.content == "ok, no image"
assert provider.calls == 2
msgs_on_retry = provider.last_kwargs["messages"]
for msg in msgs_on_retry:
content = msg.get("content")
if isinstance(content, list):
assert all(b.get("type") != "image_url" for b in content)
assert any("[image omitted]" in (b.get("text") or "") for b in content)
@pytest.mark.asyncio
async def test_image_unsupported_error_no_retry_without_image_content() -> None:
"""If messages don't contain image_url blocks, don't retry on image error."""
provider = ScriptedProvider([
LLMResponse(
content="image_url is only supported by certain models",
finish_reason="error",
),
])
response = await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
)
assert provider.calls == 1
assert response.finish_reason == "error"
@pytest.mark.asyncio
async def test_image_unsupported_fallback_returns_error_on_second_failure() -> None:
"""If the image-stripped retry also fails, return that error."""
provider = ScriptedProvider([
LLMResponse(
content="does not support image input",
finish_reason="error",
),
LLMResponse(content="some other error", finish_reason="error"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert provider.calls == 2
assert response.content == "some other error"
assert response.finish_reason == "error"
@pytest.mark.asyncio
async def test_non_image_error_does_not_trigger_image_fallback() -> None:
"""Regular non-transient errors must not trigger image stripping."""
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert provider.calls == 1
assert response.content == "401 unauthorized"

View File

@@ -5,7 +5,7 @@ import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel
from nanobot.config.schema import QQConfig
from nanobot.channels.qq import QQConfig
class _FakeApi:
@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
@pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
@@ -60,7 +60,66 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123"
assert call["msg_id"] == "msg1"
assert call["msg_seq"] == 2
assert call == {
"group_openid": "group123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.c2c_calls
@pytest.mark.asyncio
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.c2c_calls) == 1
call = channel._client.api.c2c_calls[0]
assert call == {
"openid": "user123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.group_calls
@pytest.mark.asyncio
async def test_send_group_message_uses_markdown_when_configured() -> None:
channel = QQChannel(
QQConfig(app_id="app", secret="secret", allow_from=["*"], msg_format="markdown"),
MessageBus(),
)
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
await channel.send(
OutboundMessage(
channel="qq",
chat_id="group123",
content="**hello**",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call == {
"group_openid": "group123",
"msg_type": 2,
"markdown": {"content": "**hello**"},
"msg_id": "msg1",
"msg_seq": 2,
}

View File

@@ -0,0 +1,76 @@
"""Tests for /restart slash command."""
from __future__ import annotations
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from nanobot.bus.events import InboundMessage
def _make_loop():
"""Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
workspace = MagicMock()
workspace.__truediv__ = MagicMock(return_value=MagicMock())
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager"):
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
return loop, bus
class TestRestartCommand:
@pytest.mark.asyncio
async def test_restart_sends_message_and_calls_execv(self):
loop, bus = _make_loop()
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
with patch("nanobot.agent.loop.os.execv") as mock_execv:
await loop._handle_restart(msg)
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "Restarting" in out.content
await asyncio.sleep(1.5)
mock_execv.assert_called_once()
@pytest.mark.asyncio
async def test_restart_intercepted_in_run_loop(self):
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
with patch.object(loop, "_handle_restart") as mock_handle:
mock_handle.return_value = None
await bus.publish_inbound(msg)
loop._running = True
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.1)
loop._running = False
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
mock_handle.assert_called_once()
@pytest.mark.asyncio
async def test_help_includes_restart(self):
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
response = await loop._process_message(msg)
assert response is not None
assert "/restart" in response.content

View File

@@ -0,0 +1,101 @@
"""Tests for nanobot.security.network — SSRF protection and internal URL detection."""
from __future__ import annotations
import socket
from unittest.mock import patch
import pytest
from nanobot.security.network import contains_internal_url, validate_url_target
def _fake_resolve(host: str, results: list[str]):
"""Return a getaddrinfo mock that maps the given host to fake IP results."""
def _resolver(hostname, port, family=0, type_=0):
if hostname == host:
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (ip, 0)) for ip in results]
raise socket.gaierror(f"cannot resolve {hostname}")
return _resolver
# ---------------------------------------------------------------------------
# validate_url_target — scheme / domain basics
# ---------------------------------------------------------------------------
def test_rejects_non_http_scheme():
ok, err = validate_url_target("ftp://example.com/file")
assert not ok
assert "http" in err.lower()
def test_rejects_missing_domain():
ok, err = validate_url_target("http://")
assert not ok
# ---------------------------------------------------------------------------
# validate_url_target — blocked private/internal IPs
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("ip,label", [
("127.0.0.1", "loopback"),
("127.0.0.2", "loopback_alt"),
("10.0.0.1", "rfc1918_10"),
("172.16.5.1", "rfc1918_172"),
("192.168.1.1", "rfc1918_192"),
("169.254.169.254", "metadata"),
("0.0.0.0", "zero"),
])
def test_blocks_private_ipv4(ip: str, label: str):
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("evil.com", [ip])):
ok, err = validate_url_target(f"http://evil.com/path")
assert not ok, f"Should block {label} ({ip})"
assert "private" in err.lower() or "blocked" in err.lower()
def test_blocks_ipv6_loopback():
def _resolver(hostname, port, family=0, type_=0):
return [(socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("::1", 0, 0, 0))]
with patch("nanobot.security.network.socket.getaddrinfo", _resolver):
ok, err = validate_url_target("http://evil.com/")
assert not ok
# ---------------------------------------------------------------------------
# validate_url_target — allows public IPs
# ---------------------------------------------------------------------------
def test_allows_public_ip():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["93.184.216.34"])):
ok, err = validate_url_target("http://example.com/page")
assert ok, f"Should allow public IP, got: {err}"
def test_allows_normal_https():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("github.com", ["140.82.121.3"])):
ok, err = validate_url_target("https://github.com/HKUDS/nanobot")
assert ok
# ---------------------------------------------------------------------------
# contains_internal_url — shell command scanning
# ---------------------------------------------------------------------------
def test_detects_curl_metadata():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("169.254.169.254", ["169.254.169.254"])):
assert contains_internal_url('curl -s http://169.254.169.254/computeMetadata/v1/')
def test_detects_wget_localhost():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("localhost", ["127.0.0.1"])):
assert contains_internal_url("wget http://localhost:8080/secret")
def test_allows_normal_curl():
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve("example.com", ["93.184.216.34"])):
assert not contains_internal_url("curl https://example.com/api/data")
def test_no_urls_returns_false():
assert not contains_internal_url("echo hello && ls -la")

View File

@@ -0,0 +1,146 @@
from nanobot.session.manager import Session
def _assert_no_orphans(history: list[dict]) -> None:
"""Assert every tool result in history has a matching assistant tool_call."""
declared = {
tc["id"]
for m in history if m.get("role") == "assistant"
for tc in (m.get("tool_calls") or [])
}
orphans = [
m.get("tool_call_id") for m in history
if m.get("role") == "tool" and m.get("tool_call_id") not in declared
]
assert orphans == [], f"orphan tool_call_ids: {orphans}"
def _tool_turn(prefix: str, idx: int) -> list[dict]:
"""Helper: one assistant with 2 tool_calls + 2 tool results."""
return [
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": f"{prefix}_{idx}_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": f"{prefix}_{idx}_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
},
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_a", "name": "x", "content": "ok"},
{"role": "tool", "tool_call_id": f"{prefix}_{idx}_b", "name": "y", "content": "ok"},
]
# --- Original regression test (from PR 2075) ---
def test_get_history_drops_orphan_tool_results_when_window_cuts_tool_calls():
session = Session(key="telegram:test")
session.messages.append({"role": "user", "content": "old turn"})
for i in range(20):
session.messages.extend(_tool_turn("old", i))
session.messages.append({"role": "user", "content": "problem turn"})
for i in range(25):
session.messages.extend(_tool_turn("cur", i))
session.messages.append({"role": "user", "content": "new telegram question"})
history = session.get_history(max_messages=100)
_assert_no_orphans(history)
# --- Positive test: legitimate pairs survive trimming ---
def test_legitimate_tool_pairs_preserved_after_trim():
"""Complete tool-call groups within the window must not be dropped."""
session = Session(key="test:positive")
session.messages.append({"role": "user", "content": "hello"})
for i in range(5):
session.messages.extend(_tool_turn("ok", i))
session.messages.append({"role": "assistant", "content": "done"})
history = session.get_history(max_messages=500)
_assert_no_orphans(history)
tool_ids = [m["tool_call_id"] for m in history if m.get("role") == "tool"]
assert len(tool_ids) == 10
assert history[0]["role"] == "user"
# --- last_consolidated > 0 ---
def test_orphan_trim_with_last_consolidated():
"""Orphan trimming works correctly when session is partially consolidated."""
session = Session(key="test:consolidated")
for i in range(10):
session.messages.append({"role": "user", "content": f"old {i}"})
session.messages.extend(_tool_turn("cons", i))
session.last_consolidated = 30
session.messages.append({"role": "user", "content": "recent"})
for i in range(15):
session.messages.extend(_tool_turn("new", i))
session.messages.append({"role": "user", "content": "latest"})
history = session.get_history(max_messages=20)
_assert_no_orphans(history)
assert all(m.get("role") != "tool" or m["tool_call_id"].startswith("new_") for m in history)
# --- Edge: no tool messages at all ---
def test_no_tool_messages_unchanged():
session = Session(key="test:plain")
for i in range(5):
session.messages.append({"role": "user", "content": f"q{i}"})
session.messages.append({"role": "assistant", "content": f"a{i}"})
history = session.get_history(max_messages=6)
assert len(history) == 6
_assert_no_orphans(history)
# --- Edge: all leading messages are orphan tool results ---
def test_all_orphan_prefix_stripped():
"""If the window starts with orphan tool results and nothing else, they're all dropped."""
session = Session(key="test:all-orphan")
session.messages.append({"role": "tool", "tool_call_id": "gone_1", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": "gone_2", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "fresh start"})
session.messages.append({"role": "assistant", "content": "hi"})
history = session.get_history(max_messages=500)
_assert_no_orphans(history)
assert history[0]["role"] == "user"
assert len(history) == 2
# --- Edge: empty session ---
def test_empty_session_history():
session = Session(key="test:empty")
history = session.get_history(max_messages=500)
assert history == []
# --- Window cuts mid-group: assistant present but some tool results orphaned ---
def test_window_cuts_mid_tool_group():
"""If the window starts between an assistant's tool results, the partial group is trimmed."""
session = Session(key="test:mid-cut")
session.messages.append({"role": "user", "content": "setup"})
session.messages.append({
"role": "assistant", "content": None,
"tool_calls": [
{"id": "split_a", "type": "function", "function": {"name": "x", "arguments": "{}"}},
{"id": "split_b", "type": "function", "function": {"name": "y", "arguments": "{}"}},
],
})
session.messages.append({"role": "tool", "tool_call_id": "split_a", "name": "x", "content": "ok"})
session.messages.append({"role": "tool", "tool_call_id": "split_b", "name": "y", "content": "ok"})
session.messages.append({"role": "user", "content": "next"})
session.messages.extend(_tool_turn("intact", 0))
session.messages.append({"role": "assistant", "content": "final"})
# Window of 6 should cut off the "setup" user msg and the assistant with split_a/split_b,
# leaving orphan tool results for split_a at the front.
history = session.get_history(max_messages=6)
_assert_no_orphans(history)

View File

@@ -0,0 +1,127 @@
import importlib
import shutil
import sys
import zipfile
from pathlib import Path
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
init_skill = importlib.import_module("init_skill")
package_skill = importlib.import_module("package_skill")
quick_validate = importlib.import_module("quick_validate")
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
skill_dir = init_skill.init_skill(
"demo-skill",
tmp_path,
["scripts", "references", "assets"],
include_examples=True,
)
assert skill_dir == tmp_path / "demo-skill"
assert (skill_dir / "SKILL.md").exists()
assert (skill_dir / "scripts" / "example.py").exists()
assert (skill_dir / "references" / "api_reference.md").exists()
assert (skill_dir / "assets" / "example_asset.txt").exists()
def test_validate_skill_accepts_existing_skill_creator() -> None:
valid, message = quick_validate.validate_skill(
Path("nanobot/skills/skill-creator").resolve()
)
assert valid, message
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
skill_dir = tmp_path / "placeholder-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: placeholder-skill\n"
'description: "[TODO: fill me in]"\n'
"---\n"
"# Placeholder\n",
encoding="utf-8",
)
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "TODO placeholder" in message
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
skill_dir = tmp_path / "bad-root-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: bad-root-skill\n"
"description: Valid description\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "Unexpected file or directory in skill root" in message
def test_package_skill_creates_archive(tmp_path: Path) -> None:
skill_dir = tmp_path / "package-me"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: package-me\n"
"description: Package this skill.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path == (tmp_path / "dist" / "package-me.skill")
assert archive_path.exists()
with zipfile.ZipFile(archive_path, "r") as archive:
names = set(archive.namelist())
assert "package-me/SKILL.md" in names
assert "package-me/scripts/helper.py" in names
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
skill_dir = tmp_path / "symlink-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: symlink-skill\n"
"description: Reject symlinks during packaging.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
target = tmp_path / "outside.txt"
target.write_text("secret\n", encoding="utf-8")
link = scripts_dir / "outside.txt"
try:
link.symlink_to(target)
except (OSError, NotImplementedError):
return
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path is None
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.slack import SlackChannel
from nanobot.channels.slack import SlackConfig
class _FakeAsyncWebClient:
def __init__(self) -> None:
self.chat_post_calls: list[dict[str, object | None]] = []
self.file_upload_calls: list[dict[str, object | None]] = []
async def chat_postMessage(
self,
*,
channel: str,
text: str,
thread_ts: str | None = None,
) -> None:
self.chat_post_calls.append(
{
"channel": channel,
"text": text,
"thread_ts": thread_ts,
}
)
async def files_upload_v2(
self,
*,
channel: str,
file: str,
thread_ts: str | None = None,
) -> None:
self.file_upload_calls.append(
{
"channel": channel,
"file": file,
"thread_ts": thread_ts,
}
)
@pytest.mark.asyncio
async def test_send_uses_thread_for_channel_messages() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="C123",
content="hello",
media=["/tmp/demo.txt"],
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "channel"}},
)
)
assert len(fake_web.chat_post_calls) == 1
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
assert fake_web.chat_post_calls[0]["thread_ts"] == "1700000000.000100"
assert len(fake_web.file_upload_calls) == 1
assert fake_web.file_upload_calls[0]["thread_ts"] == "1700000000.000100"
@pytest.mark.asyncio
async def test_send_omits_thread_for_dm_messages() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="D123",
content="hello",
media=["/tmp/demo.txt"],
metadata={"slack": {"thread_ts": "1700000000.000100", "channel_type": "im"}},
)
)
assert len(fake_web.chat_post_calls) == 1
assert fake_web.chat_post_calls[0]["text"] == "hello\n"
assert fake_web.chat_post_calls[0]["thread_ts"] is None
assert len(fake_web.file_upload_calls) == 1
assert fake_web.file_upload_calls[0]["thread_ts"] is None

View File

@@ -165,3 +165,46 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def scripted_chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
async def fake_execute(self, name, arguments):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
assistant_messages = [
msg for msg in captured_second_call
if msg.get("role") == "assistant" and msg.get("tool_calls")
]
assert len(assistant_messages) == 1
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]

View File

@@ -1,11 +1,14 @@
import asyncio
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TelegramChannel
from nanobot.config.schema import TelegramConfig
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel
from nanobot.channels.telegram import TelegramConfig
class _FakeHTTPXRequest:
@@ -27,9 +30,11 @@ class _FakeUpdater:
class _FakeBot:
def __init__(self) -> None:
self.sent_messages: list[dict] = []
self.get_me_calls = 0
async def get_me(self):
return SimpleNamespace(username="nanobot_test")
self.get_me_calls += 1
return SimpleNamespace(id=999, username="nanobot_test")
async def set_my_commands(self, commands) -> None:
self.commands = commands
@@ -37,6 +42,15 @@ class _FakeBot:
async def send_message(self, **kwargs) -> None:
self.sent_messages.append(kwargs)
async def send_chat_action(self, **kwargs) -> None:
pass
async def get_file(self, file_id: str):
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
async def _fake_download(path) -> None:
pass
return SimpleNamespace(download_to_drive=_fake_download)
class _FakeApp:
def __init__(self, on_start_polling) -> None:
@@ -87,6 +101,35 @@ class _FakeBuilder:
return self.app
def _make_telegram_update(
*,
chat_type: str = "group",
text: str | None = None,
caption: str | None = None,
entities=None,
caption_entities=None,
reply_to_message=None,
):
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
message = SimpleNamespace(
chat=SimpleNamespace(type=chat_type, is_forum=False),
chat_id=-100123,
text=text,
caption=caption,
entities=entities or [],
caption_entities=caption_entities or [],
reply_to_message=reply_to_message,
photo=None,
voice=None,
audio=None,
document=None,
media_group_id=None,
message_thread_id=None,
message_id=1,
)
return SimpleNamespace(message=message, effective_user=user)
@pytest.mark.asyncio
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
config = TelegramConfig(
@@ -131,6 +174,10 @@ def test_get_extension_falls_back_to_original_filename() -> None:
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
def test_telegram_group_policy_defaults_to_mention() -> None:
assert TelegramConfig().group_policy == "mention"
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
@@ -182,3 +229,437 @@ async def test_send_reply_infers_topic_from_message_id_cache() -> None:
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
assert channel._app.bot.sent_messages[0]["reply_parameters"].message_id == 10
@pytest.mark.asyncio
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
assert handled == []
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
assert len(handled) == 2
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_caption_mention() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
None,
)
assert len(handled) == 1
assert handled[0]["content"] == "@nanobot_test photo"
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
assert len(handled) == 1
@pytest.mark.asyncio
async def test_group_policy_open_accepts_plain_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello group"), None)
assert len(handled) == 1
assert channel._app.bot.get_me_calls == 0
def test_extract_reply_context_no_reply() -> None:
"""When there is no reply_to_message, _extract_reply_context returns None."""
message = SimpleNamespace(reply_to_message=None)
assert TelegramChannel._extract_reply_context(message) is None
def test_extract_reply_context_with_text() -> None:
"""When reply has text, return prefixed string."""
reply = SimpleNamespace(text="Hello world", caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
def test_extract_reply_context_with_caption_only() -> None:
"""When reply has only caption (no text), caption is used."""
reply = SimpleNamespace(text=None, caption="Photo caption")
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
def test_extract_reply_context_truncation() -> None:
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
reply = SimpleNamespace(text=long_text, caption=None)
message = SimpleNamespace(reply_to_message=reply)
result = TelegramChannel._extract_reply_context(message)
assert result is not None
assert result.startswith("[Reply to: ")
assert result.endswith("...]")
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
def test_extract_reply_context_no_text_returns_none() -> None:
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
reply = SimpleNamespace(text=None, caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) is None
@pytest.mark.asyncio
async def test_on_message_includes_reply_context() -> None:
"""When user replies to a message, content passed to bus starts with reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="translate this", reply_to_message=reply)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: Hello]")
assert "translate this" in handled[0]["content"]
@pytest.mark.asyncio
async def test_download_message_media_returns_path_when_download_succeeds(
monkeypatch, tmp_path
) -> None:
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
msg = SimpleNamespace(
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
voice=None,
audio=None,
document=None,
video=None,
video_note=None,
animation=None,
)
paths, parts = await channel._download_message_media(msg)
assert len(paths) == 1
assert len(parts) == 1
assert "fid123" in paths[0]
assert "[image:" in parts[0]
@pytest.mark.asyncio
async def test_download_message_media_uses_file_unique_id_when_available(
monkeypatch, tmp_path
) -> None:
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
downloaded: dict[str, str] = {}
async def _download_to_drive(path: str) -> None:
downloaded["path"] = path
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=_download_to_drive)
)
channel._app = app
msg = SimpleNamespace(
photo=[
SimpleNamespace(
file_id="file-id-that-should-not-be-used",
file_unique_id="stable-unique-id",
mime_type="image/jpeg",
file_name=None,
)
],
voice=None,
audio=None,
document=None,
video=None,
video_note=None,
animation=None,
)
paths, parts = await channel._download_message_media(msg)
assert downloaded["path"].endswith("stable-unique-id.jpg")
assert paths == [str(media_dir / "stable-unique-id.jpg")]
assert parts == [f"[image: {media_dir / 'stable-unique-id.jpg'}]"]
@pytest.mark.asyncio
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what is the image?",
reply_to_message=reply_with_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: [image:")
assert "what is the image?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "reply_photo_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
"""When reply has media but download fails, no media attached and no reply tag."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = None
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
await channel._on_message(update, None)
assert len(handled) == 1
assert "what is this?" in handled[0]["content"]
assert handled[0]["media"] == []
@pytest.mark.asyncio
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
"""When replying to a message with caption + photo, both text context and media are included."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_caption_and_photo = SimpleNamespace(
text=None,
caption="A cute cat",
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what breed is this?",
reply_to_message=reply_with_caption_and_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert "[Reply to: A cute cat]" in handled[0]["content"]
assert "what breed is this?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "cat_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_forward_command_does_not_inject_reply_context() -> None:
"""Slash commands forwarded via _forward_command must not include reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="/new", reply_to_message=reply)
await channel._forward_command(update, None)
assert len(handled) == 1
assert handled[0]["content"] == "/new"
@pytest.mark.asyncio
async def test_on_help_includes_restart_command() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
update = _make_telegram_update(text="/help", chat_type="private")
update.message.reply_text = AsyncMock()
await channel._on_help(update, None)
update.message.reply_text.assert_awaited_once()
help_text = update.message.reply_text.await_args.args[0]
assert "/restart" in help_text

View File

@@ -108,6 +108,32 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
assert "/tmp/out.txt" in paths
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "~/.nanobot/config.json" in paths
assert "~/out.txt" in paths
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "~/.nanobot/config.json" in paths
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests ---
@@ -337,3 +363,46 @@ def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]
# --- ExecTool enhancement tests ---
async def test_exec_always_returns_exit_code() -> None:
"""Exit code should appear in output even on success (exit 0)."""
tool = ExecTool()
result = await tool.execute(command="echo hello")
assert "Exit code: 0" in result
assert "hello" in result
async def test_exec_head_tail_truncation() -> None:
"""Long output should preserve both head and tail."""
tool = ExecTool()
# Generate output that exceeds _MAX_OUTPUT (10_000 chars)
# Use python to generate output to avoid command line length limits
result = await tool.execute(
command="python -c \"print('A' * 6000 + '\\n' + 'B' * 6000)\""
)
assert "chars truncated" in result
# Head portion should start with As
assert result.startswith("A")
# Tail portion should end with the exit code which comes after Bs
assert "Exit code:" in result
async def test_exec_timeout_parameter() -> None:
"""LLM-supplied timeout should override the constructor default."""
tool = ExecTool(timeout=60)
# A very short timeout should cause the command to be killed
result = await tool.execute(command="sleep 10", timeout=1)
assert "timed out" in result
assert "1 seconds" in result
async def test_exec_timeout_capped_at_max() -> None:
"""Timeout values above _MAX_TIMEOUT should be clamped."""
tool = ExecTool()
# Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result

View File

@@ -0,0 +1,69 @@
"""Tests for web_fetch SSRF protection and untrusted content marking."""
from __future__ import annotations
import json
import socket
from unittest.mock import patch
import pytest
from nanobot.agent.tools.web import WebFetchTool
def _fake_resolve_private(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
def _fake_resolve_public(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
@pytest.mark.asyncio
async def test_web_fetch_blocks_private_ip():
tool = WebFetchTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(url="http://169.254.169.254/computeMetadata/v1/")
data = json.loads(result)
assert "error" in data
assert "private" in data["error"].lower() or "blocked" in data["error"].lower()
@pytest.mark.asyncio
async def test_web_fetch_blocks_localhost():
tool = WebFetchTool()
def _resolve_localhost(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
with patch("nanobot.security.network.socket.getaddrinfo", _resolve_localhost):
result = await tool.execute(url="http://localhost/admin")
data = json.loads(result)
assert "error" in data
@pytest.mark.asyncio
async def test_web_fetch_result_contains_untrusted_flag():
"""When fetch succeeds, result JSON must include untrusted=True and the banner."""
tool = WebFetchTool()
fake_html = "<html><head><title>Test</title></head><body><p>Hello world</p></body></html>"
import httpx
class FakeResponse:
status_code = 200
url = "https://example.com/page"
text = fake_html
headers = {"content-type": "text/html"}
def raise_for_status(self): pass
def json(self): return {}
async def _fake_get(self, url, **kwargs):
return FakeResponse()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public), \
patch("httpx.AsyncClient.get", _fake_get):
result = await tool.execute(url="https://example.com/page")
data = json.loads(result)
assert data.get("untrusted") is True
assert "[External content" in data.get("text", "")

View File

@@ -0,0 +1,162 @@
"""Tests for multi-provider web search."""
import httpx
import pytest
from nanobot.agent.tools.web import WebSearchTool
from nanobot.config.schema import WebSearchConfig
def _tool(provider: str = "brave", api_key: str = "", base_url: str = "") -> WebSearchTool:
return WebSearchTool(config=WebSearchConfig(provider=provider, api_key=api_key, base_url=base_url))
def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
"""Build a mock httpx.Response with a dummy request attached."""
r = httpx.Response(status, json=json)
r._request = httpx.Request("GET", "https://mock")
return r
@pytest.mark.asyncio
async def test_brave_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "brave" in url
assert kw["headers"]["X-Subscription-Token"] == "brave-key"
return _response(json={
"web": {"results": [{"title": "NanoBot", "url": "https://example.com", "description": "AI assistant"}]}
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="brave", api_key="brave-key")
result = await tool.execute(query="nanobot", count=1)
assert "NanoBot" in result
assert "https://example.com" in result
@pytest.mark.asyncio
async def test_tavily_search(monkeypatch):
async def mock_post(self, url, **kw):
assert "tavily" in url
assert kw["headers"]["Authorization"] == "Bearer tavily-key"
return _response(json={
"results": [{"title": "OpenClaw", "url": "https://openclaw.io", "content": "Framework"}]
})
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
tool = _tool(provider="tavily", api_key="tavily-key")
result = await tool.execute(query="openclaw")
assert "OpenClaw" in result
assert "https://openclaw.io" in result
@pytest.mark.asyncio
async def test_searxng_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "searx.example" in url
return _response(json={
"results": [{"title": "Result", "url": "https://example.com", "content": "SearXNG result"}]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="searxng", base_url="https://searx.example")
result = await tool.execute(query="test")
assert "Result" in result
@pytest.mark.asyncio
async def test_duckduckgo_search(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "DDG Result", "href": "https://ddg.example", "body": "From DuckDuckGo"}]
monkeypatch.setattr("nanobot.agent.tools.web.DDGS", MockDDGS, raising=False)
import nanobot.agent.tools.web as web_mod
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
from ddgs import DDGS
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
tool = _tool(provider="duckduckgo")
result = await tool.execute(query="hello")
assert "DDG Result" in result
@pytest.mark.asyncio
async def test_brave_fallback_to_duckduckgo_when_no_key(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
tool = _tool(provider="brave", api_key="")
result = await tool.execute(query="test")
assert "Fallback" in result
@pytest.mark.asyncio
async def test_jina_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "s.jina.ai" in str(url)
assert kw["headers"]["Authorization"] == "Bearer jina-key"
return _response(json={
"data": [{"title": "Jina Result", "url": "https://jina.ai", "content": "AI search"}]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="jina", api_key="jina-key")
result = await tool.execute(query="test")
assert "Jina Result" in result
assert "https://jina.ai" in result
@pytest.mark.asyncio
async def test_unknown_provider():
tool = _tool(provider="unknown")
result = await tool.execute(query="test")
assert "unknown" in result
assert "Error" in result
@pytest.mark.asyncio
async def test_default_provider_is_brave(monkeypatch):
async def mock_get(self, url, **kw):
assert "brave" in url
return _response(json={"web": {"results": []}})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="", api_key="test-key")
result = await tool.execute(query="test")
assert "No results" in result
@pytest.mark.asyncio
async def test_searxng_no_base_url_falls_back(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "fallback"}]
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
monkeypatch.delenv("SEARXNG_BASE_URL", raising=False)
tool = _tool(provider="searxng", base_url="")
result = await tool.execute(query="test")
assert "Fallback" in result
@pytest.mark.asyncio
async def test_searxng_invalid_url():
tool = _tool(provider="searxng", base_url="not-a-url")
result = await tool.execute(query="test")
assert "Error" in result