fix(mcp): preserve schema semantics during normalization
Only normalize nullable MCP tool schemas for OpenAI-compatible providers so optional params still work without collapsing unrelated unions. Also teach local validation to honor nullable flags and add regression coverage for nullable and non-nullable schemas. Made-with: Cursor
This commit is contained in:
@@ -146,7 +146,9 @@ class Tool(ABC):
|
|||||||
|
|
||||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||||
raw_type = schema.get("type")
|
raw_type = schema.get("type")
|
||||||
nullable = isinstance(raw_type, list) and "null" in raw_type
|
nullable = (isinstance(raw_type, list) and "null" in raw_type) or schema.get(
|
||||||
|
"nullable", False
|
||||||
|
)
|
||||||
t, label = self._resolve_type(raw_type), path or "parameter"
|
t, label = self._resolve_type(raw_type), path or "parameter"
|
||||||
if nullable and val is None:
|
if nullable and val is None:
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -11,103 +11,67 @@ from nanobot.agent.tools.base import Tool
|
|||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
def _normalize_schema_for_openai(schema: dict[str, Any]) -> dict[str, Any]:
|
def _extract_nullable_branch(options: Any) -> tuple[dict[str, Any], bool] | None:
|
||||||
"""Normalize JSON Schema for OpenAI-compatible providers.
|
"""Return the single non-null branch for nullable unions."""
|
||||||
|
if not isinstance(options, list):
|
||||||
|
return None
|
||||||
|
|
||||||
OpenAI's API (and many compatible providers) only supports a subset of JSON Schema:
|
non_null: list[dict[str, Any]] = []
|
||||||
- Top-level type must be 'object'
|
saw_null = False
|
||||||
- No oneOf/anyOf/allOf/enum/not at the top level
|
for option in options:
|
||||||
- Properties should have simple types
|
if not isinstance(option, dict):
|
||||||
"""
|
return None
|
||||||
|
if option.get("type") == "null":
|
||||||
|
saw_null = True
|
||||||
|
continue
|
||||||
|
non_null.append(option)
|
||||||
|
|
||||||
|
if saw_null and len(non_null) == 1:
|
||||||
|
return non_null[0], True
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
||||||
|
"""Normalize only nullable JSON Schema patterns for tool definitions."""
|
||||||
if not isinstance(schema, dict):
|
if not isinstance(schema, dict):
|
||||||
return {"type": "object", "properties": {}}
|
return {"type": "object", "properties": {}}
|
||||||
|
|
||||||
# If schema has oneOf/anyOf/allOf at top level, try to extract the first option
|
normalized = dict(schema)
|
||||||
for key in ["oneOf", "anyOf", "allOf"]:
|
|
||||||
if key in schema:
|
|
||||||
options = schema[key]
|
|
||||||
if isinstance(options, list) and len(options) > 0:
|
|
||||||
# Use the first option as the base schema
|
|
||||||
first_option = options[0]
|
|
||||||
if isinstance(first_option, dict):
|
|
||||||
# Merge with other schema properties, preferring the first option
|
|
||||||
normalized = dict(schema)
|
|
||||||
del normalized[key]
|
|
||||||
normalized.update(first_option)
|
|
||||||
return _normalize_schema_for_openai(normalized)
|
|
||||||
|
|
||||||
# Ensure top-level type is object
|
raw_type = normalized.get("type")
|
||||||
if schema.get("type") != "object":
|
if isinstance(raw_type, list):
|
||||||
# If no type specified or different type, default to object
|
non_null = [item for item in raw_type if item != "null"]
|
||||||
schema = {"type": "object", **{k: v for k, v in schema.items() if k != "type"}}
|
if "null" in raw_type and len(non_null) == 1:
|
||||||
|
normalized["type"] = non_null[0]
|
||||||
|
normalized["nullable"] = True
|
||||||
|
|
||||||
# Clean up unsupported properties at top level
|
for key in ("oneOf", "anyOf"):
|
||||||
unsupported = ["enum", "not", "const"]
|
nullable_branch = _extract_nullable_branch(normalized.get(key))
|
||||||
for key in unsupported:
|
if nullable_branch is not None:
|
||||||
schema.pop(key, None)
|
branch, _ = nullable_branch
|
||||||
|
merged = {k: v for k, v in normalized.items() if k != key}
|
||||||
|
merged.update(branch)
|
||||||
|
normalized = merged
|
||||||
|
normalized["nullable"] = True
|
||||||
|
break
|
||||||
|
|
||||||
# Ensure properties and required exist
|
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||||
if "properties" not in schema:
|
normalized["properties"] = {
|
||||||
schema["properties"] = {}
|
name: _normalize_schema_for_openai(prop)
|
||||||
if "required" not in schema:
|
if isinstance(prop, dict)
|
||||||
schema["required"] = []
|
else prop
|
||||||
|
for name, prop in normalized["properties"].items()
|
||||||
|
}
|
||||||
|
|
||||||
# Recursively normalize nested property schemas
|
if "items" in normalized and isinstance(normalized["items"], dict):
|
||||||
if "properties" in schema and isinstance(schema["properties"], dict):
|
normalized["items"] = _normalize_schema_for_openai(normalized["items"])
|
||||||
for prop_name, prop_schema in schema["properties"].items():
|
|
||||||
if isinstance(prop_schema, dict):
|
|
||||||
schema["properties"][prop_name] = _normalize_property_schema(prop_schema)
|
|
||||||
|
|
||||||
return schema
|
if normalized.get("type") != "object":
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
normalized.setdefault("properties", {})
|
||||||
def _normalize_property_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
normalized.setdefault("required", [])
|
||||||
"""Normalize a property schema for OpenAI compatibility."""
|
return normalized
|
||||||
if not isinstance(schema, dict):
|
|
||||||
return {"type": "string"}
|
|
||||||
|
|
||||||
# Handle oneOf/anyOf in properties
|
|
||||||
for key in ["oneOf", "anyOf"]:
|
|
||||||
if key in schema:
|
|
||||||
options = schema[key]
|
|
||||||
if isinstance(options, list) and len(options) > 0:
|
|
||||||
first_option = options[0]
|
|
||||||
if isinstance(first_option, dict):
|
|
||||||
# Replace the complex schema with the first option
|
|
||||||
result = {k: v for k, v in schema.items() if k not in [key, "allOf", "not"]}
|
|
||||||
result.update(first_option)
|
|
||||||
return _normalize_property_schema(result)
|
|
||||||
|
|
||||||
# Handle allOf by merging all subschemas
|
|
||||||
if "allOf" in schema:
|
|
||||||
subschemas = schema["allOf"]
|
|
||||||
if isinstance(subschemas, list):
|
|
||||||
merged = {}
|
|
||||||
for sub in subschemas:
|
|
||||||
if isinstance(sub, dict):
|
|
||||||
merged.update(sub)
|
|
||||||
# Remove allOf and merge with other properties
|
|
||||||
result = {k: v for k, v in schema.items() if k != "allOf"}
|
|
||||||
result.update(merged)
|
|
||||||
return _normalize_property_schema(result)
|
|
||||||
|
|
||||||
# Ensure type is simple
|
|
||||||
if "type" not in schema:
|
|
||||||
# Try to infer type from other properties
|
|
||||||
if "enum" in schema:
|
|
||||||
schema["type"] = "string"
|
|
||||||
elif "properties" in schema:
|
|
||||||
schema["type"] = "object"
|
|
||||||
elif "items" in schema:
|
|
||||||
schema["type"] = "array"
|
|
||||||
else:
|
|
||||||
schema["type"] = "string"
|
|
||||||
|
|
||||||
# Clean up not/const
|
|
||||||
schema.pop("not", None)
|
|
||||||
schema.pop("const", None)
|
|
||||||
|
|
||||||
return schema
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolWrapper(Tool):
|
class MCPToolWrapper(Tool):
|
||||||
|
|||||||
@@ -84,6 +84,69 @@ def _make_wrapper(session: object, *, timeout: float = 0.1) -> MCPToolWrapper:
|
|||||||
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
return MCPToolWrapper(session, "test", tool_def, tool_timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapper_preserves_non_nullable_unions() -> None:
|
||||||
|
tool_def = SimpleNamespace(
|
||||||
|
name="demo",
|
||||||
|
description="demo tool",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"value": {
|
||||||
|
"anyOf": [{"type": "string"}, {"type": "integer"}],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||||
|
|
||||||
|
assert wrapper.parameters["properties"]["value"]["anyOf"] == [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "integer"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapper_normalizes_nullable_property_type_union() -> None:
|
||||||
|
tool_def = SimpleNamespace(
|
||||||
|
name="demo",
|
||||||
|
description="demo tool",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": ["string", "null"]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||||
|
|
||||||
|
assert wrapper.parameters["properties"]["name"] == {"type": "string", "nullable": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrapper_normalizes_nullable_property_anyof() -> None:
|
||||||
|
tool_def = SimpleNamespace(
|
||||||
|
name="demo",
|
||||||
|
description="demo tool",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||||
|
"description": "optional name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "test", tool_def)
|
||||||
|
|
||||||
|
assert wrapper.parameters["properties"]["name"] == {
|
||||||
|
"type": "string",
|
||||||
|
"description": "optional name",
|
||||||
|
"nullable": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_returns_text_blocks() -> None:
|
async def test_execute_returns_text_blocks() -> None:
|
||||||
async def call_tool(_name: str, arguments: dict) -> object:
|
async def call_tool(_name: str, arguments: dict) -> object:
|
||||||
|
|||||||
@@ -455,6 +455,18 @@ def test_validate_nullable_param_accepts_none() -> None:
|
|||||||
assert errors == []
|
assert errors == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_nullable_flag_accepts_none() -> None:
|
||||||
|
"""OpenAI-normalized nullable params should still accept None locally."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"type": "string", "nullable": True}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
errors = tool.validate_params({"name": None})
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
|
||||||
def test_cast_nullable_param_no_crash() -> None:
|
def test_cast_nullable_param_no_crash() -> None:
|
||||||
"""cast_params should not crash on nullable type (the original bug)."""
|
"""cast_params should not crash on nullable type (the original bug)."""
|
||||||
tool = CastTestTool(
|
tool = CastTestTool(
|
||||||
|
|||||||
Reference in New Issue
Block a user