Merge PR #1874: preserve provider-specific tool-call fields

This commit is contained in:
Re-bin
2026-03-11 15:30:33 +00:00
5 changed files with 82 additions and 17 deletions

View File

@@ -203,14 +203,7 @@ class AgentLoop:
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True) await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [ tool_call_dicts = [
{ tc.to_openai_tool_call()
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
}
}
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(

View File

@@ -126,14 +126,7 @@ class SubagentManager:
if response.has_tool_calls: if response.has_tool_calls:
tool_call_dicts = [ tool_call_dicts = [
{ tc.to_openai_tool_call()
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
},
}
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages.append(build_assistant_message( messages.append(build_assistant_message(

View File

@@ -1,6 +1,7 @@
"""Base LLM provider interface.""" """Base LLM provider interface."""
import asyncio import asyncio
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -14,6 +15,24 @@ class ToolCallRequest:
id: str id: str
name: str name: str
arguments: dict[str, Any] arguments: dict[str, Any]
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
def to_openai_tool_call(self) -> dict[str, Any]:
"""Serialize to an OpenAI-style tool_call payload."""
tool_call = {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
if self.provider_specific_fields:
tool_call["provider_specific_fields"] = self.provider_specific_fields
if self.function_provider_specific_fields:
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
return tool_call
@dataclass @dataclass

View File

@@ -309,10 +309,17 @@ class LiteLLMProvider(LLMProvider):
if isinstance(args, str): if isinstance(args, str):
args = json_repair.loads(args) args = json_repair.loads(args)
provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
function_provider_specific_fields = (
getattr(tc.function, "provider_specific_fields", None) or None
)
tool_calls.append(ToolCallRequest( tool_calls.append(ToolCallRequest(
id=_short_tool_id(), id=_short_tool_id(),
name=tc.function.name, name=tc.function.name,
arguments=args, arguments=args,
provider_specific_fields=provider_specific_fields,
function_provider_specific_fields=function_provider_specific_fields,
)) ))
usage = {} usage = {}

View File

@@ -0,0 +1,53 @@
from types import SimpleNamespace
from nanobot.providers.base import ToolCallRequest
from nanobot.providers.litellm_provider import LiteLLMProvider
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="tool_calls",
message=SimpleNamespace(
content=None,
tool_calls=[
SimpleNamespace(
id="call_123",
function=SimpleNamespace(
name="read_file",
arguments='{"path":"todo.md"}',
provider_specific_fields={"inner": "value"},
),
provider_specific_fields={"thought_signature": "signed-token"},
)
],
),
)
],
usage=None,
)
parsed = provider._parse_response(response)
assert len(parsed.tool_calls) == 1
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
def test_tool_call_request_serializes_provider_fields() -> None:
tool_call = ToolCallRequest(
id="abc123xyz",
name="read_file",
arguments={"path": "todo.md"},
provider_specific_fields={"thought_signature": "signed-token"},
function_provider_specific_fields={"inner": "value"},
)
message = tool_call.to_openai_tool_call()
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
assert message["function"]["arguments"] == '{"path": "todo.md"}'