Merge PR #1512: share transient LLM retry across agent paths

This commit is contained in:
Re-bin
2026-03-10 10:10:40 +00:00
8 changed files with 274 additions and 7 deletions

View File

@@ -191,7 +191,7 @@ class AgentLoop:
while iteration < self.max_iterations: while iteration < self.max_iterations:
iteration += 1 iteration += 1
response = await self.provider.chat( response = await self.provider.chat_with_retry(
messages=messages, messages=messages,
tools=self.tools.get_definitions(), tools=self.tools.get_definitions(),
model=self.model, model=self.model,

View File

@@ -111,7 +111,7 @@ class MemoryStore:
{chr(10).join(lines)}""" {chr(10).join(lines)}"""
try: try:
response = await provider.chat( response = await provider.chat_with_retry(
messages=[ messages=[
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
{"role": "user", "content": prompt}, {"role": "user", "content": prompt},

View File

@@ -123,7 +123,7 @@ class SubagentManager:
while iteration < max_iterations: while iteration < max_iterations:
iteration += 1 iteration += 1
response = await self.provider.chat( response = await self.provider.chat_with_retry(
messages=messages, messages=messages,
tools=tools.get_definitions(), tools=tools.get_definitions(),
model=self.model, model=self.model,

View File

@@ -87,7 +87,7 @@ class HeartbeatService:
Returns (action, tasks) where action is 'skip' or 'run'. Returns (action, tasks) where action is 'skip' or 'run'.
""" """
response = await self.provider.chat( response = await self.provider.chat_with_retry(
messages=[ messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": ( {"role": "user", "content": (

View File

@@ -1,9 +1,12 @@
"""Base LLM provider interface.""" """Base LLM provider interface."""
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from loguru import logger
@dataclass @dataclass
class ToolCallRequest: class ToolCallRequest:
@@ -37,6 +40,22 @@ class LLMProvider(ABC):
while maintaining a consistent interface. while maintaining a consistent interface.
""" """
_CHAT_RETRY_DELAYS = (1, 2, 4)
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
"500",
"502",
"503",
"504",
"overloaded",
"timeout",
"timed out",
"connection",
"server error",
"temporarily unavailable",
)
def __init__(self, api_key: str | None = None, api_base: str | None = None): def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
@@ -126,6 +145,71 @@ class LLMProvider(ABC):
""" """
pass pass
@classmethod
def _is_transient_error(cls, content: str | None) -> bool:
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures."""
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
try:
response = await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
except asyncio.CancelledError:
raise
except Exception as exc:
response = LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
return response
err = (response.content or "").lower()
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt,
len(self._CHAT_RETRY_DELAYS),
delay,
err[:120],
)
await asyncio.sleep(delay)
try:
return await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(
content=f"Error calling LLM: {exc}",
finish_reason="error",
)
@abstractmethod @abstractmethod
def get_default_model(self) -> str: def get_default_model(self) -> str:
"""Get the default model for this provider.""" """Get the default model for this provider."""

View File

@@ -3,18 +3,24 @@ import asyncio
import pytest import pytest
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
class DummyProvider: class DummyProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]): def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses) self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse: async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses: if self._responses:
return self._responses.pop(0) return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[]) return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_is_idempotent(tmp_path) -> None: async def test_start_is_idempotent(tmp_path) -> None:
@@ -115,3 +121,40 @@ async def test_trigger_now_returns_none_when_decision_is_skip(tmp_path) -> None:
) )
assert await service.trigger_now() is None assert await service.trigger_now() is None
@pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check open tasks"},
)
],
),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr(asyncio, "sleep", _fake_sleep)
service = HeartbeatService(
workspace=tmp_path,
provider=provider,
model="openai/gpt-4o-mini",
)
action, tasks = await service._decide("heartbeat content")
assert action == "run"
assert tasks == "check open tasks"
assert provider.calls == 2
assert delays == [1]

View File

@@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
def _make_session(message_count: int = 30, memory_window: int = 50): def _make_session(message_count: int = 30, memory_window: int = 50):
@@ -43,6 +43,22 @@ def _make_tool_response(history_entry, memory_update):
) )
class ScriptedProvider(LLMProvider):
def __init__(self, responses: list[LLMResponse]):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
def get_default_model(self) -> str:
return "test-model"
class TestMemoryConsolidationTypeHandling: class TestMemoryConsolidationTypeHandling:
"""Test that consolidation handles various argument types correctly.""" """Test that consolidation handles various argument types correctly."""
@@ -57,6 +73,7 @@ class TestMemoryConsolidationTypeHandling:
memory_update="# Memory\nUser likes testing.", memory_update="# Memory\nUser likes testing.",
) )
) )
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -77,6 +94,7 @@ class TestMemoryConsolidationTypeHandling:
memory_update={"facts": ["User likes testing"], "topics": ["testing"]}, memory_update={"facts": ["User likes testing"], "topics": ["testing"]},
) )
) )
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -112,6 +130,7 @@ class TestMemoryConsolidationTypeHandling:
], ],
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -127,6 +146,7 @@ class TestMemoryConsolidationTypeHandling:
provider.chat = AsyncMock( provider.chat = AsyncMock(
return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[]) return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[])
) )
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -139,6 +159,7 @@ class TestMemoryConsolidationTypeHandling:
"""Consolidation should be a no-op when messages < keep_count.""" """Consolidation should be a no-op when messages < keep_count."""
store = MemoryStore(tmp_path) store = MemoryStore(tmp_path)
provider = AsyncMock() provider = AsyncMock()
provider.chat_with_retry = provider.chat
session = _make_session(message_count=10) session = _make_session(message_count=10)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -167,6 +188,7 @@ class TestMemoryConsolidationTypeHandling:
], ],
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -192,6 +214,7 @@ class TestMemoryConsolidationTypeHandling:
], ],
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
@@ -215,8 +238,33 @@ class TestMemoryConsolidationTypeHandling:
], ],
) )
provider.chat = AsyncMock(return_value=response) provider.chat = AsyncMock(return_value=response)
provider.chat_with_retry = provider.chat
session = _make_session(message_count=60) session = _make_session(message_count=60)
result = await store.consolidate(session, provider, "test-model", memory_window=50) result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is False assert result is False
@pytest.mark.asyncio
async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None:
store = MemoryStore(tmp_path)
provider = ScriptedProvider([
LLMResponse(content="503 server error", finish_reason="error"),
_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
),
])
session = _make_session(message_count=60)
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
result = await store.consolidate(session, provider, "test-model", memory_window=50)
assert result is True
assert provider.calls == 2
assert delays == [1]

View File

@@ -0,0 +1,92 @@
import asyncio
import pytest
from nanobot.providers.base import LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider):
def __init__(self, responses):
super().__init__()
self._responses = list(responses)
self.calls = 0
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
def get_default_model(self) -> str:
return "test-model"
@pytest.mark.asyncio
async def test_chat_with_retry_retries_transient_error_then_succeeds(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "stop"
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_does_not_retry_non_transient_error(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="401 unauthorized", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "401 unauthorized"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit a", finish_reason="error"),
LLMResponse(content="429 rate limit b", finish_reason="error"),
LLMResponse(content="429 rate limit c", finish_reason="error"),
LLMResponse(content="503 final server error", finish_reason="error"),
])
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "503 final server error"
assert provider.calls == 4
assert delays == [1, 2, 4]
@pytest.mark.asyncio
async def test_chat_with_retry_preserves_cancelled_error() -> None:
provider = ScriptedProvider([asyncio.CancelledError()])
with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])