refactor: centralize tool call serialization in ToolCallRequest

This commit is contained in:
WhalerO
2026-03-11 15:01:18 +08:00
parent ed82f95f0c
commit 6ef7ab53d0
4 changed files with 21 additions and 37 deletions

View File

@@ -208,7 +208,7 @@ class AgentLoop:
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True) await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
tool_call_dicts = [ tool_call_dicts = [
self._build_tool_call_message(tc) tc.to_openai_tool_call()
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages = self.context.add_assistant_message( messages = self.context.add_assistant_message(
@@ -249,22 +249,6 @@ class AgentLoop:
return final_content, tools_used, messages return final_content, tools_used, messages
@staticmethod
def _build_tool_call_message(tc: Any) -> dict[str, Any]:
tool_call = {
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
}
}
if getattr(tc, "provider_specific_fields", None):
tool_call["provider_specific_fields"] = tc.provider_specific_fields
if getattr(tc, "function_provider_specific_fields", None):
tool_call["function"]["provider_specific_fields"] = tc.function_provider_specific_fields
return tool_call
async def run(self) -> None: async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" """Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
self._running = True self._running = True

View File

@@ -135,7 +135,7 @@ class SubagentManager:
if response.has_tool_calls: if response.has_tool_calls:
# Add assistant message with tool calls # Add assistant message with tool calls
tool_call_dicts = [ tool_call_dicts = [
self._build_tool_call_message(tc) tc.to_openai_tool_call()
for tc in response.tool_calls for tc in response.tool_calls
] ]
messages.append({ messages.append({
@@ -224,22 +224,6 @@ Stay focused on the assigned task. Your final response will be reported back to
return "\n\n".join(parts) return "\n\n".join(parts)
@staticmethod
def _build_tool_call_message(tc: Any) -> dict[str, Any]:
tool_call = {
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
},
}
if getattr(tc, "provider_specific_fields", None):
tool_call["provider_specific_fields"] = tc.provider_specific_fields
if getattr(tc, "function_provider_specific_fields", None):
tool_call["function"]["provider_specific_fields"] = tc.function_provider_specific_fields
return tool_call
async def cancel_by_session(self, session_key: str) -> int: async def cancel_by_session(self, session_key: str) -> int:
"""Cancel all subagents for the given session. Returns count cancelled.""" """Cancel all subagents for the given session. Returns count cancelled."""
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, []) tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])

View File

@@ -1,6 +1,7 @@
"""Base LLM provider interface.""" """Base LLM provider interface."""
import asyncio import asyncio
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -17,6 +18,22 @@ class ToolCallRequest:
provider_specific_fields: dict[str, Any] | None = None provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None function_provider_specific_fields: dict[str, Any] | None = None
def to_openai_tool_call(self) -> dict[str, Any]:
"""Serialize to an OpenAI-style tool_call payload."""
tool_call = {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
if self.provider_specific_fields:
tool_call["provider_specific_fields"] = self.provider_specific_fields
if self.function_provider_specific_fields:
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
return tool_call
@dataclass @dataclass
class LLMResponse: class LLMResponse:

View File

@@ -1,6 +1,5 @@
from types import SimpleNamespace from types import SimpleNamespace
from nanobot.agent.loop import AgentLoop
from nanobot.providers.base import ToolCallRequest from nanobot.providers.base import ToolCallRequest
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
@@ -38,7 +37,7 @@ def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"} assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
def test_agent_loop_replays_tool_call_provider_fields() -> None: def test_tool_call_request_serializes_provider_fields() -> None:
tool_call = ToolCallRequest( tool_call = ToolCallRequest(
id="abc123xyz", id="abc123xyz",
name="read_file", name="read_file",
@@ -47,7 +46,7 @@ def test_agent_loop_replays_tool_call_provider_fields() -> None:
function_provider_specific_fields={"inner": "value"}, function_provider_specific_fields={"inner": "value"},
) )
message = AgentLoop._build_tool_call_message(tool_call) message = tool_call.to_openai_tool_call()
assert message["provider_specific_fields"] == {"thought_signature": "signed-token"} assert message["provider_specific_fields"] == {"thought_signature": "signed-token"}
assert message["function"]["provider_specific_fields"] == {"inner": "value"} assert message["function"]["provider_specific_fields"] == {"inner": "value"}