Merge remote-tracking branch 'origin/main'
Some checks failed
Test Suite / test (3.11) (push) Failing after 1m0s
Test Suite / test (3.12) (push) Failing after 19s
Test Suite / test (3.13) (push) Failing after 18s

# Conflicts:
#	nanobot/agent/tools/shell.py
#	tests/agent/test_evaluator.py
#	tests/channels/test_feishu_tool_hint_code_block.py
#	tests/providers/test_litellm_kwargs.py
#	tests/tools/test_web_search_tool.py
This commit is contained in:
Hua
2026-03-24 16:38:50 +08:00
49 changed files with 263 additions and 34 deletions

View File

@@ -0,0 +1,69 @@
"""Tests for exec tool internal URL blocking."""
from __future__ import annotations
import socket
from unittest.mock import patch
import pytest
from nanobot.agent.tools.shell import ExecTool
def _fake_resolve_private(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
def _fake_resolve_localhost(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
def _fake_resolve_public(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
@pytest.mark.asyncio
async def test_exec_blocks_curl_metadata():
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(
command='curl -s -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/'
)
assert "Error" in result
assert "internal" in result.lower() or "private" in result.lower()
@pytest.mark.asyncio
async def test_exec_blocks_wget_localhost():
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost):
result = await tool.execute(command="wget http://localhost:8080/secret -O /tmp/out")
assert "Error" in result
@pytest.mark.asyncio
async def test_exec_allows_normal_commands():
tool = ExecTool(timeout=5)
result = await tool.execute(command="echo hello")
assert "hello" in result
assert "Error" not in result.split("\n")[0]
@pytest.mark.asyncio
async def test_exec_allows_curl_to_public_url():
"""Commands with public URLs should not be blocked by the internal URL check."""
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
guard_result = tool._guard_command("curl https://example.com/api", "/tmp")
assert guard_result is None
@pytest.mark.asyncio
async def test_exec_blocks_chained_internal_url():
"""Internal URLs buried in chained commands should still be caught."""
tool = ExecTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
)
assert "Error" in result

View File

@@ -0,0 +1,392 @@
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
import pytest
from nanobot.agent.tools.filesystem import (
EditFileTool,
ListDirTool,
ReadFileTool,
_find_match,
)
# ---------------------------------------------------------------------------
# ReadFileTool
# ---------------------------------------------------------------------------
class TestReadFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def sample_file(self, tmp_path):
f = tmp_path / "sample.txt"
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
return f
@pytest.mark.asyncio
async def test_basic_read_has_line_numbers(self, tool, sample_file):
result = await tool.execute(path=str(sample_file))
assert "1| line 1" in result
assert "20| line 20" in result
@pytest.mark.asyncio
async def test_offset_and_limit(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
assert "5| line 5" in result
assert "7| line 7" in result
assert "8| line 8" not in result
assert "Use offset=8 to continue" in result
@pytest.mark.asyncio
async def test_offset_beyond_end(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=999)
assert "Error" in result
assert "beyond end" in result
@pytest.mark.asyncio
async def test_end_of_file_marker(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
assert "End of file" in result
@pytest.mark.asyncio
async def test_empty_file(self, tool, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("", encoding="utf-8")
result = await tool.execute(path=str(f))
assert "Empty file" in result
@pytest.mark.asyncio
async def test_image_file_returns_multimodal_blocks(self, tool, tmp_path):
f = tmp_path / "pixel.png"
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
result = await tool.execute(path=str(f))
assert isinstance(result, list)
assert result[0]["type"] == "image_url"
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
assert result[0]["_meta"]["path"] == str(f)
assert result[1] == {"type": "text", "text": f"(Image file: {f})"}
@pytest.mark.asyncio
async def test_file_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope.txt"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_missing_path_returns_clear_error(self, tool):
result = await tool.execute()
assert result == "Error reading file: Unknown path"
@pytest.mark.asyncio
async def test_char_budget_trims(self, tool, tmp_path):
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
f = tmp_path / "big.txt"
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
result = await tool.execute(path=str(f))
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
assert "Use offset=" in result
# ---------------------------------------------------------------------------
# _find_match (unit tests for the helper)
# ---------------------------------------------------------------------------
class TestFindMatch:
def test_exact_match(self):
match, count = _find_match("hello world", "world")
assert match == "world"
assert count == 1
def test_exact_no_match(self):
match, count = _find_match("hello world", "xyz")
assert match is None
assert count == 0
def test_crlf_normalisation(self):
# Caller normalises CRLF before calling _find_match, so test with
# pre-normalised content to verify exact match still works.
content = "line1\nline2\nline3"
old_text = "line1\nline2\nline3"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
def test_line_trim_fallback(self):
content = " def foo():\n pass\n"
old_text = "def foo():\n pass"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
# The returned match should be the *original* indented text
assert " def foo():" in match
def test_line_trim_multiple_candidates(self):
content = " a\n b\n a\n b\n"
old_text = "a\nb"
match, count = _find_match(content, old_text)
assert count == 2
def test_empty_old_text(self):
match, count = _find_match("hello", "")
# Empty string is always "in" any string via exact match
assert match == ""
# ---------------------------------------------------------------------------
# EditFileTool
# ---------------------------------------------------------------------------
class TestEditFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_exact_match(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
assert f.read_text() == "hello earth"
@pytest.mark.asyncio
async def test_crlf_normalisation(self, tool, tmp_path):
f = tmp_path / "crlf.py"
f.write_bytes(b"line1\r\nline2\r\nline3")
result = await tool.execute(
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
)
assert "Successfully" in result
raw = f.read_bytes()
assert b"LINE1" in raw
# CRLF line endings should be preserved throughout the file
assert b"\r\n" in raw
@pytest.mark.asyncio
async def test_trim_fallback(self, tool, tmp_path):
f = tmp_path / "indent.py"
f.write_text(" def foo():\n pass\n", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
)
assert "Successfully" in result
assert "bar" in f.read_text()
@pytest.mark.asyncio
async def test_ambiguous_match(self, tool, tmp_path):
f = tmp_path / "dup.py"
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
assert "appears" in result.lower() or "Warning" in result
@pytest.mark.asyncio
async def test_replace_all(self, tool, tmp_path):
f = tmp_path / "multi.py"
f.write_text("foo bar foo bar foo", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="foo", new_text="baz", replace_all=True,
)
assert "Successfully" in result
assert f.read_text() == "baz bar baz bar baz"
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
f = tmp_path / "nf.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_missing_new_text_returns_clear_error(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="hello")
assert result == "Error editing file: Unknown new_text"
# ---------------------------------------------------------------------------
# ListDirTool
# ---------------------------------------------------------------------------
class TestListDirTool:
@pytest.fixture()
def tool(self, tmp_path):
return ListDirTool(workspace=tmp_path)
@pytest.fixture()
def populated_dir(self, tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("pass")
(tmp_path / "src" / "utils.py").write_text("pass")
(tmp_path / "README.md").write_text("hi")
(tmp_path / ".git").mkdir()
(tmp_path / ".git" / "config").write_text("x")
(tmp_path / "node_modules").mkdir()
(tmp_path / "node_modules" / "pkg").mkdir()
return tmp_path
@pytest.mark.asyncio
async def test_basic_list(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir))
assert "README.md" in result
assert "src" in result
# .git and node_modules should be ignored
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_recursive(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir), recursive=True)
assert "src/main.py" in result
assert "src/utils.py" in result
assert "README.md" in result
# Ignored dirs should not appear
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_max_entries_truncation(self, tool, tmp_path):
for i in range(10):
(tmp_path / f"file_{i}.txt").write_text("x")
result = await tool.execute(path=str(tmp_path), max_entries=3)
assert "truncated" in result
assert "3 of 10" in result
@pytest.mark.asyncio
async def test_empty_dir(self, tool, tmp_path):
d = tmp_path / "empty"
d.mkdir()
result = await tool.execute(path=str(d))
assert "empty" in result.lower()
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_missing_path_returns_clear_error(self, tool):
result = await tool.execute()
assert result == "Error listing directory: Unknown path"
# ---------------------------------------------------------------------------
# Workspace restriction + extra_allowed_dirs
# ---------------------------------------------------------------------------
class TestWorkspaceRestriction:
@pytest.mark.asyncio
async def test_read_blocked_outside_workspace(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
secret = outside / "secret.txt"
secret.write_text("top secret")
tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_allowed_with_extra_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "test_skill" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Test Skill\nDo something.")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(skill_file))
assert "Test Skill" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_extra_dirs_does_not_widen_write(self, tmp_path):
from nanobot.agent.tools.filesystem import WriteFileTool
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
unrelated = tmp_path / "other"
unrelated.mkdir()
secret = unrelated / "secret.txt"
secret.write_text("nope")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
"""Adding extra_allowed_dirs must not break normal workspace reads."""
workspace = tmp_path / "ws"
workspace.mkdir()
ws_file = workspace / "README.md"
ws_file.write_text("hello from workspace")
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(ws_file))
assert "hello from workspace" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_edit_blocked_in_extra_dir(self, tmp_path):
"""edit_file must not be able to modify files in extra_allowed_dirs."""
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "weather" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Weather\nOriginal content.")
tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(
path=str(skill_file),
old_text="Original content.",
new_text="Hacked content.",
)
assert "Error" in result
assert "outside" in result.lower()
assert skill_file.read_text() == "# Weather\nOriginal content."

View File

@@ -0,0 +1,162 @@
from __future__ import annotations
import asyncio
import sys
from types import ModuleType, SimpleNamespace
import pytest
from nanobot.agent.tools.mcp import MCPToolWrapper
class _FakeTextContent:
def __init__(self, text: str) -> None:
self.text = text
@pytest.fixture(autouse=True)
def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None:
mod = ModuleType("mcp")
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
monkeypatch.setitem(sys.modules, "mcp", mod)
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={"type": "object", "properties": {}},
)
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
def test_wrapper_preserves_non_nullable_unions() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"value": {
"anyOf": [{"type": "string"}, {"type": "integer"}],
}
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["value"]["anyOf"] == [
{"type": "string"},
{"type": "integer"},
]
def test_wrapper_normalizes_nullable_property_type_union() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"name": {"type": ["string", "null"]},
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["name"] == {"type": "string", "nullable": True}
def test_wrapper_normalizes_nullable_property_anyof() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"name": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"description": "optional name",
},
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["name"] == {
"type": "string",
"description": "optional name",
"nullable": True,
}
@pytest.mark.asyncio
async def test_execute_returns_text_blocks() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
assert arguments == {"value": 1}
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute(value=1)
assert result == "hello\n42"
@pytest.mark.asyncio
async def test_execute_returns_timeout_message() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
await asyncio.sleep(1)
return SimpleNamespace(content=[])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
result = await wrapper.execute()
assert result == "(MCP tool call timed out after 0.01s)"
@pytest.mark.asyncio
async def test_execute_handles_server_cancelled_error() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
raise asyncio.CancelledError()
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute()
assert result == "(MCP tool call was cancelled)"
@pytest.mark.asyncio
async def test_execute_re_raises_external_cancellation() -> None:
started = asyncio.Event()
async def call_tool(_name: str, arguments: dict) -> object:
started.set()
await asyncio.sleep(60)
return SimpleNamespace(content=[])
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
task = asyncio.create_task(wrapper.execute())
await started.wait()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
@pytest.mark.asyncio
async def test_execute_handles_generic_exception() -> None:
async def call_tool(_name: str, arguments: dict) -> object:
raise RuntimeError("boom")
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
result = await wrapper.execute()
assert result == "(MCP tool call failed: RuntimeError)"

View File

@@ -0,0 +1,10 @@
import pytest
from nanobot.agent.tools.message import MessageTool
@pytest.mark.asyncio
async def test_message_tool_returns_error_when_no_target_context() -> None:
tool = MessageTool()
result = await tool.execute(content="test")
assert result == "Error: No target channel/chat specified"

View File

@@ -0,0 +1,132 @@
"""Test message tool suppress logic for final replies."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.message import MessageTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
def _make_loop(tmp_path: Path) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
class TestMessageToolSuppressLogic:
"""Final reply suppressed only when message tool sends to the same target."""
@pytest.mark.asyncio
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
result = await loop._process_message(msg)
assert len(sent) == 1
assert result is None # suppressed
@pytest.mark.asyncio
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="I've sent the email.", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
result = await loop._process_message(msg)
assert len(sent) == 1
assert sent[0].channel == "email"
assert result is not None # not suppressed
assert result.channel == "feishu"
@pytest.mark.asyncio
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
loop.tools.get_definitions = MagicMock(return_value=[])
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
result = await loop._process_message(msg)
assert result is not None
assert "Hello" in result.content
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
calls = iter([
LLMResponse(
content="Visible<think>hidden</think>",
tool_calls=[tool_call],
reasoning_content="secret reasoning",
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
progress: list[tuple[str, bool]] = []
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [
("Visible", False),
('read_file("foo.txt")', True),
]
class TestMessageToolTurnTracking:
def test_sent_in_turn_tracks_same_target(self) -> None:
tool = MessageTool()
tool.set_context("feishu", "chat1")
assert not tool._sent_in_turn
tool._sent_in_turn = True
assert tool._sent_in_turn
def test_start_turn_resets(self) -> None:
tool = MessageTool()
tool._sent_in_turn = True
tool.start_turn()
assert not tool._sent_in_turn

View File

@@ -0,0 +1,479 @@
from typing import Any
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
class SampleTool(Tool):
@property
def name(self) -> str:
return "sample"
@property
def description(self) -> str:
return "sample tool"
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {"type": "string", "minLength": 2},
"count": {"type": "integer", "minimum": 1, "maximum": 10},
"mode": {"type": "string", "enum": ["fast", "full"]},
"meta": {
"type": "object",
"properties": {
"tag": {"type": "string"},
"flags": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["tag"],
},
},
"required": ["query", "count"],
}
async def execute(self, **kwargs: Any) -> str:
return "ok"
def test_validate_params_missing_required() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "hi"})
assert "missing required count" in "; ".join(errors)
def test_validate_params_type_and_range() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "hi", "count": 0})
assert any("count must be >= 1" in e for e in errors)
errors = tool.validate_params({"query": "hi", "count": "2"})
assert any("count should be integer" in e for e in errors)
def test_validate_params_enum_and_min_length() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"})
assert any("query must be at least 2 chars" in e for e in errors)
assert any("mode must be one of" in e for e in errors)
def test_validate_params_nested_object_and_array() -> None:
tool = SampleTool()
errors = tool.validate_params(
{
"query": "hi",
"count": 2,
"meta": {"flags": [1, "ok"]},
}
)
assert any("missing required meta.tag" in e for e in errors)
assert any("meta.flags[0] should be string" in e for e in errors)
def test_validate_params_ignores_unknown_fields() -> None:
tool = SampleTool()
errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"})
assert errors == []
async def test_registry_returns_validation_error() -> None:
reg = ToolRegistry()
reg.register(SampleTool())
result = await reg.execute("sample", {"query": "hi"})
assert "Invalid parameters" in result
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
cmd = r"type C:\user\workspace\txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert paths == [r"C:\user\workspace\txt"]
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
cmd = ".venv/bin/python script.py"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/bin/python" not in paths
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
cmd = "cat /tmp/data.txt > /tmp/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "/tmp/out.txt" in paths
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
assert "~/.nanobot/config.json" in paths
assert "~/out.txt" in paths
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
paths = ExecTool._extract_absolute_paths(cmd)
assert "/tmp/data.txt" in paths
assert "~/.nanobot/config.json" in paths
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests ---
class CastTestTool(Tool):
"""Minimal tool for testing cast_params."""
def __init__(self, schema: dict[str, Any]) -> None:
self._schema = schema
@property
def name(self) -> str:
return "cast_test"
@property
def description(self) -> str:
return "test tool for casting"
@property
def parameters(self) -> dict[str, Any]:
return self._schema
async def execute(self, **kwargs: Any) -> str:
return "ok"
def test_cast_params_string_to_int() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "42"})
assert result["count"] == 42
assert isinstance(result["count"], int)
def test_cast_params_string_to_number() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "3.14"})
assert result["rate"] == 3.14
assert isinstance(result["rate"], float)
def test_cast_params_string_to_bool() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {"enabled": {"type": "boolean"}},
}
)
assert tool.cast_params({"enabled": "true"})["enabled"] is True
assert tool.cast_params({"enabled": "false"})["enabled"] is False
assert tool.cast_params({"enabled": "1"})["enabled"] is True
def test_cast_params_array_items() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"nums": {"type": "array", "items": {"type": "integer"}},
},
}
)
result = tool.cast_params({"nums": ["1", "2", "3"]})
assert result["nums"] == [1, 2, 3]
def test_cast_params_nested_object() -> None:
tool = CastTestTool(
{
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"port": {"type": "integer"},
"debug": {"type": "boolean"},
},
},
},
}
)
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
assert result["config"]["port"] == 8080
assert result["config"]["debug"] is True
def test_cast_params_bool_not_cast_to_int() -> None:
"""Booleans should not be silently cast to integers."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": True})
assert result["count"] is True
errors = tool.validate_params(result)
assert any("count should be integer" in e for e in errors)
def test_cast_params_preserves_empty_string() -> None:
"""Empty strings should be preserved for string type."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string"}},
}
)
result = tool.cast_params({"name": ""})
assert result["name"] == ""
def test_cast_params_bool_string_false() -> None:
"""Test that 'false', '0', 'no' strings convert to False."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
assert tool.cast_params({"flag": "false"})["flag"] is False
assert tool.cast_params({"flag": "False"})["flag"] is False
assert tool.cast_params({"flag": "0"})["flag"] is False
assert tool.cast_params({"flag": "no"})["flag"] is False
assert tool.cast_params({"flag": "NO"})["flag"] is False
def test_cast_params_bool_string_invalid() -> None:
"""Invalid boolean strings should not be cast."""
tool = CastTestTool(
{
"type": "object",
"properties": {"flag": {"type": "boolean"}},
}
)
# Invalid strings should be preserved (validation will catch them)
result = tool.cast_params({"flag": "random"})
assert result["flag"] == "random"
result = tool.cast_params({"flag": "maybe"})
assert result["flag"] == "maybe"
def test_cast_params_invalid_string_to_int() -> None:
"""Invalid strings should not be cast to integer."""
tool = CastTestTool(
{
"type": "object",
"properties": {"count": {"type": "integer"}},
}
)
result = tool.cast_params({"count": "abc"})
assert result["count"] == "abc" # Original value preserved
result = tool.cast_params({"count": "12.5.7"})
assert result["count"] == "12.5.7"
def test_cast_params_invalid_string_to_number() -> None:
"""Invalid strings should not be cast to number."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
result = tool.cast_params({"rate": "not_a_number"})
assert result["rate"] == "not_a_number"
def test_validate_params_bool_not_accepted_as_number() -> None:
"""Booleans should not pass number validation."""
tool = CastTestTool(
{
"type": "object",
"properties": {"rate": {"type": "number"}},
}
)
errors = tool.validate_params({"rate": False})
assert any("rate should be number" in e for e in errors)
def test_cast_params_none_values() -> None:
"""Test None handling for different types."""
tool = CastTestTool(
{
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "integer"},
"items": {"type": "array"},
"config": {"type": "object"},
},
}
)
result = tool.cast_params(
{
"name": None,
"count": None,
"items": None,
"config": None,
}
)
# None should be preserved for all types
assert result["name"] is None
assert result["count"] is None
assert result["items"] is None
assert result["config"] is None
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
"""Single values should NOT be automatically wrapped into arrays."""
tool = CastTestTool(
{
"type": "object",
"properties": {"items": {"type": "array"}},
}
)
# Non-array values should be preserved (validation will catch them)
result = tool.cast_params({"items": 5})
assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"]
# --- ExecTool enhancement tests ---
async def test_exec_always_returns_exit_code() -> None:
"""Exit code should appear in output even on success (exit 0)."""
tool = ExecTool()
result = await tool.execute(command="echo hello")
assert "Exit code: 0" in result
assert "hello" in result
async def test_exec_head_tail_truncation() -> None:
"""Long output should preserve both head and tail."""
tool = ExecTool()
# Generate output that exceeds _MAX_OUTPUT
big = "A" * 6000 + "\n" + "B" * 6000
result = await tool.execute(command=f"echo '{big}'")
assert "chars truncated" in result
# Head portion should start with As
assert result.startswith("A")
# Tail portion should end with the exit code which comes after Bs
assert "Exit code:" in result
async def test_exec_timeout_parameter() -> None:
"""LLM-supplied timeout should override the constructor default."""
tool = ExecTool(timeout=60)
# A very short timeout should cause the command to be killed
result = await tool.execute(command="sleep 10", timeout=1)
assert "timed out" in result
assert "1 seconds" in result
async def test_exec_timeout_capped_at_max() -> None:
"""Timeout values above _MAX_TIMEOUT should be clamped."""
tool = ExecTool()
# Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result
# --- _resolve_type and nullable param tests ---
def test_resolve_type_simple_string() -> None:
"""Simple string type passes through unchanged."""
assert Tool._resolve_type("string") == "string"
def test_resolve_type_union_with_null() -> None:
"""Union type ['string', 'null'] resolves to 'string'."""
assert Tool._resolve_type(["string", "null"]) == "string"
def test_resolve_type_only_null() -> None:
"""Union type ['null'] resolves to None (no non-null type)."""
assert Tool._resolve_type(["null"]) is None
def test_resolve_type_none_input() -> None:
"""None input passes through as None."""
assert Tool._resolve_type(None) is None
def test_validate_nullable_param_accepts_string() -> None:
"""Nullable string param should accept a string value."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
errors = tool.validate_params({"name": "hello"})
assert errors == []
def test_validate_nullable_param_accepts_none() -> None:
"""Nullable string param should accept None."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
errors = tool.validate_params({"name": None})
assert errors == []
def test_validate_nullable_flag_accepts_none() -> None:
"""OpenAI-normalized nullable params should still accept None locally."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string", "nullable": True}},
}
)
errors = tool.validate_params({"name": None})
assert errors == []
def test_cast_nullable_param_no_crash() -> None:
"""cast_params should not crash on nullable type (the original bug)."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": ["string", "null"]}},
}
)
result = tool.cast_params({"name": "hello"})
assert result["name"] == "hello"
result = tool.cast_params({"name": None})
assert result["name"] is None

View File

@@ -0,0 +1,113 @@
"""Tests for web_fetch SSRF protection and untrusted content marking."""
from __future__ import annotations
import json
import socket
from unittest.mock import patch
import pytest
from nanobot.agent.tools.web import WebFetchTool
def _fake_resolve_private(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
def _fake_resolve_public(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
@pytest.mark.asyncio
async def test_web_fetch_blocks_private_ip():
tool = WebFetchTool()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
result = await tool.execute(url="http://169.254.169.254/computeMetadata/v1/")
data = json.loads(result)
assert "error" in data
assert "private" in data["error"].lower() or "blocked" in data["error"].lower()
@pytest.mark.asyncio
async def test_web_fetch_blocks_localhost():
tool = WebFetchTool()
def _resolve_localhost(hostname, port, family=0, type_=0):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
with patch("nanobot.security.network.socket.getaddrinfo", _resolve_localhost):
result = await tool.execute(url="http://localhost/admin")
data = json.loads(result)
assert "error" in data
@pytest.mark.asyncio
async def test_web_fetch_result_contains_untrusted_flag():
"""When fetch succeeds, result JSON must include untrusted=True and the banner."""
tool = WebFetchTool()
fake_html = "<html><head><title>Test</title></head><body><p>Hello world</p></body></html>"
import httpx
class FakeResponse:
status_code = 200
url = "https://example.com/page"
text = fake_html
headers = {"content-type": "text/html"}
def raise_for_status(self): pass
def json(self): return {}
async def _fake_get(self, url, **kwargs):
return FakeResponse()
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public), \
patch("httpx.AsyncClient.get", _fake_get):
result = await tool.execute(url="https://example.com/page")
data = json.loads(result)
assert data.get("untrusted") is True
assert "[External content" in data.get("text", "")
@pytest.mark.asyncio
async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch):
tool = WebFetchTool()
class FakeStreamResponse:
headers = {"content-type": "image/png"}
url = "http://127.0.0.1/secret.png"
content = b"\x89PNG\r\n\x1a\n"
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def aread(self):
return self.content
def raise_for_status(self):
return None
class FakeClient:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def stream(self, method, url, headers=None):
return FakeStreamResponse()
monkeypatch.setattr("nanobot.agent.tools.web.httpx.AsyncClient", FakeClient)
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
result = await tool.execute(url="https://example.com/image.png")
data = json.loads(result)
assert "error" in data
assert "redirect blocked" in data["error"].lower()