init
This commit is contained in:
338
config/config.go
Normal file
338
config/config.go
Normal file
@ -0,0 +1,338 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
cfg *Config
|
||||
once sync.Once
|
||||
mu sync.RWMutex
|
||||
|
||||
aiBalancer *AIBalancer
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port int `mapstructure:"port"`
|
||||
AdminToken string `mapstructure:"admin_token"` // 管理令牌
|
||||
AIs []AIConfig `mapstructure:"ais"`
|
||||
Git []GitConfig `mapstructure:"git"`
|
||||
AutoDisableConfig `mapstructure:"auto_disable"` // 全局自动禁用配置
|
||||
}
|
||||
|
||||
// AutoDisableConfig 自动禁用配置
|
||||
type AutoDisableConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"` // 是否启用自动禁用
|
||||
MaxFailures int `mapstructure:"max_failures"` // 最大失败次数
|
||||
ResetAfter int `mapstructure:"reset_after"` // 重置时间(分钟)
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Type string `mapstructure:"type"` // "ollama" 或 "openai"
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
APIBase string `mapstructure:"url"`
|
||||
Model string `mapstructure:"model"`
|
||||
SystemMsg string `mapstructure:"system_msg"` // 系统提示词
|
||||
Temperature float64 `mapstructure:"temperature"` // 温度
|
||||
Stream bool `mapstructure:"stream"` // 是否使用流式响应
|
||||
Weight int `mapstructure:"weight"`
|
||||
Priority int `mapstructure:"priority"` // 优先级,数字越大优先级越高
|
||||
Enabled bool `mapstructure:"enabled"` // 是否启用
|
||||
AutoDisable bool `mapstructure:"auto_disable"` // 是否启用自动禁用(覆盖全局配置)
|
||||
MaxFailures *int `mapstructure:"max_failures"` // 最大失败次数(覆盖全局配置)
|
||||
ResetAfter *int `mapstructure:"reset_after"` // 重置时间(覆盖全局配置)
|
||||
}
|
||||
|
||||
type GitConfig struct {
|
||||
Name string `mapstructure:"name"` // 平台名称
|
||||
Type string `mapstructure:"type"` // 平台类型
|
||||
Token string `mapstructure:"token"` // 访问令牌
|
||||
Secret string `mapstructure:"webhook_secret"` // webhook 密钥
|
||||
APIBase string `mapstructure:"api_base"` // API 基础 URL
|
||||
SignatureHeader string `mapstructure:"signature_header"` // webhook 签名的 header 名称
|
||||
EventHeader string `mapstructure:"event_header"` // webhook 事件类型的 header 名称
|
||||
}
|
||||
|
||||
// 添加一个新的结构体用于负载均衡
|
||||
type AIBalancer struct {
|
||||
ais []AIConfig
|
||||
current int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 设置默认值
|
||||
viper.SetDefault("port", 53321)
|
||||
viper.SetDefault("admin_token", "token")
|
||||
viper.SetDefault("ais", []interface{}{})
|
||||
viper.SetDefault("git", []interface{}{})
|
||||
viper.SetDefault("auto_disable", map[string]interface{}{
|
||||
"enabled": true,
|
||||
"max_failures": 3,
|
||||
"reset_after": 30,
|
||||
})
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
if err = viper.ReadInConfig(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg = &Config{}
|
||||
if err = viper.Unmarshal(cfg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 添加日志输出以便调试
|
||||
log.Printf("已加载配置: %+v", cfg)
|
||||
log.Printf("使用的配置文件: %s", viper.ConfigFileUsed())
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 确保配置不为空
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("配置加载失败: 配置为空")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func GetConfig() *Config {
|
||||
if cfg == nil {
|
||||
panic("配置未初始化")
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Save 保存配置到文件
|
||||
func Save(newConfig *Config) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// 创建备份
|
||||
backupFile := "config.yaml.bak"
|
||||
if _, err := os.Stat("config.yaml"); err == nil {
|
||||
if err := os.Rename("config.yaml", backupFile); err != nil {
|
||||
return fmt.Errorf("backup config failed: %w", err)
|
||||
}
|
||||
}
|
||||
// 使用自定义的 YAML 编码器
|
||||
file, err := os.Create("config.yaml")
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建配置文件失败: %w", err)
|
||||
}
|
||||
yamlEncoder := yaml.NewEncoder(file)
|
||||
yamlEncoder.SetIndent(2) // 设置缩进
|
||||
|
||||
// 将配置转换为 map 并保存
|
||||
configMap := newConfig.ToMap()
|
||||
if err := yamlEncoder.Encode(configMap); err != nil {
|
||||
// 如果保存失败,恢复备份
|
||||
if _, err := os.Stat(backupFile); err == nil {
|
||||
os.Rename(backupFile, "config.yaml")
|
||||
}
|
||||
return fmt.Errorf("save config failed: %w", err)
|
||||
}
|
||||
|
||||
// 更新内存中的配置
|
||||
cfg = newConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToMap 将配置转换为 map
|
||||
func (c *Config) ToMap() map[string]interface{} {
|
||||
ais := make([]map[string]interface{}, len(c.AIs))
|
||||
for i, ai := range c.AIs {
|
||||
ais[i] = map[string]interface{}{
|
||||
"name": ai.Name,
|
||||
"type": ai.Type,
|
||||
"api_key": ai.APIKey,
|
||||
"url": ai.APIBase,
|
||||
"model": ai.Model,
|
||||
"system_msg": ai.SystemMsg,
|
||||
"temperature": ai.Temperature,
|
||||
"stream": ai.Stream,
|
||||
"weight": ai.Weight,
|
||||
"priority": ai.Priority,
|
||||
"enabled": ai.Enabled,
|
||||
"auto_disable": ai.AutoDisable,
|
||||
"max_failures": ai.MaxFailures,
|
||||
"reset_after": ai.ResetAfter,
|
||||
}
|
||||
}
|
||||
|
||||
// 修改 Git 平台配置的转换逻辑
|
||||
platforms := make([]map[string]interface{}, len(c.Git))
|
||||
for i, platform := range c.Git {
|
||||
platforms[i] = map[string]interface{}{
|
||||
"name": platform.Name,
|
||||
"type": platform.Type,
|
||||
"token": platform.Token,
|
||||
"webhook_secret": platform.Secret,
|
||||
"api_base": platform.APIBase,
|
||||
"signature_header": platform.SignatureHeader,
|
||||
"event_header": platform.EventHeader,
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"port": c.Port,
|
||||
"admin_token": c.AdminToken,
|
||||
"ais": ais,
|
||||
"git": platforms,
|
||||
"auto_disable": map[string]interface{}{
|
||||
"enabled": c.AutoDisableConfig.Enabled,
|
||||
"max_failures": c.AutoDisableConfig.MaxFailures,
|
||||
"reset_after": c.AutoDisableConfig.ResetAfter,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToMapHtml 将配置转换为 map 页面展示
|
||||
func (c *Config) ToMapHtml() map[string]interface{} {
|
||||
ais := make([]map[string]interface{}, len(c.AIs))
|
||||
for i, ai := range c.AIs {
|
||||
ais[i] = map[string]interface{}{
|
||||
"name": ai.Name,
|
||||
"type": ai.Type,
|
||||
"api_key": ai.APIKey,
|
||||
"url": ai.APIBase,
|
||||
"model": ai.Model,
|
||||
"system_msg": ai.SystemMsg,
|
||||
"temperature": ai.Temperature,
|
||||
"stream": ai.Stream,
|
||||
"weight": ai.Weight,
|
||||
"priority": ai.Priority,
|
||||
"enabled": ai.Enabled,
|
||||
"auto_disable": ai.AutoDisable,
|
||||
"max_failures": ai.MaxFailures,
|
||||
"reset_after": ai.ResetAfter,
|
||||
}
|
||||
}
|
||||
|
||||
// 修改 Git 平台配置的转换逻辑
|
||||
platforms := make([]map[string]interface{}, len(c.Git))
|
||||
for i, platform := range c.Git {
|
||||
platforms[i] = map[string]interface{}{
|
||||
"name": platform.Name,
|
||||
"type": platform.Type,
|
||||
"token": platform.Token,
|
||||
"webhook_secret": platform.Secret,
|
||||
"api_base": platform.APIBase,
|
||||
"signature_header": platform.SignatureHeader,
|
||||
"event_header": platform.EventHeader,
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"ais": ais,
|
||||
"git": platforms,
|
||||
"auto_disable": map[string]interface{}{
|
||||
"enabled": c.AutoDisableConfig.Enabled,
|
||||
"max_failures": c.AutoDisableConfig.MaxFailures,
|
||||
"reset_after": c.AutoDisableConfig.ResetAfter,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(configFile string) error {
|
||||
log.Printf("加载配置文件: file=%s", configFile)
|
||||
|
||||
viper.SetConfigFile(configFile)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
log.Printf("读取配置文件失败: file=%s, error=%v", configFile, err)
|
||||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err := viper.Unmarshal(&cfg); err != nil {
|
||||
log.Printf("解析配置文件失败: file=%s, error=%v", configFile, err)
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("配置文件加载成功: file=%s", configFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 添加新的方法
|
||||
func NewAIBalancer(ais []AIConfig) *AIBalancer {
|
||||
// 过滤出已启用的 AI
|
||||
enabledAIs := make([]AIConfig, 0)
|
||||
for _, ai := range ais {
|
||||
if ai.Enabled {
|
||||
enabledAIs = append(enabledAIs, ai)
|
||||
}
|
||||
}
|
||||
|
||||
// 按优先级排序(从高到低)
|
||||
sort.Slice(enabledAIs, func(i, j int) bool {
|
||||
return enabledAIs[i].Priority > enabledAIs[j].Priority
|
||||
})
|
||||
|
||||
return &AIBalancer{
|
||||
ais: enabledAIs,
|
||||
current: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取下一个可用的 AI 配置
|
||||
func (b *AIBalancer) Next() (*AIConfig, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if len(b.ais) == 0 {
|
||||
return nil, fmt.Errorf("没有可用的 AI 配置")
|
||||
}
|
||||
|
||||
// 获取最高优先级
|
||||
highestPriority := b.ais[0].Priority
|
||||
|
||||
// 收集具有最高优先级的 AI
|
||||
highPriorityAIs := make([]AIConfig, 0)
|
||||
totalWeight := 0
|
||||
for _, ai := range b.ais {
|
||||
if ai.Priority == highestPriority {
|
||||
highPriorityAIs = append(highPriorityAIs, ai)
|
||||
totalWeight += ai.Weight
|
||||
}
|
||||
}
|
||||
|
||||
// 在最高优先级的 AI 中按权重选择
|
||||
target := rand.Intn(totalWeight)
|
||||
currentWeight := 0
|
||||
for i, ai := range highPriorityAIs {
|
||||
currentWeight += ai.Weight
|
||||
if currentWeight > target {
|
||||
return &highPriorityAIs[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有选中任何一个(不应该发生),返回第一个最高优先级的 AI
|
||||
return &highPriorityAIs[0], nil
|
||||
}
|
||||
|
||||
// 添加获取 AIBalancer 的方法
|
||||
func GetAIBalancer() *AIBalancer {
|
||||
return aiBalancer
|
||||
}
|
Reference in New Issue
Block a user