fix: preserve provider-specific tool call metadata for Gemini
This commit is contained in:
@@ -208,14 +208,7 @@ class AgentLoop:
|
|||||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||||
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
{
|
self._build_tool_call_message(tc)
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.name,
|
|
||||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
]
|
]
|
||||||
messages = self.context.add_assistant_message(
|
messages = self.context.add_assistant_message(
|
||||||
@@ -256,6 +249,22 @@ class AgentLoop:
|
|||||||
|
|
||||||
return final_content, tools_used, messages
|
return final_content, tools_used, messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tool_call_message(tc: Any) -> dict[str, Any]:
|
||||||
|
tool_call = {
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.name,
|
||||||
|
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if getattr(tc, "provider_specific_fields", None):
|
||||||
|
tool_call["provider_specific_fields"] = tc.provider_specific_fields
|
||||||
|
if getattr(tc, "function_provider_specific_fields", None):
|
||||||
|
tool_call["function"]["provider_specific_fields"] = tc.function_provider_specific_fields
|
||||||
|
return tool_call
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|||||||
@@ -135,14 +135,7 @@ class SubagentManager:
|
|||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
# Add assistant message with tool calls
|
# Add assistant message with tool calls
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
{
|
self._build_tool_call_message(tc)
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.name,
|
|
||||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tc in response.tool_calls
|
for tc in response.tool_calls
|
||||||
]
|
]
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -231,6 +224,22 @@ Stay focused on the assigned task. Your final response will be reported back to
|
|||||||
|
|
||||||
return "\n\n".join(parts)
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tool_call_message(tc: Any) -> dict[str, Any]:
|
||||||
|
tool_call = {
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.name,
|
||||||
|
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if getattr(tc, "provider_specific_fields", None):
|
||||||
|
tool_call["provider_specific_fields"] = tc.provider_specific_fields
|
||||||
|
if getattr(tc, "function_provider_specific_fields", None):
|
||||||
|
tool_call["function"]["provider_specific_fields"] = tc.function_provider_specific_fields
|
||||||
|
return tool_call
|
||||||
|
|
||||||
async def cancel_by_session(self, session_key: str) -> int:
|
async def cancel_by_session(self, session_key: str) -> int:
|
||||||
"""Cancel all subagents for the given session. Returns count cancelled."""
|
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||||
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ class ToolCallRequest:
|
|||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
arguments: dict[str, Any]
|
arguments: dict[str, Any]
|
||||||
|
provider_specific_fields: dict[str, Any] | None = None
|
||||||
|
function_provider_specific_fields: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -309,10 +309,17 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json_repair.loads(args)
|
args = json_repair.loads(args)
|
||||||
|
|
||||||
|
provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
|
||||||
|
function_provider_specific_fields = (
|
||||||
|
getattr(tc.function, "provider_specific_fields", None) or None
|
||||||
|
)
|
||||||
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=_short_tool_id(),
|
id=_short_tool_id(),
|
||||||
name=tc.function.name,
|
name=tc.function.name,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
|
provider_specific_fields=provider_specific_fields,
|
||||||
|
function_provider_specific_fields=function_provider_specific_fields,
|
||||||
))
|
))
|
||||||
|
|
||||||
usage = {}
|
usage = {}
|
||||||
|
|||||||
54
tests/test_gemini_thought_signature.py
Normal file
54
tests/test_gemini_thought_signature.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.providers.base import ToolCallRequest
|
||||||
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
|
||||||
|
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
|
||||||
|
|
||||||
|
response = SimpleNamespace(
|
||||||
|
choices=[
|
||||||
|
SimpleNamespace(
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
message=SimpleNamespace(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
SimpleNamespace(
|
||||||
|
id="call_123",
|
||||||
|
function=SimpleNamespace(
|
||||||
|
name="read_file",
|
||||||
|
arguments='{"path":"todo.md"}',
|
||||||
|
provider_specific_fields={"inner": "value"},
|
||||||
|
),
|
||||||
|
provider_specific_fields={"thought_signature": "signed-token"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = provider._parse_response(response)
|
||||||
|
|
||||||
|
assert len(parsed.tool_calls) == 1
|
||||||
|
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
|
||||||
|
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_loop_replays_tool_call_provider_fields() -> None:
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="abc123xyz",
|
||||||
|
name="read_file",
|
||||||
|
arguments={"path": "todo.md"},
|
||||||
|
provider_specific_fields={"thought_signature": "signed-token"},
|
||||||
|
function_provider_specific_fields={"inner": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
message = AgentLoop._build_tool_call_message(tool_call)
|
||||||
|
|
||||||
|
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
|
||||||
|
assert message["function"]["provider_specific_fields"] == {"inner": "value"}
|
||||||
|
assert message["function"]["arguments"] == '{"path": "todo.md"}'
|
||||||
Reference in New Issue
Block a user