Merge PR #1610: auto cast tool params to match schema
This commit is contained in:
@@ -52,6 +52,75 @@ class Tool(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Apply safe schema-driven casts before validation."""
|
||||||
|
schema = self.parameters or {}
|
||||||
|
if schema.get("type", "object") != "object":
|
||||||
|
return params
|
||||||
|
|
||||||
|
return self._cast_object(params, schema)
|
||||||
|
|
||||||
|
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Cast an object (dict) according to schema."""
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return obj
|
||||||
|
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for key, value in obj.items():
|
||||||
|
if key in props:
|
||||||
|
result[key] = self._cast_value(value, props[key])
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
||||||
|
"""Cast a single value according to schema."""
|
||||||
|
target_type = schema.get("type")
|
||||||
|
|
||||||
|
if target_type == "boolean" and isinstance(val, bool):
|
||||||
|
return val
|
||||||
|
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
||||||
|
return val
|
||||||
|
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
|
||||||
|
expected = self._TYPE_MAP[target_type]
|
||||||
|
if isinstance(val, expected):
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "integer" and isinstance(val, str):
|
||||||
|
try:
|
||||||
|
return int(val)
|
||||||
|
except ValueError:
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "number" and isinstance(val, str):
|
||||||
|
try:
|
||||||
|
return float(val)
|
||||||
|
except ValueError:
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "string":
|
||||||
|
return val if val is None else str(val)
|
||||||
|
|
||||||
|
if target_type == "boolean" and isinstance(val, str):
|
||||||
|
val_lower = val.lower()
|
||||||
|
if val_lower in ("true", "1", "yes"):
|
||||||
|
return True
|
||||||
|
if val_lower in ("false", "0", "no"):
|
||||||
|
return False
|
||||||
|
return val
|
||||||
|
|
||||||
|
if target_type == "array" and isinstance(val, list):
|
||||||
|
item_schema = schema.get("items")
|
||||||
|
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
|
||||||
|
|
||||||
|
if target_type == "object" and isinstance(val, dict):
|
||||||
|
return self._cast_object(val, schema)
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||||
if not isinstance(params, dict):
|
if not isinstance(params, dict):
|
||||||
@@ -63,7 +132,13 @@ class Tool(ABC):
|
|||||||
|
|
||||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||||
t, label = schema.get("type"), path or "parameter"
|
t, label = schema.get("type"), path or "parameter"
|
||||||
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
||||||
|
return [f"{label} should be integer"]
|
||||||
|
if t == "number" and (
|
||||||
|
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
|
||||||
|
):
|
||||||
|
return [f"{label} should be number"]
|
||||||
|
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
|
||||||
return [f"{label} should be {t}"]
|
return [f"{label} should be {t}"]
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ class ToolRegistry:
|
|||||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Attempt to cast parameters to match schema types
|
||||||
|
params = tool.cast_params(params)
|
||||||
|
|
||||||
|
# Validate parameters
|
||||||
errors = tool.validate_params(params)
|
errors = tool.validate_params(params)
|
||||||
if errors:
|
if errors:
|
||||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||||
|
|||||||
@@ -106,3 +106,234 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
|||||||
paths = ExecTool._extract_absolute_paths(cmd)
|
paths = ExecTool._extract_absolute_paths(cmd)
|
||||||
assert "/tmp/data.txt" in paths
|
assert "/tmp/data.txt" in paths
|
||||||
assert "/tmp/out.txt" in paths
|
assert "/tmp/out.txt" in paths
|
||||||
|
|
||||||
|
|
||||||
|
# --- cast_params tests ---
|
||||||
|
|
||||||
|
|
||||||
|
class CastTestTool(Tool):
|
||||||
|
"""Minimal tool for testing cast_params."""
|
||||||
|
|
||||||
|
def __init__(self, schema: dict[str, Any]) -> None:
|
||||||
|
self._schema = schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "cast_test"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "test tool for casting"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return self._schema
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_string_to_int() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"count": {"type": "integer"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"count": "42"})
|
||||||
|
assert result["count"] == 42
|
||||||
|
assert isinstance(result["count"], int)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_string_to_number() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"rate": {"type": "number"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"rate": "3.14"})
|
||||||
|
assert result["rate"] == 3.14
|
||||||
|
assert isinstance(result["rate"], float)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_string_to_bool() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"enabled": {"type": "boolean"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||||
|
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||||
|
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_array_items() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||||
|
assert result["nums"] == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_nested_object() -> None:
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"config": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"port": {"type": "integer"},
|
||||||
|
"debug": {"type": "boolean"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||||
|
assert result["config"]["port"] == 8080
|
||||||
|
assert result["config"]["debug"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||||
|
"""Booleans should not be silently cast to integers."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"count": {"type": "integer"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"count": True})
|
||||||
|
assert result["count"] is True
|
||||||
|
errors = tool.validate_params(result)
|
||||||
|
assert any("count should be integer" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_preserves_empty_string() -> None:
|
||||||
|
"""Empty strings should be preserved for string type."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"type": "string"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"name": ""})
|
||||||
|
assert result["name"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_bool_string_false() -> None:
|
||||||
|
"""Test that 'false', '0', 'no' strings convert to False."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"flag": {"type": "boolean"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert tool.cast_params({"flag": "false"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "False"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "0"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "no"})["flag"] is False
|
||||||
|
assert tool.cast_params({"flag": "NO"})["flag"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_bool_string_invalid() -> None:
|
||||||
|
"""Invalid boolean strings should not be cast."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"flag": {"type": "boolean"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Invalid strings should be preserved (validation will catch them)
|
||||||
|
result = tool.cast_params({"flag": "random"})
|
||||||
|
assert result["flag"] == "random"
|
||||||
|
result = tool.cast_params({"flag": "maybe"})
|
||||||
|
assert result["flag"] == "maybe"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_invalid_string_to_int() -> None:
|
||||||
|
"""Invalid strings should not be cast to integer."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"count": {"type": "integer"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"count": "abc"})
|
||||||
|
assert result["count"] == "abc" # Original value preserved
|
||||||
|
result = tool.cast_params({"count": "12.5.7"})
|
||||||
|
assert result["count"] == "12.5.7"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_invalid_string_to_number() -> None:
|
||||||
|
"""Invalid strings should not be cast to number."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"rate": {"type": "number"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params({"rate": "not_a_number"})
|
||||||
|
assert result["rate"] == "not_a_number"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||||
|
"""Booleans should not pass number validation."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"rate": {"type": "number"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
errors = tool.validate_params({"rate": False})
|
||||||
|
assert any("rate should be number" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_none_values() -> None:
|
||||||
|
"""Test None handling for different types."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"count": {"type": "integer"},
|
||||||
|
"items": {"type": "array"},
|
||||||
|
"config": {"type": "object"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
result = tool.cast_params(
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"count": None,
|
||||||
|
"items": None,
|
||||||
|
"config": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# None should be preserved for all types
|
||||||
|
assert result["name"] is None
|
||||||
|
assert result["count"] is None
|
||||||
|
assert result["items"] is None
|
||||||
|
assert result["config"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
||||||
|
"""Single values should NOT be automatically wrapped into arrays."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"items": {"type": "array"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Non-array values should be preserved (validation will catch them)
|
||||||
|
result = tool.cast_params({"items": 5})
|
||||||
|
assert result["items"] == 5 # Not wrapped to [5]
|
||||||
|
result = tool.cast_params({"items": "text"})
|
||||||
|
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||||
|
|||||||
Reference in New Issue
Block a user