fix(tools): narrow parameter auto-casting

This commit is contained in:
Re-bin
2026-03-07 05:28:12 +00:00
parent 67e6d9639c
commit c3f2d1b01d
2 changed files with 48 additions and 119 deletions

View File

@@ -3,8 +3,6 @@
from abc import ABC, abstractmethod
from typing import Any
from loguru import logger
class Tool(ABC):
"""
@@ -55,11 +53,7 @@ class Tool(ABC):
pass
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""
Attempt to cast parameters to match schema types.
Returns modified params dict. If casting fails, returns original value
and logs a debug message, allowing validation to catch the error.
"""
"""Apply safe schema-driven casts before validation."""
schema = self.parameters or {}
if schema.get("type", "object") != "object":
return params
@@ -86,91 +80,44 @@ class Tool(ABC):
"""Cast a single value according to schema."""
target_type = schema.get("type")
# Already correct type
# Note: check bool before int since bool is subclass of int
if target_type == "boolean" and isinstance(val, bool):
return val
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
return val
# For array/object, don't early-return - we need to recurse into contents
if target_type in self._TYPE_MAP and target_type not in (
"boolean",
"integer",
"array",
"object",
):
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
expected = self._TYPE_MAP[target_type]
if isinstance(val, expected):
return val
# Attempt casting
try:
if target_type == "integer":
if isinstance(val, bool):
# Don't silently convert bool to int
raise ValueError("Cannot cast bool to integer")
if isinstance(val, str):
return int(val)
if isinstance(val, (int, float)):
return int(val)
if target_type == "integer" and isinstance(val, str):
try:
return int(val)
except ValueError:
return val
elif target_type == "number":
if isinstance(val, bool):
# Don't silently convert bool to number
raise ValueError("Cannot cast bool to number")
if isinstance(val, str):
return float(val)
if isinstance(val, (int, float)):
return float(val)
if target_type == "number" and isinstance(val, str):
try:
return float(val)
except ValueError:
return val
elif target_type == "string":
# Preserve None vs empty string distinction
if val is None:
return val
return str(val)
if target_type == "string":
return val if val is None else str(val)
elif target_type == "boolean":
if isinstance(val, str):
val_lower = val.lower()
if val_lower in ("true", "1", "yes"):
return True
elif val_lower in ("false", "0", "no"):
return False
# For other strings, raise error to let validation handle it
raise ValueError(f"Cannot convert string '{val}' to boolean")
return bool(val)
if target_type == "boolean" and isinstance(val, str):
val_lower = val.lower()
if val_lower in ("true", "1", "yes"):
return True
if val_lower in ("false", "0", "no"):
return False
return val
elif target_type == "array":
if isinstance(val, list):
# Recursively cast array items if schema defines items
if "items" in schema:
return [self._cast_value(item, schema["items"]) for item in val]
return val
# Preserve None vs empty array distinction
if val is None:
return val
# Empty string → empty array
if val == "":
return []
# Don't auto-wrap single values, let validation catch the error
raise ValueError(f"Cannot convert {type(val).__name__} to array")
if target_type == "array" and isinstance(val, list):
item_schema = schema.get("items")
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
elif target_type == "object":
if isinstance(val, dict):
return self._cast_object(val, schema)
# Preserve None vs empty object distinction
if val is None:
return val
# Empty string → empty object
if val == "":
return {}
# Cannot cast to object
raise ValueError(f"Cannot cast {type(val).__name__} to object")
except (ValueError, TypeError) as e:
# Log failed casts for debugging, return original value
# Let validation catch the error
logger.debug("Failed to cast value %r to %s: %s", val, target_type, e)
if target_type == "object" and isinstance(val, dict):
return self._cast_object(val, schema)
return val
@@ -185,7 +132,13 @@ class Tool(ABC):
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
return [f"{label} should be integer"]
if t == "number" and (
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
):
return [f"{label} should be number"]
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"]
errors = []