Merge remote-tracking branch 'origin/main'
# Conflicts: # nanobot/agent/tools/shell.py # tests/agent/test_evaluator.py # tests/channels/test_feishu_tool_hint_code_block.py # tests/providers/test_litellm_kwargs.py # tests/tools/test_web_search_tool.py
This commit is contained in:
69
tests/tools/test_exec_security.py
Normal file
69
tests/tools/test_exec_security.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tests for exec tool internal URL blocking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
def _fake_resolve_private(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
|
||||
|
||||
|
||||
def _fake_resolve_localhost(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
|
||||
|
||||
|
||||
def _fake_resolve_public(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_curl_metadata():
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
|
||||
result = await tool.execute(
|
||||
command='curl -s -H "Metadata-Flavor: Google" http://169.254.169.254/computeMetadata/v1/'
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "internal" in result.lower() or "private" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_wget_localhost():
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_localhost):
|
||||
result = await tool.execute(command="wget http://localhost:8080/secret -O /tmp/out")
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allows_normal_commands():
|
||||
tool = ExecTool(timeout=5)
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "hello" in result
|
||||
assert "Error" not in result.split("\n")[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allows_curl_to_public_url():
|
||||
"""Commands with public URLs should not be blocked by the internal URL check."""
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
|
||||
guard_result = tool._guard_command("curl https://example.com/api", "/tmp")
|
||||
assert guard_result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_chained_internal_url():
|
||||
"""Internal URLs buried in chained commands should still be caught."""
|
||||
tool = ExecTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
|
||||
result = await tool.execute(
|
||||
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
|
||||
)
|
||||
assert "Error" in result
|
||||
392
tests/tools/test_filesystem_tools.py
Normal file
392
tests/tools/test_filesystem_tools.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import (
|
||||
EditFileTool,
|
||||
ListDirTool,
|
||||
ReadFileTool,
|
||||
_find_match,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReadFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def sample_file(self, tmp_path):
|
||||
f = tmp_path / "sample.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
return f
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_read_has_line_numbers(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file))
|
||||
assert "1| line 1" in result
|
||||
assert "20| line 20" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_and_limit(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
|
||||
assert "5| line 5" in result
|
||||
assert "7| line 7" in result
|
||||
assert "8| line 8" not in result
|
||||
assert "Use offset=8 to continue" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_offset_beyond_end(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=999)
|
||||
assert "Error" in result
|
||||
assert "beyond end" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_of_file_marker(self, tool, sample_file):
|
||||
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
|
||||
assert "End of file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.txt"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert "Empty file" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_file_returns_multimodal_blocks(self, tool, tmp_path):
|
||||
f = tmp_path / "pixel.png"
|
||||
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
|
||||
|
||||
result = await tool.execute(path=str(f))
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["type"] == "image_url"
|
||||
assert result[0]["image_url"]["url"].startswith("data:image/png;base64,")
|
||||
assert result[0]["_meta"]["path"] == str(f)
|
||||
assert result[1] == {"type": "text", "text": f"(Image file: {f})"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.txt"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_path_returns_clear_error(self, tool):
|
||||
result = await tool.execute()
|
||||
assert result == "Error reading file: Unknown path"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_char_budget_trims(self, tool, tmp_path):
|
||||
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
|
||||
f = tmp_path / "big.txt"
|
||||
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
|
||||
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
|
||||
result = await tool.execute(path=str(f))
|
||||
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
|
||||
assert "Use offset=" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_match (unit tests for the helper)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFindMatch:
|
||||
|
||||
def test_exact_match(self):
|
||||
match, count = _find_match("hello world", "world")
|
||||
assert match == "world"
|
||||
assert count == 1
|
||||
|
||||
def test_exact_no_match(self):
|
||||
match, count = _find_match("hello world", "xyz")
|
||||
assert match is None
|
||||
assert count == 0
|
||||
|
||||
def test_crlf_normalisation(self):
|
||||
# Caller normalises CRLF before calling _find_match, so test with
|
||||
# pre-normalised content to verify exact match still works.
|
||||
content = "line1\nline2\nline3"
|
||||
old_text = "line1\nline2\nline3"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_line_trim_fallback(self):
|
||||
content = " def foo():\n pass\n"
|
||||
old_text = "def foo():\n pass"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# The returned match should be the *original* indented text
|
||||
assert " def foo():" in match
|
||||
|
||||
def test_line_trim_multiple_candidates(self):
|
||||
content = " a\n b\n a\n b\n"
|
||||
old_text = "a\nb"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert count == 2
|
||||
|
||||
def test_empty_old_text(self):
|
||||
match, count = _find_match("hello", "")
|
||||
# Empty string is always "in" any string via exact match
|
||||
assert match == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EditFileTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditFileTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exact_match(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crlf_normalisation(self, tool, tmp_path):
|
||||
f = tmp_path / "crlf.py"
|
||||
f.write_bytes(b"line1\r\nline2\r\nline3")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
raw = f.read_bytes()
|
||||
assert b"LINE1" in raw
|
||||
# CRLF line endings should be preserved throughout the file
|
||||
assert b"\r\n" in raw
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(" def foo():\n pass\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert "bar" in f.read_text()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears" in result.lower() or "Warning" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all(self, tool, tmp_path):
|
||||
f = tmp_path / "multi.py"
|
||||
f.write_text("foo bar foo bar foo", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="foo", new_text="baz", replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "baz bar baz bar baz"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
f = tmp_path / "nf.py"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_new_text_returns_clear_error(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="hello")
|
||||
assert result == "Error editing file: Unknown new_text"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ListDirTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListDirTool:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ListDirTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def populated_dir(self, tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "main.py").write_text("pass")
|
||||
(tmp_path / "src" / "utils.py").write_text("pass")
|
||||
(tmp_path / "README.md").write_text("hi")
|
||||
(tmp_path / ".git").mkdir()
|
||||
(tmp_path / ".git" / "config").write_text("x")
|
||||
(tmp_path / "node_modules").mkdir()
|
||||
(tmp_path / "node_modules" / "pkg").mkdir()
|
||||
return tmp_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_list(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir))
|
||||
assert "README.md" in result
|
||||
assert "src" in result
|
||||
# .git and node_modules should be ignored
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recursive(self, tool, populated_dir):
|
||||
result = await tool.execute(path=str(populated_dir), recursive=True)
|
||||
assert "src/main.py" in result
|
||||
assert "src/utils.py" in result
|
||||
assert "README.md" in result
|
||||
# Ignored dirs should not appear
|
||||
assert ".git" not in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_entries_truncation(self, tool, tmp_path):
|
||||
for i in range(10):
|
||||
(tmp_path / f"file_{i}.txt").write_text("x")
|
||||
result = await tool.execute(path=str(tmp_path), max_entries=3)
|
||||
assert "truncated" in result
|
||||
assert "3 of 10" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_dir(self, tool, tmp_path):
|
||||
d = tmp_path / "empty"
|
||||
d.mkdir()
|
||||
result = await tool.execute(path=str(d))
|
||||
assert "empty" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_path_returns_clear_error(self, tool):
|
||||
result = await tool.execute()
|
||||
assert result == "Error listing directory: Unknown path"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace restriction + extra_allowed_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWorkspaceRestriction:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_blocked_outside_workspace(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
secret = outside / "secret.txt"
|
||||
secret.write_text("top secret")
|
||||
|
||||
tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(path=str(secret))
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_allowed_with_extra_dir(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
skill_file = skills_dir / "test_skill" / "SKILL.md"
|
||||
skill_file.parent.mkdir()
|
||||
skill_file.write_text("# Test Skill\nDo something.")
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(skill_file))
|
||||
assert "Test Skill" in result
|
||||
assert "Error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_dirs_does_not_widen_write(self, tmp_path):
|
||||
from nanobot.agent.tools.filesystem import WriteFileTool
|
||||
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
|
||||
tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
unrelated = tmp_path / "other"
|
||||
unrelated.mkdir()
|
||||
secret = unrelated / "secret.txt"
|
||||
secret.write_text("nope")
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(secret))
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
|
||||
"""Adding extra_allowed_dirs must not break normal workspace reads."""
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
ws_file = workspace / "README.md"
|
||||
ws_file.write_text("hello from workspace")
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(ws_file))
|
||||
assert "hello from workspace" in result
|
||||
assert "Error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_blocked_in_extra_dir(self, tmp_path):
|
||||
"""edit_file must not be able to modify files in extra_allowed_dirs."""
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
skill_file = skills_dir / "weather" / "SKILL.md"
|
||||
skill_file.parent.mkdir()
|
||||
skill_file.write_text("# Weather\nOriginal content.")
|
||||
|
||||
tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(
|
||||
path=str(skill_file),
|
||||
old_text="Original content.",
|
||||
new_text="Hacked content.",
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
assert skill_file.read_text() == "# Weather\nOriginal content."
|
||||
162
tests/tools/test_mcp_tool.py
Normal file
162
tests/tools/test_mcp_tool.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.mcp import MCPToolWrapper
|
||||
|
||||
|
||||
class _FakeTextContent:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
mod = ModuleType("mcp")
|
||||
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
|
||||
monkeypatch.setitem(sys.modules, "mcp", mod)
|
||||
|
||||
|
||||
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
||||
|
||||
|
||||
def test_wrapper_preserves_non_nullable_unions() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["value"]["anyOf"] == [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
]
|
||||
|
||||
|
||||
def test_wrapper_normalizes_nullable_property_type_union() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": ["string", "null"]},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["name"] == {"type": "string", "nullable": True}
|
||||
|
||||
|
||||
def test_wrapper_normalizes_nullable_property_anyof() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="demo",
|
||||
description="demo tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"description": "optional name",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||
|
||||
assert wrapper.parameters["properties"]["name"] == {
|
||||
"type": "string",
|
||||
"description": "optional name",
|
||||
"nullable": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_text_blocks() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
assert arguments == {"value": 1}
|
||||
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute(value=1)
|
||||
|
||||
assert result == "hello\n42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_returns_timeout_message() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call timed out after 0.01s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_server_cancelled_error() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call was cancelled)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_re_raises_external_cancellation() -> None:
|
||||
started = asyncio.Event()
|
||||
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
started.set()
|
||||
await asyncio.sleep(60)
|
||||
return SimpleNamespace(content=[])
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||
task = asyncio.create_task(wrapper.execute())
|
||||
await started.wait()
|
||||
|
||||
task.cancel()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_handles_generic_exception() -> None:
|
||||
async def call_tool(_name: str, arguments: dict) -> object:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
||||
|
||||
result = await wrapper.execute()
|
||||
|
||||
assert result == "(MCP tool call failed: RuntimeError)"
|
||||
10
tests/tools/test_message_tool.py
Normal file
10
tests/tools/test_message_tool.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_returns_error_when_no_target_context() -> None:
|
||||
tool = MessageTool()
|
||||
result = await tool.execute(content="test")
|
||||
assert result == "Error: No target channel/chat specified"
|
||||
132
tests/tools/test_message_tool_suppress.py
Normal file
132
tests/tools/test_message_tool_suppress.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
class TestMessageToolSuppressLogic:
|
||||
"""Final reply suppressed only when message tool sends to the same target."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suppress_when_sent_to_same_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Hello", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert result is None # suppressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_sent_to_different_target(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Email content", "channel": "email", "chat_id": "user@example.com"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="I've sent the email.", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Send email")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].channel == "email"
|
||||
assert result is not None # not suppressed
|
||||
assert result.channel == "feishu"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_suppress_when_no_message_tool_used(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello!", tool_calls=[]))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Hi")
|
||||
result = await loop._process_message(msg)
|
||||
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
calls = iter([
|
||||
LLMResponse(
|
||||
content="Visible<think>hidden</think>",
|
||||
tool_calls=[tool_call],
|
||||
reasoning_content="secret reasoning",
|
||||
thinking_blocks=[{"signature": "sig", "thought": "secret thought"}],
|
||||
),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
progress: list[tuple[str, bool]] = []
|
||||
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False),
|
||||
('read_file("foo.txt")', True),
|
||||
]
|
||||
|
||||
|
||||
class TestMessageToolTurnTracking:
|
||||
|
||||
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool.set_context("feishu", "chat1")
|
||||
assert not tool._sent_in_turn
|
||||
tool._sent_in_turn = True
|
||||
assert tool._sent_in_turn
|
||||
|
||||
def test_start_turn_resets(self) -> None:
|
||||
tool = MessageTool()
|
||||
tool._sent_in_turn = True
|
||||
tool.start_turn()
|
||||
assert not tool._sent_in_turn
|
||||
479
tests/tools/test_tool_validation.py
Normal file
479
tests/tools/test_tool_validation.py
Normal file
@@ -0,0 +1,479 @@
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
class SampleTool(Tool):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sample"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "sample tool"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "minLength": 2},
|
||||
"count": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||
"mode": {"type": "string", "enum": ["fast", "full"]},
|
||||
"meta": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tag": {"type": "string"},
|
||||
"flags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["tag"],
|
||||
},
|
||||
},
|
||||
"required": ["query", "count"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_validate_params_missing_required() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi"})
|
||||
assert "missing required count" in "; ".join(errors)
|
||||
|
||||
|
||||
def test_validate_params_type_and_range() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi", "count": 0})
|
||||
assert any("count must be >= 1" in e for e in errors)
|
||||
|
||||
errors = tool.validate_params({"query": "hi", "count": "2"})
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_enum_and_min_length() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"})
|
||||
assert any("query must be at least 2 chars" in e for e in errors)
|
||||
assert any("mode must be one of" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_nested_object_and_array() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params(
|
||||
{
|
||||
"query": "hi",
|
||||
"count": 2,
|
||||
"meta": {"flags": [1, "ok"]},
|
||||
}
|
||||
)
|
||||
assert any("missing required meta.tag" in e for e in errors)
|
||||
assert any("meta.flags[0] should be string" in e for e in errors)
|
||||
|
||||
|
||||
def test_validate_params_ignores_unknown_fields() -> None:
|
||||
tool = SampleTool()
|
||||
errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"})
|
||||
assert errors == []
|
||||
|
||||
|
||||
async def test_registry_returns_validation_error() -> None:
|
||||
reg = ToolRegistry()
|
||||
reg.register(SampleTool())
|
||||
result = await reg.execute("sample", {"query": "hi"})
|
||||
assert "Invalid parameters" in result
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
|
||||
cmd = r"type C:\user\workspace\txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert paths == [r"C:\user\workspace\txt"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||
cmd = ".venv/bin/python script.py"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/bin/python" not in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_home_paths() -> None:
|
||||
cmd = "cat ~/.nanobot/config.json > ~/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
assert "~/out.txt" in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_quoted_paths() -> None:
|
||||
cmd = 'cat "/tmp/data.txt" "~/.nanobot/config.json"'
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "~/.nanobot/config.json" in paths
|
||||
|
||||
|
||||
def test_exec_guard_blocks_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command("cat ~/.nanobot/config.json", str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command('cat "~/.nanobot/config.json"', str(tmp_path))
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
class CastTestTool(Tool):
|
||||
"""Minimal tool for testing cast_params."""
|
||||
|
||||
def __init__(self, schema: dict[str, Any]) -> None:
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cast_test"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "test tool for casting"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_cast_params_string_to_int() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "42"})
|
||||
assert result["count"] == 42
|
||||
assert isinstance(result["count"], int)
|
||||
|
||||
|
||||
def test_cast_params_string_to_number() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "3.14"})
|
||||
assert result["rate"] == 3.14
|
||||
assert isinstance(result["rate"], float)
|
||||
|
||||
|
||||
def test_cast_params_string_to_bool() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"enabled": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||
|
||||
|
||||
def test_cast_params_array_items() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||
assert result["nums"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_cast_params_nested_object() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"port": {"type": "integer"},
|
||||
"debug": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||
assert result["config"]["port"] == 8080
|
||||
assert result["config"]["debug"] is True
|
||||
|
||||
|
||||
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||
"""Booleans should not be silently cast to integers."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": True})
|
||||
assert result["count"] is True
|
||||
errors = tool.validate_params(result)
|
||||
assert any("count should be integer" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_preserves_empty_string() -> None:
|
||||
"""Empty strings should be preserved for string type."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": ""})
|
||||
assert result["name"] == ""
|
||||
|
||||
|
||||
def test_cast_params_bool_string_false() -> None:
|
||||
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||
|
||||
|
||||
def test_cast_params_bool_string_invalid() -> None:
|
||||
"""Invalid boolean strings should not be cast."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"flag": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
# Invalid strings should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"flag": "random"})
|
||||
assert result["flag"] == "random"
|
||||
result = tool.cast_params({"flag": "maybe"})
|
||||
assert result["flag"] == "maybe"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_int() -> None:
|
||||
"""Invalid strings should not be cast to integer."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "abc"})
|
||||
assert result["count"] == "abc" # Original value preserved
|
||||
result = tool.cast_params({"count": "12.5.7"})
|
||||
assert result["count"] == "12.5.7"
|
||||
|
||||
|
||||
def test_cast_params_invalid_string_to_number() -> None:
|
||||
"""Invalid strings should not be cast to number."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "not_a_number"})
|
||||
assert result["rate"] == "not_a_number"
|
||||
|
||||
|
||||
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||
"""Booleans should not pass number validation."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"rate": False})
|
||||
assert any("rate should be number" in e for e in errors)
|
||||
|
||||
|
||||
def test_cast_params_none_values() -> None:
|
||||
"""Test None handling for different types."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "integer"},
|
||||
"items": {"type": "array"},
|
||||
"config": {"type": "object"},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params(
|
||||
{
|
||||
"name": None,
|
||||
"count": None,
|
||||
"items": None,
|
||||
"config": None,
|
||||
}
|
||||
)
|
||||
# None should be preserved for all types
|
||||
assert result["name"] is None
|
||||
assert result["count"] is None
|
||||
assert result["items"] is None
|
||||
assert result["config"] is None
|
||||
|
||||
|
||||
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||
"""Single values should NOT be automatically wrapped into arrays."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"items": {"type": "array"}},
|
||||
}
|
||||
)
|
||||
# Non-array values should be preserved (validation will catch them)
|
||||
result = tool.cast_params({"items": 5})
|
||||
assert result["items"] == 5 # Not wrapped to [5]
|
||||
result = tool.cast_params({"items": "text"})
|
||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||
|
||||
|
||||
# --- ExecTool enhancement tests ---
|
||||
|
||||
|
||||
async def test_exec_always_returns_exit_code() -> None:
|
||||
"""Exit code should appear in output even on success (exit 0)."""
|
||||
tool = ExecTool()
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "Exit code: 0" in result
|
||||
assert "hello" in result
|
||||
|
||||
|
||||
async def test_exec_head_tail_truncation() -> None:
|
||||
"""Long output should preserve both head and tail."""
|
||||
tool = ExecTool()
|
||||
# Generate output that exceeds _MAX_OUTPUT
|
||||
big = "A" * 6000 + "\n" + "B" * 6000
|
||||
result = await tool.execute(command=f"echo '{big}'")
|
||||
assert "chars truncated" in result
|
||||
# Head portion should start with As
|
||||
assert result.startswith("A")
|
||||
# Tail portion should end with the exit code which comes after Bs
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_parameter() -> None:
|
||||
"""LLM-supplied timeout should override the constructor default."""
|
||||
tool = ExecTool(timeout=60)
|
||||
# A very short timeout should cause the command to be killed
|
||||
result = await tool.execute(command="sleep 10", timeout=1)
|
||||
assert "timed out" in result
|
||||
assert "1 seconds" in result
|
||||
|
||||
|
||||
async def test_exec_timeout_capped_at_max() -> None:
|
||||
"""Timeout values above _MAX_TIMEOUT should be clamped."""
|
||||
tool = ExecTool()
|
||||
# Should not raise — just clamp to 600
|
||||
result = await tool.execute(command="echo ok", timeout=9999)
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
# --- _resolve_type and nullable param tests ---
|
||||
|
||||
|
||||
def test_resolve_type_simple_string() -> None:
|
||||
"""Simple string type passes through unchanged."""
|
||||
assert Tool._resolve_type("string") == "string"
|
||||
|
||||
|
||||
def test_resolve_type_union_with_null() -> None:
|
||||
"""Union type ['string', 'null'] resolves to 'string'."""
|
||||
assert Tool._resolve_type(["string", "null"]) == "string"
|
||||
|
||||
|
||||
def test_resolve_type_only_null() -> None:
|
||||
"""Union type ['null'] resolves to None (no non-null type)."""
|
||||
assert Tool._resolve_type(["null"]) is None
|
||||
|
||||
|
||||
def test_resolve_type_none_input() -> None:
|
||||
"""None input passes through as None."""
|
||||
assert Tool._resolve_type(None) is None
|
||||
|
||||
|
||||
def test_validate_nullable_param_accepts_string() -> None:
|
||||
"""Nullable string param should accept a string value."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": ["string", "null"]}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"name": "hello"})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_nullable_param_accepts_none() -> None:
|
||||
"""Nullable string param should accept None."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": ["string", "null"]}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"name": None})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_nullable_flag_accepts_none() -> None:
|
||||
"""OpenAI-normalized nullable params should still accept None locally."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string", "nullable": True}},
|
||||
}
|
||||
)
|
||||
errors = tool.validate_params({"name": None})
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_cast_nullable_param_no_crash() -> None:
|
||||
"""cast_params should not crash on nullable type (the original bug)."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": ["string", "null"]}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": "hello"})
|
||||
assert result["name"] == "hello"
|
||||
result = tool.cast_params({"name": None})
|
||||
assert result["name"] is None
|
||||
113
tests/tools/test_web_fetch_security.py
Normal file
113
tests/tools/test_web_fetch_security.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Tests for web_fetch SSRF protection and untrusted content marking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.web import WebFetchTool
|
||||
|
||||
|
||||
def _fake_resolve_private(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("169.254.169.254", 0))]
|
||||
|
||||
|
||||
def _fake_resolve_public(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 0))]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_blocks_private_ip():
|
||||
tool = WebFetchTool()
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_private):
|
||||
result = await tool.execute(url="http://169.254.169.254/computeMetadata/v1/")
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
assert "private" in data["error"].lower() or "blocked" in data["error"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_blocks_localhost():
|
||||
tool = WebFetchTool()
|
||||
def _resolve_localhost(hostname, port, family=0, type_=0):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("127.0.0.1", 0))]
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _resolve_localhost):
|
||||
result = await tool.execute(url="http://localhost/admin")
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_result_contains_untrusted_flag():
|
||||
"""When fetch succeeds, result JSON must include untrusted=True and the banner."""
|
||||
tool = WebFetchTool()
|
||||
|
||||
fake_html = "<html><head><title>Test</title></head><body><p>Hello world</p></body></html>"
|
||||
|
||||
import httpx
|
||||
|
||||
class FakeResponse:
|
||||
status_code = 200
|
||||
url = "https://example.com/page"
|
||||
text = fake_html
|
||||
headers = {"content-type": "text/html"}
|
||||
def raise_for_status(self): pass
|
||||
def json(self): return {}
|
||||
|
||||
async def _fake_get(self, url, **kwargs):
|
||||
return FakeResponse()
|
||||
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public), \
|
||||
patch("httpx.AsyncClient.get", _fake_get):
|
||||
result = await tool.execute(url="https://example.com/page")
|
||||
|
||||
data = json.loads(result)
|
||||
assert data.get("untrusted") is True
|
||||
assert "[External content" in data.get("text", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_fetch_blocks_private_redirect_before_returning_image(monkeypatch):
|
||||
tool = WebFetchTool()
|
||||
|
||||
class FakeStreamResponse:
|
||||
headers = {"content-type": "image/png"}
|
||||
url = "http://127.0.0.1/secret.png"
|
||||
content = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def aread(self):
|
||||
return self.content
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def stream(self, method, url, headers=None):
|
||||
return FakeStreamResponse()
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.web.httpx.AsyncClient", FakeClient)
|
||||
|
||||
with patch("nanobot.security.network.socket.getaddrinfo", _fake_resolve_public):
|
||||
result = await tool.execute(url="https://example.com/image.png")
|
||||
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
assert "redirect blocked" in data["error"].lower()
|
||||
Reference in New Issue
Block a user