Files
ai-code-review/services/ai/client.go

186 lines
4.4 KiB
Go

package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strings"
)
// AIClient 定义了与 AI 服务交互的接口
type AIClient interface {
Chat(systemMsg, prompt string) (string, error)
}
// Client 实现了 AIClient 接口
type Client struct {
apiBase string
apiKey string
model string
aiType string
stream bool
temperature float64
client *http.Client
}
// NewClient 创建新的 AI 客户端
func NewClient(apiBase, apiKey, model, aiType string, temperature float64) *Client {
return &Client{
apiBase: apiBase,
apiKey: apiKey,
model: model,
aiType: aiType,
temperature: temperature,
client: &http.Client{},
}
}
// Chat 发送聊天请求到 AI 服务
func (c *Client) Chat(systemMsg, prompt string) (string, error) {
// 根据配置的类型判断使用哪个服务
var response string
var err error
log.Printf("AImodel=%s, prompt=%s", c.model, prompt)
if c.aiType == "ollama" {
response, err = c.ollamaChat(systemMsg, prompt)
} else {
response, err = c.openAIChat(systemMsg, prompt)
}
if err != nil {
log.Printf("AI 聊天请求失败: error=%v", err)
return "", fmt.Errorf("AI 聊天请求失败: %w", err)
}
return response, nil
}
// ollamaChat 发送请求到 Ollama API
func (c *Client) ollamaChat(systemMsg, prompt string) (string, error) {
// 组合系统提示词和用户提示词
fullPrompt := fmt.Sprintf("%s\n\n%s", systemMsg, prompt)
reqBody := map[string]interface{}{
"model": c.model,
"prompt": fullPrompt,
"stream": c.stream,
"temperature": c.temperature,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequest("POST", c.apiBase+"/api/generate", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return "", fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API 请求失败: status=%d, body=%s", resp.StatusCode, string(body))
}
var result struct {
Response string `json:"response"`
Done bool `json:"done"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
body, _ := io.ReadAll(resp.Body)
log.Printf("响应内容: %s", string(body))
return "", fmt.Errorf("解析响应失败: %w", err)
}
if !result.Done {
return "", fmt.Errorf("AI 响应未完成")
}
pattern := "(?s)<think>(.*?)</think>"
reg := regexp.MustCompile(pattern)
matches := reg.ReplaceAllString(result.Response, "")
return strings.TrimSpace(matches), nil
}
// openAIChat 发送请求到 OpenAI API
func (c *Client) openAIChat(systemMsg, prompt string) (string, error) {
reqBody := map[string]interface{}{
"model": c.model,
"messages": []map[string]string{
{
"role": "system",
"content": systemMsg,
},
{
"role": "user",
"content": prompt,
},
},
"stream": false,
"temperature": c.temperature,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequest("POST", c.apiBase+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := c.client.Do(req)
if err != nil {
return "", fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API 请求失败: status=%d, body=%s", resp.StatusCode, string(body))
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("解析响应失败: %w", err)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("AI 响应为空")
}
pattern := "(?s)<think>(.*?)</think>"
reg := regexp.MustCompile(pattern)
matches := reg.ReplaceAllString(result.Choices[0].Message.Content, "")
return strings.TrimSpace(matches), nil
}