fix(tools): narrow parameter auto-casting
This commit is contained in:
@@ -3,8 +3,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
class Tool(ABC):
|
class Tool(ABC):
|
||||||
"""
|
"""
|
||||||
@@ -55,11 +53,7 @@ class Tool(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""Apply safe schema-driven casts before validation."""
|
||||||
Attempt to cast parameters to match schema types.
|
|
||||||
Returns modified params dict. If casting fails, returns original value
|
|
||||||
and logs a debug message, allowing validation to catch the error.
|
|
||||||
"""
|
|
||||||
schema = self.parameters or {}
|
schema = self.parameters or {}
|
||||||
if schema.get("type", "object") != "object":
|
if schema.get("type", "object") != "object":
|
||||||
return params
|
return params
|
||||||
@@ -86,91 +80,44 @@ class Tool(ABC):
|
|||||||
"""Cast a single value according to schema."""
|
"""Cast a single value according to schema."""
|
||||||
target_type = schema.get("type")
|
target_type = schema.get("type")
|
||||||
|
|
||||||
# Already correct type
|
|
||||||
# Note: check bool before int since bool is subclass of int
|
|
||||||
if target_type == "boolean" and isinstance(val, bool):
|
if target_type == "boolean" and isinstance(val, bool):
|
||||||
return val
|
return val
|
||||||
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
||||||
return val
|
return val
|
||||||
# For array/object, don't early-return - we need to recurse into contents
|
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
|
||||||
if target_type in self._TYPE_MAP and target_type not in (
|
|
||||||
"boolean",
|
|
||||||
"integer",
|
|
||||||
"array",
|
|
||||||
"object",
|
|
||||||
):
|
|
||||||
expected = self._TYPE_MAP[target_type]
|
expected = self._TYPE_MAP[target_type]
|
||||||
if isinstance(val, expected):
|
if isinstance(val, expected):
|
||||||
return val
|
return val
|
||||||
|
|
||||||
# Attempt casting
|
if target_type == "integer" and isinstance(val, str):
|
||||||
try:
|
try:
|
||||||
if target_type == "integer":
|
return int(val)
|
||||||
if isinstance(val, bool):
|
except ValueError:
|
||||||
# Don't silently convert bool to int
|
return val
|
||||||
raise ValueError("Cannot cast bool to integer")
|
|
||||||
if isinstance(val, str):
|
|
||||||
return int(val)
|
|
||||||
if isinstance(val, (int, float)):
|
|
||||||
return int(val)
|
|
||||||
|
|
||||||
elif target_type == "number":
|
if target_type == "number" and isinstance(val, str):
|
||||||
if isinstance(val, bool):
|
try:
|
||||||
# Don't silently convert bool to number
|
return float(val)
|
||||||
raise ValueError("Cannot cast bool to number")
|
except ValueError:
|
||||||
if isinstance(val, str):
|
return val
|
||||||
return float(val)
|
|
||||||
if isinstance(val, (int, float)):
|
|
||||||
return float(val)
|
|
||||||
|
|
||||||
elif target_type == "string":
|
if target_type == "string":
|
||||||
# Preserve None vs empty string distinction
|
return val if val is None else str(val)
|
||||||
if val is None:
|
|
||||||
return val
|
|
||||||
return str(val)
|
|
||||||
|
|
||||||
elif target_type == "boolean":
|
if target_type == "boolean" and isinstance(val, str):
|
||||||
if isinstance(val, str):
|
val_lower = val.lower()
|
||||||
val_lower = val.lower()
|
if val_lower in ("true", "1", "yes"):
|
||||||
if val_lower in ("true", "1", "yes"):
|
return True
|
||||||
return True
|
if val_lower in ("false", "0", "no"):
|
||||||
elif val_lower in ("false", "0", "no"):
|
return False
|
||||||
return False
|
return val
|
||||||
# For other strings, raise error to let validation handle it
|
|
||||||
raise ValueError(f"Cannot convert string '{val}' to boolean")
|
|
||||||
return bool(val)
|
|
||||||
|
|
||||||
elif target_type == "array":
|
if target_type == "array" and isinstance(val, list):
|
||||||
if isinstance(val, list):
|
item_schema = schema.get("items")
|
||||||
# Recursively cast array items if schema defines items
|
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
|
||||||
if "items" in schema:
|
|
||||||
return [self._cast_value(item, schema["items"]) for item in val]
|
|
||||||
return val
|
|
||||||
# Preserve None vs empty array distinction
|
|
||||||
if val is None:
|
|
||||||
return val
|
|
||||||
# Empty string → empty array
|
|
||||||
if val == "":
|
|
||||||
return []
|
|
||||||
# Don't auto-wrap single values, let validation catch the error
|
|
||||||
raise ValueError(f"Cannot convert {type(val).__name__} to array")
|
|
||||||
|
|
||||||
elif target_type == "object":
|
if target_type == "object" and isinstance(val, dict):
|
||||||
if isinstance(val, dict):
|
return self._cast_object(val, schema)
|
||||||
return self._cast_object(val, schema)
|
|
||||||
# Preserve None vs empty object distinction
|
|
||||||
if val is None:
|
|
||||||
return val
|
|
||||||
# Empty string → empty object
|
|
||||||
if val == "":
|
|
||||||
return {}
|
|
||||||
# Cannot cast to object
|
|
||||||
raise ValueError(f"Cannot cast {type(val).__name__} to object")
|
|
||||||
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
# Log failed casts for debugging, return original value
|
|
||||||
# Let validation catch the error
|
|
||||||
logger.debug("Failed to cast value %r to %s: %s", val, target_type, e)
|
|
||||||
|
|
||||||
return val
|
return val
|
||||||
|
|
||||||
@@ -185,7 +132,13 @@ class Tool(ABC):
|
|||||||
|
|
||||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||||
t, label = schema.get("type"), path or "parameter"
|
t, label = schema.get("type"), path or "parameter"
|
||||||
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
|
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
||||||
|
return [f"{label} should be integer"]
|
||||||
|
if t == "number" and (
|
||||||
|
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
|
||||||
|
):
|
||||||
|
return [f"{label} should be number"]
|
||||||
|
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
|
||||||
return [f"{label} should be {t}"]
|
return [f"{label} should be {t}"]
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|||||||
@@ -210,9 +210,10 @@ def test_cast_params_bool_not_cast_to_int() -> None:
|
|||||||
"properties": {"count": {"type": "integer"}},
|
"properties": {"count": {"type": "integer"}},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# Bool input should remain bool (validation will catch it)
|
|
||||||
result = tool.cast_params({"count": True})
|
result = tool.cast_params({"count": True})
|
||||||
assert result["count"] is True # Not cast to 1
|
assert result["count"] is True
|
||||||
|
errors = tool.validate_params(result)
|
||||||
|
assert any("count should be integer" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
def test_cast_params_preserves_empty_string() -> None:
|
def test_cast_params_preserves_empty_string() -> None:
|
||||||
@@ -283,6 +284,18 @@ def test_cast_params_invalid_string_to_number() -> None:
|
|||||||
assert result["rate"] == "not_a_number"
|
assert result["rate"] == "not_a_number"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_bool_not_accepted_as_number() -> None:
|
||||||
|
"""Booleans should not pass number validation."""
|
||||||
|
tool = CastTestTool(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"rate": {"type": "number"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
errors = tool.validate_params({"rate": False})
|
||||||
|
assert any("rate should be number" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
def test_cast_params_none_values() -> None:
|
def test_cast_params_none_values() -> None:
|
||||||
"""Test None handling for different types."""
|
"""Test None handling for different types."""
|
||||||
tool = CastTestTool(
|
tool = CastTestTool(
|
||||||
@@ -324,40 +337,3 @@ def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
|
|||||||
assert result["items"] == 5 # Not wrapped to [5]
|
assert result["items"] == 5 # Not wrapped to [5]
|
||||||
result = tool.cast_params({"items": "text"})
|
result = tool.cast_params({"items": "text"})
|
||||||
assert result["items"] == "text" # Not wrapped to ["text"]
|
assert result["items"] == "text" # Not wrapped to ["text"]
|
||||||
|
|
||||||
|
|
||||||
def test_cast_params_empty_string_to_array() -> None:
|
|
||||||
"""Empty string should convert to empty array."""
|
|
||||||
tool = CastTestTool(
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"items": {"type": "array"}},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = tool.cast_params({"items": ""})
|
|
||||||
assert result["items"] == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_cast_params_empty_string_to_object() -> None:
|
|
||||||
"""Empty string should convert to empty object."""
|
|
||||||
tool = CastTestTool(
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"config": {"type": "object"}},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = tool.cast_params({"config": ""})
|
|
||||||
assert result["config"] == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_cast_params_float_to_int() -> None:
|
|
||||||
"""Float values should be cast to integers."""
|
|
||||||
tool = CastTestTool(
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"count": {"type": "integer"}},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result = tool.cast_params({"count": 42.7})
|
|
||||||
assert result["count"] == 42
|
|
||||||
assert isinstance(result["count"], int)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user