diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 114a948..8b6956c 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -89,6 +89,14 @@ class LLMProvider(ABC): "server error", "temporarily unavailable", ) + _IMAGE_UNSUPPORTED_MARKERS = ( + "image_url is only supported", + "does not support image", + "images are not supported", + "image input is not supported", + "image_url is not supported", + "unsupported image input", + ) _SENTINEL = object() @@ -189,6 +197,40 @@ class LLMProvider(ABC): err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) + @classmethod + def _is_image_unsupported_error(cls, content: str | None) -> bool: + err = (content or "").lower() + return any(marker in err for marker in cls._IMAGE_UNSUPPORTED_MARKERS) + + @staticmethod + def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + """Replace image_url blocks with text placeholder. Returns None if no images found.""" + found = False + result = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + new_content = [] + for b in content: + if isinstance(b, dict) and b.get("type") == "image_url": + new_content.append({"type": "text", "text": "[image omitted]"}) + found = True + else: + new_content.append(b) + result.append({**msg, "content": new_content}) + else: + result.append(msg) + return result if found else None + + async def _safe_chat(self, **kwargs: Any) -> LLMResponse: + """Call chat() and convert unexpected exceptions to error responses.""" + try: + return await self.chat(**kwargs) + except asyncio.CancelledError: + raise + except Exception as exc: + return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") + async def chat_with_retry( self, messages: list[dict[str, Any]], @@ -212,57 +254,34 @@ class LLMProvider(ABC): if reasoning_effort is self._SENTINEL: reasoning_effort = self.generation.reasoning_effort + kw: dict[str, Any] = dict( + messages=messages, tools=tools, model=model, + max_tokens=max_tokens, temperature=temperature, + reasoning_effort=reasoning_effort, tool_choice=tool_choice, + ) + for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): - try: - response = await self.chat( - messages=messages, - tools=tools, - model=model, - max_tokens=max_tokens, - temperature=temperature, - reasoning_effort=reasoning_effort, - tool_choice=tool_choice, - ) - except asyncio.CancelledError: - raise - except Exception as exc: - response = LLMResponse( - content=f"Error calling LLM: {exc}", - finish_reason="error", - ) + response = await self._safe_chat(**kw) if response.finish_reason != "error": return response + if not self._is_transient_error(response.content): + if self._is_image_unsupported_error(response.content): + stripped = self._strip_image_content(messages) + if stripped is not None: + logger.warning("Model does not support image input, retrying without images") + return await self._safe_chat(**{**kw, "messages": stripped}) return response - err = (response.content or "").lower() logger.warning( "LLM transient error (attempt {}/{}), retrying in {}s: {}", - attempt, - len(self._CHAT_RETRY_DELAYS), - delay, - err[:120], + attempt, len(self._CHAT_RETRY_DELAYS), delay, + (response.content or "")[:120].lower(), ) await asyncio.sleep(delay) - try: - return await self.chat( - messages=messages, - tools=tools, - model=model, - max_tokens=max_tokens, - temperature=temperature, - reasoning_effort=reasoning_effort, - tool_choice=tool_choice, - ) - except asyncio.CancelledError: - raise - except Exception as exc: - return LLMResponse( - content=f"Error calling LLM: {exc}", - finish_reason="error", - ) + return await self._safe_chat(**kw) @abstractmethod def get_default_model(self) -> str: diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 3dece89..d14e4c0 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -124,32 +124,6 @@ class LiteLLMProvider(LLMProvider): spec = find_by_model(model) return spec is not None and spec.supports_prompt_caching - def _supports_vision(self, model: str) -> bool: - """Return True when the provider supports vision/image inputs.""" - if self._gateway is not None: - return self._gateway.supports_vision - spec = find_by_model(model) - return spec is None or spec.supports_vision # default True for unknown providers - - @staticmethod - def _filter_image_url(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Replace image_url content blocks with [image] placeholder for non-vision models.""" - filtered = [] - for msg in messages: - content = msg.get("content") - if isinstance(content, list): - new_content = [] - for block in content: - if isinstance(block, dict) and block.get("type") == "image_url": - # Replace image with placeholder text - new_content.append({"type": "text", "text": "[image]"}) - else: - new_content.append(block) - filtered.append({**msg, "content": new_content}) - else: - filtered.append(msg) - return filtered - def _apply_cache_control( self, messages: list[dict[str, Any]], @@ -260,10 +234,6 @@ class LiteLLMProvider(LLMProvider): model = self._resolve_model(original_model) extra_msg_keys = self._extra_msg_keys(original_model, model) - # Filter image_url for non-vision models - if not self._supports_vision(original_model): - messages = self._filter_image_url(messages) - if self._supports_cache_control(original_model): messages, tools = self._apply_cache_control(messages, tools) diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index a45f14a..42c1d24 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -61,9 +61,6 @@ class ProviderSpec: # Provider supports cache_control on content blocks (e.g. Anthropic prompt caching) supports_prompt_caching: bool = False - # Provider supports vision/image inputs (most modern models do) - supports_vision: bool = True - @property def label(self) -> str: return self.display_name or self.name.title() diff --git a/tests/test_provider_retry.py b/tests/test_provider_retry.py index 2420399..6f2c165 100644 --- a/tests/test_provider_retry.py +++ b/tests/test_provider_retry.py @@ -123,3 +123,87 @@ async def test_chat_with_retry_explicit_override_beats_defaults() -> None: assert provider.last_kwargs["temperature"] == 0.9 assert provider.last_kwargs["max_tokens"] == 9999 assert provider.last_kwargs["reasoning_effort"] == "low" + + +# --------------------------------------------------------------------------- +# Image-unsupported fallback tests +# --------------------------------------------------------------------------- + +_IMAGE_MSG = [ + {"role": "user", "content": [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ]}, +] + + +@pytest.mark.asyncio +async def test_image_unsupported_error_retries_without_images() -> None: + """If the model rejects image_url, retry once with images stripped.""" + provider = ScriptedProvider([ + LLMResponse( + content="Invalid content type. image_url is only supported by certain models", + finish_reason="error", + ), + LLMResponse(content="ok, no image"), + ]) + + response = await provider.chat_with_retry(messages=_IMAGE_MSG) + + assert response.content == "ok, no image" + assert provider.calls == 2 + msgs_on_retry = provider.last_kwargs["messages"] + for msg in msgs_on_retry: + content = msg.get("content") + if isinstance(content, list): + assert all(b.get("type") != "image_url" for b in content) + assert any("[image omitted]" in (b.get("text") or "") for b in content) + + +@pytest.mark.asyncio +async def test_image_unsupported_error_no_retry_without_image_content() -> None: + """If messages don't contain image_url blocks, don't retry on image error.""" + provider = ScriptedProvider([ + LLMResponse( + content="image_url is only supported by certain models", + finish_reason="error", + ), + ]) + + response = await provider.chat_with_retry( + messages=[{"role": "user", "content": "hello"}], + ) + + assert provider.calls == 1 + assert response.finish_reason == "error" + + +@pytest.mark.asyncio +async def test_image_unsupported_fallback_returns_error_on_second_failure() -> None: + """If the image-stripped retry also fails, return that error.""" + provider = ScriptedProvider([ + LLMResponse( + content="does not support image input", + finish_reason="error", + ), + LLMResponse(content="some other error", finish_reason="error"), + ]) + + response = await provider.chat_with_retry(messages=_IMAGE_MSG) + + assert provider.calls == 2 + assert response.content == "some other error" + assert response.finish_reason == "error" + + +@pytest.mark.asyncio +async def test_non_image_error_does_not_trigger_image_fallback() -> None: + """Regular non-transient errors must not trigger image stripping.""" + provider = ScriptedProvider([ + LLMResponse(content="401 unauthorized", finish_reason="error"), + ]) + + response = await provider.chat_with_retry(messages=_IMAGE_MSG) + + assert provider.calls == 1 + assert response.content == "401 unauthorized"