diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 2cbffd0..400979b 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -36,6 +36,7 @@ class MCPToolWrapper(Tool): async def execute(self, **kwargs: Any) -> str: from mcp import types + try: result = await asyncio.wait_for( self._session.call_tool(self._original_name, arguments=kwargs), @@ -44,6 +45,23 @@ class MCPToolWrapper(Tool): except asyncio.TimeoutError: logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout) return f"(MCP tool call timed out after {self._tool_timeout}s)" + except asyncio.CancelledError: + # MCP SDK's anyio cancel scopes can leak CancelledError on timeout/failure. + # Re-raise only if our task was externally cancelled (e.g. /stop). + task = asyncio.current_task() + if task is not None and task.cancelling() > 0: + raise + logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name) + return "(MCP tool call was cancelled)" + except Exception as exc: + logger.exception( + "MCP tool '{}' failed: {}: {}", + self._name, + type(exc).__name__, + exc, + ) + return f"(MCP tool call failed: {type(exc).__name__})" + parts = [] for block in result.content: if isinstance(block, types.TextContent): diff --git a/tests/test_mcp_tool.py b/tests/test_mcp_tool.py new file mode 100644 index 0000000..bf68425 --- /dev/null +++ b/tests/test_mcp_tool.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import asyncio +import sys +from types import ModuleType, SimpleNamespace + +import pytest + +from nanobot.agent.tools.mcp import MCPToolWrapper + + +class _FakeTextContent: + def __init__(self, text: str) -> None: + self.text = text + + +@pytest.fixture(autouse=True) +def _fake_mcp_module(monkeypatch: pytest.MonkeyPatch) -> None: + mod = ModuleType("mcp") + mod.types = SimpleNamespace(TextContent=_FakeTextContent) + monkeypatch.setitem(sys.modules, "mcp", mod) + + +def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper: + tool_def = SimpleNamespace( + name="demo", + description="demo tool", + inputSchema={"type": "object", "properties": {}}, + ) + return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout) + + +@pytest.mark.asyncio +async def test_execute_returns_text_blocks() -> None: + async def call_tool(_name: str, arguments: dict) -> object: + assert arguments == {"value": 1} + return SimpleNamespace(content=[_FakeTextContent("hello"), 42]) + + wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool)) + + result = await wrapper.execute(value=1) + + assert result == "hello\n42" + + +@pytest.mark.asyncio +async def test_execute_returns_timeout_message() -> None: + async def call_tool(_name: str, arguments: dict) -> object: + await asyncio.sleep(1) + return SimpleNamespace(content=[]) + + wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=0.01) + + result = await wrapper.execute() + + assert result == "(MCP tool call timed out after 0.01s)" + + +@pytest.mark.asyncio +async def test_execute_handles_server_cancelled_error() -> None: + async def call_tool(_name: str, arguments: dict) -> object: + raise asyncio.CancelledError() + + wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool)) + + result = await wrapper.execute() + + assert result == "(MCP tool call was cancelled)" + + +@pytest.mark.asyncio +async def test_execute_re_raises_external_cancellation() -> None: + started = asyncio.Event() + + async def call_tool(_name: str, arguments: dict) -> object: + started.set() + await asyncio.sleep(60) + return SimpleNamespace(content=[]) + + wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10) + task = asyncio.create_task(wrapper.execute()) + await started.wait() + + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +async def test_execute_handles_generic_exception() -> None: + async def call_tool(_name: str, arguments: dict) -> object: + raise RuntimeError("boom") + + wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool)) + + result = await wrapper.execute() + + assert result == "(MCP tool call failed: RuntimeError)"