diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 7b0b867..02c8331 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -1,4 +1,4 @@ -"""File system tools: read, write, edit.""" +"""File system tools: read, write, edit, list.""" import difflib from pathlib import Path @@ -23,62 +23,108 @@ def _resolve_path( return resolved -class ReadFileTool(Tool): - """Tool to read file contents.""" - - _MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context +class _FsTool(Tool): + """Shared base for filesystem tools — common init and path resolution.""" def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): self._workspace = workspace self._allowed_dir = allowed_dir + def _resolve(self, path: str) -> Path: + return _resolve_path(path, self._workspace, self._allowed_dir) + + +# --------------------------------------------------------------------------- +# read_file +# --------------------------------------------------------------------------- + +class ReadFileTool(_FsTool): + """Read file contents with optional line-based pagination.""" + + _MAX_CHARS = 128_000 + _DEFAULT_LIMIT = 2000 + @property def name(self) -> str: return "read_file" @property def description(self) -> str: - return "Read the contents of a file at the given path." + return ( + "Read the contents of a file. Returns numbered lines. " + "Use offset and limit to paginate through large files." + ) @property def parameters(self) -> dict[str, Any]: return { "type": "object", - "properties": {"path": {"type": "string", "description": "The file path to read"}}, + "properties": { + "path": {"type": "string", "description": "The file path to read"}, + "offset": { + "type": "integer", + "description": "Line number to start reading from (1-indexed, default 1)", + "minimum": 1, + }, + "limit": { + "type": "integer", + "description": "Maximum number of lines to read (default 2000)", + "minimum": 1, + }, + }, "required": ["path"], } - async def execute(self, path: str, **kwargs: Any) -> str: + async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> str: try: - file_path = _resolve_path(path, self._workspace, self._allowed_dir) - if not file_path.exists(): + fp = self._resolve(path) + if not fp.exists(): return f"Error: File not found: {path}" - if not file_path.is_file(): + if not fp.is_file(): return f"Error: Not a file: {path}" - size = file_path.stat().st_size - if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes) - return ( - f"Error: File too large ({size:,} bytes). " - f"Use exec tool with head/tail/grep to read portions." - ) + all_lines = fp.read_text(encoding="utf-8").splitlines() + total = len(all_lines) - content = file_path.read_text(encoding="utf-8") - if len(content) > self._MAX_CHARS: - return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})" - return content + if offset < 1: + offset = 1 + if total == 0: + return f"(Empty file: {path})" + if offset > total: + return f"Error: offset {offset} is beyond end of file ({total} lines)" + + start = offset - 1 + end = min(start + (limit or self._DEFAULT_LIMIT), total) + numbered = [f"{start + i + 1}| {line}" for i, line in enumerate(all_lines[start:end])] + result = "\n".join(numbered) + + if len(result) > self._MAX_CHARS: + trimmed, chars = [], 0 + for line in numbered: + chars += len(line) + 1 + if chars > self._MAX_CHARS: + break + trimmed.append(line) + end = start + len(trimmed) + result = "\n".join(trimmed) + + if end < total: + result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)" + else: + result += f"\n\n(End of file — {total} lines total)" + return result except PermissionError as e: return f"Error: {e}" except Exception as e: - return f"Error reading file: {str(e)}" + return f"Error reading file: {e}" -class WriteFileTool(Tool): - """Tool to write content to a file.""" +# --------------------------------------------------------------------------- +# write_file +# --------------------------------------------------------------------------- - def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): - self._workspace = workspace - self._allowed_dir = allowed_dir +class WriteFileTool(_FsTool): + """Write content to a file.""" @property def name(self) -> str: @@ -101,22 +147,48 @@ class WriteFileTool(Tool): async def execute(self, path: str, content: str, **kwargs: Any) -> str: try: - file_path = _resolve_path(path, self._workspace, self._allowed_dir) - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(content, encoding="utf-8") - return f"Successfully wrote {len(content)} bytes to {file_path}" + fp = self._resolve(path) + fp.parent.mkdir(parents=True, exist_ok=True) + fp.write_text(content, encoding="utf-8") + return f"Successfully wrote {len(content)} bytes to {fp}" except PermissionError as e: return f"Error: {e}" except Exception as e: - return f"Error writing file: {str(e)}" + return f"Error writing file: {e}" -class EditFileTool(Tool): - """Tool to edit a file by replacing text.""" +# --------------------------------------------------------------------------- +# edit_file +# --------------------------------------------------------------------------- - def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): - self._workspace = workspace - self._allowed_dir = allowed_dir +def _find_match(content: str, old_text: str) -> tuple[str | None, int]: + """Locate old_text in content: exact first, then line-trimmed sliding window. + + Both inputs should use LF line endings (caller normalises CRLF). + Returns (matched_fragment, count) or (None, 0). + """ + if old_text in content: + return old_text, content.count(old_text) + + old_lines = old_text.splitlines() + if not old_lines: + return None, 0 + stripped_old = [l.strip() for l in old_lines] + content_lines = content.splitlines() + + candidates = [] + for i in range(len(content_lines) - len(stripped_old) + 1): + window = content_lines[i : i + len(stripped_old)] + if [l.strip() for l in window] == stripped_old: + candidates.append("\n".join(window)) + + if candidates: + return candidates[0], len(candidates) + return None, 0 + + +class EditFileTool(_FsTool): + """Edit a file by replacing text with fallback matching.""" @property def name(self) -> str: @@ -124,7 +196,11 @@ class EditFileTool(Tool): @property def description(self) -> str: - return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file." + return ( + "Edit a file by replacing old_text with new_text. " + "Supports minor whitespace/line-ending differences. " + "Set replace_all=true to replace every occurrence." + ) @property def parameters(self) -> dict[str, Any]: @@ -132,40 +208,52 @@ class EditFileTool(Tool): "type": "object", "properties": { "path": {"type": "string", "description": "The file path to edit"}, - "old_text": {"type": "string", "description": "The exact text to find and replace"}, + "old_text": {"type": "string", "description": "The text to find and replace"}, "new_text": {"type": "string", "description": "The text to replace with"}, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences (default false)", + }, }, "required": ["path", "old_text", "new_text"], } - async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str: + async def execute( + self, path: str, old_text: str, new_text: str, + replace_all: bool = False, **kwargs: Any, + ) -> str: try: - file_path = _resolve_path(path, self._workspace, self._allowed_dir) - if not file_path.exists(): + fp = self._resolve(path) + if not fp.exists(): return f"Error: File not found: {path}" - content = file_path.read_text(encoding="utf-8") + raw = fp.read_bytes() + uses_crlf = b"\r\n" in raw + content = raw.decode("utf-8").replace("\r\n", "\n") + match, count = _find_match(content, old_text.replace("\r\n", "\n")) - if old_text not in content: - return self._not_found_message(old_text, content, path) + if match is None: + return self._not_found_msg(old_text, content, path) + if count > 1 and not replace_all: + return ( + f"Warning: old_text appears {count} times. " + "Provide more context to make it unique, or set replace_all=true." + ) - # Count occurrences - count = content.count(old_text) - if count > 1: - return f"Warning: old_text appears {count} times. Please provide more context to make it unique." + norm_new = new_text.replace("\r\n", "\n") + new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1) + if uses_crlf: + new_content = new_content.replace("\n", "\r\n") - new_content = content.replace(old_text, new_text, 1) - file_path.write_text(new_content, encoding="utf-8") - - return f"Successfully edited {file_path}" + fp.write_bytes(new_content.encode("utf-8")) + return f"Successfully edited {fp}" except PermissionError as e: return f"Error: {e}" except Exception as e: - return f"Error editing file: {str(e)}" + return f"Error editing file: {e}" @staticmethod - def _not_found_message(old_text: str, content: str, path: str) -> str: - """Build a helpful error when old_text is not found.""" + def _not_found_msg(old_text: str, content: str, path: str) -> str: lines = content.splitlines(keepends=True) old_lines = old_text.splitlines(keepends=True) window = len(old_lines) @@ -177,27 +265,29 @@ class EditFileTool(Tool): best_ratio, best_start = ratio, i if best_ratio > 0.5: - diff = "\n".join( - difflib.unified_diff( - old_lines, - lines[best_start : best_start + window], - fromfile="old_text (provided)", - tofile=f"{path} (actual, line {best_start + 1})", - lineterm="", - ) - ) + diff = "\n".join(difflib.unified_diff( + old_lines, lines[best_start : best_start + window], + fromfile="old_text (provided)", + tofile=f"{path} (actual, line {best_start + 1})", + lineterm="", + )) return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" - return ( - f"Error: old_text not found in {path}. No similar text found. Verify the file content." - ) + return f"Error: old_text not found in {path}. No similar text found. Verify the file content." -class ListDirTool(Tool): - """Tool to list directory contents.""" +# --------------------------------------------------------------------------- +# list_dir +# --------------------------------------------------------------------------- - def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): - self._workspace = workspace - self._allowed_dir = allowed_dir +class ListDirTool(_FsTool): + """List directory contents with optional recursion.""" + + _DEFAULT_MAX = 200 + _IGNORE_DIRS = { + ".git", "node_modules", "__pycache__", ".venv", "venv", + "dist", "build", ".tox", ".mypy_cache", ".pytest_cache", + ".ruff_cache", ".coverage", "htmlcov", + } @property def name(self) -> str: @@ -205,34 +295,71 @@ class ListDirTool(Tool): @property def description(self) -> str: - return "List the contents of a directory." + return ( + "List the contents of a directory. " + "Set recursive=true to explore nested structure. " + "Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored." + ) @property def parameters(self) -> dict[str, Any]: return { "type": "object", - "properties": {"path": {"type": "string", "description": "The directory path to list"}}, + "properties": { + "path": {"type": "string", "description": "The directory path to list"}, + "recursive": { + "type": "boolean", + "description": "Recursively list all files (default false)", + }, + "max_entries": { + "type": "integer", + "description": "Maximum entries to return (default 200)", + "minimum": 1, + }, + }, "required": ["path"], } - async def execute(self, path: str, **kwargs: Any) -> str: + async def execute( + self, path: str, recursive: bool = False, + max_entries: int | None = None, **kwargs: Any, + ) -> str: try: - dir_path = _resolve_path(path, self._workspace, self._allowed_dir) - if not dir_path.exists(): + dp = self._resolve(path) + if not dp.exists(): return f"Error: Directory not found: {path}" - if not dir_path.is_dir(): + if not dp.is_dir(): return f"Error: Not a directory: {path}" - items = [] - for item in sorted(dir_path.iterdir()): - prefix = "📁 " if item.is_dir() else "📄 " - items.append(f"{prefix}{item.name}") + cap = max_entries or self._DEFAULT_MAX + items: list[str] = [] + total = 0 - if not items: + if recursive: + for item in sorted(dp.rglob("*")): + if any(p in self._IGNORE_DIRS for p in item.parts): + continue + total += 1 + if len(items) < cap: + rel = item.relative_to(dp) + items.append(f"{rel}/" if item.is_dir() else str(rel)) + else: + for item in sorted(dp.iterdir()): + if item.name in self._IGNORE_DIRS: + continue + total += 1 + if len(items) < cap: + pfx = "📁 " if item.is_dir() else "📄 " + items.append(f"{pfx}{item.name}") + + if not items and total == 0: return f"Directory {path} is empty" - return "\n".join(items) + result = "\n".join(items) + if total > cap: + result += f"\n\n(truncated, showing first {cap} of {total} entries)" + return result except PermissionError as e: return f"Error: {e}" except Exception as e: - return f"Error listing directory: {str(e)}" + return f"Error listing directory: {e}" diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index b650930..bf1b082 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -42,6 +42,9 @@ class ExecTool(Tool): def name(self) -> str: return "exec" + _MAX_TIMEOUT = 600 + _MAX_OUTPUT = 10_000 + @property def description(self) -> str: return "Execute a shell command and return its output. Use with caution." @@ -53,22 +56,36 @@ class ExecTool(Tool): "properties": { "command": { "type": "string", - "description": "The shell command to execute" + "description": "The shell command to execute", }, "working_dir": { "type": "string", - "description": "Optional working directory for the command" - } + "description": "Optional working directory for the command", + }, + "timeout": { + "type": "integer", + "description": ( + "Timeout in seconds. Increase for long-running commands " + "like compilation or installation (default 60, max 600)." + ), + "minimum": 1, + "maximum": 600, + }, }, - "required": ["command"] + "required": ["command"], } - - async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str: + + async def execute( + self, command: str, working_dir: str | None = None, + timeout: int | None = None, **kwargs: Any, + ) -> str: cwd = working_dir or self.working_dir or os.getcwd() guard_error = self._guard_command(command, cwd) if guard_error: return guard_error - + + effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT) + env = os.environ.copy() if self.path_append: env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append @@ -81,44 +98,46 @@ class ExecTool(Tool): cwd=cwd, env=env, ) - + try: stdout, stderr = await asyncio.wait_for( process.communicate(), - timeout=self.timeout + timeout=effective_timeout, ) except asyncio.TimeoutError: process.kill() - # Wait for the process to fully terminate so pipes are - # drained and file descriptors are released. try: await asyncio.wait_for(process.wait(), timeout=5.0) except asyncio.TimeoutError: pass - return f"Error: Command timed out after {self.timeout} seconds" - + return f"Error: Command timed out after {effective_timeout} seconds" + output_parts = [] - + if stdout: output_parts.append(stdout.decode("utf-8", errors="replace")) - + if stderr: stderr_text = stderr.decode("utf-8", errors="replace") if stderr_text.strip(): output_parts.append(f"STDERR:\n{stderr_text}") - - if process.returncode != 0: - output_parts.append(f"\nExit code: {process.returncode}") - + + output_parts.append(f"\nExit code: {process.returncode}") + result = "\n".join(output_parts) if output_parts else "(no output)" - - # Truncate very long output - max_len = 10000 + + # Head + tail truncation to preserve both start and end of output + max_len = self._MAX_OUTPUT if len(result) > max_len: - result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)" - + half = max_len // 2 + result = ( + result[:half] + + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n" + + result[-half:] + ) + return result - + except Exception as e: return f"Error executing command: {str(e)}" diff --git a/tests/test_filesystem_tools.py b/tests/test_filesystem_tools.py new file mode 100644 index 0000000..db8f256 --- /dev/null +++ b/tests/test_filesystem_tools.py @@ -0,0 +1,251 @@ +"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool.""" + +import pytest + +from nanobot.agent.tools.filesystem import ( + EditFileTool, + ListDirTool, + ReadFileTool, + _find_match, +) + + +# --------------------------------------------------------------------------- +# ReadFileTool +# --------------------------------------------------------------------------- + +class TestReadFileTool: + + @pytest.fixture() + def tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.fixture() + def sample_file(self, tmp_path): + f = tmp_path / "sample.txt" + f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8") + return f + + @pytest.mark.asyncio + async def test_basic_read_has_line_numbers(self, tool, sample_file): + result = await tool.execute(path=str(sample_file)) + assert "1| line 1" in result + assert "20| line 20" in result + + @pytest.mark.asyncio + async def test_offset_and_limit(self, tool, sample_file): + result = await tool.execute(path=str(sample_file), offset=5, limit=3) + assert "5| line 5" in result + assert "7| line 7" in result + assert "8| line 8" not in result + assert "Use offset=8 to continue" in result + + @pytest.mark.asyncio + async def test_offset_beyond_end(self, tool, sample_file): + result = await tool.execute(path=str(sample_file), offset=999) + assert "Error" in result + assert "beyond end" in result + + @pytest.mark.asyncio + async def test_end_of_file_marker(self, tool, sample_file): + result = await tool.execute(path=str(sample_file), offset=1, limit=9999) + assert "End of file" in result + + @pytest.mark.asyncio + async def test_empty_file(self, tool, tmp_path): + f = tmp_path / "empty.txt" + f.write_text("", encoding="utf-8") + result = await tool.execute(path=str(f)) + assert "Empty file" in result + + @pytest.mark.asyncio + async def test_file_not_found(self, tool, tmp_path): + result = await tool.execute(path=str(tmp_path / "nope.txt")) + assert "Error" in result + assert "not found" in result + + @pytest.mark.asyncio + async def test_char_budget_trims(self, tool, tmp_path): + """When the selected slice exceeds _MAX_CHARS the output is trimmed.""" + f = tmp_path / "big.txt" + # Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit + f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8") + result = await tool.execute(path=str(f)) + assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer + assert "Use offset=" in result + + +# --------------------------------------------------------------------------- +# _find_match (unit tests for the helper) +# --------------------------------------------------------------------------- + +class TestFindMatch: + + def test_exact_match(self): + match, count = _find_match("hello world", "world") + assert match == "world" + assert count == 1 + + def test_exact_no_match(self): + match, count = _find_match("hello world", "xyz") + assert match is None + assert count == 0 + + def test_crlf_normalisation(self): + # Caller normalises CRLF before calling _find_match, so test with + # pre-normalised content to verify exact match still works. + content = "line1\nline2\nline3" + old_text = "line1\nline2\nline3" + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + + def test_line_trim_fallback(self): + content = " def foo():\n pass\n" + old_text = "def foo():\n pass" + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + # The returned match should be the *original* indented text + assert " def foo():" in match + + def test_line_trim_multiple_candidates(self): + content = " a\n b\n a\n b\n" + old_text = "a\nb" + match, count = _find_match(content, old_text) + assert count == 2 + + def test_empty_old_text(self): + match, count = _find_match("hello", "") + # Empty string is always "in" any string via exact match + assert match == "" + + +# --------------------------------------------------------------------------- +# EditFileTool +# --------------------------------------------------------------------------- + +class TestEditFileTool: + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_exact_match(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="world", new_text="earth") + assert "Successfully" in result + assert f.read_text() == "hello earth" + + @pytest.mark.asyncio + async def test_crlf_normalisation(self, tool, tmp_path): + f = tmp_path / "crlf.py" + f.write_bytes(b"line1\r\nline2\r\nline3") + result = await tool.execute( + path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2", + ) + assert "Successfully" in result + raw = f.read_bytes() + assert b"LINE1" in raw + # CRLF line endings should be preserved throughout the file + assert b"\r\n" in raw + + @pytest.mark.asyncio + async def test_trim_fallback(self, tool, tmp_path): + f = tmp_path / "indent.py" + f.write_text(" def foo():\n pass\n", encoding="utf-8") + result = await tool.execute( + path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1", + ) + assert "Successfully" in result + assert "bar" in f.read_text() + + @pytest.mark.asyncio + async def test_ambiguous_match(self, tool, tmp_path): + f = tmp_path / "dup.py" + f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx") + assert "appears" in result.lower() or "Warning" in result + + @pytest.mark.asyncio + async def test_replace_all(self, tool, tmp_path): + f = tmp_path / "multi.py" + f.write_text("foo bar foo bar foo", encoding="utf-8") + result = await tool.execute( + path=str(f), old_text="foo", new_text="baz", replace_all=True, + ) + assert "Successfully" in result + assert f.read_text() == "baz bar baz bar baz" + + @pytest.mark.asyncio + async def test_not_found(self, tool, tmp_path): + f = tmp_path / "nf.py" + f.write_text("hello", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="xyz", new_text="abc") + assert "Error" in result + assert "not found" in result + + +# --------------------------------------------------------------------------- +# ListDirTool +# --------------------------------------------------------------------------- + +class TestListDirTool: + + @pytest.fixture() + def tool(self, tmp_path): + return ListDirTool(workspace=tmp_path) + + @pytest.fixture() + def populated_dir(self, tmp_path): + (tmp_path / "src").mkdir() + (tmp_path / "src" / "main.py").write_text("pass") + (tmp_path / "src" / "utils.py").write_text("pass") + (tmp_path / "README.md").write_text("hi") + (tmp_path / ".git").mkdir() + (tmp_path / ".git" / "config").write_text("x") + (tmp_path / "node_modules").mkdir() + (tmp_path / "node_modules" / "pkg").mkdir() + return tmp_path + + @pytest.mark.asyncio + async def test_basic_list(self, tool, populated_dir): + result = await tool.execute(path=str(populated_dir)) + assert "README.md" in result + assert "src" in result + # .git and node_modules should be ignored + assert ".git" not in result + assert "node_modules" not in result + + @pytest.mark.asyncio + async def test_recursive(self, tool, populated_dir): + result = await tool.execute(path=str(populated_dir), recursive=True) + assert "src/main.py" in result + assert "src/utils.py" in result + assert "README.md" in result + # Ignored dirs should not appear + assert ".git" not in result + assert "node_modules" not in result + + @pytest.mark.asyncio + async def test_max_entries_truncation(self, tool, tmp_path): + for i in range(10): + (tmp_path / f"file_{i}.txt").write_text("x") + result = await tool.execute(path=str(tmp_path), max_entries=3) + assert "truncated" in result + assert "3 of 10" in result + + @pytest.mark.asyncio + async def test_empty_dir(self, tool, tmp_path): + d = tmp_path / "empty" + d.mkdir() + result = await tool.execute(path=str(d)) + assert "empty" in result.lower() + + @pytest.mark.asyncio + async def test_not_found(self, tool, tmp_path): + result = await tool.execute(path=str(tmp_path / "nope")) + assert "Error" in result + assert "not found" in result diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py index e67acbf..095c041 100644 --- a/tests/test_tool_validation.py +++ b/tests/test_tool_validation.py @@ -363,3 +363,44 @@ def test_cast_params_single_value_not_auto_wrapped_to_array() -> None: assert result["items"] == 5 # Not wrapped to [5] result = tool.cast_params({"items": "text"}) assert result["items"] == "text" # Not wrapped to ["text"] + + +# --- ExecTool enhancement tests --- + + +async def test_exec_always_returns_exit_code() -> None: + """Exit code should appear in output even on success (exit 0).""" + tool = ExecTool() + result = await tool.execute(command="echo hello") + assert "Exit code: 0" in result + assert "hello" in result + + +async def test_exec_head_tail_truncation() -> None: + """Long output should preserve both head and tail.""" + tool = ExecTool() + # Generate output that exceeds _MAX_OUTPUT + big = "A" * 6000 + "\n" + "B" * 6000 + result = await tool.execute(command=f"echo '{big}'") + assert "chars truncated" in result + # Head portion should start with As + assert result.startswith("A") + # Tail portion should end with the exit code which comes after Bs + assert "Exit code:" in result + + +async def test_exec_timeout_parameter() -> None: + """LLM-supplied timeout should override the constructor default.""" + tool = ExecTool(timeout=60) + # A very short timeout should cause the command to be killed + result = await tool.execute(command="sleep 10", timeout=1) + assert "timed out" in result + assert "1 seconds" in result + + +async def test_exec_timeout_capped_at_max() -> None: + """Timeout values above _MAX_TIMEOUT should be clamped.""" + tool = ExecTool() + # Should not raise — just clamp to 600 + result = await tool.execute(command="echo ok", timeout=9999) + assert "Exit code: 0" in result