186 lines
4.4 KiB
Go
186 lines
4.4 KiB
Go
package ai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
// AIClient 定义了与 AI 服务交互的接口
|
|
type AIClient interface {
|
|
Chat(systemMsg, prompt string) (string, error)
|
|
}
|
|
|
|
// Client 实现了 AIClient 接口
|
|
type Client struct {
|
|
apiBase string
|
|
apiKey string
|
|
model string
|
|
aiType string
|
|
stream bool
|
|
temperature float64
|
|
client *http.Client
|
|
}
|
|
|
|
// NewClient 创建新的 AI 客户端
|
|
func NewClient(apiBase, apiKey, model, aiType string, temperature float64) *Client {
|
|
return &Client{
|
|
apiBase: apiBase,
|
|
apiKey: apiKey,
|
|
model: model,
|
|
aiType: aiType,
|
|
temperature: temperature,
|
|
client: &http.Client{},
|
|
}
|
|
}
|
|
|
|
// Chat 发送聊天请求到 AI 服务
|
|
func (c *Client) Chat(systemMsg, prompt string) (string, error) {
|
|
// 根据配置的类型判断使用哪个服务
|
|
var response string
|
|
var err error
|
|
|
|
log.Printf("AImodel=%s, prompt=%s", c.model, prompt)
|
|
if c.aiType == "ollama" {
|
|
response, err = c.ollamaChat(systemMsg, prompt)
|
|
} else {
|
|
response, err = c.openAIChat(systemMsg, prompt)
|
|
}
|
|
|
|
if err != nil {
|
|
log.Printf("AI 聊天请求失败: error=%v", err)
|
|
return "", fmt.Errorf("AI 聊天请求失败: %w", err)
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ollamaChat 发送请求到 Ollama API
|
|
func (c *Client) ollamaChat(systemMsg, prompt string) (string, error) {
|
|
// 组合系统提示词和用户提示词
|
|
fullPrompt := fmt.Sprintf("%s\n\n%s", systemMsg, prompt)
|
|
|
|
reqBody := map[string]interface{}{
|
|
"model": c.model,
|
|
"prompt": fullPrompt,
|
|
"stream": c.stream,
|
|
"temperature": c.temperature,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("序列化请求失败: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", c.apiBase+"/api/generate", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return "", fmt.Errorf("创建请求失败: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := c.client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("发送请求失败: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("API 请求失败: status=%d, body=%s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var result struct {
|
|
Response string `json:"response"`
|
|
Done bool `json:"done"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
log.Printf("响应内容: %s", string(body))
|
|
return "", fmt.Errorf("解析响应失败: %w", err)
|
|
}
|
|
|
|
if !result.Done {
|
|
return "", fmt.Errorf("AI 响应未完成")
|
|
}
|
|
|
|
pattern := "(?s)<think>(.*?)</think>"
|
|
|
|
reg := regexp.MustCompile(pattern)
|
|
matches := reg.ReplaceAllString(result.Response, "")
|
|
|
|
return strings.TrimSpace(matches), nil
|
|
}
|
|
|
|
// openAIChat 发送请求到 OpenAI API
|
|
func (c *Client) openAIChat(systemMsg, prompt string) (string, error) {
|
|
reqBody := map[string]interface{}{
|
|
"model": c.model,
|
|
"messages": []map[string]string{
|
|
{
|
|
"role": "system",
|
|
"content": systemMsg,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": prompt,
|
|
},
|
|
},
|
|
"stream": false,
|
|
"temperature": c.temperature,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("序列化请求失败: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", c.apiBase+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return "", fmt.Errorf("创建请求失败: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
|
|
|
resp, err := c.client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("发送请求失败: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("API 请求失败: status=%d, body=%s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var result struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return "", fmt.Errorf("解析响应失败: %w", err)
|
|
}
|
|
|
|
if len(result.Choices) == 0 {
|
|
return "", fmt.Errorf("AI 响应为空")
|
|
}
|
|
|
|
pattern := "(?s)<think>(.*?)</think>"
|
|
|
|
reg := regexp.MustCompile(pattern)
|
|
matches := reg.ReplaceAllString(result.Choices[0].Message.Content, "")
|
|
|
|
return strings.TrimSpace(matches), nil
|
|
}
|