Merge PR #1895: enhance: improve filesystem & shell tools with pagination, fallback matching, and smarter output

enhance: improve filesystem & shell tools with pagination, fallback matching, and smarter output
This commit is contained in:
Xubin Ren
2026-03-12 00:22:32 +08:00
committed by GitHub
4 changed files with 549 additions and 111 deletions

View File

@@ -1,4 +1,4 @@
"""File system tools: read, write, edit.""" """File system tools: read, write, edit, list."""
import difflib import difflib
from pathlib import Path from pathlib import Path
@@ -23,62 +23,108 @@ def _resolve_path(
return resolved return resolved
class ReadFileTool(Tool): class _FsTool(Tool):
"""Tool to read file contents.""" """Shared base for filesystem tools — common init and path resolution."""
_MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
self._workspace = workspace self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
def _resolve(self, path: str) -> Path:
return _resolve_path(path, self._workspace, self._allowed_dir)
# ---------------------------------------------------------------------------
# read_file
# ---------------------------------------------------------------------------
class ReadFileTool(_FsTool):
"""Read file contents with optional line-based pagination."""
_MAX_CHARS = 128_000
_DEFAULT_LIMIT = 2000
@property @property
def name(self) -> str: def name(self) -> str:
return "read_file" return "read_file"
@property @property
def description(self) -> str: def description(self) -> str:
return "Read the contents of a file at the given path." return (
"Read the contents of a file. Returns numbered lines. "
"Use offset and limit to paginate through large files."
)
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": {"path": {"type": "string", "description": "The file path to read"}}, "properties": {
"path": {"type": "string", "description": "The file path to read"},
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, default 1)",
"minimum": 1,
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (default 2000)",
"minimum": 1,
},
},
"required": ["path"], "required": ["path"],
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(self, path: str, offset: int = 1, limit: int | None = None, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) fp = self._resolve(path)
if not file_path.exists(): if not fp.exists():
return f"Error: File not found: {path}" return f"Error: File not found: {path}"
if not file_path.is_file(): if not fp.is_file():
return f"Error: Not a file: {path}" return f"Error: Not a file: {path}"
size = file_path.stat().st_size all_lines = fp.read_text(encoding="utf-8").splitlines()
if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes) total = len(all_lines)
return (
f"Error: File too large ({size:,} bytes). "
f"Use exec tool with head/tail/grep to read portions."
)
content = file_path.read_text(encoding="utf-8") if offset < 1:
if len(content) > self._MAX_CHARS: offset = 1
return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})" if total == 0:
return content return f"(Empty file: {path})"
if offset > total:
return f"Error: offset {offset} is beyond end of file ({total} lines)"
start = offset - 1
end = min(start + (limit or self._DEFAULT_LIMIT), total)
numbered = [f"{start + i + 1}| {line}" for i, line in enumerate(all_lines[start:end])]
result = "\n".join(numbered)
if len(result) > self._MAX_CHARS:
trimmed, chars = [], 0
for line in numbered:
chars += len(line) + 1
if chars > self._MAX_CHARS:
break
trimmed.append(line)
end = start + len(trimmed)
result = "\n".join(trimmed)
if end < total:
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
else:
result += f"\n\n(End of file — {total} lines total)"
return result
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error reading file: {str(e)}" return f"Error reading file: {e}"
class WriteFileTool(Tool): # ---------------------------------------------------------------------------
"""Tool to write content to a file.""" # write_file
# ---------------------------------------------------------------------------
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): class WriteFileTool(_FsTool):
self._workspace = workspace """Write content to a file."""
self._allowed_dir = allowed_dir
@property @property
def name(self) -> str: def name(self) -> str:
@@ -101,22 +147,48 @@ class WriteFileTool(Tool):
async def execute(self, path: str, content: str, **kwargs: Any) -> str: async def execute(self, path: str, content: str, **kwargs: Any) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) fp = self._resolve(path)
file_path.parent.mkdir(parents=True, exist_ok=True) fp.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content, encoding="utf-8") fp.write_text(content, encoding="utf-8")
return f"Successfully wrote {len(content)} bytes to {file_path}" return f"Successfully wrote {len(content)} bytes to {fp}"
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error writing file: {str(e)}" return f"Error writing file: {e}"
class EditFileTool(Tool): # ---------------------------------------------------------------------------
"""Tool to edit a file by replacing text.""" # edit_file
# ---------------------------------------------------------------------------
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
self._workspace = workspace """Locate old_text in content: exact first, then line-trimmed sliding window.
self._allowed_dir = allowed_dir
Both inputs should use LF line endings (caller normalises CRLF).
Returns (matched_fragment, count) or (None, 0).
"""
if old_text in content:
return old_text, content.count(old_text)
old_lines = old_text.splitlines()
if not old_lines:
return None, 0
stripped_old = [l.strip() for l in old_lines]
content_lines = content.splitlines()
candidates = []
for i in range(len(content_lines) - len(stripped_old) + 1):
window = content_lines[i : i + len(stripped_old)]
if [l.strip() for l in window] == stripped_old:
candidates.append("\n".join(window))
if candidates:
return candidates[0], len(candidates)
return None, 0
class EditFileTool(_FsTool):
"""Edit a file by replacing text with fallback matching."""
@property @property
def name(self) -> str: def name(self) -> str:
@@ -124,7 +196,11 @@ class EditFileTool(Tool):
@property @property
def description(self) -> str: def description(self) -> str:
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file." return (
"Edit a file by replacing old_text with new_text. "
"Supports minor whitespace/line-ending differences. "
"Set replace_all=true to replace every occurrence."
)
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
@@ -132,40 +208,52 @@ class EditFileTool(Tool):
"type": "object", "type": "object",
"properties": { "properties": {
"path": {"type": "string", "description": "The file path to edit"}, "path": {"type": "string", "description": "The file path to edit"},
"old_text": {"type": "string", "description": "The exact text to find and replace"}, "old_text": {"type": "string", "description": "The text to find and replace"},
"new_text": {"type": "string", "description": "The text to replace with"}, "new_text": {"type": "string", "description": "The text to replace with"},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default false)",
},
}, },
"required": ["path", "old_text", "new_text"], "required": ["path", "old_text", "new_text"],
} }
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str: async def execute(
self, path: str, old_text: str, new_text: str,
replace_all: bool = False, **kwargs: Any,
) -> str:
try: try:
file_path = _resolve_path(path, self._workspace, self._allowed_dir) fp = self._resolve(path)
if not file_path.exists(): if not fp.exists():
return f"Error: File not found: {path}" return f"Error: File not found: {path}"
content = file_path.read_text(encoding="utf-8") raw = fp.read_bytes()
uses_crlf = b"\r\n" in raw
content = raw.decode("utf-8").replace("\r\n", "\n")
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
if old_text not in content: if match is None:
return self._not_found_message(old_text, content, path) return self._not_found_msg(old_text, content, path)
if count > 1 and not replace_all:
return (
f"Warning: old_text appears {count} times. "
"Provide more context to make it unique, or set replace_all=true."
)
# Count occurrences norm_new = new_text.replace("\r\n", "\n")
count = content.count(old_text) new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
if count > 1: if uses_crlf:
return f"Warning: old_text appears {count} times. Please provide more context to make it unique." new_content = new_content.replace("\n", "\r\n")
new_content = content.replace(old_text, new_text, 1) fp.write_bytes(new_content.encode("utf-8"))
file_path.write_text(new_content, encoding="utf-8") return f"Successfully edited {fp}"
return f"Successfully edited {file_path}"
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error editing file: {str(e)}" return f"Error editing file: {e}"
@staticmethod @staticmethod
def _not_found_message(old_text: str, content: str, path: str) -> str: def _not_found_msg(old_text: str, content: str, path: str) -> str:
"""Build a helpful error when old_text is not found."""
lines = content.splitlines(keepends=True) lines = content.splitlines(keepends=True)
old_lines = old_text.splitlines(keepends=True) old_lines = old_text.splitlines(keepends=True)
window = len(old_lines) window = len(old_lines)
@@ -177,27 +265,29 @@ class EditFileTool(Tool):
best_ratio, best_start = ratio, i best_ratio, best_start = ratio, i
if best_ratio > 0.5: if best_ratio > 0.5:
diff = "\n".join( diff = "\n".join(difflib.unified_diff(
difflib.unified_diff( old_lines, lines[best_start : best_start + window],
old_lines, fromfile="old_text (provided)",
lines[best_start : best_start + window], tofile=f"{path} (actual, line {best_start + 1})",
fromfile="old_text (provided)", lineterm="",
tofile=f"{path} (actual, line {best_start + 1})", ))
lineterm="",
)
)
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
return ( return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
)
class ListDirTool(Tool): # ---------------------------------------------------------------------------
"""Tool to list directory contents.""" # list_dir
# ---------------------------------------------------------------------------
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): class ListDirTool(_FsTool):
self._workspace = workspace """List directory contents with optional recursion."""
self._allowed_dir = allowed_dir
_DEFAULT_MAX = 200
_IGNORE_DIRS = {
".git", "node_modules", "__pycache__", ".venv", "venv",
"dist", "build", ".tox", ".mypy_cache", ".pytest_cache",
".ruff_cache", ".coverage", "htmlcov",
}
@property @property
def name(self) -> str: def name(self) -> str:
@@ -205,34 +295,71 @@ class ListDirTool(Tool):
@property @property
def description(self) -> str: def description(self) -> str:
return "List the contents of a directory." return (
"List the contents of a directory. "
"Set recursive=true to explore nested structure. "
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
)
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
"type": "object", "type": "object",
"properties": {"path": {"type": "string", "description": "The directory path to list"}}, "properties": {
"path": {"type": "string", "description": "The directory path to list"},
"recursive": {
"type": "boolean",
"description": "Recursively list all files (default false)",
},
"max_entries": {
"type": "integer",
"description": "Maximum entries to return (default 200)",
"minimum": 1,
},
},
"required": ["path"], "required": ["path"],
} }
async def execute(self, path: str, **kwargs: Any) -> str: async def execute(
self, path: str, recursive: bool = False,
max_entries: int | None = None, **kwargs: Any,
) -> str:
try: try:
dir_path = _resolve_path(path, self._workspace, self._allowed_dir) dp = self._resolve(path)
if not dir_path.exists(): if not dp.exists():
return f"Error: Directory not found: {path}" return f"Error: Directory not found: {path}"
if not dir_path.is_dir(): if not dp.is_dir():
return f"Error: Not a directory: {path}" return f"Error: Not a directory: {path}"
items = [] cap = max_entries or self._DEFAULT_MAX
for item in sorted(dir_path.iterdir()): items: list[str] = []
prefix = "📁 " if item.is_dir() else "📄 " total = 0
items.append(f"{prefix}{item.name}")
if not items: if recursive:
for item in sorted(dp.rglob("*")):
if any(p in self._IGNORE_DIRS for p in item.parts):
continue
total += 1
if len(items) < cap:
rel = item.relative_to(dp)
items.append(f"{rel}/" if item.is_dir() else str(rel))
else:
for item in sorted(dp.iterdir()):
if item.name in self._IGNORE_DIRS:
continue
total += 1
if len(items) < cap:
pfx = "📁 " if item.is_dir() else "📄 "
items.append(f"{pfx}{item.name}")
if not items and total == 0:
return f"Directory {path} is empty" return f"Directory {path} is empty"
return "\n".join(items) result = "\n".join(items)
if total > cap:
result += f"\n\n(truncated, showing first {cap} of {total} entries)"
return result
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error listing directory: {str(e)}" return f"Error listing directory: {e}"

View File

@@ -42,6 +42,9 @@ class ExecTool(Tool):
def name(self) -> str: def name(self) -> str:
return "exec" return "exec"
_MAX_TIMEOUT = 600
_MAX_OUTPUT = 10_000
@property @property
def description(self) -> str: def description(self) -> str:
return "Execute a shell command and return its output. Use with caution." return "Execute a shell command and return its output. Use with caution."
@@ -53,22 +56,36 @@ class ExecTool(Tool):
"properties": { "properties": {
"command": { "command": {
"type": "string", "type": "string",
"description": "The shell command to execute" "description": "The shell command to execute",
}, },
"working_dir": { "working_dir": {
"type": "string", "type": "string",
"description": "Optional working directory for the command" "description": "Optional working directory for the command",
} },
"timeout": {
"type": "integer",
"description": (
"Timeout in seconds. Increase for long-running commands "
"like compilation or installation (default 60, max 600)."
),
"minimum": 1,
"maximum": 600,
},
}, },
"required": ["command"] "required": ["command"],
} }
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str: async def execute(
self, command: str, working_dir: str | None = None,
timeout: int | None = None, **kwargs: Any,
) -> str:
cwd = working_dir or self.working_dir or os.getcwd() cwd = working_dir or self.working_dir or os.getcwd()
guard_error = self._guard_command(command, cwd) guard_error = self._guard_command(command, cwd)
if guard_error: if guard_error:
return guard_error return guard_error
effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT)
env = os.environ.copy() env = os.environ.copy()
if self.path_append: if self.path_append:
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
@@ -81,44 +98,46 @@ class ExecTool(Tool):
cwd=cwd, cwd=cwd,
env=env, env=env,
) )
try: try:
stdout, stderr = await asyncio.wait_for( stdout, stderr = await asyncio.wait_for(
process.communicate(), process.communicate(),
timeout=self.timeout timeout=effective_timeout,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
process.kill() process.kill()
# Wait for the process to fully terminate so pipes are
# drained and file descriptors are released.
try: try:
await asyncio.wait_for(process.wait(), timeout=5.0) await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
return f"Error: Command timed out after {self.timeout} seconds" return f"Error: Command timed out after {effective_timeout} seconds"
output_parts = [] output_parts = []
if stdout: if stdout:
output_parts.append(stdout.decode("utf-8", errors="replace")) output_parts.append(stdout.decode("utf-8", errors="replace"))
if stderr: if stderr:
stderr_text = stderr.decode("utf-8", errors="replace") stderr_text = stderr.decode("utf-8", errors="replace")
if stderr_text.strip(): if stderr_text.strip():
output_parts.append(f"STDERR:\n{stderr_text}") output_parts.append(f"STDERR:\n{stderr_text}")
if process.returncode != 0: output_parts.append(f"\nExit code: {process.returncode}")
output_parts.append(f"\nExit code: {process.returncode}")
result = "\n".join(output_parts) if output_parts else "(no output)" result = "\n".join(output_parts) if output_parts else "(no output)"
# Truncate very long output # Head + tail truncation to preserve both start and end of output
max_len = 10000 max_len = self._MAX_OUTPUT
if len(result) > max_len: if len(result) > max_len:
result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)" half = max_len // 2
result = (
result[:half]
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
+ result[-half:]
)
return result return result
except Exception as e: except Exception as e:
return f"Error executing command: {str(e)}" return f"Error executing command: {str(e)}"

View File

@@ -0,0 +1,251 @@
"""Tests for enhanced filesystem tools: ReadFileTool, EditFileTool, ListDirTool."""
import pytest
from nanobot.agent.tools.filesystem import (
EditFileTool,
ListDirTool,
ReadFileTool,
_find_match,
)
# ---------------------------------------------------------------------------
# ReadFileTool
# ---------------------------------------------------------------------------
class TestReadFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def sample_file(self, tmp_path):
f = tmp_path / "sample.txt"
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
return f
@pytest.mark.asyncio
async def test_basic_read_has_line_numbers(self, tool, sample_file):
result = await tool.execute(path=str(sample_file))
assert "1| line 1" in result
assert "20| line 20" in result
@pytest.mark.asyncio
async def test_offset_and_limit(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=5, limit=3)
assert "5| line 5" in result
assert "7| line 7" in result
assert "8| line 8" not in result
assert "Use offset=8 to continue" in result
@pytest.mark.asyncio
async def test_offset_beyond_end(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=999)
assert "Error" in result
assert "beyond end" in result
@pytest.mark.asyncio
async def test_end_of_file_marker(self, tool, sample_file):
result = await tool.execute(path=str(sample_file), offset=1, limit=9999)
assert "End of file" in result
@pytest.mark.asyncio
async def test_empty_file(self, tool, tmp_path):
f = tmp_path / "empty.txt"
f.write_text("", encoding="utf-8")
result = await tool.execute(path=str(f))
assert "Empty file" in result
@pytest.mark.asyncio
async def test_file_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope.txt"))
assert "Error" in result
assert "not found" in result
@pytest.mark.asyncio
async def test_char_budget_trims(self, tool, tmp_path):
"""When the selected slice exceeds _MAX_CHARS the output is trimmed."""
f = tmp_path / "big.txt"
# Each line is ~110 chars, 2000 lines ≈ 220 KB > 128 KB limit
f.write_text("\n".join("x" * 110 for _ in range(2000)), encoding="utf-8")
result = await tool.execute(path=str(f))
assert len(result) <= ReadFileTool._MAX_CHARS + 500 # small margin for footer
assert "Use offset=" in result
# ---------------------------------------------------------------------------
# _find_match (unit tests for the helper)
# ---------------------------------------------------------------------------
class TestFindMatch:
def test_exact_match(self):
match, count = _find_match("hello world", "world")
assert match == "world"
assert count == 1
def test_exact_no_match(self):
match, count = _find_match("hello world", "xyz")
assert match is None
assert count == 0
def test_crlf_normalisation(self):
# Caller normalises CRLF before calling _find_match, so test with
# pre-normalised content to verify exact match still works.
content = "line1\nline2\nline3"
old_text = "line1\nline2\nline3"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
def test_line_trim_fallback(self):
content = " def foo():\n pass\n"
old_text = "def foo():\n pass"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
# The returned match should be the *original* indented text
assert " def foo():" in match
def test_line_trim_multiple_candidates(self):
content = " a\n b\n a\n b\n"
old_text = "a\nb"
match, count = _find_match(content, old_text)
assert count == 2
def test_empty_old_text(self):
match, count = _find_match("hello", "")
# Empty string is always "in" any string via exact match
assert match == ""
# ---------------------------------------------------------------------------
# EditFileTool
# ---------------------------------------------------------------------------
class TestEditFileTool:
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_exact_match(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
assert f.read_text() == "hello earth"
@pytest.mark.asyncio
async def test_crlf_normalisation(self, tool, tmp_path):
f = tmp_path / "crlf.py"
f.write_bytes(b"line1\r\nline2\r\nline3")
result = await tool.execute(
path=str(f), old_text="line1\nline2", new_text="LINE1\nLINE2",
)
assert "Successfully" in result
raw = f.read_bytes()
assert b"LINE1" in raw
# CRLF line endings should be preserved throughout the file
assert b"\r\n" in raw
@pytest.mark.asyncio
async def test_trim_fallback(self, tool, tmp_path):
f = tmp_path / "indent.py"
f.write_text(" def foo():\n pass\n", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="def foo():\n pass", new_text="def bar():\n return 1",
)
assert "Successfully" in result
assert "bar" in f.read_text()
@pytest.mark.asyncio
async def test_ambiguous_match(self, tool, tmp_path):
f = tmp_path / "dup.py"
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
assert "appears" in result.lower() or "Warning" in result
@pytest.mark.asyncio
async def test_replace_all(self, tool, tmp_path):
f = tmp_path / "multi.py"
f.write_text("foo bar foo bar foo", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="foo", new_text="baz", replace_all=True,
)
assert "Successfully" in result
assert f.read_text() == "baz bar baz bar baz"
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
f = tmp_path / "nf.py"
f.write_text("hello", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="xyz", new_text="abc")
assert "Error" in result
assert "not found" in result
# ---------------------------------------------------------------------------
# ListDirTool
# ---------------------------------------------------------------------------
class TestListDirTool:
@pytest.fixture()
def tool(self, tmp_path):
return ListDirTool(workspace=tmp_path)
@pytest.fixture()
def populated_dir(self, tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "main.py").write_text("pass")
(tmp_path / "src" / "utils.py").write_text("pass")
(tmp_path / "README.md").write_text("hi")
(tmp_path / ".git").mkdir()
(tmp_path / ".git" / "config").write_text("x")
(tmp_path / "node_modules").mkdir()
(tmp_path / "node_modules" / "pkg").mkdir()
return tmp_path
@pytest.mark.asyncio
async def test_basic_list(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir))
assert "README.md" in result
assert "src" in result
# .git and node_modules should be ignored
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_recursive(self, tool, populated_dir):
result = await tool.execute(path=str(populated_dir), recursive=True)
assert "src/main.py" in result
assert "src/utils.py" in result
assert "README.md" in result
# Ignored dirs should not appear
assert ".git" not in result
assert "node_modules" not in result
@pytest.mark.asyncio
async def test_max_entries_truncation(self, tool, tmp_path):
for i in range(10):
(tmp_path / f"file_{i}.txt").write_text("x")
result = await tool.execute(path=str(tmp_path), max_entries=3)
assert "truncated" in result
assert "3 of 10" in result
@pytest.mark.asyncio
async def test_empty_dir(self, tool, tmp_path):
d = tmp_path / "empty"
d.mkdir()
result = await tool.execute(path=str(d))
assert "empty" in result.lower()
@pytest.mark.asyncio
async def test_not_found(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope"))
assert "Error" in result
assert "not found" in result

View File

@@ -363,3 +363,44 @@ def test_cast_params_single_value_not_auto_wrapped_to_array() -> None:
assert result["items"] == 5 # Not wrapped to [5] assert result["items"] == 5 # Not wrapped to [5]
result = tool.cast_params({"items": "text"}) result = tool.cast_params({"items": "text"})
assert result["items"] == "text" # Not wrapped to ["text"] assert result["items"] == "text" # Not wrapped to ["text"]
# --- ExecTool enhancement tests ---
async def test_exec_always_returns_exit_code() -> None:
"""Exit code should appear in output even on success (exit 0)."""
tool = ExecTool()
result = await tool.execute(command="echo hello")
assert "Exit code: 0" in result
assert "hello" in result
async def test_exec_head_tail_truncation() -> None:
"""Long output should preserve both head and tail."""
tool = ExecTool()
# Generate output that exceeds _MAX_OUTPUT
big = "A" * 6000 + "\n" + "B" * 6000
result = await tool.execute(command=f"echo '{big}'")
assert "chars truncated" in result
# Head portion should start with As
assert result.startswith("A")
# Tail portion should end with the exit code which comes after Bs
assert "Exit code:" in result
async def test_exec_timeout_parameter() -> None:
"""LLM-supplied timeout should override the constructor default."""
tool = ExecTool(timeout=60)
# A very short timeout should cause the command to be killed
result = await tool.execute(command="sleep 10", timeout=1)
assert "timed out" in result
assert "1 seconds" in result
async def test_exec_timeout_capped_at_max() -> None:
"""Timeout values above _MAX_TIMEOUT should be clamped."""
tool = ExecTool()
# Should not raise — just clamp to 600
result = await tool.execute(command="echo ok", timeout=9999)
assert "Exit code: 0" in result