diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 8605a09..b1bfd2f 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -52,9 +52,6 @@ class AgentLoop: workspace: Path, model: str | None = None, max_iterations: int = 40, - temperature: float = 0.1, - max_tokens: int = 4096, - reasoning_effort: str | None = None, context_window_tokens: int = 65_536, brave_api_key: str | None = None, web_proxy: str | None = None, @@ -72,9 +69,6 @@ class AgentLoop: self.workspace = workspace self.model = model or provider.get_default_model() self.max_iterations = max_iterations - self.temperature = temperature - self.max_tokens = max_tokens - self.reasoning_effort = reasoning_effort self.context_window_tokens = context_window_tokens self.brave_api_key = brave_api_key self.web_proxy = web_proxy @@ -90,9 +84,6 @@ class AgentLoop: workspace=workspace, bus=bus, model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - reasoning_effort=reasoning_effort, brave_api_key=brave_api_key, web_proxy=web_proxy, exec_config=self.exec_config, @@ -202,9 +193,6 @@ class AgentLoop: messages=messages, tools=tool_defs, model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - reasoning_effort=self.reasoning_effort, ) if response.has_tool_calls: diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index cd5f54f..59ba40e 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -57,7 +57,6 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None: return args[0] if args and isinstance(args[0], dict) else None return args if isinstance(args, dict) else None - class MemoryStore: """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index eff0b4f..21b8b32 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -28,9 +28,6 @@ class SubagentManager: workspace: Path, bus: MessageBus, model: str | None = None, - temperature: float = 0.7, - max_tokens: int = 4096, - reasoning_effort: str | None = None, brave_api_key: str | None = None, web_proxy: str | None = None, exec_config: "ExecToolConfig | None" = None, @@ -41,9 +38,6 @@ class SubagentManager: self.workspace = workspace self.bus = bus self.model = model or provider.get_default_model() - self.temperature = temperature - self.max_tokens = max_tokens - self.reasoning_effort = reasoning_effort self.brave_api_key = brave_api_key self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -128,9 +122,6 @@ class SubagentManager: messages=messages, tools=tools.get_definitions(), model=self.model, - temperature=self.temperature, - max_tokens=self.max_tokens, - reasoning_effort=self.reasoning_effort, ) if response.has_tool_calls: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 8387b28..f5ac859 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -215,6 +215,7 @@ def onboard(): def _make_provider(config: Config): """Create the appropriate LLM provider from config.""" + from nanobot.providers.base import GenerationSettings from nanobot.providers.openai_codex_provider import OpenAICodexProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider @@ -224,46 +225,50 @@ def _make_provider(config: Config): # OpenAI Codex (OAuth) if provider_name == "openai_codex" or model.startswith("openai-codex/"): - return OpenAICodexProvider(default_model=model) - + provider = OpenAICodexProvider(default_model=model) # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM - from nanobot.providers.custom_provider import CustomProvider - if provider_name == "custom": - return CustomProvider( + elif provider_name == "custom": + from nanobot.providers.custom_provider import CustomProvider + provider = CustomProvider( api_key=p.api_key if p else "no-key", api_base=config.get_api_base(model) or "http://localhost:8000/v1", default_model=model, ) - # Azure OpenAI: direct Azure OpenAI endpoint with deployment name - if provider_name == "azure_openai": + elif provider_name == "azure_openai": if not p or not p.api_key or not p.api_base: console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") console.print("Use the model field to specify the deployment name.") raise typer.Exit(1) - - return AzureOpenAIProvider( + provider = AzureOpenAIProvider( api_key=p.api_key, api_base=p.api_base, default_model=model, ) + else: + from nanobot.providers.litellm_provider import LiteLLMProvider + from nanobot.providers.registry import find_by_name + spec = find_by_name(provider_name) + if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)): + console.print("[red]Error: No API key configured.[/red]") + console.print("Set one in ~/.nanobot/config.json under providers section") + raise typer.Exit(1) + provider = LiteLLMProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), + default_model=model, + extra_headers=p.extra_headers if p else None, + provider_name=provider_name, + ) - from nanobot.providers.litellm_provider import LiteLLMProvider - from nanobot.providers.registry import find_by_name - spec = find_by_name(provider_name) - if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)): - console.print("[red]Error: No API key configured.[/red]") - console.print("Set one in ~/.nanobot/config.json under providers section") - raise typer.Exit(1) - - return LiteLLMProvider( - api_key=p.api_key if p else None, - api_base=config.get_api_base(model), - default_model=model, - extra_headers=p.extra_headers if p else None, - provider_name=provider_name, + defaults = config.agents.defaults + provider.generation = GenerationSettings( + temperature=defaults.temperature, + max_tokens=defaults.max_tokens, + reasoning_effort=defaults.reasoning_effort, ) + return provider def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config: @@ -341,10 +346,7 @@ def gateway( provider=provider, workspace=config.workspace_path, model=config.agents.defaults.model, - temperature=config.agents.defaults.temperature, - max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, - reasoning_effort=config.agents.defaults.reasoning_effort, context_window_tokens=config.agents.defaults.context_window_tokens, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, @@ -527,10 +529,7 @@ def agent( provider=provider, workspace=config.workspace_path, model=config.agents.defaults.model, - temperature=config.agents.defaults.temperature, - max_tokens=config.agents.defaults.max_tokens, max_iterations=config.agents.defaults.max_tool_iterations, - reasoning_effort=config.agents.defaults.reasoning_effort, context_window_tokens=config.agents.defaults.context_window_tokens, brave_api_key=config.tools.web.search.api_key or None, web_proxy=config.tools.web.proxy or None, diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index a3b6c47..d4ea60d 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -32,6 +32,21 @@ class LLMResponse: return len(self.tool_calls) > 0 +@dataclass(frozen=True) +class GenerationSettings: + """Default generation parameters for LLM calls. + + Stored on the provider so every call site inherits the same defaults + without having to pass temperature / max_tokens / reasoning_effort + through every layer. Individual call sites can still override by + passing explicit keyword arguments to chat() / chat_with_retry(). + """ + + temperature: float = 0.7 + max_tokens: int = 4096 + reasoning_effort: str | None = None + + class LLMProvider(ABC): """ Abstract base class for LLM providers. @@ -56,9 +71,12 @@ class LLMProvider(ABC): "temporarily unavailable", ) + _SENTINEL = object() + def __init__(self, api_key: str | None = None, api_base: str | None = None): self.api_key = api_key self.api_base = api_base + self.generation: GenerationSettings = GenerationSettings() @staticmethod def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: @@ -155,11 +173,23 @@ class LLMProvider(ABC): messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, + max_tokens: object = _SENTINEL, + temperature: object = _SENTINEL, + reasoning_effort: object = _SENTINEL, ) -> LLMResponse: - """Call chat() with retry on transient provider failures.""" + """Call chat() with retry on transient provider failures. + + Parameters default to ``self.generation`` when not explicitly passed, + so callers no longer need to thread temperature / max_tokens / + reasoning_effort through every layer. + """ + if max_tokens is self._SENTINEL: + max_tokens = self.generation.max_tokens + if temperature is self._SENTINEL: + temperature = self.generation.temperature + if reasoning_effort is self._SENTINEL: + reasoning_effort = self.generation.reasoning_effort + for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): try: response = await self.chat( diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py index 0263f01..69be858 100644 --- a/tests/test_memory_consolidation_types.py +++ b/tests/test_memory_consolidation_types.py @@ -265,3 +265,26 @@ class TestMemoryConsolidationTypeHandling: assert result is True assert provider.calls == 2 assert delays == [1] + + @pytest.mark.asyncio + async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None: + """Consolidation no longer passes generation params — the provider owns them.""" + store = MemoryStore(tmp_path) + provider = AsyncMock() + provider.chat_with_retry = AsyncMock( + return_value=_make_tool_response( + history_entry="[2026-01-01] User discussed testing.", + memory_update="# Memory\nUser likes testing.", + ) + ) + messages = _make_messages(message_count=60) + + result = await store.consolidate(messages, provider, "test-model") + + assert result is True + provider.chat_with_retry.assert_awaited_once() + _, kwargs = provider.chat_with_retry.await_args + assert kwargs["model"] == "test-model" + assert "temperature" not in kwargs + assert "max_tokens" not in kwargs + assert "reasoning_effort" not in kwargs diff --git a/tests/test_provider_retry.py b/tests/test_provider_retry.py index 751ecc3..2420399 100644 --- a/tests/test_provider_retry.py +++ b/tests/test_provider_retry.py @@ -2,7 +2,7 @@ import asyncio import pytest -from nanobot.providers.base import LLMProvider, LLMResponse +from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse class ScriptedProvider(LLMProvider): @@ -10,9 +10,11 @@ class ScriptedProvider(LLMProvider): super().__init__() self._responses = list(responses) self.calls = 0 + self.last_kwargs: dict = {} async def chat(self, *args, **kwargs) -> LLMResponse: self.calls += 1 + self.last_kwargs = kwargs response = self._responses.pop(0) if isinstance(response, BaseException): raise response @@ -90,3 +92,34 @@ async def test_chat_with_retry_preserves_cancelled_error() -> None: with pytest.raises(asyncio.CancelledError): await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + +@pytest.mark.asyncio +async def test_chat_with_retry_uses_provider_generation_defaults() -> None: + """When callers omit generation params, provider.generation defaults are used.""" + provider = ScriptedProvider([LLMResponse(content="ok")]) + provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high") + + await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert provider.last_kwargs["temperature"] == 0.2 + assert provider.last_kwargs["max_tokens"] == 321 + assert provider.last_kwargs["reasoning_effort"] == "high" + + +@pytest.mark.asyncio +async def test_chat_with_retry_explicit_override_beats_defaults() -> None: + """Explicit kwargs should override provider.generation defaults.""" + provider = ScriptedProvider([LLMResponse(content="ok")]) + provider.generation = GenerationSettings(temperature=0.2, max_tokens=321, reasoning_effort="high") + + await provider.chat_with_retry( + messages=[{"role": "user", "content": "hello"}], + temperature=0.9, + max_tokens=9999, + reasoning_effort="low", + ) + + assert provider.last_kwargs["temperature"] == 0.9 + assert provider.last_kwargs["max_tokens"] == 9999 + assert provider.last_kwargs["reasoning_effort"] == "low"