349 lines
9.5 KiB
Go
349 lines
9.5 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"math/rand"
|
|
"os"
|
|
"sort"
|
|
"sync"
|
|
|
|
"github.com/spf13/viper"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
var (
|
|
cfg *Config
|
|
once sync.Once
|
|
mu sync.RWMutex
|
|
|
|
aiBalancer *AIBalancer
|
|
balancerMu sync.RWMutex
|
|
)
|
|
|
|
type Config struct {
|
|
Port int `mapstructure:"port"`
|
|
AdminToken string `mapstructure:"admin_token"` // 管理令牌
|
|
AIs []AIConfig `mapstructure:"ais"`
|
|
Git []GitConfig `mapstructure:"git"`
|
|
AutoDisableConfig `mapstructure:"auto_disable"` // 全局自动禁用配置
|
|
}
|
|
|
|
// AutoDisableConfig 自动禁用配置
|
|
type AutoDisableConfig struct {
|
|
Enabled bool `mapstructure:"enabled"` // 是否启用自动禁用
|
|
MaxFailures int `mapstructure:"max_failures"` // 最大失败次数
|
|
ResetAfter int `mapstructure:"reset_after"` // 重置时间(分钟)
|
|
}
|
|
|
|
type AIConfig struct {
|
|
Name string `mapstructure:"name"`
|
|
Type string `mapstructure:"type"` // "ollama" 或 "openai"
|
|
APIKey string `mapstructure:"api_key"`
|
|
APIBase string `mapstructure:"url"`
|
|
Model string `mapstructure:"model"`
|
|
SystemMsg string `mapstructure:"system_msg"` // 系统提示词
|
|
Temperature float64 `mapstructure:"temperature"` // 温度
|
|
Stream bool `mapstructure:"stream"` // 是否使用流式响应
|
|
Weight int `mapstructure:"weight"`
|
|
Priority int `mapstructure:"priority"` // 优先级,数字越大优先级越高
|
|
Enabled bool `mapstructure:"enabled"` // 是否启用
|
|
AutoDisable bool `mapstructure:"auto_disable"` // 是否启用自动禁用(覆盖全局配置)
|
|
MaxFailures *int `mapstructure:"max_failures"` // 最大失败次数(覆盖全局配置)
|
|
ResetAfter *int `mapstructure:"reset_after"` // 重置时间(覆盖全局配置)
|
|
}
|
|
|
|
type GitConfig struct {
|
|
Name string `mapstructure:"name"` // 平台名称
|
|
Type string `mapstructure:"type"` // 平台类型
|
|
Token string `mapstructure:"token"` // 访问令牌
|
|
Secret string `mapstructure:"webhook_secret"` // webhook 密钥
|
|
APIBase string `mapstructure:"api_base"` // API 基础 URL
|
|
SignatureHeader string `mapstructure:"signature_header"` // webhook 签名的 header 名称
|
|
EventHeader string `mapstructure:"event_header"` // webhook 事件类型的 header 名称
|
|
}
|
|
|
|
// 添加一个新的结构体用于负载均衡
|
|
type AIBalancer struct {
|
|
ais []AIConfig
|
|
current int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func init() {
|
|
// 设置默认值
|
|
viper.SetDefault("port", 53321)
|
|
viper.SetDefault("admin_token", "token")
|
|
viper.SetDefault("ais", []interface{}{})
|
|
viper.SetDefault("git", []interface{}{})
|
|
viper.SetDefault("auto_disable", map[string]interface{}{
|
|
"enabled": true,
|
|
"max_failures": 3,
|
|
"reset_after": 30,
|
|
})
|
|
}
|
|
|
|
func Load() (*Config, error) {
|
|
var err error
|
|
once.Do(func() {
|
|
viper.SetConfigName("config")
|
|
viper.SetConfigType("yaml")
|
|
viper.AddConfigPath(".")
|
|
|
|
if err = viper.ReadInConfig(); err != nil {
|
|
return
|
|
}
|
|
|
|
cfg = &Config{}
|
|
if err = viper.Unmarshal(cfg); err != nil {
|
|
return
|
|
}
|
|
|
|
// 添加日志输出以便调试
|
|
log.Printf("已加载配置: %+v", cfg)
|
|
log.Printf("使用的配置文件: %s", viper.ConfigFileUsed())
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("加载配置失败: %w", err)
|
|
}
|
|
|
|
// 确保配置不为空
|
|
if cfg == nil {
|
|
return nil, fmt.Errorf("配置加载失败: 配置为空")
|
|
}
|
|
|
|
return cfg, nil
|
|
}
|
|
|
|
func GetConfig() *Config {
|
|
if cfg == nil {
|
|
panic("配置未初始化")
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
// Save 保存配置到文件
|
|
func Save(newConfig *Config) error {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
// 创建备份
|
|
backupFile := "config.yaml.bak"
|
|
if _, err := os.Stat("config.yaml"); err == nil {
|
|
if err := os.Rename("config.yaml", backupFile); err != nil {
|
|
return fmt.Errorf("backup config failed: %w", err)
|
|
}
|
|
}
|
|
// 使用自定义的 YAML 编码器
|
|
file, err := os.Create("config.yaml")
|
|
if err != nil {
|
|
return fmt.Errorf("创建配置文件失败: %w", err)
|
|
}
|
|
yamlEncoder := yaml.NewEncoder(file)
|
|
yamlEncoder.SetIndent(2) // 设置缩进
|
|
|
|
// 将配置转换为 map 并保存
|
|
configMap := newConfig.ToMap()
|
|
if err := yamlEncoder.Encode(configMap); err != nil {
|
|
// 如果保存失败,恢复备份
|
|
if _, err := os.Stat(backupFile); err == nil {
|
|
os.Rename(backupFile, "config.yaml")
|
|
}
|
|
return fmt.Errorf("save config failed: %w", err)
|
|
}
|
|
|
|
// 更新内存中的配置
|
|
cfg = newConfig
|
|
|
|
return nil
|
|
}
|
|
|
|
// ToMap 将配置转换为 map
|
|
func (c *Config) ToMap() map[string]interface{} {
|
|
ais := make([]map[string]interface{}, len(c.AIs))
|
|
for i, ai := range c.AIs {
|
|
ais[i] = map[string]interface{}{
|
|
"name": ai.Name,
|
|
"type": ai.Type,
|
|
"api_key": ai.APIKey,
|
|
"url": ai.APIBase,
|
|
"model": ai.Model,
|
|
"system_msg": ai.SystemMsg,
|
|
"temperature": ai.Temperature,
|
|
"stream": ai.Stream,
|
|
"weight": ai.Weight,
|
|
"priority": ai.Priority,
|
|
"enabled": ai.Enabled,
|
|
"auto_disable": ai.AutoDisable,
|
|
"max_failures": ai.MaxFailures,
|
|
"reset_after": ai.ResetAfter,
|
|
}
|
|
}
|
|
|
|
// 修改 Git 平台配置的转换逻辑
|
|
platforms := make([]map[string]interface{}, len(c.Git))
|
|
for i, platform := range c.Git {
|
|
platforms[i] = map[string]interface{}{
|
|
"name": platform.Name,
|
|
"type": platform.Type,
|
|
"token": platform.Token,
|
|
"webhook_secret": platform.Secret,
|
|
"api_base": platform.APIBase,
|
|
"signature_header": platform.SignatureHeader,
|
|
"event_header": platform.EventHeader,
|
|
}
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"port": c.Port,
|
|
"admin_token": c.AdminToken,
|
|
"ais": ais,
|
|
"git": platforms,
|
|
"auto_disable": map[string]interface{}{
|
|
"enabled": c.AutoDisableConfig.Enabled,
|
|
"max_failures": c.AutoDisableConfig.MaxFailures,
|
|
"reset_after": c.AutoDisableConfig.ResetAfter,
|
|
},
|
|
}
|
|
}
|
|
|
|
// ToMapHtml 将配置转换为 map 页面展示
|
|
func (c *Config) ToMapHtml() map[string]interface{} {
|
|
ais := make([]map[string]interface{}, len(c.AIs))
|
|
for i, ai := range c.AIs {
|
|
ais[i] = map[string]interface{}{
|
|
"name": ai.Name,
|
|
"type": ai.Type,
|
|
"api_key": ai.APIKey,
|
|
"url": ai.APIBase,
|
|
"model": ai.Model,
|
|
"system_msg": ai.SystemMsg,
|
|
"temperature": ai.Temperature,
|
|
"stream": ai.Stream,
|
|
"weight": ai.Weight,
|
|
"priority": ai.Priority,
|
|
"enabled": ai.Enabled,
|
|
"auto_disable": ai.AutoDisable,
|
|
"max_failures": ai.MaxFailures,
|
|
"reset_after": ai.ResetAfter,
|
|
}
|
|
}
|
|
|
|
// 修改 Git 平台配置的转换逻辑
|
|
platforms := make([]map[string]interface{}, len(c.Git))
|
|
for i, platform := range c.Git {
|
|
platforms[i] = map[string]interface{}{
|
|
"name": platform.Name,
|
|
"type": platform.Type,
|
|
"token": platform.Token,
|
|
"webhook_secret": platform.Secret,
|
|
"api_base": platform.APIBase,
|
|
"signature_header": platform.SignatureHeader,
|
|
"event_header": platform.EventHeader,
|
|
}
|
|
}
|
|
|
|
return map[string]interface{}{
|
|
"ais": ais,
|
|
"git": platforms,
|
|
"auto_disable": map[string]interface{}{
|
|
"enabled": c.AutoDisableConfig.Enabled,
|
|
"max_failures": c.AutoDisableConfig.MaxFailures,
|
|
"reset_after": c.AutoDisableConfig.ResetAfter,
|
|
},
|
|
}
|
|
}
|
|
|
|
func LoadConfig(configFile string) error {
|
|
log.Printf("加载配置文件: file=%s", configFile)
|
|
|
|
viper.SetConfigFile(configFile)
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
log.Printf("读取配置文件失败: file=%s, error=%v", configFile, err)
|
|
return fmt.Errorf("读取配置文件失败: %w", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if err := viper.Unmarshal(&cfg); err != nil {
|
|
log.Printf("解析配置文件失败: file=%s, error=%v", configFile, err)
|
|
return fmt.Errorf("解析配置文件失败: %w", err)
|
|
}
|
|
|
|
log.Printf("配置文件加载成功: file=%s", configFile)
|
|
return nil
|
|
}
|
|
|
|
// 添加新的方法
|
|
func NewAIBalancer(ais []AIConfig) *AIBalancer {
|
|
// 过滤出已启用的 AI
|
|
enabledAIs := make([]AIConfig, 0)
|
|
for _, ai := range ais {
|
|
if ai.Enabled {
|
|
enabledAIs = append(enabledAIs, ai)
|
|
}
|
|
}
|
|
|
|
// 按优先级排序(从高到低)
|
|
sort.Slice(enabledAIs, func(i, j int) bool {
|
|
return enabledAIs[i].Priority > enabledAIs[j].Priority
|
|
})
|
|
|
|
return &AIBalancer{
|
|
ais: enabledAIs,
|
|
current: 0,
|
|
}
|
|
}
|
|
|
|
// 获取下一个可用的 AI 配置
|
|
func (b *AIBalancer) Next() (*AIConfig, error) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
if len(b.ais) == 0 {
|
|
return nil, fmt.Errorf("没有可用的 AI 配置")
|
|
}
|
|
|
|
// 获取最高优先级
|
|
highestPriority := b.ais[0].Priority
|
|
|
|
// 收集具有最高优先级的 AI
|
|
highPriorityAIs := make([]AIConfig, 0)
|
|
totalWeight := 0
|
|
for _, ai := range b.ais {
|
|
if ai.Priority == highestPriority {
|
|
highPriorityAIs = append(highPriorityAIs, ai)
|
|
totalWeight += ai.Weight
|
|
}
|
|
}
|
|
|
|
// 在最高优先级的 AI 中按权重选择
|
|
target := rand.Intn(totalWeight)
|
|
currentWeight := 0
|
|
for i, ai := range highPriorityAIs {
|
|
currentWeight += ai.Weight
|
|
if currentWeight > target {
|
|
return &highPriorityAIs[i], nil
|
|
}
|
|
}
|
|
|
|
// 如果没有选中任何一个(不应该发生),返回第一个最高优先级的 AI
|
|
return &highPriorityAIs[0], nil
|
|
}
|
|
|
|
// ResetAIBalancer 重置 AI 负载均衡器
|
|
func ResetAIBalancer(ais []AIConfig) {
|
|
balancerMu.Lock()
|
|
defer balancerMu.Unlock()
|
|
aiBalancer = NewAIBalancer(ais)
|
|
}
|
|
|
|
// GetAIBalancer 获取当前的 AI 负载均衡器
|
|
func GetAIBalancer() *AIBalancer {
|
|
balancerMu.RLock()
|
|
defer balancerMu.RUnlock()
|
|
return aiBalancer
|
|
}
|