package ai import ( "bytes" "encoding/json" "fmt" "io" "log" "net/http" "regexp" "strings" ) // AIClient 定义了与 AI 服务交互的接口 type AIClient interface { Chat(systemMsg, prompt string) (string, error) } // Client 实现了 AIClient 接口 type Client struct { apiBase string apiKey string model string aiType string stream bool temperature float64 client *http.Client } // NewClient 创建新的 AI 客户端 func NewClient(apiBase, apiKey, model, aiType string, temperature float64) *Client { return &Client{ apiBase: apiBase, apiKey: apiKey, model: model, aiType: aiType, temperature: temperature, client: &http.Client{}, } } // Chat 发送聊天请求到 AI 服务 func (c *Client) Chat(systemMsg, prompt string) (string, error) { // 根据配置的类型判断使用哪个服务 var response string var err error if c.aiType == "ollama" { response, err = c.ollamaChat(systemMsg, prompt) } else { response, err = c.openAIChat(systemMsg, prompt) } if err != nil { log.Printf("AI 聊天请求失败: error=%v", err) return "", fmt.Errorf("AI 聊天请求失败: %w", err) } return response, nil } // ollamaChat 发送请求到 Ollama API func (c *Client) ollamaChat(systemMsg, prompt string) (string, error) { // 组合系统提示词和用户提示词 fullPrompt := fmt.Sprintf("%s\n\n%s", systemMsg, prompt) reqBody := map[string]interface{}{ "model": c.model, "prompt": fullPrompt, "stream": c.stream, "temperature": c.temperature, } jsonData, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequest("POST", c.apiBase+"/api/generate", bytes.NewBuffer(jsonData)) if err != nil { return "", fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := c.client.Do(req) if err != nil { return "", fmt.Errorf("发送请求失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API 请求失败: status=%d, body=%s", resp.StatusCode, string(body)) } var result struct { Response string `json:"response"` Done bool `json:"done"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { body, _ := io.ReadAll(resp.Body) log.Printf("响应内容: %s", string(body)) return "", fmt.Errorf("解析响应失败: %w", err) } if !result.Done { return "", fmt.Errorf("AI 响应未完成") } pattern := "(?s)(.*?)" reg := regexp.MustCompile(pattern) matches := reg.ReplaceAllString(result.Response, "") return strings.TrimSpace(matches), nil } // openAIChat 发送请求到 OpenAI API func (c *Client) openAIChat(systemMsg, prompt string) (string, error) { reqBody := map[string]interface{}{ "model": c.model, "messages": []map[string]string{ { "role": "system", "content": systemMsg, }, { "role": "user", "content": prompt, }, }, "stream": false, "temperature": c.temperature, } jsonData, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequest("POST", c.apiBase+"/v1/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { return "", fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.apiKey) resp, err := c.client.Do(req) if err != nil { return "", fmt.Errorf("发送请求失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("API 请求失败: status=%d, body=%s", resp.StatusCode, string(body)) } var result struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return "", fmt.Errorf("解析响应失败: %w", err) } if len(result.Choices) == 0 { return "", fmt.Errorf("AI 响应为空") } pattern := "(?s)(.*?)" reg := regexp.MustCompile(pattern) matches := reg.ReplaceAllString(result.Choices[0].Message.Content, "") return strings.TrimSpace(matches), nil }