diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 1301d47..e7eac88 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -57,6 +57,20 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None: return args[0] if args and isinstance(args[0], dict) else None return args if isinstance(args, dict) else None +_TOOL_CHOICE_ERROR_MARKERS = ( + "tool_choice", + "toolchoice", + "does not support", + 'should be ["none", "auto"]', +) + + +def _is_tool_choice_unsupported(content: str | None) -> bool: + """Detect provider errors caused by forced tool_choice being unsupported.""" + text = (content or "").lower() + return any(m in text for m in _TOOL_CHOICE_ERROR_MARKERS) + + class MemoryStore: """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" @@ -118,15 +132,33 @@ class MemoryStore: ] try: + forced = {"type": "function", "function": {"name": "save_memory"}} response = await provider.chat_with_retry( messages=chat_messages, tools=_SAVE_MEMORY_TOOL, model=model, - tool_choice={"type": "function", "function": {"name": "save_memory"}}, + tool_choice=forced, ) + if response.finish_reason == "error" and _is_tool_choice_unsupported( + response.content + ): + logger.warning("Forced tool_choice unsupported, retrying with auto") + response = await provider.chat_with_retry( + messages=chat_messages, + tools=_SAVE_MEMORY_TOOL, + model=model, + tool_choice="auto", + ) + if not response.has_tool_calls: - logger.warning("Memory consolidation: LLM did not call save_memory, skipping") + logger.warning( + "Memory consolidation: LLM did not call save_memory " + "(finish_reason={}, content_len={}, content_preview={})", + response.finish_reason, + len(response.content or ""), + (response.content or "")[:200], + ) return False args = _normalize_save_memory_args(response.tool_calls[0].arguments) diff --git a/tests/test_memory_consolidation_types.py b/tests/test_memory_consolidation_types.py index 69be858..f1280fc 100644 --- a/tests/test_memory_consolidation_types.py +++ b/tests/test_memory_consolidation_types.py @@ -288,3 +288,60 @@ class TestMemoryConsolidationTypeHandling: assert "temperature" not in kwargs assert "max_tokens" not in kwargs assert "reasoning_effort" not in kwargs + + @pytest.mark.asyncio + async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None: + """Forced tool_choice rejected by provider -> retry with auto and succeed.""" + store = MemoryStore(tmp_path) + error_resp = LLMResponse( + content="Error calling LLM: litellm.BadRequestError: " + "The tool_choice parameter does not support being set to required or object", + finish_reason="error", + tool_calls=[], + ) + ok_resp = _make_tool_response( + history_entry="[2026-01-01] Fallback worked.", + memory_update="# Memory\nFallback OK.", + ) + + call_log: list[dict] = [] + + async def _tracking_chat(**kwargs): + call_log.append(kwargs) + return error_resp if len(call_log) == 1 else ok_resp + + provider = AsyncMock() + provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat) + messages = _make_messages(message_count=60) + + result = await store.consolidate(messages, provider, "test-model") + + assert result is True + assert len(call_log) == 2 + assert isinstance(call_log[0]["tool_choice"], dict) + assert call_log[1]["tool_choice"] == "auto" + assert "Fallback worked." in store.history_file.read_text() + + @pytest.mark.asyncio + async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None: + """Forced rejected, auto retry also produces no tool call -> return False.""" + store = MemoryStore(tmp_path) + error_resp = LLMResponse( + content="Error: tool_choice must be none or auto", + finish_reason="error", + tool_calls=[], + ) + no_tool_resp = LLMResponse( + content="Here is a summary.", + finish_reason="stop", + tool_calls=[], + ) + + provider = AsyncMock() + provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp]) + messages = _make_messages(message_count=60) + + result = await store.consolidate(messages, provider, "test-model") + + assert result is False + assert not store.history_file.exists()