- support both raw and wrapped MCP tool names - treat [\"*\"] as all tools and [] as no tools - add warnings, tests, and README docs for enabledTools
283 lines
8.5 KiB
Python
283 lines
8.5 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import AsyncExitStack, asynccontextmanager
|
|
import sys
|
|
from types import ModuleType, SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from nanobot.agent.tools.mcp import MCPToolWrapper, connect_mcp_servers
|
|
from nanobot.agent.tools.registry import ToolRegistry
|
|
from nanobot.config.schema import MCPServerConfig
|
|
|
|
|
|
class _FakeTextContent:
|
|
def __init__(self, text: str) -> None:
|
|
self.text = text
|
|
|
|
|
|
@pytest.fixture
|
|
def fake_mcp_runtime() -> dict[str, object | None]:
|
|
return {"session": None}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _fake_mcp_module(
|
|
monkeypatch: pytest.MonkeyPatch, fake_mcp_runtime: dict[str, object | None]
|
|
) -> None:
|
|
mod = ModuleType("mcp")
|
|
mod.types = SimpleNamespace(TextContent=_FakeTextContent)
|
|
|
|
class _FakeStdioServerParameters:
|
|
def __init__(self, command: str, args: list[str], env: dict | None = None) -> None:
|
|
self.command = command
|
|
self.args = args
|
|
self.env = env
|
|
|
|
class _FakeClientSession:
|
|
def __init__(self, _read: object, _write: object) -> None:
|
|
self._session = fake_mcp_runtime["session"]
|
|
|
|
async def __aenter__(self) -> object:
|
|
return self._session
|
|
|
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
|
return False
|
|
|
|
@asynccontextmanager
|
|
async def _fake_stdio_client(_params: object):
|
|
yield object(), object()
|
|
|
|
@asynccontextmanager
|
|
async def _fake_sse_client(_url: str, httpx_client_factory=None):
|
|
yield object(), object()
|
|
|
|
@asynccontextmanager
|
|
async def _fake_streamable_http_client(_url: str, http_client=None):
|
|
yield object(), object(), object()
|
|
|
|
mod.ClientSession = _FakeClientSession
|
|
mod.StdioServerParameters = _FakeStdioServerParameters
|
|
monkeypatch.setitem(sys.modules, "mcp", mod)
|
|
|
|
client_mod = ModuleType("mcp.client")
|
|
stdio_mod = ModuleType("mcp.client.stdio")
|
|
stdio_mod.stdio_client = _fake_stdio_client
|
|
sse_mod = ModuleType("mcp.client.sse")
|
|
sse_mod.sse_client = _fake_sse_client
|
|
streamable_http_mod = ModuleType("mcp.client.streamable_http")
|
|
streamable_http_mod.streamable_http_client = _fake_streamable_http_client
|
|
|
|
monkeypatch.setitem(sys.modules, "mcp.client", client_mod)
|
|
monkeypatch.setitem(sys.modules, "mcp.client.stdio", stdio_mod)
|
|
monkeypatch.setitem(sys.modules, "mcp.client.sse", sse_mod)
|
|
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", streamable_http_mod)
|
|
|
|
|
|
def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
|
tool_def = SimpleNamespace(
|
|
name="demo",
|
|
description="demo tool",
|
|
inputSchema={"type": "object", "properties": {}},
|
|
)
|
|
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_returns_text_blocks() -> None:
|
|
async def call_tool(_name: str, arguments: dict) -> object:
|
|
assert arguments == {"value": 1}
|
|
return SimpleNamespace(content=[_FakeTextContent("hello"), 42])
|
|
|
|
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
|
|
|
result = await wrapper.execute(value=1)
|
|
|
|
assert result == "hello\n42"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_returns_timeout_message() -> None:
|
|
async def call_tool(_name: str, arguments: dict) -> object:
|
|
await asyncio.sleep(1)
|
|
return SimpleNamespace(content=[])
|
|
|
|
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01)
|
|
|
|
result = await wrapper.execute()
|
|
|
|
assert result == "(MCP tool call timed out after 0.01s)"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_handles_server_cancelled_error() -> None:
|
|
async def call_tool(_name: str, arguments: dict) -> object:
|
|
raise asyncio.CancelledError()
|
|
|
|
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
|
|
|
result = await wrapper.execute()
|
|
|
|
assert result == "(MCP tool call was cancelled)"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_re_raises_external_cancellation() -> None:
|
|
started = asyncio.Event()
|
|
|
|
async def call_tool(_name: str, arguments: dict) -> object:
|
|
started.set()
|
|
await asyncio.sleep(60)
|
|
return SimpleNamespace(content=[])
|
|
|
|
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
|
task = asyncio.create_task(wrapper.execute())
|
|
await started.wait()
|
|
|
|
task.cancel()
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_handles_generic_exception() -> None:
|
|
async def call_tool(_name: str, arguments: dict) -> object:
|
|
raise RuntimeError("boom")
|
|
|
|
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool))
|
|
|
|
result = await wrapper.execute()
|
|
|
|
assert result == "(MCP tool call failed: RuntimeError)"
|
|
|
|
|
|
def _make_tool_def(name: str) -> SimpleNamespace:
|
|
return SimpleNamespace(
|
|
name=name,
|
|
description=f"{name} tool",
|
|
inputSchema={"type": "object", "properties": {}},
|
|
)
|
|
|
|
|
|
def _make_fake_session(tool_names: list[str]) -> SimpleNamespace:
|
|
async def initialize() -> None:
|
|
return None
|
|
|
|
async def list_tools() -> SimpleNamespace:
|
|
return SimpleNamespace(tools=[_make_tool_def(name) for name in tool_names])
|
|
|
|
return SimpleNamespace(initialize=initialize, list_tools=list_tools)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
|
fake_mcp_runtime: dict[str, object | None],
|
|
) -> None:
|
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
|
registry = ToolRegistry()
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
try:
|
|
await connect_mcp_servers(
|
|
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
|
registry,
|
|
stack,
|
|
)
|
|
finally:
|
|
await stack.aclose()
|
|
|
|
assert registry.tool_names == ["mcp_test_demo"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
|
fake_mcp_runtime: dict[str, object | None],
|
|
) -> None:
|
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
|
registry = ToolRegistry()
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
try:
|
|
await connect_mcp_servers(
|
|
{"test": MCPServerConfig(command="fake")},
|
|
registry,
|
|
stack,
|
|
)
|
|
finally:
|
|
await stack.aclose()
|
|
|
|
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
|
fake_mcp_runtime: dict[str, object | None],
|
|
) -> None:
|
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
|
registry = ToolRegistry()
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
try:
|
|
await connect_mcp_servers(
|
|
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
|
registry,
|
|
stack,
|
|
)
|
|
finally:
|
|
await stack.aclose()
|
|
|
|
assert registry.tool_names == ["mcp_test_demo"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
|
fake_mcp_runtime: dict[str, object | None],
|
|
) -> None:
|
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
|
registry = ToolRegistry()
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
try:
|
|
await connect_mcp_servers(
|
|
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
|
registry,
|
|
stack,
|
|
)
|
|
finally:
|
|
await stack.aclose()
|
|
|
|
assert registry.tool_names == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
|
fake_mcp_runtime: dict[str, object | None], monkeypatch: pytest.MonkeyPatch
|
|
) -> None:
|
|
fake_mcp_runtime["session"] = _make_fake_session(["demo"])
|
|
registry = ToolRegistry()
|
|
warnings: list[str] = []
|
|
|
|
def _warning(message: str, *args: object) -> None:
|
|
warnings.append(message.format(*args))
|
|
|
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
|
|
|
stack = AsyncExitStack()
|
|
await stack.__aenter__()
|
|
try:
|
|
await connect_mcp_servers(
|
|
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
|
registry,
|
|
stack,
|
|
)
|
|
finally:
|
|
await stack.aclose()
|
|
|
|
assert registry.tool_names == []
|
|
assert warnings
|
|
assert "enabledTools entries not found: unknown" in warnings[-1]
|
|
assert "Available raw names: demo" in warnings[-1]
|
|
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|