fix: handle image_url rejection by retrying without images
Replace the static provider-level supports_vision check with a reactive fallback: when a model returns an image-unsupported error, strip image_url blocks from messages and retry once. This avoids maintaining an inaccurate vision capability table and correctly handles gateway/unknown model scenarios. Also extract _safe_chat() to deduplicate try/except boilerplate in chat_with_retry().
This commit is contained in:
@@ -89,6 +89,14 @@ class LLMProvider(ABC):
|
|||||||
"server error",
|
"server error",
|
||||||
"temporarily unavailable",
|
"temporarily unavailable",
|
||||||
)
|
)
|
||||||
|
_IMAGE_UNSUPPORTED_MARKERS = (
|
||||||
|
"image_url is only supported",
|
||||||
|
"does not support image",
|
||||||
|
"images are not supported",
|
||||||
|
"image input is not supported",
|
||||||
|
"image_url is not supported",
|
||||||
|
"unsupported image input",
|
||||||
|
)
|
||||||
|
|
||||||
_SENTINEL = object()
|
_SENTINEL = object()
|
||||||
|
|
||||||
@@ -189,6 +197,40 @@ class LLMProvider(ABC):
|
|||||||
err = (content or "").lower()
|
err = (content or "").lower()
|
||||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _is_image_unsupported_error(cls, content: str | None) -> bool:
|
||||||
|
err = (content or "").lower()
|
||||||
|
return any(marker in err for marker in cls._IMAGE_UNSUPPORTED_MARKERS)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
||||||
|
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
||||||
|
found = False
|
||||||
|
result = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
new_content = []
|
||||||
|
for b in content:
|
||||||
|
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||||
|
new_content.append({"type": "text", "text": "[image omitted]"})
|
||||||
|
found = True
|
||||||
|
else:
|
||||||
|
new_content.append(b)
|
||||||
|
result.append({**msg, "content": new_content})
|
||||||
|
else:
|
||||||
|
result.append(msg)
|
||||||
|
return result if found else None
|
||||||
|
|
||||||
|
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
"""Call chat() and convert unexpected exceptions to error responses."""
|
||||||
|
try:
|
||||||
|
return await self.chat(**kwargs)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
@@ -212,57 +254,34 @@ class LLMProvider(ABC):
|
|||||||
if reasoning_effort is self._SENTINEL:
|
if reasoning_effort is self._SENTINEL:
|
||||||
reasoning_effort = self.generation.reasoning_effort
|
reasoning_effort = self.generation.reasoning_effort
|
||||||
|
|
||||||
|
kw: dict[str, Any] = dict(
|
||||||
|
messages=messages, tools=tools, model=model,
|
||||||
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
|
)
|
||||||
|
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||||
try:
|
response = await self._safe_chat(**kw)
|
||||||
response = await self.chat(
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
model=model,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
reasoning_effort=reasoning_effort,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
response = LLMResponse(
|
|
||||||
content=f"Error calling LLM: {exc}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.finish_reason != "error":
|
if response.finish_reason != "error":
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if not self._is_transient_error(response.content):
|
if not self._is_transient_error(response.content):
|
||||||
|
if self._is_image_unsupported_error(response.content):
|
||||||
|
stripped = self._strip_image_content(messages)
|
||||||
|
if stripped is not None:
|
||||||
|
logger.warning("Model does not support image input, retrying without images")
|
||||||
|
return await self._safe_chat(**{**kw, "messages": stripped})
|
||||||
return response
|
return response
|
||||||
|
|
||||||
err = (response.content or "").lower()
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||||
attempt,
|
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||||
len(self._CHAT_RETRY_DELAYS),
|
(response.content or "")[:120].lower(),
|
||||||
delay,
|
|
||||||
err[:120],
|
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
try:
|
return await self._safe_chat(**kw)
|
||||||
return await self.chat(
|
|
||||||
messages=messages,
|
|
||||||
tools=tools,
|
|
||||||
model=model,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
reasoning_effort=reasoning_effort,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
raise
|
|
||||||
except Exception as exc:
|
|
||||||
return LLMResponse(
|
|
||||||
content=f"Error calling LLM: {exc}",
|
|
||||||
finish_reason="error",
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
|
|||||||
@@ -124,32 +124,6 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
spec = find_by_model(model)
|
spec = find_by_model(model)
|
||||||
return spec is not None and spec.supports_prompt_caching
|
return spec is not None and spec.supports_prompt_caching
|
||||||
|
|
||||||
def _supports_vision(self, model: str) -> bool:
|
|
||||||
"""Return True when the provider supports vision/image inputs."""
|
|
||||||
if self._gateway is not None:
|
|
||||||
return self._gateway.supports_vision
|
|
||||||
spec = find_by_model(model)
|
|
||||||
return spec is None or spec.supports_vision # default True for unknown providers
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _filter_image_url(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""Replace image_url content blocks with [image] placeholder for non-vision models."""
|
|
||||||
filtered = []
|
|
||||||
for msg in messages:
|
|
||||||
content = msg.get("content")
|
|
||||||
if isinstance(content, list):
|
|
||||||
new_content = []
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "image_url":
|
|
||||||
# Replace image with placeholder text
|
|
||||||
new_content.append({"type": "text", "text": "[image]"})
|
|
||||||
else:
|
|
||||||
new_content.append(block)
|
|
||||||
filtered.append({**msg, "content": new_content})
|
|
||||||
else:
|
|
||||||
filtered.append(msg)
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
def _apply_cache_control(
|
def _apply_cache_control(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
@@ -260,10 +234,6 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
model = self._resolve_model(original_model)
|
model = self._resolve_model(original_model)
|
||||||
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
||||||
|
|
||||||
# Filter image_url for non-vision models
|
|
||||||
if not self._supports_vision(original_model):
|
|
||||||
messages = self._filter_image_url(messages)
|
|
||||||
|
|
||||||
if self._supports_cache_control(original_model):
|
if self._supports_cache_control(original_model):
|
||||||
messages, tools = self._apply_cache_control(messages, tools)
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
|
|
||||||
|
|||||||
@@ -61,9 +61,6 @@ class ProviderSpec:
|
|||||||
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
||||||
supports_prompt_caching: bool = False
|
supports_prompt_caching: bool = False
|
||||||
|
|
||||||
# Provider supports vision/image inputs (most modern models do)
|
|
||||||
supports_vision: bool = True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def label(self) -> str:
|
def label(self) -> str:
|
||||||
return self.display_name or self.name.title()
|
return self.display_name or self.name.title()
|
||||||
|
|||||||
@@ -123,3 +123,87 @@ async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
|||||||
assert provider.last_kwargs["temperature"] == 0.9
|
assert provider.last_kwargs["temperature"] == 0.9
|
||||||
assert provider.last_kwargs["max_tokens"] == 9999
|
assert provider.last_kwargs["max_tokens"] == 9999
|
||||||
assert provider.last_kwargs["reasoning_effort"] == "low"
|
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Image-unsupported fallback tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_IMAGE_MSG = [
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "text", "text": "describe this"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||||
|
]},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_image_unsupported_error_retries_without_images() -> None:
|
||||||
|
"""If the model rejects image_url, retry once with images stripped."""
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="Invalid content type. image_url is only supported by certain models",
|
||||||
|
finish_reason="error",
|
||||||
|
),
|
||||||
|
LLMResponse(content="ok, no image"),
|
||||||
|
])
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||||
|
|
||||||
|
assert response.content == "ok, no image"
|
||||||
|
assert provider.calls == 2
|
||||||
|
msgs_on_retry = provider.last_kwargs["messages"]
|
||||||
|
for msg in msgs_on_retry:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
assert all(b.get("type") != "image_url" for b in content)
|
||||||
|
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_image_unsupported_error_no_retry_without_image_content() -> None:
|
||||||
|
"""If messages don't contain image_url blocks, don't retry on image error."""
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="image_url is only supported by certain models",
|
||||||
|
finish_reason="error",
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.calls == 1
|
||||||
|
assert response.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_image_unsupported_fallback_returns_error_on_second_failure() -> None:
|
||||||
|
"""If the image-stripped retry also fails, return that error."""
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="does not support image input",
|
||||||
|
finish_reason="error",
|
||||||
|
),
|
||||||
|
LLMResponse(content="some other error", finish_reason="error"),
|
||||||
|
])
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||||
|
|
||||||
|
assert provider.calls == 2
|
||||||
|
assert response.content == "some other error"
|
||||||
|
assert response.finish_reason == "error"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_image_error_does_not_trigger_image_fallback() -> None:
|
||||||
|
"""Regular non-transient errors must not trigger image stripping."""
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="401 unauthorized", finish_reason="error"),
|
||||||
|
])
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||||
|
|
||||||
|
assert provider.calls == 1
|
||||||
|
assert response.content == "401 unauthorized"
|
||||||
|
|||||||
Reference in New Issue
Block a user