Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -26,6 +26,7 @@ from nanobot.agent.i18n import (
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
@@ -177,7 +178,9 @@ class AgentLoop:
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
@@ -97,7 +98,8 @@ class SubagentManager:
|
||||
# Build subagent tools (no message tool, no spawn tool)
|
||||
tools = ToolRegistry()
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
|
||||
@@ -8,7 +8,10 @@ from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
def _resolve_path(
|
||||
path: str, workspace: Path | None = None, allowed_dir: Path | None = None
|
||||
path: str,
|
||||
workspace: Path | None = None,
|
||||
allowed_dir: Path | None = None,
|
||||
extra_allowed_dirs: list[Path] | None = None,
|
||||
) -> Path:
|
||||
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||
p = Path(path).expanduser()
|
||||
@@ -16,22 +19,35 @@ def _resolve_path(
|
||||
p = workspace / p
|
||||
resolved = p.resolve()
|
||||
if allowed_dir:
|
||||
try:
|
||||
resolved.relative_to(allowed_dir.resolve())
|
||||
except ValueError:
|
||||
all_dirs = [allowed_dir] + (extra_allowed_dirs or [])
|
||||
if not any(_is_under(resolved, d) for d in all_dirs):
|
||||
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_under(path: Path, directory: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(directory.resolve())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class _FsTool(Tool):
|
||||
"""Shared base for filesystem tools — common init and path resolution."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path | None = None,
|
||||
allowed_dir: Path | None = None,
|
||||
extra_allowed_dirs: list[Path] | None = None,
|
||||
):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
self._extra_allowed_dirs = extra_allowed_dirs
|
||||
|
||||
def _resolve(self, path: str) -> Path:
|
||||
return _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
return _resolve_path(path, self._workspace, self._allowed_dir, self._extra_allowed_dirs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -62,6 +62,49 @@ class NanobotDingTalkHandler(CallbackHandler):
|
||||
if not content:
|
||||
content = message.data.get("text", {}).get("content", "").strip()
|
||||
|
||||
# Handle file/image messages
|
||||
file_paths = []
|
||||
if chatbot_msg.message_type == "picture" and chatbot_msg.image_content:
|
||||
download_code = chatbot_msg.image_content.download_code
|
||||
if download_code:
|
||||
sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown"
|
||||
fp = await self.channel._download_dingtalk_file(download_code, "image.jpg", sender_uid)
|
||||
if fp:
|
||||
file_paths.append(fp)
|
||||
content = content or "[Image]"
|
||||
|
||||
elif chatbot_msg.message_type == "file":
|
||||
download_code = message.data.get("content", {}).get("downloadCode") or message.data.get("downloadCode")
|
||||
fname = message.data.get("content", {}).get("fileName") or message.data.get("fileName") or "file"
|
||||
if download_code:
|
||||
sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown"
|
||||
fp = await self.channel._download_dingtalk_file(download_code, fname, sender_uid)
|
||||
if fp:
|
||||
file_paths.append(fp)
|
||||
content = content or "[File]"
|
||||
|
||||
elif chatbot_msg.message_type == "richText" and chatbot_msg.rich_text_content:
|
||||
rich_list = chatbot_msg.rich_text_content.rich_text_list or []
|
||||
for item in rich_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
t = item.get("text", "").strip()
|
||||
if t:
|
||||
content = (content + " " + t).strip() if content else t
|
||||
elif item.get("downloadCode"):
|
||||
dc = item["downloadCode"]
|
||||
fname = item.get("fileName") or "file"
|
||||
sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown"
|
||||
fp = await self.channel._download_dingtalk_file(dc, fname, sender_uid)
|
||||
if fp:
|
||||
file_paths.append(fp)
|
||||
content = content or "[File]"
|
||||
|
||||
if file_paths:
|
||||
file_list = "\n".join("- " + p for p in file_paths)
|
||||
content = content + "\n\nReceived files:\n" + file_list
|
||||
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Received empty or unsupported message type: {}",
|
||||
@@ -472,3 +515,50 @@ class DingTalkChannel(BaseChannel):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error publishing DingTalk message: {}", e)
|
||||
|
||||
async def _download_dingtalk_file(
|
||||
self,
|
||||
download_code: str,
|
||||
filename: str,
|
||||
sender_id: str,
|
||||
) -> str | None:
|
||||
"""Download a DingTalk file to the media directory, return local path."""
|
||||
from nanobot.config.paths import get_media_dir
|
||||
|
||||
try:
|
||||
token = await self._get_access_token()
|
||||
if not token or not self._http:
|
||||
logger.error("DingTalk file download: no token or http client")
|
||||
return None
|
||||
|
||||
# Step 1: Exchange downloadCode for a temporary download URL
|
||||
api_url = "https://api.dingtalk.com/v1.0/robot/messageFiles/download"
|
||||
headers = {"x-acs-dingtalk-access-token": token, "Content-Type": "application/json"}
|
||||
payload = {"downloadCode": download_code, "robotCode": self.config.client_id}
|
||||
resp = await self._http.post(api_url, json=payload, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
logger.error("DingTalk get download URL failed: status={}, body={}", resp.status_code, resp.text)
|
||||
return None
|
||||
|
||||
result = resp.json()
|
||||
download_url = result.get("downloadUrl")
|
||||
if not download_url:
|
||||
logger.error("DingTalk download URL not found in response: {}", result)
|
||||
return None
|
||||
|
||||
# Step 2: Download the file content
|
||||
file_resp = await self._http.get(download_url, follow_redirects=True)
|
||||
if file_resp.status_code != 200:
|
||||
logger.error("DingTalk file download failed: status={}", file_resp.status_code)
|
||||
return None
|
||||
|
||||
# Save to media directory (accessible under workspace)
|
||||
download_dir = get_media_dir("dingtalk") / sender_id
|
||||
download_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = download_dir / filename
|
||||
await asyncio.to_thread(file_path.write_bytes, file_resp.content)
|
||||
logger.info("DingTalk file saved: {}", file_path)
|
||||
return str(file_path)
|
||||
except Exception as e:
|
||||
logger.error("DingTalk file download error: {}", e)
|
||||
return None
|
||||
|
||||
@@ -89,6 +89,14 @@ class LLMProvider(ABC):
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
)
|
||||
_IMAGE_UNSUPPORTED_MARKERS = (
|
||||
"image_url is only supported",
|
||||
"does not support image",
|
||||
"images are not supported",
|
||||
"image input is not supported",
|
||||
"image_url is not supported",
|
||||
"unsupported image input",
|
||||
)
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
@@ -189,6 +197,40 @@ class LLMProvider(ABC):
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
@classmethod
|
||||
def _is_image_unsupported_error(cls, content: str | None) -> bool:
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._IMAGE_UNSUPPORTED_MARKERS)
|
||||
|
||||
@staticmethod
|
||||
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
||||
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
||||
found = False
|
||||
result = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
new_content = []
|
||||
for b in content:
|
||||
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||
new_content.append({"type": "text", "text": "[image omitted]"})
|
||||
found = True
|
||||
else:
|
||||
new_content.append(b)
|
||||
result.append({**msg, "content": new_content})
|
||||
else:
|
||||
result.append(msg)
|
||||
return result if found else None
|
||||
|
||||
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
|
||||
"""Call chat() and convert unexpected exceptions to error responses."""
|
||||
try:
|
||||
return await self.chat(**kwargs)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -212,57 +254,34 @@ class LLMProvider(ABC):
|
||||
if reasoning_effort is self._SENTINEL:
|
||||
reasoning_effort = self.generation.reasoning_effort
|
||||
|
||||
kw: dict[str, Any] = dict(
|
||||
messages=messages, tools=tools, model=model,
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
response = LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
response = await self._safe_chat(**kw)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
if self._is_image_unsupported_error(response.content):
|
||||
stripped = self._strip_image_content(messages)
|
||||
if stripped is not None:
|
||||
logger.warning("Model does not support image input, retrying without images")
|
||||
return await self._safe_chat(**{**kw, "messages": stripped})
|
||||
return response
|
||||
|
||||
err = (response.content or "").lower()
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt,
|
||||
len(self._CHAT_RETRY_DELAYS),
|
||||
delay,
|
||||
err[:120],
|
||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {exc}",
|
||||
finish_reason="error",
|
||||
)
|
||||
return await self._safe_chat(**kw)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
|
||||
@@ -14,19 +14,31 @@ class _FakeResponse:
|
||||
self.status_code = status_code
|
||||
self._json_body = json_body or {}
|
||||
self.text = "{}"
|
||||
self.content = b""
|
||||
self.headers = {"content-type": "application/json"}
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._json_body
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, responses: list[_FakeResponse] | None = None) -> None:
|
||||
self.calls: list[dict] = []
|
||||
self._responses = list(responses) if responses else []
|
||||
|
||||
async def post(self, url: str, json=None, headers=None):
|
||||
self.calls.append({"url": url, "json": json, "headers": headers})
|
||||
def _next_response(self) -> _FakeResponse:
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return _FakeResponse()
|
||||
|
||||
async def post(self, url: str, json=None, headers=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
|
||||
return self._next_response()
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
return self._next_response()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||
@@ -109,3 +121,93 @@ async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatc
|
||||
assert msg.content == "voice transcript"
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:conv123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_processes_file_message(monkeypatch) -> None:
|
||||
"""Test that file messages are handled and forwarded with downloaded path."""
|
||||
bus = MessageBus()
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
|
||||
bus,
|
||||
)
|
||||
handler = NanobotDingTalkHandler(channel)
|
||||
|
||||
class _FakeFileChatbotMessage:
|
||||
text = None
|
||||
extensions = {}
|
||||
image_content = None
|
||||
rich_text_content = None
|
||||
sender_staff_id = "user1"
|
||||
sender_id = "fallback-user"
|
||||
sender_nick = "Alice"
|
||||
message_type = "file"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(_data):
|
||||
return _FakeFileChatbotMessage()
|
||||
|
||||
async def fake_download(download_code, filename, sender_id):
|
||||
return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}"
|
||||
|
||||
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage)
|
||||
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
|
||||
monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download)
|
||||
|
||||
status, body = await handler.process(
|
||||
SimpleNamespace(
|
||||
data={
|
||||
"conversationType": "1",
|
||||
"content": {"downloadCode": "abc123", "fileName": "report.xlsx"},
|
||||
"text": {"content": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*list(channel._background_tasks))
|
||||
msg = await bus.consume_inbound()
|
||||
|
||||
assert (status, body) == ("OK", "OK")
|
||||
assert "[File]" in msg.content
|
||||
assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
|
||||
"""Test the two-step file download flow (get URL then download content)."""
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
# Mock access token
|
||||
async def fake_get_token():
|
||||
return "test-token"
|
||||
|
||||
monkeypatch.setattr(channel, "_get_access_token", fake_get_token)
|
||||
|
||||
# Mock HTTP: first POST returns downloadUrl, then GET returns file bytes
|
||||
file_content = b"fake file content"
|
||||
channel._http = _FakeHttp(responses=[
|
||||
_FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}),
|
||||
_FakeResponse(200),
|
||||
])
|
||||
channel._http._responses[1].content = file_content
|
||||
|
||||
# Redirect media dir to tmp_path
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.paths.get_media_dir",
|
||||
lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path,
|
||||
)
|
||||
|
||||
result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1")
|
||||
|
||||
assert result is not None
|
||||
assert result.endswith("test.xlsx")
|
||||
assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content
|
||||
|
||||
# Verify API calls
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
assert "messageFiles/download" in channel._http.calls[0]["url"]
|
||||
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
|
||||
assert channel._http.calls[1]["method"] == "GET"
|
||||
|
||||
@@ -249,3 +249,114 @@ class TestListDirTool:
|
||||
result = await tool.execute(path=str(tmp_path / "nope"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace restriction + extra_allowed_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWorkspaceRestriction:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_blocked_outside_workspace(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
secret = outside / "secret.txt"
|
||||
secret.write_text("top secret")
|
||||
|
||||
tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(path=str(secret))
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_allowed_with_extra_dir(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
skill_file = skills_dir / "test_skill" / "SKILL.md"
|
||||
skill_file.parent.mkdir()
|
||||
skill_file.write_text("# Test Skill\nDo something.")
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(skill_file))
|
||||
assert "Test Skill" in result
|
||||
assert "Error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_dirs_does_not_widen_write(self, tmp_path):
|
||||
from nanobot.agent.tools.filesystem import WriteFileTool
|
||||
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside"
|
||||
outside.mkdir()
|
||||
|
||||
tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
unrelated = tmp_path / "other"
|
||||
unrelated.mkdir()
|
||||
secret = unrelated / "secret.txt"
|
||||
secret.write_text("nope")
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(secret))
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
|
||||
"""Adding extra_allowed_dirs must not break normal workspace reads."""
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
ws_file = workspace / "README.md"
|
||||
ws_file.write_text("hello from workspace")
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
tool = ReadFileTool(
|
||||
workspace=workspace, allowed_dir=workspace,
|
||||
extra_allowed_dirs=[skills_dir],
|
||||
)
|
||||
result = await tool.execute(path=str(ws_file))
|
||||
assert "hello from workspace" in result
|
||||
assert "Error" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_blocked_in_extra_dir(self, tmp_path):
|
||||
"""edit_file must not be able to modify files in extra_allowed_dirs."""
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
skill_file = skills_dir / "weather" / "SKILL.md"
|
||||
skill_file.parent.mkdir()
|
||||
skill_file.write_text("# Weather\nOriginal content.")
|
||||
|
||||
tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
|
||||
result = await tool.execute(
|
||||
path=str(skill_file),
|
||||
old_text="Original content.",
|
||||
new_text="Hacked content.",
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "outside" in result.lower()
|
||||
assert skill_file.read_text() == "# Weather\nOriginal content."
|
||||
|
||||
@@ -123,3 +123,87 @@ async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
||||
assert provider.last_kwargs["temperature"] == 0.9
|
||||
assert provider.last_kwargs["max_tokens"] == 9999
|
||||
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image-unsupported fallback tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMAGE_MSG = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
]},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_unsupported_error_retries_without_images() -> None:
|
||||
"""If the model rejects image_url, retry once with images stripped."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="Invalid content type. image_url is only supported by certain models",
|
||||
finish_reason="error",
|
||||
),
|
||||
LLMResponse(content="ok, no image"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||
|
||||
assert response.content == "ok, no image"
|
||||
assert provider.calls == 2
|
||||
msgs_on_retry = provider.last_kwargs["messages"]
|
||||
for msg in msgs_on_retry:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
assert all(b.get("type") != "image_url" for b in content)
|
||||
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_unsupported_error_no_retry_without_image_content() -> None:
|
||||
"""If messages don't contain image_url blocks, don't retry on image error."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="image_url is only supported by certain models",
|
||||
finish_reason="error",
|
||||
),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
)
|
||||
|
||||
assert provider.calls == 1
|
||||
assert response.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_unsupported_fallback_returns_error_on_second_failure() -> None:
|
||||
"""If the image-stripped retry also fails, return that error."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="does not support image input",
|
||||
finish_reason="error",
|
||||
),
|
||||
LLMResponse(content="some other error", finish_reason="error"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||
|
||||
assert provider.calls == 2
|
||||
assert response.content == "some other error"
|
||||
assert response.finish_reason == "error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_image_error_does_not_trigger_image_fallback() -> None:
|
||||
"""Regular non-transient errors must not trigger image stripping."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||
|
||||
assert provider.calls == 1
|
||||
assert response.content == "401 unauthorized"
|
||||
|
||||
Reference in New Issue
Block a user