feat: auto casting tool params to match schema type
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
@@ -52,6 +54,118 @@ class Tool(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Attempt to cast parameters to match schema types.
|
||||
Returns modified params dict. Raises ValueError if casting is impossible.
|
||||
"""
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
return params
|
||||
|
||||
return self._cast_object(params, schema)
|
||||
|
||||
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Cast an object (dict) according to schema."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for key, value in obj.items():
|
||||
if key in props:
|
||||
result[key] = self._cast_value(value, props[key])
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
||||
"""Cast a single value according to schema."""
|
||||
target_type = schema.get("type")
|
||||
|
||||
# Already correct type
|
||||
# Note: check bool before int since bool is subclass of int
|
||||
if target_type == "boolean" and isinstance(val, bool):
|
||||
return val
|
||||
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
||||
return val
|
||||
# For array/object, don't early-return - we need to recurse into contents
|
||||
if target_type in self._TYPE_MAP and target_type not in (
|
||||
"boolean",
|
||||
"integer",
|
||||
"array",
|
||||
"object",
|
||||
):
|
||||
expected = self._TYPE_MAP[target_type]
|
||||
if isinstance(val, expected):
|
||||
return val
|
||||
|
||||
# Attempt casting
|
||||
try:
|
||||
if target_type == "integer":
|
||||
if isinstance(val, bool):
|
||||
# Don't silently convert bool to int
|
||||
raise ValueError(f"Cannot cast bool to integer")
|
||||
if isinstance(val, str):
|
||||
return int(val)
|
||||
if isinstance(val, (int, float)):
|
||||
return int(val)
|
||||
|
||||
elif target_type == "number":
|
||||
if isinstance(val, bool):
|
||||
# Don't silently convert bool to number
|
||||
raise ValueError(f"Cannot cast bool to number")
|
||||
if isinstance(val, str):
|
||||
return float(val)
|
||||
if isinstance(val, (int, float)):
|
||||
return float(val)
|
||||
|
||||
elif target_type == "string":
|
||||
# Preserve None vs empty string distinction
|
||||
if val is None:
|
||||
return val
|
||||
return str(val)
|
||||
|
||||
elif target_type == "boolean":
|
||||
if isinstance(val, str):
|
||||
return val.lower() in ("true", "1", "yes")
|
||||
return bool(val)
|
||||
|
||||
elif target_type == "array":
|
||||
if isinstance(val, list):
|
||||
# Recursively cast array items if schema defines items
|
||||
if "items" in schema:
|
||||
return [self._cast_value(item, schema["items"]) for item in val]
|
||||
return val
|
||||
# Preserve None vs empty array distinction
|
||||
if val is None:
|
||||
return val
|
||||
# Try to convert single value to array
|
||||
if val == "":
|
||||
return []
|
||||
return [val]
|
||||
|
||||
elif target_type == "object":
|
||||
if isinstance(val, dict):
|
||||
return self._cast_object(val, schema)
|
||||
# Preserve None vs empty object distinction
|
||||
if val is None:
|
||||
return val
|
||||
# Empty string → empty object
|
||||
if val == "":
|
||||
return {}
|
||||
# Cannot cast to object
|
||||
raise ValueError(f"Cannot cast {type(val).__name__} to object")
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
# Log failed casts for debugging, return original value
|
||||
# Let validation catch the error
|
||||
logger.debug("Failed to cast value %r to %s: %s", val, target_type, e)
|
||||
|
||||
return val
|
||||
|
||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||
if not isinstance(params, dict):
|
||||
|
||||
@@ -44,6 +44,10 @@ class ToolRegistry:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
|
||||
try:
|
||||
# Attempt to cast parameters to match schema types
|
||||
params = tool.cast_params(params)
|
||||
|
||||
# Validate parameters
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||
|
||||
@@ -106,3 +106,122 @@ def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert "/tmp/data.txt" in paths
|
||||
assert "/tmp/out.txt" in paths
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
class CastTestTool(Tool):
|
||||
"""Minimal tool for testing cast_params."""
|
||||
|
||||
def __init__(self, schema: dict[str, Any]) -> None:
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cast_test"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "test tool for casting"
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_cast_params_string_to_int() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"count": "42"})
|
||||
assert result["count"] == 42
|
||||
assert isinstance(result["count"], int)
|
||||
|
||||
|
||||
def test_cast_params_string_to_number() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"rate": {"type": "number"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"rate": "3.14"})
|
||||
assert result["rate"] == 3.14
|
||||
assert isinstance(result["rate"], float)
|
||||
|
||||
|
||||
def test_cast_params_string_to_bool() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"enabled": {"type": "boolean"}},
|
||||
}
|
||||
)
|
||||
assert tool.cast_params({"enabled": "true"})["enabled"] is True
|
||||
assert tool.cast_params({"enabled": "false"})["enabled"] is False
|
||||
assert tool.cast_params({"enabled": "1"})["enabled"] is True
|
||||
|
||||
|
||||
def test_cast_params_array_items() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nums": {"type": "array", "items": {"type": "integer"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"nums": ["1", "2", "3"]})
|
||||
assert result["nums"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_cast_params_nested_object() -> None:
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"port": {"type": "integer"},
|
||||
"debug": {"type": "boolean"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"config": {"port": "8080", "debug": "true"}})
|
||||
assert result["config"]["port"] == 8080
|
||||
assert result["config"]["debug"] is True
|
||||
|
||||
|
||||
def test_cast_params_bool_not_cast_to_int() -> None:
|
||||
"""Booleans should not be silently cast to integers."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
}
|
||||
)
|
||||
# Bool input should remain bool (validation will catch it)
|
||||
result = tool.cast_params({"count": True})
|
||||
assert result["count"] is True # Not cast to 1
|
||||
|
||||
|
||||
def test_cast_params_preserves_empty_string() -> None:
|
||||
"""Empty strings should be preserved for string type."""
|
||||
tool = CastTestTool(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
result = tool.cast_params({"name": ""})
|
||||
assert result["name"] == ""
|
||||
|
||||
Reference in New Issue
Block a user