Files
nanobot/nanobot/agent/tools/mcp.py

285 lines
12 KiB
Python

"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
import asyncio
from contextlib import AsyncExitStack
from typing import Any
import httpx
from loguru import logger
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
def _normalize_schema_for_openai(schema: dict[str, Any]) -> dict[str, Any]:
"""Normalize JSON Schema for OpenAI-compatible providers.
OpenAI's API (and many compatible providers) only supports a subset of JSON Schema:
- Top-level type must be 'object'
- No oneOf/anyOf/allOf/enum/not at the top level
- Properties should have simple types
"""
if not isinstance(schema, dict):
return {"type": "object", "properties": {}}
# If schema has oneOf/anyOf/allOf at top level, try to extract the first option
for key in ["oneOf", "anyOf", "allOf"]:
if key in schema:
options = schema[key]
if isinstance(options, list) and len(options) > 0:
# Use the first option as the base schema
first_option = options[0]
if isinstance(first_option, dict):
# Merge with other schema properties, preferring the first option
normalized = dict(schema)
del normalized[key]
normalized.update(first_option)
return _normalize_schema_for_openai(normalized)
# Ensure top-level type is object
if schema.get("type") != "object":
# If no type specified or different type, default to object
schema = {"type": "object", **{k: v for k, v in schema.items() if k != "type"}}
# Clean up unsupported properties at top level
unsupported = ["enum", "not", "const"]
for key in unsupported:
schema.pop(key, None)
# Ensure properties and required exist
if "properties" not in schema:
schema["properties"] = {}
if "required" not in schema:
schema["required"] = []
# Recursively normalize nested property schemas
if "properties" in schema and isinstance(schema["properties"], dict):
for prop_name, prop_schema in schema["properties"].items():
if isinstance(prop_schema, dict):
schema["properties"][prop_name] = _normalize_property_schema(prop_schema)
return schema
def _normalize_property_schema(schema: dict[str, Any]) -> dict[str, Any]:
"""Normalize a property schema for OpenAI compatibility."""
if not isinstance(schema, dict):
return {"type": "string"}
# Handle oneOf/anyOf in properties
for key in ["oneOf", "anyOf"]:
if key in schema:
options = schema[key]
if isinstance(options, list) and len(options) > 0:
first_option = options[0]
if isinstance(first_option, dict):
# Replace the complex schema with the first option
result = {k: v for k, v in schema.items() if k not in [key, "allOf", "not"]}
result.update(first_option)
return _normalize_property_schema(result)
# Handle allOf by merging all subschemas
if "allOf" in schema:
subschemas = schema["allOf"]
if isinstance(subschemas, list):
merged = {}
for sub in subschemas:
if isinstance(sub, dict):
merged.update(sub)
# Remove allOf and merge with other properties
result = {k: v for k, v in schema.items() if k != "allOf"}
result.update(merged)
return _normalize_property_schema(result)
# Ensure type is simple
if "type" not in schema:
# Try to infer type from other properties
if "enum" in schema:
schema["type"] = "string"
elif "properties" in schema:
schema["type"] = "object"
elif "items" in schema:
schema["type"] = "array"
else:
schema["type"] = "string"
# Clean up not/const
schema.pop("not", None)
schema.pop("const", None)
return schema
class MCPToolWrapper(Tool):
"""Wraps a single MCP server tool as a nanobot Tool."""
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
self._session = session
self._original_name = tool_def.name
self._name = f"mcp_{server_name}_{tool_def.name}"
self._description = tool_def.description or tool_def.name
raw_schema = tool_def.inputSchema or {"type": "object", "properties": {}}
self._parameters = _normalize_schema_for_openai(raw_schema)
self._tool_timeout = tool_timeout
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def parameters(self) -> dict[str, Any]:
return self._parameters
async def execute(self, **kwargs: Any) -> str:
from mcp import types
try:
result = await asyncio.wait_for(
self._session.call_tool(self._original_name, arguments=kwargs),
timeout=self._tool_timeout,
)
except asyncio.TimeoutError:
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
return f"(MCP tool call timed out after {self._tool_timeout}s)"
except asyncio.CancelledError:
# MCP SDK's anyio cancel scopes can leak CancelledError on timeout/failure.
# Re-raise only if our task was externally cancelled (e.g. /stop).
task = asyncio.current_task()
if task is not None and task.cancelling() > 0:
raise
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
return "(MCP tool call was cancelled)"
except Exception as exc:
logger.exception(
"MCP tool '{}' failed: {}: {}",
self._name,
type(exc).__name__,
exc,
)
return f"(MCP tool call failed: {type(exc).__name__})"
parts = []
for block in result.content:
if isinstance(block, types.TextContent):
parts.append(block.text)
else:
parts.append(str(block))
return "\n".join(parts) or "(no output)"
async def connect_mcp_servers(
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
) -> None:
"""Connect to configured MCP servers and register their tools."""
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
for name, cfg in mcp_servers.items():
try:
transport_type = cfg.type
if not transport_type:
if cfg.command:
transport_type = "stdio"
elif cfg.url:
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue
if transport_type == "stdio":
params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None
)
read, write = await stack.enter_async_context(stdio_client(params))
elif transport_type == "sse":
def httpx_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
merged_headers = {**(cfg.headers or {}), **(headers or {})}
return httpx.AsyncClient(
headers=merged_headers or None,
follow_redirects=True,
timeout=timeout,
auth=auth,
)
read, write = await stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
)
elif transport_type == "streamableHttp":
# Always provide an explicit httpx client so MCP HTTP transport does not
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
http_client = await stack.enter_async_context(
httpx.AsyncClient(
headers=cfg.headers or None,
follow_redirects=True,
timeout=None,
)
)
read, write, _ = await stack.enter_async_context(
streamable_http_client(cfg.url, http_client=http_client)
)
else:
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
continue
session = await stack.enter_async_context(ClientSession(read, write))
await session.initialize()
tools = await session.list_tools()
enabled_tools = set(cfg.enabled_tools)
allow_all_tools = "*" in enabled_tools
registered_count = 0
matched_enabled_tools: set[str] = set()
available_raw_names = [tool_def.name for tool_def in tools.tools]
available_wrapped_names = [f"mcp_{name}_{tool_def.name}" for tool_def in tools.tools]
for tool_def in tools.tools:
wrapped_name = f"mcp_{name}_{tool_def.name}"
if (
not allow_all_tools
and tool_def.name not in enabled_tools
and wrapped_name not in enabled_tools
):
logger.debug(
"MCP: skipping tool '{}' from server '{}' (not in enabledTools)",
wrapped_name,
name,
)
continue
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
registry.register(wrapper)
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
registered_count += 1
if enabled_tools:
if tool_def.name in enabled_tools:
matched_enabled_tools.add(tool_def.name)
if wrapped_name in enabled_tools:
matched_enabled_tools.add(wrapped_name)
if enabled_tools and not allow_all_tools:
unmatched_enabled_tools = sorted(enabled_tools - matched_enabled_tools)
if unmatched_enabled_tools:
logger.warning(
"MCP server '{}': enabledTools entries not found: {}. Available raw names: {}. "
"Available wrapped names: {}",
name,
", ".join(unmatched_enabled_tools),
", ".join(available_raw_names) or "(none)",
", ".join(available_wrapped_names) or "(none)",
)
logger.info("MCP server '{}': connected, {} tools registered", name, registered_count)
except Exception as e:
logger.error("MCP server '{}': failed to connect: {}", name, e)