Merge PR #1868: generation settings owned by provider, loop/memory/subagent agnostic

This commit is contained in:
Re-bin
2026-03-11 09:47:04 +00:00
7 changed files with 119 additions and 56 deletions

View File

@@ -52,9 +52,6 @@ class AgentLoop:
workspace: Path, workspace: Path,
model: str | None = None, model: str | None = None,
max_iterations: int = 40, max_iterations: int = 40,
temperature: float = 0.1,
max_tokens: int = 4096,
reasoning_effort: str | None = None,
context_window_tokens: int = 65_536, context_window_tokens: int = 65_536,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
@@ -72,9 +69,6 @@ class AgentLoop:
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.temperature = temperature
self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.context_window_tokens = context_window_tokens self.context_window_tokens = context_window_tokens
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy self.web_proxy = web_proxy
@@ -90,9 +84,6 @@ class AgentLoop:
workspace=workspace, workspace=workspace,
bus=bus, bus=bus,
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=reasoning_effort,
brave_api_key=brave_api_key, brave_api_key=brave_api_key,
web_proxy=web_proxy, web_proxy=web_proxy,
exec_config=self.exec_config, exec_config=self.exec_config,
@@ -202,9 +193,6 @@ class AgentLoop:
messages=messages, messages=messages,
tools=tool_defs, tools=tool_defs,
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:

View File

@@ -57,7 +57,6 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
return args[0] if args and isinstance(args[0], dict) else None return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None return args if isinstance(args, dict) else None
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""

View File

@@ -28,9 +28,6 @@ class SubagentManager:
workspace: Path, workspace: Path,
bus: MessageBus, bus: MessageBus,
model: str | None = None, model: str | None = None,
temperature: float = 0.7,
max_tokens: int = 4096,
reasoning_effort: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
exec_config: "ExecToolConfig | None" = None, exec_config: "ExecToolConfig | None" = None,
@@ -41,9 +38,6 @@ class SubagentManager:
self.workspace = workspace self.workspace = workspace
self.bus = bus self.bus = bus
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.temperature = temperature
self.max_tokens = max_tokens
self.reasoning_effort = reasoning_effort
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@@ -128,9 +122,6 @@ class SubagentManager:
messages=messages, messages=messages,
tools=tools.get_definitions(), tools=tools.get_definitions(),
model=self.model, model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
reasoning_effort=self.reasoning_effort,
) )
if response.has_tool_calls: if response.has_tool_calls:

View File

@@ -215,6 +215,7 @@ def onboard():
def _make_provider(config: Config): def _make_provider(config: Config):
"""Create the appropriate LLM provider from config.""" """Create the appropriate LLM provider from config."""
from nanobot.providers.base import GenerationSettings
from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
@@ -224,46 +225,50 @@ def _make_provider(config: Config):
# OpenAI Codex (OAuth) # OpenAI Codex (OAuth)
if provider_name == "openai_codex" or model.startswith("openai-codex/"): if provider_name == "openai_codex" or model.startswith("openai-codex/"):
return OpenAICodexProvider(default_model=model) provider = OpenAICodexProvider(default_model=model)
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
from nanobot.providers.custom_provider import CustomProvider elif provider_name == "custom":
if provider_name == "custom": from nanobot.providers.custom_provider import CustomProvider
return CustomProvider( provider = CustomProvider(
api_key=p.api_key if p else "no-key", api_key=p.api_key if p else "no-key",
api_base=config.get_api_base(model) or "http://localhost:8000/v1", api_base=config.get_api_base(model) or "http://localhost:8000/v1",
default_model=model, default_model=model,
) )
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name # Azure OpenAI: direct Azure OpenAI endpoint with deployment name
if provider_name == "azure_openai": elif provider_name == "azure_openai":
if not p or not p.api_key or not p.api_base: if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.") console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1) raise typer.Exit(1)
provider = AzureOpenAIProvider(
return AzureOpenAIProvider(
api_key=p.api_key, api_key=p.api_key,
api_base=p.api_base, api_base=p.api_base,
default_model=model, default_model=model,
) )
else:
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.registry import find_by_name
spec = find_by_name(provider_name)
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
console.print("[red]Error: No API key configured.[/red]")
console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1)
provider = LiteLLMProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
provider_name=provider_name,
)
from nanobot.providers.litellm_provider import LiteLLMProvider defaults = config.agents.defaults
from nanobot.providers.registry import find_by_name provider.generation = GenerationSettings(
spec = find_by_name(provider_name) temperature=defaults.temperature,
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)): max_tokens=defaults.max_tokens,
console.print("[red]Error: No API key configured.[/red]") reasoning_effort=defaults.reasoning_effort,
console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1)
return LiteLLMProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
provider_name=provider_name,
) )
return provider
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config: def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
@@ -341,10 +346,7 @@ def gateway(
provider=provider, provider=provider,
workspace=config.workspace_path, workspace=config.workspace_path,
model=config.agents.defaults.model, model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
reasoning_effort=config.agents.defaults.reasoning_effort,
context_window_tokens=config.agents.defaults.context_window_tokens, context_window_tokens=config.agents.defaults.context_window_tokens,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
@@ -527,10 +529,7 @@ def agent(
provider=provider, provider=provider,
workspace=config.workspace_path, workspace=config.workspace_path,
model=config.agents.defaults.model, model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
reasoning_effort=config.agents.defaults.reasoning_effort,
context_window_tokens=config.agents.defaults.context_window_tokens, context_window_tokens=config.agents.defaults.context_window_tokens,
brave_api_key=config.tools.web.search.api_key or None, brave_api_key=config.tools.web.search.api_key or None,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,

View File

@@ -32,6 +32,21 @@ class LLMResponse:
return len(self.tool_calls) > 0 return len(self.tool_calls) > 0
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation parameters for LLM calls.
Stored on the provider so every call site inherits the same defaults
without having to pass temperature / max_tokens / reasoning_effort
through every layer. Individual call sites can still override by
passing explicit keyword arguments to chat() / chat_with_retry().
"""
temperature: float = 0.7
max_tokens: int = 4096
reasoning_effort: str | None = None
class LLMProvider(ABC): class LLMProvider(ABC):
""" """
Abstract base class for LLM providers. Abstract base class for LLM providers.
@@ -56,9 +71,12 @@ class LLMProvider(ABC):
"temporarily unavailable", "temporarily unavailable",
) )
_SENTINEL = object()
def __init__(self, api_key: str | None = None, api_base: str | None = None): def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
self.generation: GenerationSettings = GenerationSettings()
@staticmethod @staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -155,11 +173,23 @@ class LLMProvider(ABC):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
model: str | None = None, model: str | None = None,
max_tokens: int = 4096, max_tokens: object = _SENTINEL,
temperature: float = 0.7, temperature: object = _SENTINEL,
reasoning_effort: str | None = None, reasoning_effort: object = _SENTINEL,
) -> LLMResponse: ) -> LLMResponse:
"""Call chat() with retry on transient provider failures.""" """Call chat() with retry on transient provider failures.
Parameters default to ``self.generation`` when not explicitly passed,
so callers no longer need to thread temperature / max_tokens /
reasoning_effort through every layer.
"""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
try: try:
response = await self.chat( response = await self.chat(

View File

@@ -265,3 +265,26 @@ class TestMemoryConsolidationTypeHandling:
assert result is True assert result is True
assert provider.calls == 2 assert provider.calls == 2
assert delays == [1] assert delays == [1]
@pytest.mark.asyncio
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
"""Consolidation no longer passes generation params — the provider owns them."""
store = MemoryStore(tmp_path)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(
return_value=_make_tool_response(
history_entry="[2026-01-01] User discussed testing.",
memory_update="# Memory\nUser likes testing.",
)
)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
provider.chat_with_retry.assert_awaited_once()
_, kwargs = provider.chat_with_retry.await_args
assert kwargs["model"] == "test-model"
assert "temperature" not in kwargs
assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs

View File

@@ -2,7 +2,7 @@ import asyncio
import pytest import pytest
from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
class ScriptedProvider(LLMProvider): class ScriptedProvider(LLMProvider):
@@ -10,9 +10,11 @@ class ScriptedProvider(LLMProvider):
super().__init__() super().__init__()
self._responses = list(responses) self._responses = list(responses)
self.calls = 0 self.calls = 0
self.last_kwargs: dict = {}
async def chat(self, *args, **kwargs) -> LLMResponse: async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1 self.calls += 1
self.last_kwargs = kwargs
response = self._responses.pop(0) response = self._responses.pop(0)
if isinstance(response, BaseException): if isinstance(response, BaseException):
raise response raise response
@@ -90,3 +92,34 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
@pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert provider.last_kwargs["temperature"] == 0.2
assert provider.last_kwargs["max_tokens"] == 321
assert provider.last_kwargs["reasoning_effort"] == "high"
@pytest.mark.asyncio
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
"""Explicit kwargs should override provider.generation defaults."""
provider = ScriptedProvider([LLMResponse(content="ok")])
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
temperature=0.9,
max_tokens=9999,
reasoning_effort="low",
)
assert provider.last_kwargs["temperature"] == 0.9
assert provider.last_kwargs["max_tokens"] == 9999
assert provider.last_kwargs["reasoning_effort"] == "low"