fix(memory): fallback to tool_choice=auto when provider rejects forced function call

Some providers (e.g. Dashscope in thinking mode) reject object-style
tool_choice with "does not support being set to required or object".
Retry once with tool_choice="auto" instead of failing silently.

Made-with: Cursor
This commit is contained in:
Xubin Ren
2026-03-13 03:18:08 +00:00
parent e30d19e94d
commit 4f77b9385c
2 changed files with 91 additions and 2 deletions

View File

@@ -57,6 +57,20 @@ def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None:
return args[0] if args and isinstance(args[0], dict) else None return args[0] if args and isinstance(args[0], dict) else None
return args if isinstance(args, dict) else None return args if isinstance(args, dict) else None
_TOOL_CHOICE_ERROR_MARKERS = (
"tool_choice",
"toolchoice",
"does not support",
'should be ["none", "auto"]',
)
def _is_tool_choice_unsupported(content: str | None) -> bool:
"""Detect provider errors caused by forced tool_choice being unsupported."""
text = (content or "").lower()
return any(m in text for m in _TOOL_CHOICE_ERROR_MARKERS)
class MemoryStore: class MemoryStore:
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
@@ -118,15 +132,33 @@ class MemoryStore:
] ]
try: try:
forced = {"type": "function", "function": {"name": "save_memory"}}
response = await provider.chat_with_retry( response = await provider.chat_with_retry(
messages=chat_messages, messages=chat_messages,
tools=_SAVE_MEMORY_TOOL, tools=_SAVE_MEMORY_TOOL,
model=model, model=model,
tool_choice={"type": "function", "function": {"name": "save_memory"}}, tool_choice=forced,
) )
if response.finish_reason == "error" and _is_tool_choice_unsupported(
response.content
):
logger.warning("Forced tool_choice unsupported, retrying with auto")
response = await provider.chat_with_retry(
messages=chat_messages,
tools=_SAVE_MEMORY_TOOL,
model=model,
tool_choice="auto",
)
if not response.has_tool_calls: if not response.has_tool_calls:
logger.warning("Memory consolidation: LLM did not call save_memory, skipping") logger.warning(
"Memory consolidation: LLM did not call save_memory "
"(finish_reason={}, content_len={}, content_preview={})",
response.finish_reason,
len(response.content or ""),
(response.content or "")[:200],
)
return False return False
args = _normalize_save_memory_args(response.tool_calls[0].arguments) args = _normalize_save_memory_args(response.tool_calls[0].arguments)

View File

@@ -288,3 +288,60 @@ class TestMemoryConsolidationTypeHandling:
assert "temperature" not in kwargs assert "temperature" not in kwargs
assert "max_tokens" not in kwargs assert "max_tokens" not in kwargs
assert "reasoning_effort" not in kwargs assert "reasoning_effort" not in kwargs
@pytest.mark.asyncio
async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None:
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error calling LLM: litellm.BadRequestError: "
"The tool_choice parameter does not support being set to required or object",
finish_reason="error",
tool_calls=[],
)
ok_resp = _make_tool_response(
history_entry="[2026-01-01] Fallback worked.",
memory_update="# Memory\nFallback OK.",
)
call_log: list[dict] = []
async def _tracking_chat(**kwargs):
call_log.append(kwargs)
return error_resp if len(call_log) == 1 else ok_resp
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat)
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is True
assert len(call_log) == 2
assert isinstance(call_log[0]["tool_choice"], dict)
assert call_log[1]["tool_choice"] == "auto"
assert "Fallback worked." in store.history_file.read_text()
@pytest.mark.asyncio
async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None:
"""Forced rejected, auto retry also produces no tool call -> return False."""
store = MemoryStore(tmp_path)
error_resp = LLMResponse(
content="Error: tool_choice must be none or auto",
finish_reason="error",
tool_calls=[],
)
no_tool_resp = LLMResponse(
content="Here is a summary.",
finish_reason="stop",
tool_calls=[],
)
provider = AsyncMock()
provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp])
messages = _make_messages(message_count=60)
result = await store.consolidate(messages, provider, "test-model")
assert result is False
assert not store.history_file.exists()