Merge remote-tracking branch 'origin/main'
# Conflicts: # nanobot/agent/context.py # nanobot/agent/loop.py # nanobot/agent/tools/web.py # nanobot/channels/telegram.py # nanobot/cli/commands.py # tests/test_commands.py # tests/test_config_migration.py # tests/test_telegram_channel.py
This commit is contained in:
@@ -111,3 +111,33 @@ async def test_print_interactive_progress_line_pauses_spinner_before_printing():
|
||||
await commands._print_interactive_progress_line("tool running", thinking)
|
||||
|
||||
assert order == ["start", "stop", "print", "start", "stop"]
|
||||
|
||||
|
||||
def test_response_renderable_uses_text_for_explicit_plain_rendering():
|
||||
status = (
|
||||
"🐈 nanobot v0.1.4.post5\n"
|
||||
"🧠 Model: MiniMax-M2.7\n"
|
||||
"📊 Tokens: 20639 in / 29 out"
|
||||
)
|
||||
|
||||
renderable = commands._response_renderable(
|
||||
status,
|
||||
render_markdown=True,
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
|
||||
assert renderable.__class__.__name__ == "Text"
|
||||
|
||||
|
||||
def test_response_renderable_preserves_normal_markdown_rendering():
|
||||
renderable = commands._response_renderable("**bold**", render_markdown=True)
|
||||
|
||||
assert renderable.__class__.__name__ == "Markdown"
|
||||
|
||||
|
||||
def test_response_renderable_without_metadata_keeps_markdown_path():
|
||||
help_text = "🐈 nanobot commands:\n/status — Show bot status\n/help — Show available commands"
|
||||
|
||||
renderable = commands._response_renderable(help_text, render_markdown=True)
|
||||
|
||||
assert renderable.__class__.__name__ == "Markdown"
|
||||
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.cli.commands import _make_provider, app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
@@ -117,7 +118,6 @@ def test_onboard_existing_workspace_safe_create(mock_paths):
|
||||
assert "Created AGENTS.md" in result.stdout
|
||||
assert (workspace_dir / "AGENTS.md").exists()
|
||||
|
||||
|
||||
def test_onboard_help_shows_workspace_and_config_options():
|
||||
result = runner.invoke(app, ["onboard", "--help"])
|
||||
|
||||
@@ -127,9 +127,28 @@ def test_onboard_help_shows_workspace_and_config_options():
|
||||
assert "-w" in stripped_output
|
||||
assert "--config" in stripped_output
|
||||
assert "-c" in stripped_output
|
||||
assert "--wizard" in stripped_output
|
||||
assert "--dir" not in stripped_output
|
||||
|
||||
|
||||
def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch):
|
||||
config_file, workspace_dir, _ = mock_paths
|
||||
|
||||
from nanobot.cli.onboard_wizard import OnboardResult
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.onboard_wizard.run_onboard",
|
||||
lambda initial_config: OnboardResult(config=initial_config, should_save=False),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["onboard", "--wizard"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No changes were saved" in result.stdout
|
||||
assert not config_file.exists()
|
||||
assert not workspace_dir.exists()
|
||||
|
||||
|
||||
def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "instance" / "config.json"
|
||||
workspace_path = tmp_path / "workspace"
|
||||
@@ -152,6 +171,31 @@ def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch)
|
||||
assert f"--config {resolved_config}" in compact_output
|
||||
|
||||
|
||||
def test_onboard_wizard_preserves_explicit_config_in_next_steps(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "instance" / "config.json"
|
||||
workspace_path = tmp_path / "workspace"
|
||||
|
||||
from nanobot.cli.onboard_wizard import OnboardResult
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.onboard_wizard.run_onboard",
|
||||
lambda initial_config: OnboardResult(config=initial_config, should_save=True),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {})
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["onboard", "--wizard", "--config", str(config_path), "--workspace", str(workspace_path)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
stripped_output = _strip_ansi(result.stdout)
|
||||
compact_output = stripped_output.replace("\n", "")
|
||||
resolved_config = str(config_path.resolve())
|
||||
assert f'nanobot agent -m "Hello!" --config {resolved_config}' in compact_output
|
||||
assert f"nanobot gateway --config {resolved_config}" in compact_output
|
||||
|
||||
|
||||
def test_config_matches_github_copilot_codex_with_hyphen_prefix():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "github-copilot/gpt-5.3-codex"
|
||||
@@ -166,6 +210,15 @@ def test_config_matches_openai_codex_with_hyphen_prefix():
|
||||
assert config.get_provider_name() == "openai_codex"
|
||||
|
||||
|
||||
def test_config_dump_excludes_oauth_provider_blocks():
|
||||
config = Config()
|
||||
|
||||
providers = config.model_dump(by_alias=True)["providers"]
|
||||
|
||||
assert "openaiCodex" not in providers
|
||||
assert "githubCopilot" not in providers
|
||||
|
||||
|
||||
def test_config_matches_explicit_ollama_prefix_without_api_key():
|
||||
config = Config()
|
||||
config.agents.defaults.model = "ollama/llama3.2"
|
||||
@@ -289,7 +342,9 @@ def mock_agent_runtime(tmp_path):
|
||||
|
||||
agent_loop = MagicMock()
|
||||
agent_loop.channels_config = None
|
||||
agent_loop.process_direct = AsyncMock(return_value="mock-response")
|
||||
agent_loop.process_direct = AsyncMock(
|
||||
return_value=OutboundMessage(channel="cli", chat_id="direct", content="mock-response"),
|
||||
)
|
||||
agent_loop.close_mcp = AsyncMock(return_value=None)
|
||||
mock_agent_loop_cls.return_value = agent_loop
|
||||
|
||||
@@ -325,7 +380,9 @@ def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_
|
||||
mock_agent_runtime["config"].workspace_path
|
||||
)
|
||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||
mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True)
|
||||
mock_agent_runtime["print_response"].assert_called_once_with(
|
||||
"mock-response", render_markdown=True, metadata={},
|
||||
)
|
||||
|
||||
|
||||
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
|
||||
@@ -361,8 +418,8 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs) -> str:
|
||||
return "ok"
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(channel="cli", chat_id="direct", content="ok")
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
@@ -404,14 +461,15 @@ def test_agent_workspace_override_wins_over_config_workspace(mock_agent_runtime,
|
||||
assert mock_agent_runtime["agent_loop_cls"].call_args.kwargs["workspace"] == workspace_path
|
||||
|
||||
|
||||
def test_agent_warns_about_deprecated_memory_window(mock_agent_runtime):
|
||||
mock_agent_runtime["config"].agents.defaults.memory_window = 100
|
||||
def test_agent_hints_about_deprecated_memory_window(mock_agent_runtime, tmp_path):
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text(json.dumps({"agents": {"defaults": {"memoryWindow": 42}}}))
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello"])
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
assert "no longer used" in result.stdout
|
||||
|
||||
|
||||
def test_agent_passes_web_search_config_to_agent_loop(mock_agent_runtime) -> None:
|
||||
@@ -492,10 +550,9 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
config_file.write_text(json.dumps({"agents": {"defaults": {"memoryWindow": 42}}}))
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.memory_window = 100
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
@@ -510,7 +567,6 @@ def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Pat
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
|
||||
@@ -10,7 +10,7 @@ from nanobot.config.loader import load_config, save_config
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path) -> None:
|
||||
def test_load_config_keeps_max_tokens_and_ignores_legacy_memory_window(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
@@ -30,7 +30,7 @@ def test_load_config_keeps_max_tokens_and_warns_on_legacy_memory_window(tmp_path
|
||||
|
||||
assert config.agents.defaults.max_tokens == 1234
|
||||
assert config.agents.defaults.context_window_tokens == 65_536
|
||||
assert config.agents.defaults.should_warn_deprecated_memory_window is True
|
||||
assert not hasattr(config.agents.defaults, "memory_window")
|
||||
|
||||
|
||||
def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path) -> None:
|
||||
@@ -59,7 +59,7 @@ def test_save_config_writes_context_window_tokens_but_not_memory_window(tmp_path
|
||||
assert "memoryWindow" not in defaults
|
||||
|
||||
|
||||
def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) -> None:
|
||||
def test_onboard_does_not_crash_with_legacy_memory_window(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
@@ -82,15 +82,11 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch)
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
defaults = saved["agents"]["defaults"]
|
||||
assert defaults["maxTokens"] == 3333
|
||||
assert defaults["contextWindowTokens"] == 65_536
|
||||
assert "memoryWindow" not in defaults
|
||||
|
||||
|
||||
def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) -> None:
|
||||
from types import SimpleNamespace
|
||||
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path.write_text(
|
||||
|
||||
@@ -182,7 +182,7 @@ class TestConsolidationTriggerConditions:
|
||||
"""Test consolidation trigger conditions and logic."""
|
||||
|
||||
def test_consolidation_needed_when_messages_exceed_window(self):
|
||||
"""Test consolidation logic: should trigger when messages > memory_window."""
|
||||
"""Test consolidation logic: should trigger when messages exceed the window."""
|
||||
session = create_session_with_messages("test:trigger", 60)
|
||||
|
||||
total_messages = len(session.messages)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -32,6 +33,87 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||
assert job.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_job_records_run_history(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
loaded = service.get_job(job.id)
|
||||
assert loaded is not None
|
||||
assert len(loaded.state.run_history) == 1
|
||||
rec = loaded.state.run_history[0]
|
||||
assert rec.status == "ok"
|
||||
assert rec.duration_ms >= 0
|
||||
assert rec.error is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_history_records_errors(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
|
||||
async def fail(_):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
service = CronService(store_path, on_job=fail)
|
||||
job = service.add_job(
|
||||
name="fail",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
loaded = service.get_job(job.id)
|
||||
assert len(loaded.state.run_history) == 1
|
||||
assert loaded.state.run_history[0].status == "error"
|
||||
assert loaded.state.run_history[0].error == "boom"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_history_trimmed_to_max(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="trim",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
for _ in range(25):
|
||||
await service.run_job(job.id)
|
||||
|
||||
loaded = service.get_job(job.id)
|
||||
assert len(loaded.state.run_history) == CronService._MAX_RUN_HISTORY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_history_persisted_to_disk(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="persist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
raw = json.loads(store_path.read_text())
|
||||
history = raw["jobs"][0]["state"]["runHistory"]
|
||||
assert len(history) == 1
|
||||
assert history[0]["status"] == "ok"
|
||||
assert "runAtMs" in history[0]
|
||||
assert "durationMs" in history[0]
|
||||
|
||||
fresh = CronService(store_path)
|
||||
loaded = fresh.get_job(job.id)
|
||||
assert len(loaded.state.run_history) == 1
|
||||
assert loaded.state.run_history[0].status == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from email.message import EmailMessage
|
||||
from datetime import date
|
||||
import imaplib
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -82,6 +83,120 @@ def test_fetch_new_messages_parses_unseen_and_marks_seen(monkeypatch) -> None:
|
||||
assert items_again == []
|
||||
|
||||
|
||||
def test_fetch_new_messages_retries_once_when_imap_connection_goes_stale(monkeypatch) -> None:
|
||||
raw = _make_raw_email(subject="Invoice", body="Please pay")
|
||||
fail_once = {"pending": True}
|
||||
|
||||
class FlakyIMAP:
|
||||
def __init__(self) -> None:
|
||||
self.store_calls: list[tuple[bytes, str, str]] = []
|
||||
self.search_calls = 0
|
||||
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"1"]
|
||||
|
||||
def search(self, *_args):
|
||||
self.search_calls += 1
|
||||
if fail_once["pending"]:
|
||||
fail_once["pending"] = False
|
||||
raise imaplib.IMAP4.abort("socket error")
|
||||
return "OK", [b"1"]
|
||||
|
||||
def fetch(self, _imap_id: bytes, _parts: str):
|
||||
return "OK", [(b"1 (UID 123 BODY[] {200})", raw), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, op: str, flags: str):
|
||||
self.store_calls.append((imap_id, op, flags))
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
fake_instances: list[FlakyIMAP] = []
|
||||
|
||||
def _factory(_host: str, _port: int):
|
||||
instance = FlakyIMAP()
|
||||
fake_instances.append(instance)
|
||||
return instance
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", _factory)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert len(fake_instances) == 2
|
||||
assert fake_instances[0].search_calls == 1
|
||||
assert fake_instances[1].search_calls == 1
|
||||
|
||||
|
||||
def test_fetch_new_messages_keeps_messages_collected_before_stale_retry(monkeypatch) -> None:
|
||||
raw_first = _make_raw_email(subject="First", body="First body")
|
||||
raw_second = _make_raw_email(subject="Second", body="Second body")
|
||||
mailbox_state = {
|
||||
b"1": {"uid": b"123", "raw": raw_first, "seen": False},
|
||||
b"2": {"uid": b"124", "raw": raw_second, "seen": False},
|
||||
}
|
||||
fail_once = {"pending": True}
|
||||
|
||||
class FlakyIMAP:
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
return "OK", [b"2"]
|
||||
|
||||
def search(self, *_args):
|
||||
unseen_ids = [imap_id for imap_id, item in mailbox_state.items() if not item["seen"]]
|
||||
return "OK", [b" ".join(unseen_ids)]
|
||||
|
||||
def fetch(self, imap_id: bytes, _parts: str):
|
||||
if imap_id == b"2" and fail_once["pending"]:
|
||||
fail_once["pending"] = False
|
||||
raise imaplib.IMAP4.abort("socket error")
|
||||
item = mailbox_state[imap_id]
|
||||
header = b"%s (UID %s BODY[] {200})" % (imap_id, item["uid"])
|
||||
return "OK", [(header, item["raw"]), b")"]
|
||||
|
||||
def store(self, imap_id: bytes, _op: str, _flags: str):
|
||||
mailbox_state[imap_id]["seen"] = True
|
||||
return "OK", [b""]
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: FlakyIMAP())
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert [item["subject"] for item in items] == ["First", "Second"]
|
||||
|
||||
|
||||
def test_fetch_new_messages_skips_missing_mailbox(monkeypatch) -> None:
|
||||
class MissingMailboxIMAP:
|
||||
def login(self, _user: str, _pw: str):
|
||||
return "OK", [b"logged in"]
|
||||
|
||||
def select(self, _mailbox: str):
|
||||
raise imaplib.IMAP4.error("Mailbox doesn't exist")
|
||||
|
||||
def logout(self):
|
||||
return "BYE", [b""]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.email.imaplib.IMAP4_SSL",
|
||||
lambda _h, _p: MissingMailboxIMAP(),
|
||||
)
|
||||
|
||||
channel = EmailChannel(_make_config(), MessageBus())
|
||||
|
||||
assert channel._fetch_new_messages() == []
|
||||
|
||||
|
||||
def test_extract_text_body_falls_back_to_html() -> None:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = "alice@example.com"
|
||||
|
||||
@@ -58,6 +58,19 @@ class TestReadFileTool:
|
||||
result = await tool.execute(path=str(f))
|
||||
assert "Empty file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_file_returns_multimodal_blocks(self, tool, tmp_path):
|
||||
f = tmp_path / "pixel.png"
|
||||
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
|
||||
|
||||
result = await tool.execute(path=str(f))
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["type"] == "image_url"
|
||||
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
|
||||
assert result[0]["_meta"]["path"] == str(f)
|
||||
assert result[1] == {"type": "text", "text": f"(Image file: {f})"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.txt"))
|
||||
|
||||
@@ -30,6 +30,69 @@ def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
||||
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
||||
|
||||
|
||||
def test_wrapper_preserves_non_nullable_unions() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["value"]["anyOf"] == [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
]
|
||||
|
||||
|
||||
def test_wrapper_normalizes_nullable_property_type_union() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": ["string", "null"]},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["name"] == {"type": "string", "nullable": True}
|
||||
|
||||
|
||||
def test_wrapper_normalizes_nullable_property_anyof() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"description": "optional name",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["name"] == {
|
||||
"type": "string",
|
||||
"description": "optional name",
|
||||
"nullable": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_text_blocks() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
|
||||
495
tests/test_onboard_logic.py
Normal file
495
tests/test_onboard_logic.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""Unit tests for onboard core logic functions.
|
||||
|
||||
These tests focus on the business logic behind the onboard wizard,
|
||||
without testing the interactive UI components.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nanobot.cli import onboard_wizard
|
||||
|
||||
# Import functions to test
|
||||
from nanobot.cli.commands import _merge_missing_defaults
|
||||
from nanobot.cli.onboard_wizard import (
|
||||
_BACK_PRESSED,
|
||||
_configure_pydantic_model,
|
||||
_format_value,
|
||||
_get_field_display_name,
|
||||
_get_field_type_info,
|
||||
run_onboard,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
|
||||
|
||||
class TestMergeMissingDefaults:
|
||||
"""Tests for _merge_missing_defaults recursive config merging."""
|
||||
|
||||
def test_adds_missing_top_level_keys(self):
|
||||
existing = {"a": 1}
|
||||
defaults = {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
result = _merge_missing_defaults(existing, defaults)
|
||||
|
||||
assert result == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
def test_preserves_existing_values(self):
|
||||
existing = {"a": "custom_value"}
|
||||
defaults = {"a": "default_value"}
|
||||
|
||||
result = _merge_missing_defaults(existing, defaults)
|
||||
|
||||
assert result == {"a": "custom_value"}
|
||||
|
||||
def test_merges_nested_dicts_recursively(self):
|
||||
existing = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"existing": "kept",
|
||||
}
|
||||
}
|
||||
}
|
||||
defaults = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"existing": "replaced",
|
||||
"added": "new",
|
||||
},
|
||||
"level2b": "also_new",
|
||||
}
|
||||
}
|
||||
|
||||
result = _merge_missing_defaults(existing, defaults)
|
||||
|
||||
assert result == {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"existing": "kept",
|
||||
"added": "new",
|
||||
},
|
||||
"level2b": "also_new",
|
||||
}
|
||||
}
|
||||
|
||||
def test_returns_existing_if_not_dict(self):
|
||||
assert _merge_missing_defaults("string", {"a": 1}) == "string"
|
||||
assert _merge_missing_defaults([1, 2, 3], {"a": 1}) == [1, 2, 3]
|
||||
assert _merge_missing_defaults(None, {"a": 1}) is None
|
||||
assert _merge_missing_defaults(42, {"a": 1}) == 42
|
||||
|
||||
def test_returns_existing_if_defaults_not_dict(self):
|
||||
assert _merge_missing_defaults({"a": 1}, "string") == {"a": 1}
|
||||
assert _merge_missing_defaults({"a": 1}, None) == {"a": 1}
|
||||
|
||||
def test_handles_empty_dicts(self):
|
||||
assert _merge_missing_defaults({}, {"a": 1}) == {"a": 1}
|
||||
assert _merge_missing_defaults({"a": 1}, {}) == {"a": 1}
|
||||
assert _merge_missing_defaults({}, {}) == {}
|
||||
|
||||
def test_backfills_channel_config(self):
|
||||
"""Real-world scenario: backfill missing channel fields."""
|
||||
existing_channel = {
|
||||
"enabled": False,
|
||||
"appId": "",
|
||||
"secret": "",
|
||||
}
|
||||
default_channel = {
|
||||
"enabled": False,
|
||||
"appId": "",
|
||||
"secret": "",
|
||||
"msgFormat": "plain",
|
||||
"allowFrom": [],
|
||||
}
|
||||
|
||||
result = _merge_missing_defaults(existing_channel, default_channel)
|
||||
|
||||
assert result["msgFormat"] == "plain"
|
||||
assert result["allowFrom"] == []
|
||||
|
||||
|
||||
class TestGetFieldTypeInfo:
|
||||
"""Tests for _get_field_type_info type extraction."""
|
||||
|
||||
def test_extracts_str_type(self):
|
||||
class Model(BaseModel):
|
||||
field: str
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["field"])
|
||||
assert type_name == "str"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_int_type(self):
|
||||
class Model(BaseModel):
|
||||
count: int
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["count"])
|
||||
assert type_name == "int"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_bool_type(self):
|
||||
class Model(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["enabled"])
|
||||
assert type_name == "bool"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_float_type(self):
|
||||
class Model(BaseModel):
|
||||
ratio: float
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["ratio"])
|
||||
assert type_name == "float"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_list_type_with_item_type(self):
|
||||
class Model(BaseModel):
|
||||
items: list[str]
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["items"])
|
||||
assert type_name == "list"
|
||||
assert inner is str
|
||||
|
||||
def test_extracts_list_type_without_item_type(self):
|
||||
# Plain list without type param falls back to str
|
||||
class Model(BaseModel):
|
||||
items: list # type: ignore
|
||||
|
||||
# Plain list annotation doesn't match list check, returns str
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["items"])
|
||||
assert type_name == "str" # Falls back to str for untyped list
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_dict_type(self):
|
||||
# Plain dict without type param falls back to str
|
||||
class Model(BaseModel):
|
||||
data: dict # type: ignore
|
||||
|
||||
# Plain dict annotation doesn't match dict check, returns str
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["data"])
|
||||
assert type_name == "str" # Falls back to str for untyped dict
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_optional_type(self):
|
||||
class Model(BaseModel):
|
||||
optional: str | None = None
|
||||
|
||||
type_name, inner = _get_field_type_info(Model.model_fields["optional"])
|
||||
# Should unwrap Optional and get str
|
||||
assert type_name == "str"
|
||||
assert inner is None
|
||||
|
||||
def test_extracts_nested_model_type(self):
|
||||
class Inner(BaseModel):
|
||||
x: int
|
||||
|
||||
class Outer(BaseModel):
|
||||
nested: Inner
|
||||
|
||||
type_name, inner = _get_field_type_info(Outer.model_fields["nested"])
|
||||
assert type_name == "model"
|
||||
assert inner is Inner
|
||||
|
||||
def test_handles_none_annotation(self):
|
||||
"""Field with None annotation defaults to str."""
|
||||
class Model(BaseModel):
|
||||
field: Any = None
|
||||
|
||||
# Create a mock field_info with None annotation
|
||||
field_info = SimpleNamespace(annotation=None)
|
||||
type_name, inner = _get_field_type_info(field_info)
|
||||
assert type_name == "str"
|
||||
assert inner is None
|
||||
|
||||
|
||||
class TestGetFieldDisplayName:
|
||||
"""Tests for _get_field_display_name human-readable name generation."""
|
||||
|
||||
def test_uses_description_if_present(self):
|
||||
class Model(BaseModel):
|
||||
api_key: str = Field(description="API Key for authentication")
|
||||
|
||||
name = _get_field_display_name("api_key", Model.model_fields["api_key"])
|
||||
assert name == "API Key for authentication"
|
||||
|
||||
def test_converts_snake_case_to_title(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("user_name", field_info)
|
||||
assert name == "User Name"
|
||||
|
||||
def test_adds_url_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("api_url", field_info)
|
||||
# Title case: "Api Url"
|
||||
assert "Url" in name and "Api" in name
|
||||
|
||||
def test_adds_path_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("file_path", field_info)
|
||||
assert "Path" in name and "File" in name
|
||||
|
||||
def test_adds_id_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("user_id", field_info)
|
||||
# Title case: "User Id"
|
||||
assert "Id" in name and "User" in name
|
||||
|
||||
def test_adds_key_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("api_key", field_info)
|
||||
assert "Key" in name and "Api" in name
|
||||
|
||||
def test_adds_token_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("auth_token", field_info)
|
||||
assert "Token" in name and "Auth" in name
|
||||
|
||||
def test_adds_seconds_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("timeout_s", field_info)
|
||||
# Contains "(Seconds)" with title case
|
||||
assert "(Seconds)" in name or "(seconds)" in name
|
||||
|
||||
def test_adds_ms_suffix(self):
|
||||
field_info = SimpleNamespace(description=None)
|
||||
name = _get_field_display_name("delay_ms", field_info)
|
||||
# Contains "(Ms)" or "(ms)"
|
||||
assert "(Ms)" in name or "(ms)" in name
|
||||
|
||||
|
||||
class TestFormatValue:
|
||||
"""Tests for _format_value display formatting."""
|
||||
|
||||
def test_formats_none_as_not_set(self):
|
||||
assert "not set" in _format_value(None)
|
||||
|
||||
def test_formats_empty_string_as_not_set(self):
|
||||
assert "not set" in _format_value("")
|
||||
|
||||
def test_formats_empty_dict_as_not_set(self):
|
||||
assert "not set" in _format_value({})
|
||||
|
||||
def test_formats_empty_list_as_not_set(self):
|
||||
assert "not set" in _format_value([])
|
||||
|
||||
def test_formats_string_value(self):
|
||||
result = _format_value("hello")
|
||||
assert "hello" in result
|
||||
|
||||
def test_formats_list_value(self):
|
||||
result = _format_value(["a", "b"])
|
||||
assert "a" in result or "b" in result
|
||||
|
||||
def test_formats_dict_value(self):
|
||||
result = _format_value({"key": "value"})
|
||||
assert "key" in result or "value" in result
|
||||
|
||||
def test_formats_int_value(self):
|
||||
result = _format_value(42)
|
||||
assert "42" in result
|
||||
|
||||
def test_formats_bool_true(self):
|
||||
result = _format_value(True)
|
||||
assert "true" in result.lower() or "✓" in result
|
||||
|
||||
def test_formats_bool_false(self):
|
||||
result = _format_value(False)
|
||||
assert "false" in result.lower() or "✗" in result
|
||||
|
||||
|
||||
class TestSyncWorkspaceTemplates:
|
||||
"""Tests for sync_workspace_templates file synchronization."""
|
||||
|
||||
def test_creates_missing_files(self, tmp_path):
|
||||
"""Should create template files that don't exist."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
added = sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
# Check that some files were created
|
||||
assert isinstance(added, list)
|
||||
# The actual files depend on the templates directory
|
||||
|
||||
def test_does_not_overwrite_existing_files(self, tmp_path):
|
||||
"""Should not overwrite files that already exist."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
(workspace / "AGENTS.md").write_text("existing content")
|
||||
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
# Existing file should not be changed
|
||||
content = (workspace / "AGENTS.md").read_text()
|
||||
assert content == "existing content"
|
||||
|
||||
def test_creates_memory_directory(self, tmp_path):
|
||||
"""Should create memory directory structure."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert (workspace / "memory").exists() or (workspace / "skills").exists()
|
||||
|
||||
def test_returns_list_of_added_files(self, tmp_path):
|
||||
"""Should return list of relative paths for added files."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
added = sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert isinstance(added, list)
|
||||
# All paths should be relative to workspace
|
||||
for path in added:
|
||||
assert not Path(path).is_absolute()
|
||||
|
||||
|
||||
class TestProviderChannelInfo:
|
||||
"""Tests for provider and channel info retrieval."""
|
||||
|
||||
def test_get_provider_names_returns_dict(self):
|
||||
from nanobot.cli.onboard_wizard import _get_provider_names
|
||||
|
||||
names = _get_provider_names()
|
||||
assert isinstance(names, dict)
|
||||
assert len(names) > 0
|
||||
# Should include common providers
|
||||
assert "openai" in names or "anthropic" in names
|
||||
assert "openai_codex" not in names
|
||||
assert "github_copilot" not in names
|
||||
|
||||
def test_get_channel_names_returns_dict(self):
|
||||
from nanobot.cli.onboard_wizard import _get_channel_names
|
||||
|
||||
names = _get_channel_names()
|
||||
assert isinstance(names, dict)
|
||||
# Should include at least some channels
|
||||
assert len(names) >= 0
|
||||
|
||||
def test_get_provider_info_returns_valid_structure(self):
|
||||
from nanobot.cli.onboard_wizard import _get_provider_info
|
||||
|
||||
info = _get_provider_info()
|
||||
assert isinstance(info, dict)
|
||||
# Each value should be a tuple with expected structure
|
||||
for provider_name, value in info.items():
|
||||
assert isinstance(value, tuple)
|
||||
assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var)
|
||||
|
||||
|
||||
class _SimpleDraftModel(BaseModel):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class _NestedDraftModel(BaseModel):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class _OuterDraftModel(BaseModel):
|
||||
nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel)
|
||||
|
||||
|
||||
class TestConfigurePydanticModelDrafts:
|
||||
@staticmethod
|
||||
def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"):
|
||||
sequence = iter(tokens)
|
||||
|
||||
def fake_select(_prompt, choices, default=None):
|
||||
token = next(sequence)
|
||||
if token == "first":
|
||||
return choices[0]
|
||||
if token == "done":
|
||||
return "[Done]"
|
||||
if token == "back":
|
||||
return _BACK_PRESSED
|
||||
return token
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
|
||||
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value
|
||||
)
|
||||
|
||||
def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch):
|
||||
model = _SimpleDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "back"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Simple")
|
||||
|
||||
assert result is None
|
||||
assert model.api_key == ""
|
||||
|
||||
def test_completing_section_returns_updated_draft(self, monkeypatch):
|
||||
model = _SimpleDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Simple")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_SimpleDraftModel, result)
|
||||
assert updated.api_key == "secret"
|
||||
assert model.api_key == ""
|
||||
|
||||
def test_nested_section_back_discards_nested_edits(self, monkeypatch):
|
||||
model = _OuterDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Outer")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_OuterDraftModel, result)
|
||||
assert updated.nested.api_key == ""
|
||||
assert model.nested.api_key == ""
|
||||
|
||||
def test_nested_section_done_commits_nested_edits(self, monkeypatch):
|
||||
model = _OuterDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Outer")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_OuterDraftModel, result)
|
||||
assert updated.nested.api_key == "secret"
|
||||
assert model.nested.api_key == ""
|
||||
|
||||
|
||||
class TestRunOnboardExitBehavior:
|
||||
def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch):
|
||||
initial_config = Config()
|
||||
|
||||
responses = iter(
|
||||
[
|
||||
"[A] Agent Settings",
|
||||
KeyboardInterrupt(),
|
||||
"[X] Exit Without Saving",
|
||||
]
|
||||
)
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
def fake_configure_general_settings(config, section):
|
||||
if section == "Agent Settings":
|
||||
config.agents.defaults.model = "test/provider-model"
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
|
||||
monkeypatch.setattr(onboard_wizard, "questionary", SimpleNamespace(select=fake_select))
|
||||
monkeypatch.setattr(onboard_wizard, "_configure_general_settings", fake_configure_general_settings)
|
||||
|
||||
result = run_onboard(initial_config=initial_config)
|
||||
|
||||
assert result.should_save is False
|
||||
assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)
|
||||
@@ -3,11 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop():
|
||||
@@ -65,6 +67,44 @@ class TestRestartCommand:
|
||||
|
||||
mock_handle.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_intercepted_in_run_loop(self):
|
||||
"""Verify /status is handled at the run-loop level for immediate replies."""
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
|
||||
with patch.object(loop, "_status_response") as mock_status:
|
||||
mock_status.return_value = OutboundMessage(
|
||||
channel="telegram", chat_id="c1", content="status ok"
|
||||
)
|
||||
await bus.publish_inbound(msg)
|
||||
|
||||
loop._running = True
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.1)
|
||||
loop._running = False
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
mock_status.assert_called_once()
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert out.content == "status ok"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_propagates_external_cancellation(self):
|
||||
"""External task cancellation should not be swallowed by the inbound wait loop."""
|
||||
loop, _bus = _make_loop()
|
||||
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
await asyncio.sleep(0.1)
|
||||
run_task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await asyncio.wait_for(run_task, timeout=1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_includes_restart(self):
|
||||
loop, bus = _make_loop()
|
||||
@@ -74,3 +114,75 @@ class TestRestartCommand:
|
||||
|
||||
assert response is not None
|
||||
assert "/restart" in response.content
|
||||
assert "/status" in response.content
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_reports_runtime_info(self):
|
||||
loop, _bus = _make_loop()
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = [{"role": "user"}] * 3
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop._start_time = time.time() - 125
|
||||
loop._last_usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
return_value=(20500, "tiktoken")
|
||||
)
|
||||
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
assert response is not None
|
||||
assert "Model: test-model" in response.content
|
||||
assert "Tokens: 0 in / 0 out" in response.content
|
||||
assert "Context: 20k/64k (31%)" in response.content
|
||||
assert "Session: 3 messages" in response.content
|
||||
assert "Uptime: 2m 5s" in response.content
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_loop_resets_usage_when_provider_omits_it(self):
|
||||
loop, _bus = _make_loop()
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=[
|
||||
LLMResponse(content="first", usage={"prompt_tokens": 9, "completion_tokens": 4}),
|
||||
LLMResponse(content="second", usage={}),
|
||||
])
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):
|
||||
loop, _bus = _make_loop()
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = [{"role": "user"}]
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop._last_usage = {"prompt_tokens": 1200, "completion_tokens": 34}
|
||||
loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
return_value=(0, "none")
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert "Tokens: 1200 in / 34 out" in response.content
|
||||
assert "Context: 1k/64k (1%)" in response.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_preserves_render_metadata(self):
|
||||
loop, _bus = _make_loop()
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = []
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop.subagents.get_running_count.return_value = 0
|
||||
|
||||
response = await loop.process_direct("/status", session_key="cli:test")
|
||||
|
||||
assert response is not None
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_loop():
|
||||
def _make_loop(*, exec_config=None):
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@@ -23,7 +23,7 @@ def _make_loop():
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=workspace, exec_config=exec_config)
|
||||
return loop, bus
|
||||
|
||||
|
||||
@@ -90,6 +90,13 @@ class TestHandleStop:
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
def test_exec_tool_not_registered_when_disabled(self):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
loop, _bus = _make_loop(exec_config=ExecToolConfig(enable=False))
|
||||
|
||||
assert loop.tools.get("exec") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_processes_and_publishes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
@@ -39,7 +39,7 @@ class _FakeBot:
|
||||
self.get_me_calls += 1
|
||||
return SimpleNamespace(id=999, username="nanobot_test")
|
||||
|
||||
async def set_my_commands(self, commands) -> None:
|
||||
async def set_my_commands(self, commands, language_code=None) -> None:
|
||||
self.commands = commands
|
||||
|
||||
async def send_message(self, **kwargs) -> None:
|
||||
@@ -175,6 +175,7 @@ async def test_start_creates_separate_pools_with_proxy(monkeypatch) -> None:
|
||||
assert poll_req.kwargs["connection_pool_size"] == 4
|
||||
assert builder.request_value is api_req
|
||||
assert builder.get_updates_request_value is poll_req
|
||||
assert any(cmd.command == "status" for cmd in app.bot.commands)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -775,3 +776,20 @@ async def test_forward_command_does_not_inject_reply_context() -> None:
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_help_includes_restart_command() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
update = _make_telegram_update(text="/help", chat_type="private")
|
||||
update.message.reply_text = AsyncMock()
|
||||
|
||||
await channel._on_help(update, None)
|
||||
|
||||
update.message.reply_text.assert_awaited_once()
|
||||
help_text = update.message.reply_text.await_args.args[0]
|
||||
assert "/restart" in help_text
|
||||
assert "/status" in help_text
|
||||
|
||||
@@ -453,6 +453,18 @@ def test_validate_nullable_param_accepts_none() -> None:
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_nullable_flag_accepts_none() -> None:
|
||||
"""OpenAI-normalized nullable params should still accept None locally."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string", "nullable": True}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"name": None})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_cast_nullable_param_no_crash() -> None:
|
||||
"""cast_params should not crash on nullable type (the original bug)."""
|
||||
tool = CastTestTool(
|
||||
|
||||
@@ -67,3 +67,47 @@ async def test_web_fetch_result_contains_untrusted_flag():
|
||||
data = json.loads(result)
|
||||
assert data.get("untrusted") is True
|
||||
assert "[External content" in data.get("text", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch):
|
||||
tool = WebFetchTool()
|
||||
|
||||
class FakeStreamResponse:
|
||||
headers = {"content-type": "image/png"}
|
||||
url = "http://127.0.0.1/secret.png"
|
||||
content = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def aread(self):
|
||||
return self.content
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def stream(self, method, url, headers=None):
|
||||
return FakeStreamResponse()
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.web.httpx.AsyncClient", FakeClient)
|
||||
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
|
||||
result = await tool.execute(url="https://example.com/image.png")
|
||||
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
assert "redirect blocked" in data["error"].lower()
|
||||
|
||||
Reference in New Issue
Block a user