Merge remote-tracking branch 'origin/main'

This commit is contained in:
Hua
2026-03-16 09:43:17 +08:00
8 changed files with 477 additions and 50 deletions

View File

@@ -26,6 +26,7 @@ from nanobot.agent.i18n import (
from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
@@ -177,7 +178,9 @@ class AgentLoop:
def _register_default_tools(self) -> None: def _register_default_tools(self) -> None:
"""Register the default set of tools.""" """Register the default set of tools."""
allowed_dir = self.workspace if self.restrict_to_workspace else None allowed_dir = self.workspace if self.restrict_to_workspace else None
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool): extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
for cls in (WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(ExecTool( self.tools.register(ExecTool(
working_dir=str(self.workspace), working_dir=str(self.workspace),

View File

@@ -8,6 +8,7 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool from nanobot.agent.tools.shell import ExecTool
@@ -97,7 +98,8 @@ class SubagentManager:
# Build subagent tools (no message tool, no spawn tool) # Build subagent tools (no message tool, no spawn tool)
tools = ToolRegistry() tools = ToolRegistry()
allowed_dir = self.workspace if self.restrict_to_workspace else None allowed_dir = self.workspace if self.restrict_to_workspace else None
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))

View File

@@ -8,7 +8,10 @@ from nanobot.agent.tools.base import Tool
def _resolve_path( def _resolve_path(
path: str, workspace: Path | None = None, allowed_dir: Path | None = None path: str,
workspace: Path | None = None,
allowed_dir: Path | None = None,
extra_allowed_dirs: list[Path] | None = None,
) -> Path: ) -> Path:
"""Resolve path against workspace (if relative) and enforce directory restriction.""" """Resolve path against workspace (if relative) and enforce directory restriction."""
p = Path(path).expanduser() p = Path(path).expanduser()
@@ -16,22 +19,35 @@ def _resolve_path(
p = workspace / p p = workspace / p
resolved = p.resolve() resolved = p.resolve()
if allowed_dir: if allowed_dir:
try: all_dirs = [allowed_dir] + (extra_allowed_dirs or [])
resolved.relative_to(allowed_dir.resolve()) if not any(_is_under(resolved, d) for d in all_dirs):
except ValueError:
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}") raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
return resolved return resolved
def _is_under(path: Path, directory: Path) -> bool:
try:
path.relative_to(directory.resolve())
return True
except ValueError:
return False
class _FsTool(Tool): class _FsTool(Tool):
"""Shared base for filesystem tools — common init and path resolution.""" """Shared base for filesystem tools — common init and path resolution."""
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None): def __init__(
self,
workspace: Path | None = None,
allowed_dir: Path | None = None,
extra_allowed_dirs: list[Path] | None = None,
):
self._workspace = workspace self._workspace = workspace
self._allowed_dir = allowed_dir self._allowed_dir = allowed_dir
self._extra_allowed_dirs = extra_allowed_dirs
def _resolve(self, path: str) -> Path: def _resolve(self, path: str) -> Path:
return _resolve_path(path, self._workspace, self._allowed_dir) return _resolve_path(path, self._workspace, self._allowed_dir, self._extra_allowed_dirs)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -62,6 +62,49 @@ class NanobotDingTalkHandler(CallbackHandler):
if not content: if not content:
content = message.data.get("text", {}).get("content", "").strip() content = message.data.get("text", {}).get("content", "").strip()
# Handle file/image messages
file_paths = []
if chatbot_msg.message_type == "picture" and chatbot_msg.image_content:
download_code = chatbot_msg.image_content.download_code
if download_code:
sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown"
fp = await self.channel._download_dingtalk_file(download_code, "image.jpg", sender_uid)
if fp:
file_paths.append(fp)
content = content or "[Image]"
elif chatbot_msg.message_type == "file":
download_code = message.data.get("content", {}).get("downloadCode") or message.data.get("downloadCode")
fname = message.data.get("content", {}).get("fileName") or message.data.get("fileName") or "file"
if download_code:
sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown"
fp = await self.channel._download_dingtalk_file(download_code, fname, sender_uid)
if fp:
file_paths.append(fp)
content = content or "[File]"
elif chatbot_msg.message_type == "richText" and chatbot_msg.rich_text_content:
rich_list = chatbot_msg.rich_text_content.rich_text_list or []
for item in rich_list:
if not isinstance(item, dict):
continue
if item.get("type") == "text":
t = item.get("text", "").strip()
if t:
content = (content + " " + t).strip() if content else t
elif item.get("downloadCode"):
dc = item["downloadCode"]
fname = item.get("fileName") or "file"
sender_uid = chatbot_msg.sender_staff_id or chatbot_msg.sender_id or "unknown"
fp = await self.channel._download_dingtalk_file(dc, fname, sender_uid)
if fp:
file_paths.append(fp)
content = content or "[File]"
if file_paths:
file_list = "\n".join("- " + p for p in file_paths)
content = content + "\n\nReceived files:\n" + file_list
if not content: if not content:
logger.warning( logger.warning(
"Received empty or unsupported message type: {}", "Received empty or unsupported message type: {}",
@@ -472,3 +515,50 @@ class DingTalkChannel(BaseChannel):
) )
except Exception as e: except Exception as e:
logger.error("Error publishing DingTalk message: {}", e) logger.error("Error publishing DingTalk message: {}", e)
async def _download_dingtalk_file(
self,
download_code: str,
filename: str,
sender_id: str,
) -> str | None:
"""Download a DingTalk file to the media directory, return local path."""
from nanobot.config.paths import get_media_dir
try:
token = await self._get_access_token()
if not token or not self._http:
logger.error("DingTalk file download: no token or http client")
return None
# Step 1: Exchange downloadCode for a temporary download URL
api_url = "https://api.dingtalk.com/v1.0/robot/messageFiles/download"
headers = {"x-acs-dingtalk-access-token": token, "Content-Type": "application/json"}
payload = {"downloadCode": download_code, "robotCode": self.config.client_id}
resp = await self._http.post(api_url, json=payload, headers=headers)
if resp.status_code != 200:
logger.error("DingTalk get download URL failed: status={}, body={}", resp.status_code, resp.text)
return None
result = resp.json()
download_url = result.get("downloadUrl")
if not download_url:
logger.error("DingTalk download URL not found in response: {}", result)
return None
# Step 2: Download the file content
file_resp = await self._http.get(download_url, follow_redirects=True)
if file_resp.status_code != 200:
logger.error("DingTalk file download failed: status={}", file_resp.status_code)
return None
# Save to media directory (accessible under workspace)
download_dir = get_media_dir("dingtalk") / sender_id
download_dir.mkdir(parents=True, exist_ok=True)
file_path = download_dir / filename
await asyncio.to_thread(file_path.write_bytes, file_resp.content)
logger.info("DingTalk file saved: {}", file_path)
return str(file_path)
except Exception as e:
logger.error("DingTalk file download error: {}", e)
return None

View File

@@ -89,6 +89,14 @@ class LLMProvider(ABC):
"server error", "server error",
"temporarily unavailable", "temporarily unavailable",
) )
_IMAGE_UNSUPPORTED_MARKERS = (
"image_url is only supported",
"does not support image",
"images are not supported",
"image input is not supported",
"image_url is not supported",
"unsupported image input",
)
_SENTINEL = object() _SENTINEL = object()
@@ -189,6 +197,40 @@ class LLMProvider(ABC):
err = (content or "").lower() err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
@classmethod
def _is_image_unsupported_error(cls, content: str | None) -> bool:
err = (content or "").lower()
return any(marker in err for marker in cls._IMAGE_UNSUPPORTED_MARKERS)
@staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
found = False
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
new_content = []
for b in content:
if isinstance(b, dict) and b.get("type") == "image_url":
new_content.append({"type": "text", "text": "[image omitted]"})
found = True
else:
new_content.append(b)
result.append({**msg, "content": new_content})
else:
result.append(msg)
return result if found else None
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
"""Call chat() and convert unexpected exceptions to error responses."""
try:
return await self.chat(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_with_retry( async def chat_with_retry(
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
@@ -212,57 +254,34 @@ class LLMProvider(ABC):
if reasoning_effort is self._SENTINEL: if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
try: response = await self._safe_chat(**kw)
response = await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
except asyncio.CancelledError:
raise
except Exception as exc:
response = LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
if response.finish_reason != "error": if response.finish_reason != "error":
return response return response
if not self._is_transient_error(response.content): if not self._is_transient_error(response.content):
if self._is_image_unsupported_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Model does not support image input, retrying without images")
return await self._safe_chat(**{**kw, "messages": stripped})
return response return response
err = (response.content or "").lower()
logger.warning( logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}", "LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, attempt, len(self._CHAT_RETRY_DELAYS), delay,
len(self._CHAT_RETRY_DELAYS), (response.content or "")[:120].lower(),
delay,
err[:120],
) )
await asyncio.sleep(delay) await asyncio.sleep(delay)
try: return await self._safe_chat(**kw)
return await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
@abstractmethod @abstractmethod
def get_default_model(self) -> str: def get_default_model(self) -> str:

View File

@@ -14,19 +14,31 @@ class _FakeResponse:
self.status_code = status_code self.status_code = status_code
self._json_body = json_body or {} self._json_body = json_body or {}
self.text = "{}" self.text = "{}"
self.content = b""
self.headers = {"content-type": "application/json"}
def json(self) -> dict: def json(self) -> dict:
return self._json_body return self._json_body
class _FakeHttp: class _FakeHttp:
def __init__(self) -> None: def __init__(self, responses: list[_FakeResponse] | None = None) -> None:
self.calls: list[dict] = [] self.calls: list[dict] = []
self._responses = list(responses) if responses else []
async def post(self, url: str, json=None, headers=None): def _next_response(self) -> _FakeResponse:
self.calls.append({"url": url, "json": json, "headers": headers}) if self._responses:
return self._responses.pop(0)
return _FakeResponse() return _FakeResponse()
async def post(self, url: str, json=None, headers=None, **kwargs):
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
return self._next_response()
async def get(self, url: str, **kwargs):
self.calls.append({"method": "GET", "url": url})
return self._next_response()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None: async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
@@ -109,3 +121,93 @@ async def test_handler_uses_voice_recognition_text_when_text_is_empty(monkeypatc
assert msg.content == "voice transcript" assert msg.content == "voice transcript"
assert msg.sender_id == "user1" assert msg.sender_id == "user1"
assert msg.chat_id == "group:conv123" assert msg.chat_id == "group:conv123"
@pytest.mark.asyncio
async def test_handler_processes_file_message(monkeypatch) -> None:
"""Test that file messages are handled and forwarded with downloaded path."""
bus = MessageBus()
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"]),
bus,
)
handler = NanobotDingTalkHandler(channel)
class _FakeFileChatbotMessage:
text = None
extensions = {}
image_content = None
rich_text_content = None
sender_staff_id = "user1"
sender_id = "fallback-user"
sender_nick = "Alice"
message_type = "file"
@staticmethod
def from_dict(_data):
return _FakeFileChatbotMessage()
async def fake_download(download_code, filename, sender_id):
return f"/tmp/nanobot_dingtalk/{sender_id}/{filename}"
monkeypatch.setattr(dingtalk_module, "ChatbotMessage", _FakeFileChatbotMessage)
monkeypatch.setattr(dingtalk_module, "AckMessage", SimpleNamespace(STATUS_OK="OK"))
monkeypatch.setattr(channel, "_download_dingtalk_file", fake_download)
status, body = await handler.process(
SimpleNamespace(
data={
"conversationType": "1",
"content": {"downloadCode": "abc123", "fileName": "report.xlsx"},
"text": {"content": ""},
}
)
)
await asyncio.gather(*list(channel._background_tasks))
msg = await bus.consume_inbound()
assert (status, body) == ("OK", "OK")
assert "[File]" in msg.content
assert "/tmp/nanobot_dingtalk/user1/report.xlsx" in msg.content
@pytest.mark.asyncio
async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
"""Test the two-step file download flow (get URL then download content)."""
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
MessageBus(),
)
# Mock access token
async def fake_get_token():
return "test-token"
monkeypatch.setattr(channel, "_get_access_token", fake_get_token)
# Mock HTTP: first POST returns downloadUrl, then GET returns file bytes
file_content = b"fake file content"
channel._http = _FakeHttp(responses=[
_FakeResponse(200, {"downloadUrl": "https://example.com/tmpfile"}),
_FakeResponse(200),
])
channel._http._responses[1].content = file_content
# Redirect media dir to tmp_path
monkeypatch.setattr(
"nanobot.config.paths.get_media_dir",
lambda channel_name=None: tmp_path / channel_name if channel_name else tmp_path,
)
result = await channel._download_dingtalk_file("code123", "test.xlsx", "user1")
assert result is not None
assert result.endswith("test.xlsx")
assert (tmp_path / "dingtalk" / "user1" / "test.xlsx").read_bytes() == file_content
# Verify API calls
assert channel._http.calls[0]["method"] == "POST"
assert "messageFiles/download" in channel._http.calls[0]["url"]
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
assert channel._http.calls[1]["method"] == "GET"

View File

@@ -249,3 +249,114 @@ class TestListDirTool:
result = await tool.execute(path=str(tmp_path / "nope")) result = await tool.execute(path=str(tmp_path / "nope"))
assert "Error" in result assert "Error" in result
assert "not found" in result assert "not found" in result
# ---------------------------------------------------------------------------
# Workspace restriction + extra_allowed_dirs
# ---------------------------------------------------------------------------
class TestWorkspaceRestriction:
@pytest.mark.asyncio
async def test_read_blocked_outside_workspace(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
secret = outside / "secret.txt"
secret.write_text("top secret")
tool = ReadFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_allowed_with_extra_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "test_skill" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Test Skill\nDo something.")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(skill_file))
assert "Test Skill" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_extra_dirs_does_not_widen_write(self, tmp_path):
from nanobot.agent.tools.filesystem import WriteFileTool
workspace = tmp_path / "ws"
workspace.mkdir()
outside = tmp_path / "outside"
outside.mkdir()
tool = WriteFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(path=str(outside / "hack.txt"), content="pwned")
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_read_still_blocked_for_unrelated_dir(self, tmp_path):
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
unrelated = tmp_path / "other"
unrelated.mkdir()
secret = unrelated / "secret.txt"
secret.write_text("nope")
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(secret))
assert "Error" in result
assert "outside" in result.lower()
@pytest.mark.asyncio
async def test_workspace_file_still_readable_with_extra_dirs(self, tmp_path):
"""Adding extra_allowed_dirs must not break normal workspace reads."""
workspace = tmp_path / "ws"
workspace.mkdir()
ws_file = workspace / "README.md"
ws_file.write_text("hello from workspace")
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
tool = ReadFileTool(
workspace=workspace, allowed_dir=workspace,
extra_allowed_dirs=[skills_dir],
)
result = await tool.execute(path=str(ws_file))
assert "hello from workspace" in result
assert "Error" not in result
@pytest.mark.asyncio
async def test_edit_blocked_in_extra_dir(self, tmp_path):
"""edit_file must not be able to modify files in extra_allowed_dirs."""
workspace = tmp_path / "ws"
workspace.mkdir()
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
skill_file = skills_dir / "weather" / "SKILL.md"
skill_file.parent.mkdir()
skill_file.write_text("# Weather\nOriginal content.")
tool = EditFileTool(workspace=workspace, allowed_dir=workspace)
result = await tool.execute(
path=str(skill_file),
old_text="Original content.",
new_text="Hacked content.",
)
assert "Error" in result
assert "outside" in result.lower()
assert skill_file.read_text() == "# Weather\nOriginal content."

View File

@@ -123,3 +123,87 @@ async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
assert provider.last_kwargs["temperature"] == 0.9 assert provider.last_kwargs["temperature"] == 0.9
assert provider.last_kwargs["max_tokens"] == 9999 assert provider.last_kwargs["max_tokens"] == 9999
assert provider.last_kwargs["reasoning_effort"] == "low" assert provider.last_kwargs["reasoning_effort"] == "low"
# ---------------------------------------------------------------------------
# Image-unsupported fallback tests
# ---------------------------------------------------------------------------
_IMAGE_MSG = [
{"role": "user", "content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]},
]
@pytest.mark.asyncio
async def test_image_unsupported_error_retries_without_images() -> None:
"""If the model rejects image_url, retry once with images stripped."""
provider = ScriptedProvider([
LLMResponse(
content="Invalid content type. image_url is only supported by certain models",
finish_reason="error",
),
LLMResponse(content="ok, no image"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert response.content == "ok, no image"
assert provider.calls == 2
msgs_on_retry = provider.last_kwargs["messages"]
for msg in msgs_on_retry:
content = msg.get("content")
if isinstance(content, list):
assert all(b.get("type") != "image_url" for b in content)
assert any("[image omitted]" in (b.get("text") or "") for b in content)
@pytest.mark.asyncio
async def test_image_unsupported_error_no_retry_without_image_content() -> None:
"""If messages don't contain image_url blocks, don't retry on image error."""
provider = ScriptedProvider([
LLMResponse(
content="image_url is only supported by certain models",
finish_reason="error",
),
])
response = await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
)
assert provider.calls == 1
assert response.finish_reason == "error"
@pytest.mark.asyncio
async def test_image_unsupported_fallback_returns_error_on_second_failure() -> None:
"""If the image-stripped retry also fails, return that error."""
provider = ScriptedProvider([
LLMResponse(
content="does not support image input",
finish_reason="error",
),
LLMResponse(content="some other error", finish_reason="error"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert provider.calls == 2
assert response.content == "some other error"
assert response.finish_reason == "error"
@pytest.mark.asyncio
async def test_non_image_error_does_not_trigger_image_fallback() -> None:
"""Regular non-transient errors must not trigger image stripping."""
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
assert provider.calls == 1
assert response.content == "401 unauthorized"