diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 06f5bdd..b9bafe7 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -21,6 +21,20 @@ class Tool(ABC): "object": dict, } + @staticmethod + def _resolve_type(t: Any) -> str | None: + """Resolve JSON Schema type to a simple string. + + JSON Schema allows ``"type": ["string", "null"]`` (union types). + We extract the first non-null type so validation/casting works. + """ + if isinstance(t, list): + for item in t: + if item != "null": + return item + return None + return t + @property @abstractmethod def name(self) -> str: @@ -78,7 +92,7 @@ class Tool(ABC): def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any: """Cast a single value according to schema.""" - target_type = schema.get("type") + target_type = self._resolve_type(schema.get("type")) if target_type == "boolean" and isinstance(val, bool): return val @@ -131,7 +145,11 @@ class Tool(ABC): return self._validate(params, {**schema, "type": "object"}, "") def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]: - t, label = schema.get("type"), path or "parameter" + raw_type = schema.get("type") + nullable = isinstance(raw_type, list) and "null" in raw_type + t, label = self._resolve_type(raw_type), path or "parameter" + if nullable and val is None: + return [] if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)): return [f"{label} should be integer"] if t == "number" and ( diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py index 1d822b3..e817f37 100644 --- a/tests/test_tool_validation.py +++ b/tests/test_tool_validation.py @@ -406,3 +406,64 @@ async def test_exec_timeout_capped_at_max() -> None: # Should not raise — just clamp to 600 result = await tool.execute(command="echo ok", timeout=9999) assert "Exit code: 0" in result + + +# --- _resolve_type and nullable param tests --- + + +def test_resolve_type_simple_string() -> None: + """Simple string type passes through unchanged.""" + assert Tool._resolve_type("string") == "string" + + +def test_resolve_type_union_with_null() -> None: + """Union type ['string', 'null'] resolves to 'string'.""" + assert Tool._resolve_type(["string", "null"]) == "string" + + +def test_resolve_type_only_null() -> None: + """Union type ['null'] resolves to None (no non-null type).""" + assert Tool._resolve_type(["null"]) is None + + +def test_resolve_type_none_input() -> None: + """None input passes through as None.""" + assert Tool._resolve_type(None) is None + + +def test_validate_nullable_param_accepts_string() -> None: + """Nullable string param should accept a string value.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"name": {"type": ["string", "null"]}}, + } + ) + errors = tool.validate_params({"name": "hello"}) + assert errors == [] + + +def test_validate_nullable_param_accepts_none() -> None: + """Nullable string param should accept None.""" + tool = CastTestTool( + { + "type": "object", + "properties": {"name": {"type": ["string", "null"]}}, + } + ) + errors = tool.validate_params({"name": None}) + assert errors == [] + + +def test_cast_nullable_param_no_crash() -> None: + """cast_params should not crash on nullable type (the original bug).""" + tool = CastTestTool( + { + "type": "object", + "properties": {"name": {"type": ["string", "null"]}}, + } + ) + result = tool.cast_params({"name": "hello"}) + assert result["name"] == "hello" + result = tool.cast_params({"name": None}) + assert result["name"] is None