merge: resolve PR #1796 conflicts with main

Merge the latest main branch into the Telegram media filename fix and keep the file_unique_id-based download path on top of the refactored media handling and newer Telegram tests.

Made-with: Cursor
This commit is contained in:
Protocol Zero
2026-03-14 08:33:56 +00:00
64 changed files with 6213 additions and 1300 deletions

View File

@@ -0,0 +1,228 @@
"""Tests for channel plugin discovery, merging, and config compatibility."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.channels.manager import ChannelManager
from nanobot.config.schema import ChannelsConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakePlugin(BaseChannel):
name = "fakeplugin"
display_name = "Fake Plugin"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
class _FakeTelegram(BaseChannel):
"""Plugin that tries to shadow built-in telegram."""
name = "telegram"
display_name = "Fake Telegram"
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def send(self, msg: OutboundMessage) -> None:
pass
def _make_entry_point(name: str, cls: type):
"""Create a mock entry point that returns *cls* on load()."""
ep = SimpleNamespace(name=name, load=lambda _cls=cls: _cls)
return ep
# ---------------------------------------------------------------------------
# ChannelsConfig extra="allow"
# ---------------------------------------------------------------------------
def test_channels_config_accepts_unknown_keys():
cfg = ChannelsConfig.model_validate({
"myplugin": {"enabled": True, "token": "abc"},
})
extra = cfg.model_extra
assert extra is not None
assert extra["myplugin"]["enabled"] is True
assert extra["myplugin"]["token"] == "abc"
def test_channels_config_getattr_returns_extra():
cfg = ChannelsConfig.model_validate({"myplugin": {"enabled": True}})
section = getattr(cfg, "myplugin", None)
assert isinstance(section, dict)
assert section["enabled"] is True
def test_channels_config_builtin_fields_removed():
"""After decoupling, ChannelsConfig has no explicit channel fields."""
cfg = ChannelsConfig()
assert not hasattr(cfg, "telegram")
assert cfg.send_progress is True
assert cfg.send_tool_hints is False
# ---------------------------------------------------------------------------
# discover_plugins
# ---------------------------------------------------------------------------
_EP_TARGET = "importlib.metadata.entry_points"
def test_discover_plugins_loads_entry_points():
from nanobot.channels.registry import discover_plugins
ep = _make_entry_point("line", _FakePlugin)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_plugins()
assert "line" in result
assert result["line"] is _FakePlugin
def test_discover_plugins_handles_load_error():
from nanobot.channels.registry import discover_plugins
def _boom():
raise RuntimeError("broken")
ep = SimpleNamespace(name="broken", load=_boom)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_plugins()
assert "broken" not in result
# ---------------------------------------------------------------------------
# discover_all — merge & priority
# ---------------------------------------------------------------------------
def test_discover_all_includes_builtins():
from nanobot.channels.registry import discover_all, discover_channel_names
with patch(_EP_TARGET, return_value=[]):
result = discover_all()
# discover_all() only returns channels that are actually available (dependencies installed)
# discover_channel_names() returns all built-in channel names
# So we check that all actually loaded channels are in the result
for name in result:
assert name in discover_channel_names()
def test_discover_all_includes_external_plugin():
from nanobot.channels.registry import discover_all
ep = _make_entry_point("line", _FakePlugin)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_all()
assert "line" in result
assert result["line"] is _FakePlugin
def test_discover_all_builtin_shadows_plugin():
from nanobot.channels.registry import discover_all
ep = _make_entry_point("telegram", _FakeTelegram)
with patch(_EP_TARGET, return_value=[ep]):
result = discover_all()
assert "telegram" in result
assert result["telegram"] is not _FakeTelegram
# ---------------------------------------------------------------------------
# Manager _init_channels with dict config (plugin scenario)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_manager_loads_plugin_from_dict_config():
"""ChannelManager should instantiate a plugin channel from a raw dict config."""
from nanobot.channels.manager import ChannelManager
fake_config = SimpleNamespace(
channels=ChannelsConfig.model_validate({
"fakeplugin": {"enabled": True, "allowFrom": ["*"]},
}),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
with patch(
"nanobot.channels.registry.discover_all",
return_value={"fakeplugin": _FakePlugin},
):
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {}
mgr._dispatch_task = None
mgr._init_channels()
assert "fakeplugin" in mgr.channels
assert isinstance(mgr.channels["fakeplugin"], _FakePlugin)
@pytest.mark.asyncio
async def test_manager_skips_disabled_plugin():
fake_config = SimpleNamespace(
channels=ChannelsConfig.model_validate({
"fakeplugin": {"enabled": False},
}),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
)
with patch(
"nanobot.channels.registry.discover_all",
return_value={"fakeplugin": _FakePlugin},
):
mgr = ChannelManager.__new__(ChannelManager)
mgr.config = fake_config
mgr.bus = MessageBus()
mgr.channels = {}
mgr._dispatch_task = None
mgr._init_channels()
assert "fakeplugin" not in mgr.channels
# ---------------------------------------------------------------------------
# Built-in channel default_config() and dict->Pydantic conversion
# ---------------------------------------------------------------------------
def test_builtin_channel_default_config():
"""Built-in channels expose default_config() returning a dict with 'enabled': False."""
from nanobot.channels.telegram import TelegramChannel
cfg = TelegramChannel.default_config()
assert isinstance(cfg, dict)
assert cfg["enabled"] is False
assert "token" in cfg
def test_builtin_channel_init_from_dict():
"""Built-in channels accept a raw dict and convert to Pydantic internally."""
from nanobot.channels.telegram import TelegramChannel
bus = MessageBus()
ch = TelegramChannel({"enabled": False, "token": "test-tok", "allowFrom": ["*"]}, bus)
assert ch.config.token == "test-tok"
assert ch.config.allow_from == ["*"]

View File

@@ -1,3 +1,4 @@
import re
import shutil
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
@@ -11,6 +12,12 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_model
def _strip_ansi(text):
"""Remove ANSI escape codes from text."""
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
return ansi_escape.sub('', text)
runner = CliRunner()
@@ -114,6 +121,64 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
assert config.get_provider_name() == "openai_codex"
def test_config_matches_explicit_ollama_prefix_without_api_key():
config = Config()
config.agents.defaults.model = "ollama/llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
config = Config()
config.agents.defaults.provider = "ollama"
config.agents.defaults.model = "llama3.2"
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_auto_detects_ollama_from_local_api_base():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
"ollama": {"apiBase": "http://localhost:11434"},
},
}
)
assert config.get_provider_name() == "ollama"
assert config.get_api_base() == "http://localhost:11434"
def test_config_falls_back_to_vllm_when_ollama_not_configured():
config = Config.model_validate(
{
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
"providers": {
"vllm": {"apiBase": "http://localhost:8000"},
},
}
)
assert config.get_provider_name() == "vllm"
assert config.get_api_base() == "http://localhost:8000"
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
spec = find_by_model("github-copilot/gpt-5.3-codex")
@@ -170,10 +235,11 @@ def test_agent_help_shows_workspace_and_config_options():
result = runner.invoke(app, ["agent", "--help"])
assert result.exit_code == 0
assert "--workspace" in result.stdout
assert "-w" in result.stdout
assert "--config" in result.stdout
assert "-c" in result.stdout
stripped_output = _strip_ansi(result.stdout)
assert "--workspace" in stripped_output
assert "-w" in stripped_output
assert "--config" in stripped_output
assert "-c" in stripped_output
def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_runtime):
@@ -267,6 +333,16 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
mock_agent_runtime["config"].agents.defaults.memory_window = 100
result = runner.invoke(app, ["agent", "-m", "hello"])
assert result.exit_code == 0
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
@@ -328,6 +404,28 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
assert config.workspace_path == override
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.memory_window = 100
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "memoryWindow" in result.stdout
assert "contextWindowTokens" in result.stdout
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
@@ -356,3 +454,47 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat
assert isinstance(result.exception, _StopGateway)
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGateway)
assert "port 18791" in result.stdout
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
)
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
assert isinstance(result.exception, _StopGateway)
assert "port 18792" in result.stdout

View File

@@ -0,0 +1,132 @@
import json
from types import SimpleNamespace
from typer.testing import CliRunner
from nanobot.cli.commands import app
from nanobot.config.loader import load_config, save_config
runner = CliRunner()
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 1234,
"memoryWindow": 42,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.agents.defaults.max_tokens == 1234
assert config.agents.defaults.context_window_tokens == 65_536
assert config.agents.defaults.should_warn_deprecated_memory_window is True
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 2222,
"memoryWindow": 30,
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 2222
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"agents": {
"defaults": {
"maxTokens": 3333,
"memoryWindow": 50,
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
assert "contextWindowTokens" in result.stdout
saved = json.loads(config_path.read_text(encoding="utf-8"))
defaults = saved["agents"]["defaults"]
assert defaults["maxTokens"] == 3333
assert defaults["contextWindowTokens"] == 65_536
assert "memoryWindow" not in defaults
def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
workspace = tmp_path / "workspace"
config_path.write_text(
json.dumps(
{
"channels": {
"qq": {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
}
}
}
),
encoding="utf-8",
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda: workspace)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {
"qq": SimpleNamespace(
default_config=lambda: {
"enabled": False,
"appId": "",
"secret": "",
"allowFrom": [],
"msgFormat": "plain",
}
)
},
)
result = runner.invoke(app, ["onboard"], input="n\n")
assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["qq"]["msgFormat"] == "plain"

View File

@@ -480,226 +480,35 @@ class TestEmptyAndBoundarySessions:
assert_messages_content(old_messages, 10, 34)
class TestConsolidationDeduplicationGuard:
"""Test that consolidation tasks are deduplicated and serialized."""
class TestNewCommandArchival:
"""Test /new archival behavior with the simplified consolidation flow."""
@pytest.mark.asyncio
async def test_consolidation_guard_prevents_duplicate_tasks(self, tmp_path: Path) -> None:
"""Concurrent messages above memory_window spawn only one consolidation task."""
@staticmethod
def _make_loop(tmp_path: Path):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (10_000, "test")
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=1,
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await loop._process_message(msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 1, (
f"Expected exactly 1 consolidation, got {consolidation_calls}"
)
@pytest.mark.asyncio
async def test_new_command_guard_prevents_concurrent_consolidation(
self, tmp_path: Path
) -> None:
"""/new command does not run consolidation concurrently with in-flight consolidation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
consolidation_calls = 0
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
max_active = max(max_active, active)
await asyncio.sleep(0.05)
active -= 1
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
await loop._process_message(new_msg)
await asyncio.sleep(0.1)
assert consolidation_calls == 2, (
f"Expected normal + /new consolidations, got {consolidation_calls}"
)
assert max_active == 1, (
f"Expected serialized consolidation, observed concurrency={max_active}"
)
@pytest.mark.asyncio
async def test_consolidation_tasks_are_referenced(self, tmp_path: Path) -> None:
"""create_task results are tracked in _consolidation_tasks while in flight."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
started.set()
await asyncio.sleep(0.1)
loop._consolidate_memory = _slow_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
assert len(loop._consolidation_tasks) == 1, "Task must be referenced while in-flight"
await asyncio.sleep(0.15)
assert len(loop._consolidation_tasks) == 0, (
"Task reference must be removed after completion"
)
@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_preserves_messages(
self, tmp_path: Path
) -> None:
"""/new waits for in-flight consolidation and archives before clear."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done(), "/new should wait while consolidation is in-flight"
release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count > 0, "Expected /new archival to process a non-empty snapshot"
session_after = loop.sessions.get_or_create("cli:test")
assert session_after.messages == [], "Session should be cleared after successful archival"
return loop
@pytest.mark.asyncio
async def test_new_does_not_clear_session_when_archive_fails(self, tmp_path: Path) -> None:
"""/new must keep session data if archive step reports failure."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
@@ -707,111 +516,61 @@ class TestConsolidationDeduplicationGuard:
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
if archive_all:
return False
return True
async def _failing_consolidate(_messages) -> bool:
return False
loop._consolidate_memory = _failing_consolidate # type: ignore[method-assign]
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)
assert response is not None
assert "failed" in response.content.lower()
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == before_count, (
"Session must remain intact when /new archival fails"
)
assert len(loop.sessions.get_or_create("cli:test").messages) == before_count
@pytest.mark.asyncio
async def test_new_archives_only_unconsolidated_messages_after_inflight_task(
self, tmp_path: Path
) -> None:
"""/new should archive only messages not yet consolidated by prior task."""
from nanobot.agent.loop import AgentLoop
async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None:
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(15):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
session.last_consolidated = len(session.messages) - 3
loop.sessions.save(session)
started = asyncio.Event()
release = asyncio.Event()
archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
async def _fake_consolidate(messages) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
return True
started.set()
await release.wait()
sess.last_consolidated = len(sess.messages) - 3
archived_count = len(messages)
return True
loop._consolidate_memory = _fake_consolidate # type: ignore[method-assign]
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="hello")
await loop._process_message(msg)
await started.wait()
loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
pending_new = asyncio.create_task(loop._process_message(new_msg))
await asyncio.sleep(0.02)
assert not pending_new.done()
release.set()
response = await pending_new
response = await loop._process_message(new_msg)
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count == 3, (
f"Expected only unconsolidated tail to archive, got {archived_count}"
)
assert archived_count == 3
@pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
"""/new clears session and returns confirmation."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
async def _ok_consolidate(_messages) -> bool:
return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg)

View File

@@ -1,10 +1,12 @@
import asyncio
from types import SimpleNamespace
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.dingtalk import DingTalkChannel
from nanobot.config.schema import DingTalkConfig
import nanobot.channels.dingtalk as dingtalk_module
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
from nanobot.channels.dingtalk import DingTalkConfig
class _FakeResponse:
@@ -64,3 +66,46 @@ async def test_group_send_uses_group_messages_api() -> None:
assert call["url"] == "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
assert call["json"]["openConversationId"] == "conv123"
assert call["json"]["msgKey"] == "sampleMarkdown"
@pytest.mark.asyncio
async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatch) -> None:
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeChatbotMessage:
text = None
extensions = {"content": {"recognition": "voice transcript"}}
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "audio"
@staticmethod
def from_dict(_data):
return _FakeChatbotMessage()
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "2",
"conversationId": "conv123",
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert msg.content == "voice transcript"
assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123"

View File

@@ -6,7 +6,7 @@ import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.email import EmailChannel
from nanobot.config.schema import EmailConfig
from nanobot.channels.email import EmailConfig
def _make_config() -> EmailConfig:

View File

@@ -0,0 +1,253 @@
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
import pytest
from nanobot.agent.tools.filesystem import (
EditFileTool,
ListDirTool,
ReadFileTool,
_find_match,
)
# ---------------------------------------------------------------------------
# ReadFileTool
# ---------------------------------------------------------------------------
class TestReadFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def sample_file(self, tmp_path):
f = tmp_path / "sample.txt"
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
return f
@pytest.mark.asyncio
async def test_basic_read_has_line_numbers(self, tool, sample_file):
result = await tool.execute(path=str(sample_file))
assert "1| line 1" in result
assert "20| line 20" in result
@pytest.mark.asyncio
async def test_offset_and_limit(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
assert "5| line 5" in result
assert "7| line 7" in result
assert "8| line 8" not in result
assert "Use offset=8 to continue" in result
@pytest.mark.asyncio
async def test_offset_beyond_end(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=999)
assert "Error" in result
assert "beyond end" in result
@pytest.mark.asyncio
async def test_end_of_file_marker(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
assert "End of file" in result
@pytest.mark.asyncio
async def test_empty_file(self, tool, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("", encoding="utf-8")
result = await tool.execute(path=str(f))
assert "Empty file" in result
@pytest.mark.asyncio
async def test_file_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope.txt"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_char_budget_trims(self, tool, tmp_path):
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
f = tmp_path / "big.txt"
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
result = await tool.execute(path=str(f))
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
assert "Use offset=" in result
# ---------------------------------------------------------------------------
# _find_match (unit tests for the helper)
# ---------------------------------------------------------------------------
class TestFindMatch:
def test_exact_match(self):
match, count = _find_match("hello world", "world")
assert match == "world"
assert count == 1
def test_exact_no_match(self):
match, count = _find_match("hello world", "xyz")
assert match is None
assert count == 0
def test_crlf_normalisation(self):
# Caller normalises CRLF before calling _find_match, so test with
# pre-normalised content to verify exact match still works.
content = "line1\nline2\nline3"
old_text = "line1\nline2\nline3"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
def test_line_trim_fallback(self):
content = " def foo():\n pass\n"
old_text = "def foo():\n pass"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
# The returned match should be the *original* indented text
assert " def foo():" in match
def test_line_trim_multiple_candidates(self):
content = " a\n b\n a\n b\n"
old_text = "a\nb"
match, count = _find_match(content, old_text)
assert count == 2
def test_empty_old_text(self):
match, count = _find_match("hello", "")
# Empty string is always "in" any string via exact match
assert match == ""
# ---------------------------------------------------------------------------
# EditFileTool
# ---------------------------------------------------------------------------
class TestEditFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_exact_match(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
assert f.read_text() == "hello earth"
@pytest.mark.asyncio
async def test_crlf_normalisation(self, tool, tmp_path):
f = tmp_path / "crlf.py"
f.write_bytes(b"line1\r\nline2\r\nline3")
result = await tool.execute(
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
)
assert "Successfully" in result
raw = f.read_bytes()
assert b"LINE1" in raw
# CRLF line endings should be preserved throughout the file
assert b"\r\n" in raw
@pytest.mark.asyncio
async def test_trim_fallback(self, tool, tmp_path):
f = tmp_path / "indent.py"
f.write_text(" def foo():\n pass\n", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
)
assert "Successfully" in result
assert "bar" in f.read_text()
@pytest.mark.asyncio
async def test_ambiguous_match(self, tool, tmp_path):
f = tmp_path / "dup.py"
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
assert "appears" in result.lower() or "Warning" in result
@pytest.mark.asyncio
async def test_replace_all(self, tool, tmp_path):
f = tmp_path / "multi.py"
f.write_text("foo bar foo bar foo", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="foo", new_text="baz", replace_all=True,
)
assert "Successfully" in result
assert f.read_text() == "baz bar baz bar baz"
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
f = tmp_path / "nf.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
assert "Error" in result
assert "not found" in result
# ---------------------------------------------------------------------------
# ListDirTool
# ---------------------------------------------------------------------------
class TestListDirTool:
@pytest.fixture()
def tool(self, tmp_path):
return ListDirTool(workspace=tmp_path)
@pytest.fixture()
def populated_dir(self, tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("pass")
(tmp_path / "src" / "utils.py").write_text("pass")
(tmp_path / "README.md").write_text("hi")
(tmp_path / ".git").mkdir()
(tmp_path / ".git" / "config").write_text("x")
(tmp_path / "node_modules").mkdir()
(tmp_path / "node_modules" / "pkg").mkdir()
return tmp_path
@pytest.mark.asyncio
async def test_basic_list(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir))
assert "README.md" in result
assert "src" in result
# .git and node_modules should be ignored
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_recursive(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir), recursive=True)
# Normalize path separators for cross-platform compatibility
normalized = result.replace("\\", "/")
assert "src/main.py" in normalized
assert "src/utils.py" in normalized
assert "README.md" in result
# Ignored dirs should not appear
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_max_entries_truncation(self, tool, tmp_path):
for i in range(10):
(tmp_path / f"file_{i}.txt").write_text("x")
result = await tool.execute(path=str(tmp_path), max_entries=3)
assert "truncated" in result
assert "3 of 10" in result
@pytest.mark.asyncio
async def test_empty_dir(self, tool, tmp_path):
d = tmp_path / "empty"
d.mkdir()
result = await tool.execute(path=str(d))
assert "empty" in result.lower()
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope"))
assert "Error" in result
assert "not found" in result

View File

@@ -0,0 +1,53 @@
from types import SimpleNamespace
from nanobot.providers.base import ToolCallRequest
from nanobot.providers.litellm_provider import LiteLLMProvider
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="tool_calls",
message=SimpleNamespace(
content=None,
tool_calls=[
SimpleNamespace(
id="call_123",
function=SimpleNamespace(
name="read_file",
arguments='{"path":"todo.md"}',
provider_specific_fields={"inner": "value"},
),
provider_specific_fields={"thought_signature": "signed-token"},
)
],
),
)
],
usage=None,
)
parsed = provider._parse_response(response)
assert len(parsed.tool_calls) == 1
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
def test_tool_call_request_serializes_provider_fields() -> None:
tool_call = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"thought_signature": "signed-token"},
function_provider_specific_fields={"inner": "value"},
)
message = tool_call.to_openai_tool_call()
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
assert message["function"]["arguments"] == '{"path": "todo.md"}'

View File

@@ -3,18 +3,24 @@ import asyncio
import pytest
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider:
class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None:
@@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
)
assert await service.trigger_now() is None
@pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "run"
assert tasks == "check open tasks"
assert provider.calls == 2
assert delays == [1]

View File

@@ -0,0 +1,190 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
import nanobot.agent.memory as memory_module
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) -> AgentLoop:
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=context_window_tokens,
)
loop.tools.get_definitions = MagicMock(return_value=[])
return loop
@pytest.mark.asyncio
async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None:
loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
loop.memory_consolidator.consolidate_messages.assert_not_awaited()
@pytest.mark.asyncio
async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _message: 500)
await loop.process_direct("hello", session_key="cli:test")
assert loop.memory_consolidator.consolidate_messages.await_count >= 1
@pytest.mark.asyncio
async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None:
loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
]
loop.sessions.save(session)
token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120}
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]])
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0]
assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"]
assert session.last_consolidated == 4
@pytest.mark.asyncio
async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None:
"""Verify maybe_consolidate_by_tokens keeps looping until under threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (300, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None:
"""Once triggered, consolidation should continue until it drops below half threshold."""
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign]
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
{"role": "assistant", "content": "a2", "timestamp": "2026-01-01T00:00:03"},
{"role": "user", "content": "u3", "timestamp": "2026-01-01T00:00:04"},
{"role": "assistant", "content": "a3", "timestamp": "2026-01-01T00:00:05"},
{"role": "user", "content": "u4", "timestamp": "2026-01-01T00:00:06"},
]
loop.sessions.save(session)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
if call_count[0] == 1:
return (500, "test")
if call_count[0] == 2:
return (150, "test")
return (80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.memory_consolidator.maybe_consolidate_by_tokens(session)
assert loop.memory_consolidator.consolidate_messages.await_count == 2
assert session.last_consolidated == 6
@pytest.mark.asyncio
async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None:
"""Verify preflight consolidation runs before the LLM call in process_direct."""
order: list[str] = []
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages):
order.append("consolidate")
return True
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
async def track_llm(*args, **kwargs):
order.append("llm")
return LLMResponse(content="ok", tool_calls=[])
loop.provider.chat_with_retry = track_llm
session = loop.sessions.get_or_create("cli:test")
session.messages = [
{"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"},
{"role": "assistant", "content": "a1", "timestamp": "2026-01-01T00:00:01"},
{"role": "user", "content": "u2", "timestamp": "2026-01-01T00:00:02"},
]
loop.sessions.save(session)
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500)
call_count = [0]
def mock_estimate(_session):
call_count[0] += 1
return (1000 if call_count[0] <= 1 else 80, "test")
loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
await loop.process_direct("hello", session_key="cli:test")
assert "consolidate" in order
assert "llm" in order
assert order.index("consolidate") < order.index("llm")

View File

@@ -5,7 +5,7 @@ from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = 500
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
return loop
@@ -39,3 +39,17 @@ def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
skip=0,
)
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
def test_save_turn_keeps_tool_results_under_16k() -> None:
loop = _mk_loop()
session = Session(key="test:tool-result")
content = "x" * 12_000
loop._save_turn(
session,
[{"role": "tool", "tool_call_id": "call_1", "name": "read_file", "content": content}],
skip=0,
)
assert session.messages[0]["content"] == content

View File

@@ -12,7 +12,7 @@ from nanobot.channels.matrix import (
TYPING_NOTICE_TIMEOUT_MS,
MatrixChannel,
)
from nanobot.config.schema import MatrixConfig
from nanobot.channels.matrix import MatrixConfig
_ROOM_SEND_UNSET = object()

View File

@@ -7,23 +7,20 @@ tool call response, it should serialize them to JSON instead of raising TypeErro
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
import pytest
from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50):
"""Create a mock session with messages."""
session = MagicMock()
session.messages = [
def _make_messages(message_count: int = 30):
"""Create a list of mock messages."""
return [
{"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"}
for i in range(message_count)
]
session.last_consolidated = 0
return session
def _make_tool_response(history_entry, memory_update):
@@ -43,6 +40,22 @@ def _make_tool_response(history_entry, memory_update):
)
class ScriptedProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly."""
@@ -57,9 +70,10 @@ class TestMemoryConsolidationTypeHandling:
memory_update="# Memory\nUser likes testing.",
)
)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
@@ -77,9 +91,10 @@ class TestMemoryConsolidationTypeHandling:
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
)
)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert store.history_file.exists()
@@ -97,7 +112,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a JSON string (not yet parsed)
response = LLMResponse(
content=None,
tool_calls=[
@@ -112,9 +126,10 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@@ -127,21 +142,23 @@ class TestMemoryConsolidationTypeHandling:
provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_skips_when_few_messages(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when messages < keep_count."""
async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None:
"""Consolidation should be a no-op when the selected chunk is empty."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
session = _make_session(message_count=10)
provider.chat_with_retry = provider.chat
messages: list[dict] = []
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat.assert_not_called()
@@ -152,7 +169,6 @@ class TestMemoryConsolidationTypeHandling:
store = MemoryStore(tmp_path)
provider = AsyncMock()
# Simulate arguments being a list containing a dict
response = LLMResponse(
content=None,
tool_calls=[
@@ -167,9 +183,10 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert "User discussed testing." in store.history_file.read_text()
@@ -192,9 +209,10 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@@ -215,8 +233,246 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
session = _make_session(message_count=60)
provider.chat_with_retry = provider.chat
messages = _make_messages(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
@pytest.mark.asyncio
async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not persist partial results when required fields are missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"memory_update": "# Memory\nOnly memory update"},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Do not append history if memory_update is missing."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(
content=None,
tool_calls=[
ToolCallRequest(
id="call_1",
name="save_memory",
arguments={"history_entry": "[2026-01-01] Partial output."},
)
],
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Null required fields should be rejected before persistence."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=None,
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None:
"""Empty history entries should be rejected to avoid blank archival records."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry=" ",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path)
provider = ScriptedProvider([
LLMResponse(content="503 server error", finish_reason="error"),
_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
),
])
messages = _make_messages(message_count=60)
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
"""Consolidation no longer passes generation params — the provider owns them."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat_with_retry.assert_awaited_once()
_, kwargs = provider.chat_with_retry.await_args
assert kwargs["model"] == "test-model"
assert "temperature" not in kwargs
assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs
@pytest.mark.asyncio
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error calling LLM: litellm.BadRequestError: "
"The tool_choice parameter does not support being set to required or object",
finish_reason="error",
tool_calls=[],
)
ok_resp = _make_tool_response(
history_entry="[2026-01-01] Fallback worked.",
memory_update="# Memory\nFallback OK.",
)
call_log: list[dict] = []
async def _tracking_chat(**kwargs):
call_log.append(kwargs)
return error_resp if len(call_log) == 1 else ok_resp
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert len(call_log) == 2
assert isinstance(call_log[0]["tool_choice"], dict)
assert call_log[1]["tool_choice"] == "auto"
assert "Fallback worked." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
"""Forced rejected, auto retry also produces no tool call -> return False."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error: tool_choice must be none or auto",
finish_reason="error",
tool_calls=[],
)
no_tool_resp = LLMResponse(
content="Here is a summary.",
finish_reason="stop",
tool_calls=[],
)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None:
"""After 3 consecutive failures, raw-archive messages and return True."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[])
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
messages = _make_messages(message_count=10)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is True
assert store.history_file.exists()
content = store.history_file.read_text()
assert "[RAW]" in content
assert "10 messages" in content
assert "msg0" in content
assert not store.memory_file.exists()
@pytest.mark.asyncio
async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None:
"""A successful consolidation resets the failure counter."""
store = MemoryStore(tmp_path)
no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[])
ok_resp = _make_tool_response(
history_entry="[2026-01-01] OK.",
memory_update="# Memory\nOK.",
)
messages = _make_messages(message_count=10)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 2
provider.chat_with_retry = AsyncMock(return_value=ok_resp)
assert await store.consolidate(messages, provider, "m") is True
assert store._consecutive_failures == 0
provider.chat_with_retry = AsyncMock(return_value=no_tool)
assert await store.consolidate(messages, provider, "m") is False
assert store._consecutive_failures == 1

View File

@@ -16,7 +16,7 @@ def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
class TestMessageToolSuppressLogic:
@@ -33,7 +33,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
@@ -58,7 +58,7 @@ class TestMessageToolSuppressLogic:
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
@@ -77,7 +77,7 @@ class TestMessageToolSuppressLogic:
@pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
@@ -98,7 +98,7 @@ class TestMessageToolSuppressLogic:
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")

View File

@@ -0,0 +1,125 @@
import asyncio
import pytest
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider):
def __init__(self, responses):
super().__init__()
self._responses = list(responses)
self.calls = 0
self.last_kwargs: dict = {}
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
self.last_kwargs = kwargs
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "stop"
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "401 unauthorized"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit a", finish_reason="error"),
LLMResponse(content="429 rate limit b", finish_reason="error"),
LLMResponse(content="429 rate limit c", finish_reason="error"),
LLMResponse(content="503 final server error", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "503 final server error"
assert provider.calls == 4
assert delays == [1, 2, 4]
@pytest.mark.asyncio
async def test_chat_with_retry_preserves_cancelled_error() -> None:
provider = ScriptedProvider([asyncio.CancelledError()])
with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
@pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert provider.last_kwargs["temperature"] == 0.2
assert provider.last_kwargs["max_tokens"] == 321
assert provider.last_kwargs["reasoning_effort"] == "high"
@pytest.mark.asyncio
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
"""Explicit kwargs should override provider.generation defaults."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
temperature=0.9,
max_tokens=9999,
reasoning_effort="low",
)
assert provider.last_kwargs["temperature"] == 0.9
assert provider.last_kwargs["max_tokens"] == 9999
assert provider.last_kwargs["reasoning_effort"] == "low"

View File

@@ -5,7 +5,7 @@ import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel
from nanobot.config.schema import QQConfig
from nanobot.channels.qq import QQConfig
class _FakeApi:
@@ -44,7 +44,7 @@ async def test_on_group_message_routes_to_group_chat_id() -> None:
@pytest.mark.asyncio
async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
async def test_send_group_message_uses_plain_text_group_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
@@ -60,7 +60,66 @@ async def test_send_group_message_uses_group_api_with_msg_seq() -> None:
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call["group_openid"] == "group123"
assert call["msg_id"] == "msg1"
assert call["msg_seq"] == 2
assert call == {
"group_openid": "group123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.c2c_calls
@pytest.mark.asyncio
async def test_send_c2c_message_uses_plain_text_c2c_api_with_msg_seq() -> None:
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user123",
content="hello",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.c2c_calls) == 1
call = channel._client.api.c2c_calls[0]
assert call == {
"openid": "user123",
"msg_type": 0,
"content": "hello",
"msg_id": "msg1",
"msg_seq": 2,
}
assert not channel._client.api.group_calls
@pytest.mark.asyncio
async def test_send_group_message_uses_markdown_when_configured() -> None:
channel = QQChannel(
QQConfig(app_id="app", secret="secret", allow_from=["*"], msg_format="markdown"),
MessageBus(),
)
channel._client = _FakeClient()
channel._chat_type_cache["group123"] = "group"
await channel.send(
OutboundMessage(
channel="qq",
chat_id="group123",
content="**hello**",
metadata={"message_id": "msg1"},
)
)
assert len(channel._client.api.group_calls) == 1
call = channel._client.api.group_calls[0]
assert call == {
"group_openid": "group123",
"msg_type": 2,
"markdown": {"content": "**hello**"},
"msg_id": "msg1",
"msg_seq": 2,
}

View File

@@ -0,0 +1,76 @@
"""Tests for /restart slash command."""
from __future__ import annotations
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from nanobot.bus.events import InboundMessage
def _make_loop():
"""Create a minimal AgentLoop with mocked dependencies."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
workspace = MagicMock()
workspace.__truediv__ = MagicMock(return_value=MagicMock())
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager"):
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
return loop, bus
class TestRestartCommand:
@pytest.mark.asyncio
async def test_restart_sends_message_and_calls_execv(self):
loop, bus = _make_loop()
msg = InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="/restart")
with patch("nanobot.agent.loop.os.execv") as mock_execv:
await loop._handle_restart(msg)
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert "Restarting" in out.content
await asyncio.sleep(1.5)
mock_execv.assert_called_once()
@pytest.mark.asyncio
async def test_restart_intercepted_in_run_loop(self):
"""Verify /restart is handled at the run-loop level, not inside _dispatch."""
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/restart")
with patch.object(loop, "_handle_restart") as mock_handle:
mock_handle.return_value = None
await bus.publish_inbound(msg)
loop._running = True
run_task = asyncio.create_task(loop.run())
await asyncio.sleep(0.1)
loop._running = False
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
mock_handle.assert_called_once()
@pytest.mark.asyncio
async def test_help_includes_restart(self):
loop, bus = _make_loop()
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/help")
response = await loop._process_message(msg)
assert response is not None
assert "/restart" in response.content

View File

@@ -0,0 +1,127 @@
import importlib
import shutil
import sys
import zipfile
from pathlib import Path
SCRIPT_DIR = Path("nanobot/skills/skill-creator/scripts").resolve()
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
init_skill = importlib.import_module("init_skill")
package_skill = importlib.import_module("package_skill")
quick_validate = importlib.import_module("quick_validate")
def test_init_skill_creates_expected_files(tmp_path: Path) -> None:
skill_dir = init_skill.init_skill(
"demo-skill",
tmp_path,
["scripts", "references", "assets"],
include_examples=True,
)
assert skill_dir == tmp_path / "demo-skill"
assert (skill_dir / "SKILL.md").exists()
assert (skill_dir / "scripts" / "example.py").exists()
assert (skill_dir / "references" / "api_reference.md").exists()
assert (skill_dir / "assets" / "example_asset.txt").exists()
def test_validate_skill_accepts_existing_skill_creator() -> None:
valid, message = quick_validate.validate_skill(
Path("nanobot/skills/skill-creator").resolve()
)
assert valid, message
def test_validate_skill_rejects_placeholder_description(tmp_path: Path) -> None:
skill_dir = tmp_path / "placeholder-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: placeholder-skill\n"
'description: "[TODO: fill me in]"\n'
"---\n"
"# Placeholder\n",
encoding="utf-8",
)
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "TODO placeholder" in message
def test_validate_skill_rejects_root_files_outside_allowed_dirs(tmp_path: Path) -> None:
skill_dir = tmp_path / "bad-root-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: bad-root-skill\n"
"description: Valid description\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
(skill_dir / "README.md").write_text("extra\n", encoding="utf-8")
valid, message = quick_validate.validate_skill(skill_dir)
assert not valid
assert "Unexpected file or directory in skill root" in message
def test_package_skill_creates_archive(tmp_path: Path) -> None:
skill_dir = tmp_path / "package-me"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: package-me\n"
"description: Package this skill.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
(scripts_dir / "helper.py").write_text("print('ok')\n", encoding="utf-8")
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path == (tmp_path / "dist" / "package-me.skill")
assert archive_path.exists()
with zipfile.ZipFile(archive_path, "r") as archive:
names = set(archive.namelist())
assert "package-me/SKILL.md" in names
assert "package-me/scripts/helper.py" in names
def test_package_skill_rejects_symlink(tmp_path: Path) -> None:
skill_dir = tmp_path / "symlink-skill"
skill_dir.mkdir()
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: symlink-skill\n"
"description: Reject symlinks during packaging.\n"
"---\n"
"# Skill\n",
encoding="utf-8",
)
scripts_dir = skill_dir / "scripts"
scripts_dir.mkdir()
target = tmp_path / "outside.txt"
target.write_text("secret\n", encoding="utf-8")
link = scripts_dir / "outside.txt"
try:
link.symlink_to(target)
except (OSError, NotImplementedError):
return
archive_path = package_skill.package_skill(skill_dir, tmp_path / "dist")
assert archive_path is None
assert not (tmp_path / "dist" / "symlink-skill.skill").exists()

View File

@@ -5,7 +5,7 @@ import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.slack import SlackChannel
from nanobot.config.schema import SlackConfig
from nanobot.channels.slack import SlackConfig
class _FakeAsyncWebClient:

View File

@@ -165,3 +165,46 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio
async def test_subagent_preserves_reasoning_fields_in_tool_turn(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def scripted_chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
async def fake_execute(self, name, arguments):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
assistant_messages = [
msg for msg in captured_second_call
if msg.get("role") == "assistant" and msg.get("tool_calls")
]
assert len(assistant_messages) == 1
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]

View File

@@ -1,11 +1,14 @@
import asyncio
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.telegram import TelegramChannel
from nanobot.config.schema import TelegramConfig
from nanobot.channels.telegram import TELEGRAM_REPLY_CONTEXT_MAX_LEN, TelegramChannel
from nanobot.channels.telegram import TelegramConfig
class _FakeHTTPXRequest:
@@ -27,10 +30,11 @@ class _FakeUpdater:
class _FakeBot:
def __init__(self) -> None:
self.sent_messages: list[dict] = []
self.file = None
self.get_me_calls = 0
async def get_me(self):
return SimpleNamespace(username="nanobot_test")
self.get_me_calls += 1
return SimpleNamespace(id=999, username="nanobot_test")
async def set_my_commands(self, commands) -> None:
self.commands = commands
@@ -38,8 +42,14 @@ class _FakeBot:
async def send_message(self, **kwargs) -> None:
self.sent_messages.append(kwargs)
async def get_file(self, _file_id):
return self.file
async def send_chat_action(self, **kwargs) -> None:
pass
async def get_file(self, file_id: str):
"""Return a fake file that 'downloads' to a path (for reply-to-media tests)."""
async def _fake_download(path) -> None:
pass
return SimpleNamespace(download_to_drive=_fake_download)
class _FakeApp:
@@ -91,6 +101,35 @@ class _FakeBuilder:
return self.app
def _make_telegram_update(
*,
chat_type: str = "group",
text: str | None = None,
caption: str | None = None,
entities=None,
caption_entities=None,
reply_to_message=None,
):
user = SimpleNamespace(id=12345, username="alice", first_name="Alice")
message = SimpleNamespace(
chat=SimpleNamespace(type=chat_type, is_forum=False),
chat_id=-100123,
text=text,
caption=caption,
entities=entities or [],
caption_entities=caption_entities or [],
reply_to_message=reply_to_message,
photo=None,
voice=None,
audio=None,
document=None,
media_group_id=None,
message_thread_id=None,
message_id=1,
)
return SimpleNamespace(message=message, effective_user=user)
@pytest.mark.asyncio
async def test_start_uses_request_proxy_without_builder_proxy(monkeypatch) -> None:
config = TelegramConfig(
@@ -135,6 +174,10 @@ def test_get_extension_falls_back_to_original_filename() -> None:
assert channel._get_extension("file", None, "archive.tar.gz") == ".tar.gz"
def test_telegram_group_policy_defaults_to_mention() -> None:
assert TelegramConfig().group_policy == "mention"
def test_is_allowed_accepts_legacy_telegram_id_username_formats() -> None:
channel = TelegramChannel(TelegramConfig(allow_from=["12345", "alice", "67890|bob"]), MessageBus())
@@ -189,53 +232,418 @@ async def test_send_reply_infers_topic_from_message_id_cache() -> None:
@pytest.mark.asyncio
async def test_on_message_uses_file_unique_id_for_downloaded_media(monkeypatch, tmp_path) -> None:
config = TelegramConfig(enabled=True, token="123:abc", allow_from=["*"])
channel = TelegramChannel(config, MessageBus())
async def test_group_policy_mention_ignores_unmentioned_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello everyone"), None)
assert handled == []
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_text_mention_and_caches_bot_identity() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(_make_telegram_update(text="@nanobot_test hi", entities=[mention]), None)
await channel._on_message(_make_telegram_update(text="@nanobot_test again", entities=[mention]), None)
assert len(handled) == 2
assert channel._app.bot.get_me_calls == 1
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_caption_mention() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
mention = SimpleNamespace(type="mention", offset=0, length=13)
await channel._on_message(
_make_telegram_update(caption="@nanobot_test photo", caption_entities=[mention]),
None,
)
assert len(handled) == 1
assert handled[0]["content"] == "@nanobot_test photo"
@pytest.mark.asyncio
async def test_group_policy_mention_accepts_reply_to_bot() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(from_user=SimpleNamespace(id=999))
await channel._on_message(_make_telegram_update(text="reply", reply_to_message=reply), None)
assert len(handled) == 1
@pytest.mark.asyncio
async def test_group_policy_open_accepts_plain_group_message() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
await channel._on_message(_make_telegram_update(text="hello group"), None)
assert len(handled) == 1
assert channel._app.bot.get_me_calls == 0
def test_extract_reply_context_no_reply() -> None:
"""When there is no reply_to_message, _extract_reply_context returns None."""
message = SimpleNamespace(reply_to_message=None)
assert TelegramChannel._extract_reply_context(message) is None
def test_extract_reply_context_with_text() -> None:
"""When reply has text, return prefixed string."""
reply = SimpleNamespace(text="Hello world", caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]"
def test_extract_reply_context_with_caption_only() -> None:
"""When reply has only caption (no text), caption is used."""
reply = SimpleNamespace(text=None, caption="Photo caption")
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]"
def test_extract_reply_context_truncation() -> None:
"""Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN."""
long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100)
reply = SimpleNamespace(text=long_text, caption=None)
message = SimpleNamespace(reply_to_message=reply)
result = TelegramChannel._extract_reply_context(message)
assert result is not None
assert result.startswith("[Reply to: ")
assert result.endswith("...]")
assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...")
def test_extract_reply_context_no_text_returns_none() -> None:
"""When reply has no text/caption, _extract_reply_context returns None (media handled separately)."""
reply = SimpleNamespace(text=None, caption=None)
message = SimpleNamespace(reply_to_message=reply)
assert TelegramChannel._extract_reply_context(message) is None
@pytest.mark.asyncio
async def test_on_message_includes_reply_context() -> None:
"""When user replies to a message, content passed to bus starts with reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply = SimpleNamespace(text="Hello", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="translate this", reply_to_message=reply)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: Hello]")
assert "translate this" in handled[0]["content"]
@pytest.mark.asyncio
async def test_download_message_media_returns_path_when_download_succeeds(
monkeypatch, tmp_path
) -> None:
"""_download_message_media returns (paths, content_parts) when bot.get_file and download succeed."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
msg = SimpleNamespace(
photo=[SimpleNamespace(file_id="fid123", mime_type="image/jpeg")],
voice=None,
audio=None,
document=None,
video=None,
video_note=None,
animation=None,
)
paths, parts = await channel._download_message_media(msg)
assert len(paths) == 1
assert len(parts) == 1
assert "fid123" in paths[0]
assert "[image:" in parts[0]
@pytest.mark.asyncio
async def test_download_message_media_uses_file_unique_id_when_available(
monkeypatch, tmp_path
) -> None:
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
downloaded: dict[str, str] = {}
class _FakeDownloadedFile:
async def download_to_drive(self, path: str) -> None:
downloaded["path"] = path
async def _download_to_drive(path: str) -> None:
downloaded["path"] = path
channel._app.bot.file = _FakeDownloadedFile()
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=_download_to_drive)
)
channel._app = app
captured: dict[str, object] = {}
async def _capture_message(**kwargs) -> None:
captured.update(kwargs)
monkeypatch.setattr(channel, "_handle_message", _capture_message)
monkeypatch.setattr(channel, "_start_typing", lambda _chat_id: None)
monkeypatch.setattr("nanobot.channels.telegram.get_media_dir", lambda _name=None: tmp_path)
update = SimpleNamespace(
effective_user=SimpleNamespace(id=123, username="alice", first_name="Alice"),
message=SimpleNamespace(
message_id=1,
chat=SimpleNamespace(type="private", is_forum=False),
chat_id=456,
text=None,
caption=None,
photo=[
SimpleNamespace(
file_id="file-id-that-should-not-be-used",
file_unique_id="stable-unique-id",
mime_type="image/jpeg",
file_name=None,
)
],
voice=None,
audio=None,
document=None,
media_group_id=None,
message_thread_id=None,
),
msg = SimpleNamespace(
photo=[
SimpleNamespace(
file_id="file-id-that-should-not-be-used",
file_unique_id="stable-unique-id",
mime_type="image/jpeg",
file_name=None,
)
],
voice=None,
audio=None,
document=None,
video=None,
video_note=None,
animation=None,
)
await channel._on_message(update, None)
paths, parts = await channel._download_message_media(msg)
assert downloaded["path"].endswith("stable-unique-id.jpg")
assert captured["media"] == [str(tmp_path / "stable-unique-id.jpg")]
assert paths == [str(media_dir / "stable-unique-id.jpg")]
assert parts == [f"[image: {media_dir / 'stable-unique-id.jpg'}]"]
@pytest.mark.asyncio
async def test_on_message_attaches_reply_to_media_when_available(monkeypatch, tmp_path) -> None:
"""When user replies to a message with media, that media is downloaded and attached to the turn."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="reply_photo_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what is the image?",
reply_to_message=reply_with_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert handled[0]["content"].startswith("[Reply to: [image:")
assert "what is the image?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "reply_photo_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_on_message_reply_to_media_fallback_when_download_fails() -> None:
"""When reply has media but download fails, no media attached and no reply tag."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.get_file = None
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_photo = SimpleNamespace(
text=None,
caption=None,
photo=[SimpleNamespace(file_id="x", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(text="what is this?", reply_to_message=reply_with_photo)
await channel._on_message(update, None)
assert len(handled) == 1
assert "what is this?" in handled[0]["content"]
assert handled[0]["media"] == []
@pytest.mark.asyncio
async def test_on_message_reply_to_caption_and_media(monkeypatch, tmp_path) -> None:
"""When replying to a message with caption + photo, both text context and media are included."""
media_dir = tmp_path / "media" / "telegram"
media_dir.mkdir(parents=True)
monkeypatch.setattr(
"nanobot.channels.telegram.get_media_dir",
lambda channel=None: media_dir if channel else tmp_path / "media",
)
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
app = _FakeApp(lambda: None)
app.bot.get_file = AsyncMock(
return_value=SimpleNamespace(download_to_drive=AsyncMock(return_value=None))
)
channel._app = app
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._start_typing = lambda _chat_id: None
reply_with_caption_and_photo = SimpleNamespace(
text=None,
caption="A cute cat",
photo=[SimpleNamespace(file_id="cat_fid", mime_type="image/jpeg")],
document=None,
voice=None,
audio=None,
video=None,
video_note=None,
animation=None,
)
update = _make_telegram_update(
text="what breed is this?",
reply_to_message=reply_with_caption_and_photo,
)
await channel._on_message(update, None)
assert len(handled) == 1
assert "[Reply to: A cute cat]" in handled[0]["content"]
assert "what breed is this?" in handled[0]["content"]
assert len(handled[0]["media"]) == 1
assert "cat_fid" in handled[0]["media"][0]
@pytest.mark.asyncio
async def test_forward_command_does_not_inject_reply_context() -> None:
"""Slash commands forwarded via _forward_command must not include reply context."""
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
handled = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
reply = SimpleNamespace(text="some old message", message_id=2, from_user=SimpleNamespace(id=1))
update = _make_telegram_update(text="/new", reply_to_message=reply)
await channel._forward_command(update, None)
assert len(handled) == 1
assert handled[0]["content"] == "/new"

View File

@@ -108,6 +108,32 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
assert "/tmp/out.txt" in paths
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "~/.nanobot/config.json" in paths
assert "~/out.txt" in paths
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "~/.nanobot/config.json" in paths
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests ---
@@ -337,3 +363,46 @@ def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]
# --- ExecTool enhancement tests ---
async def test_exec_always_returns_exit_code() -> None:
"""Exit code should appear in output even on success (exit 0)."""
tool = ExecTool()
result = await tool.execute(command="echo hello")
assert "Exit code: 0" in result
assert "hello" in result
async def test_exec_head_tail_truncation() -> None:
"""Long output should preserve both head and tail."""
tool = ExecTool()
# Generate output that exceeds _MAX_OUTPUT (10_000 chars)
# Use python to generate output to avoid command line length limits
result = await tool.execute(
command="python -c \"print('A' * 6000 + '\\n' + 'B' * 6000)\""
)
assert "chars truncated" in result
# Head portion should start with As
assert result.startswith("A")
# Tail portion should end with the exit code which comes after Bs
assert "Exit code:" in result
async def test_exec_timeout_parameter() -> None:
"""LLM-supplied timeout should override the constructor default."""
tool = ExecTool(timeout=60)
# A very short timeout should cause the command to be killed
result = await tool.execute(command="sleep 10", timeout=1)
assert "timed out" in result
assert "1 seconds" in result
async def test_exec_timeout_capped_at_max() -> None:
"""Timeout values above _MAX_TIMEOUT should be clamped."""
tool = ExecTool()
# Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result

View File

@@ -0,0 +1,162 @@
"""Tests for multi-provider web search."""
import httpx
import pytest
from nanobot.agent.tools.web import WebSearchTool
from nanobot.config.schema import WebSearchConfig
def _tool(provider: str = "brave", api_key: str = "", base_url: str = "") -> WebSearchTool:
return WebSearchTool(config=WebSearchConfig(provider=provider, api_key=api_key, base_url=base_url))
def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
"""Build a mock httpx.Response with a dummy request attached."""
r = httpx.Response(status, json=json)
r._request = httpx.Request("GET", "https://mock")
return r
@pytest.mark.asyncio
async def test_brave_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "brave" in url
assert kw["headers"]["X-Subscription-Token"] == "brave-key"
return _response(json={
"web": {"results": [{"title": "NanoBot", "url": "https://example.com", "description": "AI assistant"}]}
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="brave", api_key="brave-key")
result = await tool.execute(query="nanobot", count=1)
assert "NanoBot" in result
assert "https://example.com" in result
@pytest.mark.asyncio
async def test_tavily_search(monkeypatch):
async def mock_post(self, url, **kw):
assert "tavily" in url
assert kw["headers"]["Authorization"] == "Bearer tavily-key"
return _response(json={
"results": [{"title": "OpenClaw", "url": "https://openclaw.io", "content": "Framework"}]
})
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
tool = _tool(provider="tavily", api_key="tavily-key")
result = await tool.execute(query="openclaw")
assert "OpenClaw" in result
assert "https://openclaw.io" in result
@pytest.mark.asyncio
async def test_searxng_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "searx.example" in url
return _response(json={
"results": [{"title": "Result", "url": "https://example.com", "content": "SearXNG result"}]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="searxng", base_url="https://searx.example")
result = await tool.execute(query="test")
assert "Result" in result
@pytest.mark.asyncio
async def test_duckduckgo_search(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "DDG Result", "href": "https://ddg.example", "body": "From DuckDuckGo"}]
monkeypatch.setattr("nanobot.agent.tools.web.DDGS", MockDDGS, raising=False)
import nanobot.agent.tools.web as web_mod
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
from ddgs import DDGS
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
tool = _tool(provider="duckduckgo")
result = await tool.execute(query="hello")
assert "DDG Result" in result
@pytest.mark.asyncio
async def test_brave_fallback_to_duckduckgo_when_no_key(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
tool = _tool(provider="brave", api_key="")
result = await tool.execute(query="test")
assert "Fallback" in result
@pytest.mark.asyncio
async def test_jina_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "s.jina.ai" in str(url)
assert kw["headers"]["Authorization"] == "Bearer jina-key"
return _response(json={
"data": [{"title": "Jina Result", "url": "https://jina.ai", "content": "AI search"}]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="jina", api_key="jina-key")
result = await tool.execute(query="test")
assert "Jina Result" in result
assert "https://jina.ai" in result
@pytest.mark.asyncio
async def test_unknown_provider():
tool = _tool(provider="unknown")
result = await tool.execute(query="test")
assert "unknown" in result
assert "Error" in result
@pytest.mark.asyncio
async def test_default_provider_is_brave(monkeypatch):
async def mock_get(self, url, **kw):
assert "brave" in url
return _response(json={"web": {"results": []}})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="", api_key="test-key")
result = await tool.execute(query="test")
assert "No results" in result
@pytest.mark.asyncio
async def test_searxng_no_base_url_falls_back(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "fallback"}]
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
monkeypatch.delenv("SEARXNG_BASE_URL", raising=False)
tool = _tool(provider="searxng", base_url="")
result = await tool.execute(query="test")
assert "Fallback" in result
@pytest.mark.asyncio
async def test_searxng_invalid_url():
tool = _tool(provider="searxng", base_url="not-a-url")
result = await tool.execute(query="test")
assert "Error" in result