Merge PR #1868: generation settings owned by provider, loop/memory/subagent agnostic
This commit is contained in:
@@ -52,9 +52,6 @@ class AgentLoop:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_iterations: int = 40,
|
max_iterations: int = 40,
|
||||||
temperature: float = 0.1,
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
context_window_tokens: int = 65_536,
|
context_window_tokens: int = 65_536,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
@@ -72,9 +69,6 @@ class AgentLoop:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
self.context_window_tokens = context_window_tokens
|
self.context_window_tokens = context_window_tokens
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
@@ -90,9 +84,6 @@ class AgentLoop:
|
|||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
reasoning_effort=reasoning_effort,
|
|
||||||
brave_api_key=brave_api_key,
|
brave_api_key=brave_api_key,
|
||||||
web_proxy=web_proxy,
|
web_proxy=web_proxy,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
@@ -202,9 +193,6 @@ class AgentLoop:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tool_defs,
|
tools=tool_defs,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
reasoning_effort=self.reasoning_effort,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
|
|||||||
return args[0] if args and isinstance(args[0], dict) else None
|
return args[0] if args and isinstance(args[0], dict) else None
|
||||||
return args if isinstance(args, dict) else None
|
return args if isinstance(args, dict) else None
|
||||||
|
|
||||||
|
|
||||||
class MemoryStore:
|
class MemoryStore:
|
||||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||||
|
|
||||||
|
|||||||
@@ -28,9 +28,6 @@ class SubagentManager:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
temperature: float = 0.7,
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
reasoning_effort: str | None = None,
|
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
@@ -41,9 +38,6 @@ class SubagentManager:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@@ -128,9 +122,6 @@ class SubagentManager:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools.get_definitions(),
|
tools=tools.get_definitions(),
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
reasoning_effort=self.reasoning_effort,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ def onboard():
|
|||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config."""
|
||||||
|
from nanobot.providers.base import GenerationSettings
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
@@ -224,31 +225,28 @@ def _make_provider(config: Config):
|
|||||||
|
|
||||||
# OpenAI Codex (OAuth)
|
# OpenAI Codex (OAuth)
|
||||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||||
return OpenAICodexProvider(default_model=model)
|
provider = OpenAICodexProvider(default_model=model)
|
||||||
|
|
||||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||||
|
elif provider_name == "custom":
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
from nanobot.providers.custom_provider import CustomProvider
|
||||||
if provider_name == "custom":
|
provider = CustomProvider(
|
||||||
return CustomProvider(
|
|
||||||
api_key=p.api_key if p else "no-key",
|
api_key=p.api_key if p else "no-key",
|
||||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||||
if provider_name == "azure_openai":
|
elif provider_name == "azure_openai":
|
||||||
if not p or not p.api_key or not p.api_base:
|
if not p or not p.api_key or not p.api_base:
|
||||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||||
console.print("Use the model field to specify the deployment name.")
|
console.print("Use the model field to specify the deployment name.")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
return AzureOpenAIProvider(
|
|
||||||
api_key=p.api_key,
|
api_key=p.api_key,
|
||||||
api_base=p.api_base,
|
api_base=p.api_base,
|
||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
spec = find_by_name(provider_name)
|
spec = find_by_name(provider_name)
|
||||||
@@ -256,8 +254,7 @@ def _make_provider(config: Config):
|
|||||||
console.print("[red]Error: No API key configured.[/red]")
|
console.print("[red]Error: No API key configured.[/red]")
|
||||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||||
raise typer.Exit(1)
|
raise typer.Exit(1)
|
||||||
|
provider = LiteLLMProvider(
|
||||||
return LiteLLMProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
api_key=p.api_key if p else None,
|
||||||
api_base=config.get_api_base(model),
|
api_base=config.get_api_base(model),
|
||||||
default_model=model,
|
default_model=model,
|
||||||
@@ -265,6 +262,14 @@ def _make_provider(config: Config):
|
|||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
defaults = config.agents.defaults
|
||||||
|
provider.generation = GenerationSettings(
|
||||||
|
temperature=defaults.temperature,
|
||||||
|
max_tokens=defaults.max_tokens,
|
||||||
|
reasoning_effort=defaults.reasoning_effort,
|
||||||
|
)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||||
"""Load config and optionally override the active workspace."""
|
"""Load config and optionally override the active workspace."""
|
||||||
@@ -341,10 +346,7 @@ def gateway(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
temperature=config.agents.defaults.temperature,
|
|
||||||
max_tokens=config.agents.defaults.max_tokens,
|
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
@@ -527,10 +529,7 @@ def agent(
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
temperature=config.agents.defaults.temperature,
|
|
||||||
max_tokens=config.agents.defaults.max_tokens,
|
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
|
|||||||
@@ -32,6 +32,21 @@ class LLMResponse:
|
|||||||
return len(self.tool_calls) > 0
|
return len(self.tool_calls) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GenerationSettings:
|
||||||
|
"""Default generation parameters for LLM calls.
|
||||||
|
|
||||||
|
Stored on the provider so every call site inherits the same defaults
|
||||||
|
without having to pass temperature / max_tokens / reasoning_effort
|
||||||
|
through every layer. Individual call sites can still override by
|
||||||
|
passing explicit keyword arguments to chat() / chat_with_retry().
|
||||||
|
"""
|
||||||
|
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 4096
|
||||||
|
reasoning_effort: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LLMProvider(ABC):
|
class LLMProvider(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for LLM providers.
|
Abstract base class for LLM providers.
|
||||||
@@ -56,9 +71,12 @@ class LLMProvider(ABC):
|
|||||||
"temporarily unavailable",
|
"temporarily unavailable",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
|
self.generation: GenerationSettings = GenerationSettings()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
@@ -155,11 +173,23 @@ class LLMProvider(ABC):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: object = _SENTINEL,
|
||||||
temperature: float = 0.7,
|
temperature: object = _SENTINEL,
|
||||||
reasoning_effort: str | None = None,
|
reasoning_effort: object = _SENTINEL,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat() with retry on transient provider failures."""
|
"""Call chat() with retry on transient provider failures.
|
||||||
|
|
||||||
|
Parameters default to ``self.generation`` when not explicitly passed,
|
||||||
|
so callers no longer need to thread temperature / max_tokens /
|
||||||
|
reasoning_effort through every layer.
|
||||||
|
"""
|
||||||
|
if max_tokens is self._SENTINEL:
|
||||||
|
max_tokens = self.generation.max_tokens
|
||||||
|
if temperature is self._SENTINEL:
|
||||||
|
temperature = self.generation.temperature
|
||||||
|
if reasoning_effort is self._SENTINEL:
|
||||||
|
reasoning_effort = self.generation.reasoning_effort
|
||||||
|
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||||
try:
|
try:
|
||||||
response = await self.chat(
|
response = await self.chat(
|
||||||
|
|||||||
@@ -265,3 +265,26 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
assert result is True
|
assert result is True
|
||||||
assert provider.calls == 2
|
assert provider.calls == 2
|
||||||
assert delays == [1]
|
assert delays == [1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None:
|
||||||
|
"""Consolidation no longer passes generation params — the provider owns them."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=_make_tool_response(
|
||||||
|
history_entry="[2026-01-01] User discussed testing.",
|
||||||
|
memory_update="# Memory\nUser likes testing.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
messages = _make_messages(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(messages, provider, "test-model")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
provider.chat_with_retry.assert_awaited_once()
|
||||||
|
_, kwargs = provider.chat_with_retry.await_args
|
||||||
|
assert kwargs["model"] == "test-model"
|
||||||
|
assert "temperature" not in kwargs
|
||||||
|
assert "max_tokens" not in kwargs
|
||||||
|
assert "reasoning_effort" not in kwargs
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
class ScriptedProvider(LLMProvider):
|
class ScriptedProvider(LLMProvider):
|
||||||
@@ -10,9 +10,11 @@ class ScriptedProvider(LLMProvider):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._responses = list(responses)
|
self._responses = list(responses)
|
||||||
self.calls = 0
|
self.calls = 0
|
||||||
|
self.last_kwargs: dict = {}
|
||||||
|
|
||||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
self.calls += 1
|
self.calls += 1
|
||||||
|
self.last_kwargs = kwargs
|
||||||
response = self._responses.pop(0)
|
response = self._responses.pop(0)
|
||||||
if isinstance(response, BaseException):
|
if isinstance(response, BaseException):
|
||||||
raise response
|
raise response
|
||||||
@@ -90,3 +92,34 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
|||||||
|
|
||||||
with pytest.raises(asyncio.CancelledError):
|
with pytest.raises(asyncio.CancelledError):
|
||||||
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||||
|
"""When callers omit generation params, provider.generation defaults are used."""
|
||||||
|
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||||
|
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||||
|
|
||||||
|
await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert provider.last_kwargs["temperature"] == 0.2
|
||||||
|
assert provider.last_kwargs["max_tokens"] == 321
|
||||||
|
assert provider.last_kwargs["reasoning_effort"] == "high"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_explicit_override_beats_defaults() -> None:
|
||||||
|
"""Explicit kwargs should override provider.generation defaults."""
|
||||||
|
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||||
|
provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high")
|
||||||
|
|
||||||
|
await provider.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
temperature=0.9,
|
||||||
|
max_tokens=9999,
|
||||||
|
reasoning_effort="low",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.last_kwargs["temperature"] == 0.9
|
||||||
|
assert provider.last_kwargs["max_tokens"] == 9999
|
||||||
|
assert provider.last_kwargs["reasoning_effort"] == "low"
|
||||||
|
|||||||
Reference in New Issue
Block a user