MCP servers (e.g. Zapier) return JSON Schema union types like `"type": ["string", "null"]` for nullable parameters. The existing `validate_params()` and `cast_params()` methods expected only simple strings as `type`, causing `TypeError: unhashable type: 'list'` on every MCP tool call with nullable parameters. Add `_resolve_type()` helper that extracts the first non-null type from union types, and use it in `_cast_value()` and `_validate()`. Also handle `None` values correctly when the schema declares a nullable type. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
200 lines
7.0 KiB
Python
200 lines
7.0 KiB
Python
"""Base class for agent tools."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any
|
|
|
|
|
|
class Tool(ABC):
|
|
"""
|
|
Abstract base class for agent tools.
|
|
|
|
Tools are capabilities that the agent can use to interact with
|
|
the environment, such as reading files, executing commands, etc.
|
|
"""
|
|
|
|
_TYPE_MAP = {
|
|
"string": str,
|
|
"integer": int,
|
|
"number": (int, float),
|
|
"boolean": bool,
|
|
"array": list,
|
|
"object": dict,
|
|
}
|
|
|
|
@staticmethod
|
|
def _resolve_type(t: Any) -> str | None:
|
|
"""Resolve JSON Schema type to a simple string.
|
|
|
|
JSON Schema allows ``"type": ["string", "null"]`` (union types).
|
|
We extract the first non-null type so validation/casting works.
|
|
"""
|
|
if isinstance(t, list):
|
|
for item in t:
|
|
if item != "null":
|
|
return item
|
|
return None
|
|
return t
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""Tool name used in function calls."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def description(self) -> str:
|
|
"""Description of what the tool does."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def parameters(self) -> dict[str, Any]:
|
|
"""JSON Schema for tool parameters."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def execute(self, **kwargs: Any) -> str:
|
|
"""
|
|
Execute the tool with given parameters.
|
|
|
|
Args:
|
|
**kwargs: Tool-specific parameters.
|
|
|
|
Returns:
|
|
String result of the tool execution.
|
|
"""
|
|
pass
|
|
|
|
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
|
"""Apply safe schema-driven casts before validation."""
|
|
schema = self.parameters or {}
|
|
if schema.get("type", "object") != "object":
|
|
return params
|
|
|
|
return self._cast_object(params, schema)
|
|
|
|
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""Cast an object (dict) according to schema."""
|
|
if not isinstance(obj, dict):
|
|
return obj
|
|
|
|
props = schema.get("properties", {})
|
|
result = {}
|
|
|
|
for key, value in obj.items():
|
|
if key in props:
|
|
result[key] = self._cast_value(value, props[key])
|
|
else:
|
|
result[key] = value
|
|
|
|
return result
|
|
|
|
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
|
"""Cast a single value according to schema."""
|
|
target_type = self._resolve_type(schema.get("type"))
|
|
|
|
if target_type == "boolean" and isinstance(val, bool):
|
|
return val
|
|
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
|
return val
|
|
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
|
|
expected = self._TYPE_MAP[target_type]
|
|
if isinstance(val, expected):
|
|
return val
|
|
|
|
if target_type == "integer" and isinstance(val, str):
|
|
try:
|
|
return int(val)
|
|
except ValueError:
|
|
return val
|
|
|
|
if target_type == "number" and isinstance(val, str):
|
|
try:
|
|
return float(val)
|
|
except ValueError:
|
|
return val
|
|
|
|
if target_type == "string":
|
|
return val if val is None else str(val)
|
|
|
|
if target_type == "boolean" and isinstance(val, str):
|
|
val_lower = val.lower()
|
|
if val_lower in ("true", "1", "yes"):
|
|
return True
|
|
if val_lower in ("false", "0", "no"):
|
|
return False
|
|
return val
|
|
|
|
if target_type == "array" and isinstance(val, list):
|
|
item_schema = schema.get("items")
|
|
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
|
|
|
|
if target_type == "object" and isinstance(val, dict):
|
|
return self._cast_object(val, schema)
|
|
|
|
return val
|
|
|
|
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
|
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
|
if not isinstance(params, dict):
|
|
return [f"parameters must be an object, got {type(params).__name__}"]
|
|
schema = self.parameters or {}
|
|
if schema.get("type", "object") != "object":
|
|
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
|
return self._validate(params, {**schema, "type": "object"}, "")
|
|
|
|
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
|
raw_type = schema.get("type")
|
|
nullable = isinstance(raw_type, list) and "null" in raw_type
|
|
t, label = self._resolve_type(raw_type), path or "parameter"
|
|
if nullable and val is None:
|
|
return []
|
|
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
|
return [f"{label} should be integer"]
|
|
if t == "number" and (
|
|
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
|
|
):
|
|
return [f"{label} should be number"]
|
|
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
|
|
return [f"{label} should be {t}"]
|
|
|
|
errors = []
|
|
if "enum" in schema and val not in schema["enum"]:
|
|
errors.append(f"{label} must be one of {schema['enum']}")
|
|
if t in ("integer", "number"):
|
|
if "minimum" in schema and val < schema["minimum"]:
|
|
errors.append(f"{label} must be >= {schema['minimum']}")
|
|
if "maximum" in schema and val > schema["maximum"]:
|
|
errors.append(f"{label} must be <= {schema['maximum']}")
|
|
if t == "string":
|
|
if "minLength" in schema and len(val) < schema["minLength"]:
|
|
errors.append(f"{label} must be at least {schema['minLength']} chars")
|
|
if "maxLength" in schema and len(val) > schema["maxLength"]:
|
|
errors.append(f"{label} must be at most {schema['maxLength']} chars")
|
|
if t == "object":
|
|
props = schema.get("properties", {})
|
|
for k in schema.get("required", []):
|
|
if k not in val:
|
|
errors.append(f"missing required {path + '.' + k if path else k}")
|
|
for k, v in val.items():
|
|
if k in props:
|
|
errors.extend(self._validate(v, props[k], path + "." + k if path else k))
|
|
if t == "array" and "items" in schema:
|
|
for i, item in enumerate(val):
|
|
errors.extend(
|
|
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
|
|
)
|
|
return errors
|
|
|
|
def to_schema(self) -> dict[str, Any]:
|
|
"""Convert tool to OpenAI function schema format."""
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"parameters": self.parameters,
|
|
},
|
|
}
|