From b0a5435b8720a5968e683ce5aa82a8b16e614452 Mon Sep 17 00:00:00 2001 From: Re-bin Date: Tue, 10 Mar 2026 10:10:37 +0000 Subject: [PATCH] refactor(llm): share transient retry across agent paths --- nanobot/agent/loop.py | 29 +------- nanobot/agent/memory.py | 2 +- nanobot/agent/subagent.py | 2 +- nanobot/heartbeat/service.py | 2 +- nanobot/providers/base.py | 84 ++++++++++++++++++++++ tests/test_heartbeat_service.py | 47 +++++++++++- tests/test_memory_consolidation_types.py | 50 ++++++++++++- tests/test_provider_retry.py | 92 ++++++++++++++++++++++++ 8 files changed, 274 insertions(+), 34 deletions(-) create mode 100644 tests/test_provider_retry.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b67baae..fcbc880 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -159,33 +159,6 @@ class AgentLoop: if hasattr(tool, "set_context"): tool.set_context(channel, chat_id, *([message_id] if name == "message" else [])) - _RETRY_DELAYS = (1, 2, 4) # seconds — exponential backoff for transient LLM errors - - async def _chat_with_retry(self, **kwargs: Any) -> Any: - """Call provider.chat() with retry on transient errors (429, 5xx, network).""" - from nanobot.providers.base import LLMResponse - - last_response: LLMResponse | None = None - for attempt, delay in enumerate(self._RETRY_DELAYS): - response = await self.provider.chat(**kwargs) - if response.finish_reason != "error": - return response - # Check if the error looks transient (rate limit, server error, network) - err = (response.content or "").lower() - is_transient = any(kw in err for kw in ( - "429", "rate limit", "500", "502", "503", "504", - "overloaded", "timeout", "connection", "server error", - )) - if not is_transient: - return response # permanent error (400, 401, etc.) — don't retry - last_response = response - logger.warning("LLM transient error (attempt {}/{}), retrying in {}s: {}", - attempt + 1, len(self._RETRY_DELAYS), delay, err[:120]) - await asyncio.sleep(delay) - # All retries exhausted — make one final attempt - response = await self.provider.chat(**kwargs) - return response if response.finish_reason != "error" else (last_response or response) - @staticmethod def _strip_think(text: str | None) -> str | None: """Remove blocks that some models embed in content.""" @@ -218,7 +191,7 @@ class AgentLoop: while iteration < self.max_iterations: iteration += 1 - response = await self._chat_with_retry( + response = await self.provider.chat_with_retry( messages=messages, tools=self.tools.get_definitions(), model=self.model, diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 21fe77d..66efec2 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -111,7 +111,7 @@ class MemoryStore: {chr(10).join(lines)}""" try: - response = await provider.chat( + response = await provider.chat_with_retry( messages=[ {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, {"role": "user", "content": prompt}, diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index f2d6ee5..f9eda1f 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -123,7 +123,7 @@ class SubagentManager: while iteration < max_iterations: iteration += 1 - response = await self.provider.chat( + response = await self.provider.chat_with_retry( messages=messages, tools=tools.get_definitions(), model=self.model, diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py index e534017..831ae85 100644 --- a/nanobot/heartbeat/service.py +++ b/nanobot/heartbeat/service.py @@ -87,7 +87,7 @@ class HeartbeatService: Returns (action, tasks) where action is 'skip' or 'run'. """ - response = await self.provider.chat( + response = await self.provider.chat_with_retry( messages=[ {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "user", "content": ( diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 0f73544..a3b6c47 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -1,9 +1,12 @@ """Base LLM provider interface.""" +import asyncio from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any +from loguru import logger + @dataclass class ToolCallRequest: @@ -37,6 +40,22 @@ class LLMProvider(ABC): while maintaining a consistent interface. """ + _CHAT_RETRY_DELAYS = (1, 2, 4) + _TRANSIENT_ERROR_MARKERS = ( + "429", + "rate limit", + "500", + "502", + "503", + "504", + "overloaded", + "timeout", + "timed out", + "connection", + "server error", + "temporarily unavailable", + ) + def __init__(self, api_key: str | None = None, api_base: str | None = None): self.api_key = api_key self.api_base = api_base @@ -126,6 +145,71 @@ class LLMProvider(ABC): """ pass + @classmethod + def _is_transient_error(cls, content: str | None) -> bool: + err = (content or "").lower() + return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) + + async def chat_with_retry( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + ) -> LLMResponse: + """Call chat() with retry on transient provider failures.""" + for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): + try: + response = await self.chat( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) + except asyncio.CancelledError: + raise + except Exception as exc: + response = LLMResponse( + content=f"Error calling LLM: {exc}", + finish_reason="error", + ) + + if response.finish_reason != "error": + return response + if not self._is_transient_error(response.content): + return response + + err = (response.content or "").lower() + logger.warning( + "LLM transient error (attempt {}/{}), retrying in {}s: {}", + attempt, + len(self._CHAT_RETRY_DELAYS), + delay, + err[:120], + ) + await asyncio.sleep(delay) + + try: + return await self.chat( + messages=messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + reasoning_effort=reasoning_effort, + ) + except asyncio.CancelledError: + raise + except Exception as exc: + return LLMResponse( + content=f"Error calling LLM: {exc}", + finish_reason="error", + ) + @abstractmethod def get_default_model(self) -> str: """Get the default model for this provider.""" diff --git a/tests/test_heartbeat_service.py b/tests/test_heartbeat_service.py index c5478af..9ce8912 100644 --- a/tests/test_heartbeat_service.py +++ b/tests/test_heartbeat_service.py @@ -3,18 +3,24 @@ import asyncio import pytest from nanobot.heartbeat.service import HeartbeatService -from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -class DummyProvider: +class DummyProvider(LLMProvider): def __init__(self, responses: list[LLMResponse]): + super().__init__() self._responses = list(responses) + self.calls = 0 async def chat(self, *args, **kwargs) -> LLMResponse: + self.calls += 1 if self._responses: return self._responses.pop(0) return LLMResponse(content="", tool_calls=[]) + def get_default_model(self) -> str: + return "test-model" + @pytest.mark.asyncio async def test_start_is_idempotent(tmp_path) -> None: @@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None: ) assert await service.trigger_now() is None + + +@pytest.mark.asyncio +async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None: + provider = DummyProvider([ + LLMResponse(content="429 rate limit", finish_reason="error"), + LLMResponse( + content="", + tool_calls=[ + ToolCallRequest( + id="hb_1", + name="heartbeat", + arguments={"action": "run", "tasks": "check open tasks"}, + ) + ], + ), + ]) + + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr(asyncio, "sleep", _fake_sleep) + + service = HeartbeatService( + workspace=tmp_path, + provider=provider, + model="openai/gpt-4o-mini", + ) + + action, tasks = await service._decide("heartbeat content") + + assert action == "run" + assert tasks == "check open tasks" + assert provider.calls == 2 + assert delays == [1] diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py index ff15584..2605bf7 100644 --- a/tests/test_memory_consolidation_types.py +++ b/tests/test_memory_consolidation_types.py @@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest from nanobot.agent.memory import MemoryStore -from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest def _make_session(message_count: int = 30, memory_window: int = 50): @@ -43,6 +43,22 @@ def _make_tool_response(history_entry, memory_update): ) +class ScriptedProvider(LLMProvider): + def __init__(self, responses: list[LLMResponse]): + super().__init__() + self._responses = list(responses) + self.calls = 0 + + async def chat(self, *args, **kwargs) -> LLMResponse: + self.calls += 1 + if self._responses: + return self._responses.pop(0) + return LLMResponse(content="", tool_calls=[]) + + def get_default_model(self) -> str: + return "test-model" + + class TestMemoryConsolidationTypeHandling: """Test that consolidation handles various argument types correctly.""" @@ -57,6 +73,7 @@ class TestMemoryConsolidationTypeHandling: memory_update="# Memory\nUser likes testing.", ) ) + provider.chat_with_retry = provider.chat session = _make_session(message_count=60) result = await store.consolidate(session, provider, "test-model", memory_window=50) @@ -77,6 +94,7 @@ class TestMemoryConsolidationTypeHandling: memory_update={"facts": ["User likes testing"], "topics": ["testing"]}, ) ) + provider.chat_with_retry = provider.chat session = _make_session(message_count=60) result = await store.consolidate(session, provider, "test-model", memory_window=50) @@ -112,6 +130,7 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) + provider.chat_with_retry = provider.chat session = _make_session(message_count=60) result = await store.consolidate(session, provider, "test-model", memory_window=50) @@ -127,6 +146,7 @@ class TestMemoryConsolidationTypeHandling: provider.chat = AsyncMock( return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[]) ) + provider.chat_with_retry = provider.chat session = _make_session(message_count=60) result = await store.consolidate(session, provider, "test-model", memory_window=50) @@ -139,6 +159,7 @@ class TestMemoryConsolidationTypeHandling: """Consolidation should be a no-op when messages < keep_count.""" store = MemoryStore(tmp_path) provider = AsyncMock() + provider.chat_with_retry = provider.chat session = _make_session(message_count=10) result = await store.consolidate(session, provider, "test-model", memory_window=50) @@ -167,6 +188,7 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) + provider.chat_with_retry = provider.chat session = _make_session(message_count=60) result = await store.consolidate(session, provider, "test-model", memory_window=50) @@ -192,6 +214,7 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) + provider.chat_with_retry = provider.chat session = _make_session(message_count=60) result = await store.consolidate(session, provider, "test-model", memory_window=50) @@ -215,8 +238,33 @@ class TestMemoryConsolidationTypeHandling: ], ) provider.chat = AsyncMock(return_value=response) + provider.chat_with_retry = provider.chat session = _make_session(message_count=60) result = await store.consolidate(session, provider, "test-model", memory_window=50) assert result is False + + @pytest.mark.asyncio + async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None: + store = MemoryStore(tmp_path) + provider = ScriptedProvider([ + LLMResponse(content="503 server error", finish_reason="error"), + _make_tool_response( + history_entry="[2026-01-01] User discussed testing.", + memory_update="# Memory\nUser likes testing.", + ), + ]) + session = _make_session(message_count=60) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + result = await store.consolidate(session, provider, "test-model", memory_window=50) + + assert result is True + assert provider.calls == 2 + assert delays == [1] diff --git a/tests/test_provider_retry.py b/tests/test_provider_retry.py new file mode 100644 index 0000000..751ecc3 --- /dev/null +++ b/tests/test_provider_retry.py @@ -0,0 +1,92 @@ +import asyncio + +import pytest + +from nanobot.providers.base import LLMProvider, LLMResponse + + +class ScriptedProvider(LLMProvider): + def __init__(self, responses): + super().__init__() + self._responses = list(responses) + self.calls = 0 + + async def chat(self, *args, **kwargs) -> LLMResponse: + self.calls += 1 + response = self._responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + def get_default_model(self) -> str: + return "test-model" + + +@pytest.mark.asyncio +async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="429 rate limit", finish_reason="error"), + LLMResponse(content="ok"), + ]) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.finish_reason == "stop" + assert response.content == "ok" + assert provider.calls == 2 + assert delays == [1] + + +@pytest.mark.asyncio +async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="401 unauthorized", finish_reason="error"), + ]) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "401 unauthorized" + assert provider.calls == 1 + assert delays == [] + + +@pytest.mark.asyncio +async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="429 rate limit a", finish_reason="error"), + LLMResponse(content="429 rate limit b", finish_reason="error"), + LLMResponse(content="429 rate limit c", finish_reason="error"), + LLMResponse(content="503 final server error", finish_reason="error"), + ]) + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "503 final server error" + assert provider.calls == 4 + assert delays == [1, 2, 4] + + +@pytest.mark.asyncio +async def test_chat_with_retry_preserves_cancelled_error() -> None: + provider = ScriptedProvider([asyncio.CancelledError()]) + + with pytest.raises(asyncio.CancelledError): + await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])