fix(mcp): preserve schema semantics during normalization

Only normalize nullable MCP tool schemas for OpenAI-compatible providers so optional params still work without collapsing unrelated unions. Also teach local validation to honor nullable flags and add regression coverage for nullable and non-nullable schemas.

Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-21 06:21:26 +00:00
committed by Xubin Ren
parent b6cf7020ac
commit e87bb0a82d
4 changed files with 135 additions and 94 deletions

View File

@@ -146,7 +146,9 @@ class Tool(ABC):
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]: def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
raw_type = schema.get("type") raw_type = schema.get("type")
nullable = isinstance(raw_type, list) and "null" in raw_type nullable = (isinstance(raw_type, list) and "null" in raw_type) or schema.get(
"nullable", False
)
t, label = self._resolve_type(raw_type), path or "parameter" t, label = self._resolve_type(raw_type), path or "parameter"
if nullable and val is None: if nullable and val is None:
return [] return []

View File

@@ -11,103 +11,67 @@ from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
def _normalize_schema_for_openai(schema: dict[str, Any]) -> dict[str, Any]: def _extract_nullable_branch(options: Any) -> tuple[dict[str, Any], bool] | None:
"""Normalize JSON Schema for OpenAI-compatible providers. """Return the single non-null branch for nullable unions."""
if not isinstance(options, list):
return None
OpenAI's API (and many compatible providers) only supports a subset of JSON Schema: non_null: list[dict[str, Any]] = []
- Top-level type must be 'object' saw_null = False
- No oneOf/anyOf/allOf/enum/not at the top level for option in options:
- Properties should have simple types if not isinstance(option, dict):
""" return None
if option.get("type") == "null":
saw_null = True
continue
non_null.append(option)
if saw_null and len(non_null) == 1:
return non_null[0], True
return None
def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
"""Normalize only nullable JSON Schema patterns for tool definitions."""
if not isinstance(schema, dict): if not isinstance(schema, dict):
return {"type": "object", "properties": {}} return {"type": "object", "properties": {}}
# If schema has oneOf/anyOf/allOf at top level, try to extract the first option normalized = dict(schema)
for key in ["oneOf", "anyOf", "allOf"]:
if key in schema:
options = schema[key]
if isinstance(options, list) and len(options) > 0:
# Use the first option as the base schema
first_option = options[0]
if isinstance(first_option, dict):
# Merge with other schema properties, preferring the first option
normalized = dict(schema)
del normalized[key]
normalized.update(first_option)
return _normalize_schema_for_openai(normalized)
# Ensure top-level type is object raw_type = normalized.get("type")
if schema.get("type") != "object": if isinstance(raw_type, list):
# If no type specified or different type, default to object non_null = [item for item in raw_type if item != "null"]
schema = {"type": "object", **{k: v for k, v in schema.items() if k != "type"}} if "null" in raw_type and len(non_null) == 1:
normalized["type"] = non_null[0]
normalized["nullable"] = True
# Clean up unsupported properties at top level for key in ("oneOf", "anyOf"):
unsupported = ["enum", "not", "const"] nullable_branch = _extract_nullable_branch(normalized.get(key))
for key in unsupported: if nullable_branch is not None:
schema.pop(key, None) branch, _ = nullable_branch
merged = {k: v for k, v in normalized.items() if k != key}
merged.update(branch)
normalized = merged
normalized["nullable"] = True
break
# Ensure properties and required exist if "properties" in normalized and isinstance(normalized["properties"], dict):
if "properties" not in schema: normalized["properties"] = {
schema["properties"] = {} name: _normalize_schema_for_openai(prop)
if "required" not in schema: if isinstance(prop, dict)
schema["required"] = [] else prop
for name, prop in normalized["properties"].items()
}
# Recursively normalize nested property schemas if "items" in normalized and isinstance(normalized["items"], dict):
if "properties" in schema and isinstance(schema["properties"], dict): normalized["items"] = _normalize_schema_for_openai(normalized["items"])
for prop_name, prop_schema in schema["properties"].items():
if isinstance(prop_schema, dict):
schema["properties"][prop_name] = _normalize_property_schema(prop_schema)
return schema if normalized.get("type") != "object":
return normalized
normalized.setdefault("properties", {})
def _normalize_property_schema(schema: dict[str, Any]) -> dict[str, Any]: normalized.setdefault("required", [])
"""Normalize a property schema for OpenAI compatibility.""" return normalized
if not isinstance(schema, dict):
return {"type": "string"}
# Handle oneOf/anyOf in properties
for key in ["oneOf", "anyOf"]:
if key in schema:
options = schema[key]
if isinstance(options, list) and len(options) > 0:
first_option = options[0]
if isinstance(first_option, dict):
# Replace the complex schema with the first option
result = {k: v for k, v in schema.items() if k not in [key, "allOf", "not"]}
result.update(first_option)
return _normalize_property_schema(result)
# Handle allOf by merging all subschemas
if "allOf" in schema:
subschemas = schema["allOf"]
if isinstance(subschemas, list):
merged = {}
for sub in subschemas:
if isinstance(sub, dict):
merged.update(sub)
# Remove allOf and merge with other properties
result = {k: v for k, v in schema.items() if k != "allOf"}
result.update(merged)
return _normalize_property_schema(result)
# Ensure type is simple
if "type" not in schema:
# Try to infer type from other properties
if "enum" in schema:
schema["type"] = "string"
elif "properties" in schema:
schema["type"] = "object"
elif "items" in schema:
schema["type"] = "array"
else:
schema["type"] = "string"
# Clean up not/const
schema.pop("not", None)
schema.pop("const", None)
return schema
class MCPToolWrapper(Tool): class MCPToolWrapper(Tool):

View File

@@ -84,6 +84,69 @@ def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout) return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
def test_wrapper_preserves_non_nullable_unions() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"value": {
"anyOf": [{"type": "string"}, {"type": "integer"}],
}
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["value"]["anyOf"] == [
{"type": "string"},
{"type": "integer"},
]
def test_wrapper_normalizes_nullable_property_type_union() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"name": {"type": ["string", "null"]},
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["name"] == {"type": "string", "nullable": True}
def test_wrapper_normalizes_nullable_property_anyof() -> None:
tool_def = SimpleNamespace(
name="demo",
description="demo tool",
inputSchema={
"type": "object",
"properties": {
"name": {
"anyOf": [{"type": "string"}, {"type": "null"}],
"description": "optional name",
},
},
},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
assert wrapper.parameters["properties"]["name"] == {
"type": "string",
"description": "optional name",
"nullable": True,
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_execute_returns_text_blocks() -> None: async def test_execute_returns_text_blocks() -> None:
async def call_tool(_name: str, arguments: dict) -> object: async def call_tool(_name: str, arguments: dict) -> object:

View File

@@ -455,6 +455,18 @@ def test_validate_nullable_param_accepts_none() -> None:
assert errors == [] assert errors == []
def test_validate_nullable_flag_accepts_none() -> None:
"""OpenAI-normalized nullable params should still accept None locally."""
tool = CastTestTool(
{
"type": "object",
"properties": {"name": {"type": "string", "nullable": True}},
}
)
errors = tool.validate_params({"name": None})
assert errors == []
def test_cast_nullable_param_no_crash() -> None: def test_cast_nullable_param_no_crash() -> None:
"""cast_params should not crash on nullable type (the original bug).""" """cast_params should not crash on nullable type (the original bug)."""
tool = CastTestTool( tool = CastTestTool(