fix: preserve provider-specific tool call metadata for Gemini

This commit is contained in:
WhalerO
2026-03-11 09:56:18 +08:00
parent 947ed508ad
commit ed82f95f0c
5 changed files with 97 additions and 16 deletions

View File

@@ -208,14 +208,7 @@ class AgentLoop:
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
}
}
self._build_tool_call_message(tc)
for tc in response.tool_calls
]
messages = self.context.add_assistant_message(
@@ -256,6 +249,22 @@ class AgentLoop:
return final_content, tools_used, messages
@staticmethod
def _build_tool_call_message(tc: Any) -> dict[str, Any]:
tool_call = {
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
}
}
if getattr(tc, "provider_specific_fields", None):
tool_call["provider_specific_fields"] = tc.provider_specific_fields
if getattr(tc, "function_provider_specific_fields", None):
tool_call["function"]["provider_specific_fields"] = tc.function_provider_specific_fields
return tool_call
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
self._running = True

View File

@@ -135,14 +135,7 @@ class SubagentManager:
if response.has_tool_calls:
# Add assistant message with tool calls
tool_call_dicts = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
},
}
self._build_tool_call_message(tc)
for tc in response.tool_calls
]
messages.append({
@@ -230,6 +223,22 @@ Stay focused on the assigned task. Your final response will be reported back to
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
return "\n\n".join(parts)
@staticmethod
def _build_tool_call_message(tc: Any) -> dict[str, Any]:
tool_call = {
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
},
}
if getattr(tc, "provider_specific_fields", None):
tool_call["provider_specific_fields"] = tc.provider_specific_fields
if getattr(tc, "function_provider_specific_fields", None):
tool_call["function"]["provider_specific_fields"] = tc.function_provider_specific_fields
return tool_call
async def cancel_by_session(self, session_key: str) -> int:
"""Cancel all subagents for the given session. Returns count cancelled."""

View File

@@ -14,6 +14,8 @@ class ToolCallRequest:
id: str
name: str
arguments: dict[str, Any]
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
@dataclass

View File

@@ -309,10 +309,17 @@ class LiteLLMProvider(LLMProvider):
if isinstance(args, str):
args = json_repair.loads(args)
provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
function_provider_specific_fields = (
getattr(tc.function, "provider_specific_fields", None) or None
)
tool_calls.append(ToolCallRequest(
id=_short_tool_id(),
name=tc.function.name,
arguments=args,
provider_specific_fields=provider_specific_fields,
function_provider_specific_fields=function_provider_specific_fields,
))
usage = {}

View File

@@ -0,0 +1,54 @@
from types import SimpleNamespace
from nanobot.agent.loop import AgentLoop
from nanobot.providers.base import ToolCallRequest
from nanobot.providers.litellm_provider import LiteLLMProvider
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="tool_calls",
message=SimpleNamespace(
content=None,
tool_calls=[
SimpleNamespace(
id="call_123",
function=SimpleNamespace(
name="read_file",
arguments='{"path":"todo.md"}',
provider_specific_fields={"inner": "value"},
),
provider_specific_fields={"thought_signature": "signed-token"},
)
],
),
)
],
usage=None,
)
parsed = provider._parse_response(response)
assert len(parsed.tool_calls) == 1
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
def test_agent_loop_replays_tool_call_provider_fields() -> None:
tool_call = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"thought_signature": "signed-token"},
function_provider_specific_fields={"inner": "value"},
)
message = AgentLoop._build_tool_call_message(tool_call)
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
assert message["function"]["arguments"] == '{"path": "todo.md"}'