refactor(llm): share transient retry across agent paths

This commit is contained in:
Re-bin
2026-03-10 10:10:37 +00:00
parent 46b31ce7e7
commit b0a5435b87
8 changed files with 274 additions and 34 deletions

View File

@@ -159,33 +159,6 @@ class AgentLoop:
if hasattr(tool, "set_context"):
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
_RETRY_DELAYS = (1, 2, 4) # seconds — exponential backoff for transient LLM errors
async def _chat_with_retry(self, **kwargs: Any) -> Any:
"""Call provider.chat() with retry on transient errors (429, 5xx, network)."""
from nanobot.providers.base import LLMResponse
last_response: LLMResponse | None = None
for attempt, delay in enumerate(self._RETRY_DELAYS):
response = await self.provider.chat(**kwargs)
if response.finish_reason != "error":
return response
# Check if the error looks transient (rate limit, server error, network)
err = (response.content or "").lower()
is_transient = any(kw in err for kw in (
"429", "rate limit", "500", "502", "503", "504",
"overloaded", "timeout", "connection", "server error",
))
if not is_transient:
return response # permanent error (400, 401, etc.) — don't retry
last_response = response
logger.warning("LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt + 1, len(self._RETRY_DELAYS), delay, err[:120])
await asyncio.sleep(delay)
# All retries exhausted — make one final attempt
response = await self.provider.chat(**kwargs)
return response if response.finish_reason != "error" else (last_response or response)
@staticmethod
def _strip_think(text: str | None) -> str | None:
"""Remove <think>…</think> blocks that some models embed in content."""
@@ -218,7 +191,7 @@ class AgentLoop:
while iteration < self.max_iterations:
iteration += 1
response = await self._chat_with_retry(
response = await self.provider.chat_with_retry(
messages=messages,
tools=self.tools.get_definitions(),
model=self.model,

View File

@@ -111,7 +111,7 @@ class MemoryStore:
{chr(10).join(lines)}"""
try:
response = await provider.chat(
response = await provider.chat_with_retry(
messages=[
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt},

View File

@@ -123,7 +123,7 @@ class SubagentManager:
while iteration < max_iterations:
iteration += 1
response = await self.provider.chat(
response = await self.provider.chat_with_retry(
messages=messages,
tools=tools.get_definitions(),
model=self.model,

View File

@@ -87,7 +87,7 @@ class HeartbeatService:
Returns (action, tasks) where action is 'skip' or 'run'.
"""
response = await self.provider.chat(
response = await self.provider.chat_with_retry(
messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": (

View File

@@ -1,9 +1,12 @@
"""Base LLM provider interface."""
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from loguru import logger
@dataclass
class ToolCallRequest:
@@ -37,6 +40,22 @@ class LLMProvider(ABC):
while maintaining a consistent interface.
"""
_CHAT_RETRY_DELAYS = (1, 2, 4)
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
"500",
"502",
"503",
"504",
"overloaded",
"timeout",
"timed out",
"connection",
"server error",
"temporarily unavailable",
)
def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key
self.api_base = api_base
@@ -126,6 +145,71 @@ class LLMProvider(ABC):
"""
pass
@classmethod
def _is_transient_error(cls, content: str | None) -> bool:
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures."""
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
try:
response = await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
except asyncio.CancelledError:
raise
except Exception as exc:
response = LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
return response
err = (response.content or "").lower()
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt,
len(self._CHAT_RETRY_DELAYS),
delay,
err[:120],
)
await asyncio.sleep(delay)
try:
return await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
@abstractmethod
def get_default_model(self) -> str:
"""Get the default model for this provider."""

View File

@@ -3,18 +3,24 @@ import asyncio
import pytest
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider:
class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None:
@@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
)
assert await service.trigger_now() is None
@pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "run"
assert tasks == "check open tasks"
assert provider.calls == 2
assert delays == [1]

View File

@@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50):
@@ -43,6 +43,22 @@ def _make_tool_response(history_entry, memory_update):
)
class ScriptedProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly."""
@@ -57,6 +73,7 @@ class TestMemoryConsolidationTypeHandling:
memory_update="# Memory\nUser likes testing.",
)
)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -77,6 +94,7 @@ class TestMemoryConsolidationTypeHandling:
memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
)
)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -112,6 +130,7 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -127,6 +146,7 @@ class TestMemoryConsolidationTypeHandling:
provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -139,6 +159,7 @@ class TestMemoryConsolidationTypeHandling:
"""Consolidation should be a no-op when messages < keep_count."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = provider.chat
session = _make_session(message_count=10)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -167,6 +188,7 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -192,6 +214,7 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -215,8 +238,33 @@ class TestMemoryConsolidationTypeHandling:
],
)
provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False
@pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path)
provider = ScriptedProvider([
LLMResponse(content="503 server error", finish_reason="error"),
_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
),
])
session = _make_session(message_count=60)
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert provider.calls == 2
assert delays == [1]

View File

@@ -0,0 +1,92 @@
import asyncio
import pytest
from nanobot.providers.base import LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider):
def __init__(self, responses):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "stop"
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "401 unauthorized"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit a", finish_reason="error"),
LLMResponse(content="429 rate limit b", finish_reason="error"),
LLMResponse(content="429 rate limit c", finish_reason="error"),
LLMResponse(content="503 final server error", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "503 final server error"
assert provider.calls == 4
assert delays == [1, 2, 4]
@pytest.mark.asyncio
async def test_chat_with_retry_preserves_cancelled_error() -> None:
provider = ScriptedProvider([asyncio.CancelledError()])
with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])