上传文件至 /
This commit is contained in:
20
config.ini
Normal file
20
config.ini
Normal file
@ -0,0 +1,20 @@
|
||||
[API]
|
||||
gemini_api_base_url = https://g.shatang.me/v1beta
|
||||
gemini_api_key =
|
||||
default_gemini_model = models/gemini-2.0-flash
|
||||
memory_update_model = models/gemini-2.0-flash
|
||||
|
||||
[Application]
|
||||
context_window_size = 5
|
||||
memory_retention_turns = 10
|
||||
max_short_term_events = 15
|
||||
features_dir = data/features
|
||||
memories_dir = data/memories
|
||||
chat_logs_dir = data/chat_logs
|
||||
app_log_file_path = logs/app.log
|
||||
|
||||
[Session]
|
||||
current_role_id = default_role
|
||||
current_memory_id = default_memory
|
||||
current_chat_log_id = default_role_default_memory_ce76c8a8-4072-44f4-8f1b-d942166b62e1
|
||||
|
168
config.py
Normal file
168
config.py
Normal file
@ -0,0 +1,168 @@
|
||||
# config.py
|
||||
|
||||
import configparser
|
||||
import os
|
||||
import logging
|
||||
|
||||
# 获取根日志器,确保日志配置已在utils.py中完成
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_FILE = "config.ini"
|
||||
|
||||
# 定义默认配置,如果config.ini不存在或缺失,将使用这些值
|
||||
DEFAULT_CONFIG = {
|
||||
"API": {
|
||||
"GEMINI_API_BASE_URL": "https://generativelanguage.googleapis.com/v1beta",
|
||||
"GEMINI_API_KEY": "YOUR_GEMINI_API_KEY_HERE", # 默认值,强烈建议用户手动修改
|
||||
"DEFAULT_GEMINI_MODEL": "gemini-pro",
|
||||
"MEMORY_UPDATE_MODEL": "gemini-pro",
|
||||
},
|
||||
"Application": {
|
||||
"CONTEXT_WINDOW_SIZE": "6",
|
||||
"MEMORY_RETENTION_TURNS": "18",
|
||||
"MAX_SHORT_TERM_EVENTS": "20",
|
||||
"FEATURES_DIR": "data/features",
|
||||
"MEMORIES_DIR": "data/memories",
|
||||
"CHAT_LOGS_DIR": "data/chat_logs",
|
||||
"APP_LOG_FILE_PATH": "logs/app.log",
|
||||
},
|
||||
"Session": {
|
||||
"CURRENT_ROLE_ID": "default_role",
|
||||
"CURRENT_MEMORY_ID": "default_memory",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_config():
|
||||
"""
|
||||
从 config.ini 文件加载配置。
|
||||
如果文件不存在或某些配置项缺失,则使用默认值并尝试创建/更新文件。
|
||||
:return: 包含所有配置的字典。
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
# 检查config.ini是否存在,如果不存在则创建并写入默认配置
|
||||
if not os.path.exists(CONFIG_FILE):
|
||||
logger.info(f"config.ini 文件不存在,正在创建默认配置文件:{CONFIG_FILE}")
|
||||
for section, options in DEFAULT_CONFIG.items():
|
||||
config.add_section(section)
|
||||
for key, value in options.items():
|
||||
config.set(section, key, value)
|
||||
try:
|
||||
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
logger.info("默认 config.ini 文件创建成功。")
|
||||
except IOError as e:
|
||||
logger.error(f"无法写入 config.ini 文件: {e}")
|
||||
# 即使写入失败,也继续使用内存中的默认配置
|
||||
else:
|
||||
logger.info(f"从 {CONFIG_FILE} 加载配置...")
|
||||
|
||||
# 读取配置
|
||||
try:
|
||||
config.read(CONFIG_FILE, encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"读取 config.ini 文件时发生错误: {e},将使用默认配置。")
|
||||
# 如果读取失败,清空 configparser 对象,让后续逻辑使用默认值
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
# 将配置转换为字典格式,并确保所有默认值都存在
|
||||
app_config = {}
|
||||
for section, options in DEFAULT_CONFIG.items():
|
||||
app_config[section] = {}
|
||||
for key, default_value in options.items():
|
||||
# 使用 get() 方法获取值,如果不存在则使用默认值
|
||||
# 对于整型或布尔型配置,需要手动转换
|
||||
value = config.get(section, key, fallback=default_value)
|
||||
|
||||
# 尝试将特定配置项转换为数字类型
|
||||
if section == "Application" and key in [
|
||||
"CONTEXT_WINDOW_SIZE",
|
||||
"MEMORY_RETENTION_TURNS",
|
||||
"MAX_SHORT_TERM_EVENTS",
|
||||
]:
|
||||
try:
|
||||
app_config[section][key] = int(value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"配置项 {section}.{key} 的值 '{value}' 不是有效数字,使用默认值 {default_value}。"
|
||||
)
|
||||
app_config[section][key] = int(default_value)
|
||||
elif section == "API" and key == "GEMINI_API_KEY":
|
||||
# 优先从环境变量读取 GEMINI_API_KEY
|
||||
env_key = os.getenv("GEMINI_API_KEY")
|
||||
if env_key:
|
||||
app_config[section][key] = env_key
|
||||
logger.info("已从环境变量加载 GEMINI_API_KEY。")
|
||||
else:
|
||||
app_config[section][key] = value
|
||||
else:
|
||||
app_config[section][key] = value
|
||||
|
||||
# 如果存在config.ini但缺少某些默认值,更新它
|
||||
try:
|
||||
updated = False
|
||||
for section, options in DEFAULT_CONFIG.items():
|
||||
if not config.has_section(section):
|
||||
config.add_section(section)
|
||||
updated = True
|
||||
for key, default_value in options.items():
|
||||
if not config.has_option(section, key):
|
||||
config.set(section, key, str(default_value)) # 确保写入的是字符串
|
||||
updated = True
|
||||
if updated:
|
||||
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
logger.info("config.ini 文件已更新,补齐缺失配置项。")
|
||||
except IOError as e:
|
||||
logger.error(f"无法更新 config.ini 文件以补齐缺失配置项: {e}")
|
||||
|
||||
logger.info("配置加载完成。")
|
||||
return app_config
|
||||
|
||||
|
||||
def save_config(new_config):
|
||||
"""
|
||||
将给定的配置字典保存到 config.ini 文件。
|
||||
它会读取现有配置,然后只更新 new_config 中提供的部分。
|
||||
:param new_config: 包含要保存的所有配置的字典。
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
# 首先读取现有配置,以保留未在 new_config 中提供的项
|
||||
if os.path.exists(CONFIG_FILE):
|
||||
try:
|
||||
config.read(CONFIG_FILE, encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"读取现有 config.ini 文件时发生错误: {e},将只保存新配置。")
|
||||
|
||||
for section, options in new_config.items():
|
||||
if not config.has_section(section):
|
||||
config.add_section(section)
|
||||
for key, value in options.items():
|
||||
# 确保将值转换为字符串,因为 configparser 存储的是字符串
|
||||
config.set(section, key, str(value))
|
||||
|
||||
try:
|
||||
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
|
||||
config.write(f)
|
||||
logger.info(f"配置已成功保存到 {CONFIG_FILE}")
|
||||
except IOError as e:
|
||||
logger.error(f"保存配置到 {CONFIG_FILE} 失败: {e}")
|
||||
|
||||
|
||||
# 全局变量,存储当前加载的配置
|
||||
app_config = load_config()
|
||||
|
||||
|
||||
def get_config():
|
||||
"""获取当前加载的配置"""
|
||||
return app_config
|
||||
|
||||
|
||||
def set_config(new_config):
|
||||
"""更新内存中的配置并保存到文件"""
|
||||
global app_config
|
||||
app_config = new_config
|
||||
save_config(app_config)
|
||||
logger.info("内存配置已更新并保存。")
|
432
file_manager.py
Normal file
432
file_manager.py
Normal file
@ -0,0 +1,432 @@
|
||||
# file_manager.py
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import shutil # 用于删除目录及其内容
|
||||
import logging
|
||||
from config import get_config # 导入config模块的获取配置函数
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- 辅助函数:JSON 和 JSONL 文件操作 ---
|
||||
|
||||
|
||||
def _load_json_file(file_path, default_content=None):
|
||||
"""
|
||||
加载JSON文件。如果文件不存在,则创建并写入默认内容(如果提供)。
|
||||
:param file_path: JSON文件路径。
|
||||
:param default_content: 文件不存在时要写入的默认Python对象。
|
||||
:return: 文件内容(Python对象)或None(如果加载失败)。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在
|
||||
if default_content is not None:
|
||||
logger.info(f"文件不存在,创建并写入默认内容到:{file_path}")
|
||||
_save_json_file(file_path, default_content)
|
||||
else:
|
||||
logger.warning(f"尝试加载的JSON文件不存在且未提供默认内容:{file_path}")
|
||||
return {} # 返回空字典,以提高健壮性
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON文件解码失败:{file_path} - {e}")
|
||||
return None
|
||||
except IOError as e:
|
||||
logger.error(f"读取JSON文件失败:{file_path} - {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _save_json_file(file_path, data):
|
||||
"""
|
||||
保存Python对象到JSON文件。
|
||||
:param file_path: JSON文件路径。
|
||||
:param data: 要保存的Python对象。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"JSON文件已保存:{file_path}")
|
||||
return True
|
||||
except IOError as e:
|
||||
logger.error(f"保存JSON文件失败:{file_path} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _append_jsonl_file(file_path, data):
|
||||
"""
|
||||
追加一条JSON记录到JSONL文件。
|
||||
:param file_path: JSONL文件路径。
|
||||
:param data: 要追加的Python对象。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在
|
||||
try:
|
||||
with open(file_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(data, ensure_ascii=False) + "\n")
|
||||
logger.debug(f"JSONL记录已追加到:{file_path}")
|
||||
return True
|
||||
except IOError as e:
|
||||
logger.error(f"追加JSONL文件失败:{file_path} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _read_jsonl_file(file_path, limit=None):
|
||||
"""
|
||||
读取JSONL文件的所有记录。
|
||||
:param file_path: JSONL文件路径。
|
||||
:param limit: 可选,限制读取的最后N条记录。
|
||||
:return: 包含所有记录的列表。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
return []
|
||||
|
||||
records = []
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
if limit:
|
||||
lines = lines[-limit:] # 只读取最后N行
|
||||
for line in lines:
|
||||
try:
|
||||
records.append(json.loads(line.strip()))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"跳过JSONL文件中的损坏行:{file_path} - {line.strip()} - {e}"
|
||||
)
|
||||
return records
|
||||
except IOError as e:
|
||||
logger.error(f"读取JSONL文件失败:{file_path} - {e}")
|
||||
return []
|
||||
|
||||
|
||||
# --- 默认文件内容 ---
|
||||
|
||||
# 默认特征文件内容
|
||||
DEFAULT_FEATURE_CONTENT = {
|
||||
"角色名称": "通用AI助手",
|
||||
"角色描述": "你是一个有帮助的AI助手,旨在提供清晰、准确、友好的回答。你没有预设的特定角色或个性,能够灵活适应各种对话主题。",
|
||||
"心理特征": {
|
||||
"核心哲学与世界观": {"描述": "以提供信息、解决问题为核心,保持中立客观。"},
|
||||
"价值观与道德观": {"描述": "遵循伦理原则,避免偏见,确保信息准确。"},
|
||||
"决策风格": {"描述": "基于逻辑和可用信息进行决策。"},
|
||||
"情绪反应与表达": {"描述": "保持平静、专业、不带个人情感。"},
|
||||
"人际互动与关系处理": {"描述": "以帮助用户为目标,保持礼貌和尊重。"},
|
||||
"动机与目标": {"描述": "提供有价值的信息和解决方案。"},
|
||||
"自我认知": {"描述": "认知自己是AI,不具备人类情感和意识。"},
|
||||
"应对机制": {"描述": "遇到未知问题时,会请求更多信息或承认能力限制。"},
|
||||
},
|
||||
"语言特征": {
|
||||
"词汇与措辞": {"描述": "清晰、简洁、准确,避免模糊不清的表达。"},
|
||||
"句式结构与复杂度": {"描述": "常用清晰直接的句式,逻辑性强。"},
|
||||
"语气与风格": {"描述": "专业、友善、乐于助人。"},
|
||||
"修辞手法与模式": {"描述": "多使用直接陈述、解释说明。"},
|
||||
"互动模式": {"描述": "响应式,根据用户输入提供信息或引导对话。"},
|
||||
},
|
||||
"重要人际关系": [],
|
||||
"角色弧线总结": "作为一个持续学习和改进的AI,不断提升服务质量。",
|
||||
}
|
||||
|
||||
# 默认记忆文件内容
|
||||
DEFAULT_MEMORY_CONTENT = {
|
||||
"long_term_facts": ["此记忆集为默认记忆,未包含特定用户或AI的长期事实。"],
|
||||
"short_term_events": [],
|
||||
}
|
||||
|
||||
# --- 路径管理函数 ---
|
||||
|
||||
|
||||
def _get_dir_path(base_dir_config_key):
|
||||
"""根据配置获取目录路径"""
|
||||
config = get_config()
|
||||
return config["Application"][base_dir_config_key]
|
||||
|
||||
|
||||
def get_features_dir():
|
||||
"""获取特征文件根目录"""
|
||||
return _get_dir_path("FEATURES_DIR")
|
||||
|
||||
|
||||
def get_memories_dir():
|
||||
"""获取记忆文件根目录"""
|
||||
return _get_dir_path("MEMORIES_DIR")
|
||||
|
||||
|
||||
def get_chat_logs_dir():
|
||||
"""获取聊天记录文件根目录"""
|
||||
return _get_dir_path("CHAT_LOGS_DIR")
|
||||
|
||||
|
||||
def get_role_features_file_path(role_id):
|
||||
"""获取指定角色ID的特征文件路径"""
|
||||
return os.path.join(get_features_dir(), role_id, "features.json")
|
||||
|
||||
|
||||
def get_memory_data_file_path(memory_id):
|
||||
"""获取指定记忆ID的记忆数据文件路径"""
|
||||
return os.path.join(get_memories_dir(), memory_id, "memory.json")
|
||||
|
||||
|
||||
def get_chat_log_file_path(chat_log_id):
|
||||
"""获取指定聊天记录ID的聊天记录文件路径"""
|
||||
# 聊天记录ID通常与角色ID和记忆ID关联,或者直接就是session ID
|
||||
# 这里简化为直接以chat_log_id为文件名
|
||||
return os.path.join(get_chat_logs_dir(), f"{chat_log_id}.jsonl")
|
||||
|
||||
|
||||
# --- 角色和记忆管理 ---
|
||||
|
||||
|
||||
def list_roles():
|
||||
"""
|
||||
列出所有可用的角色ID和名称。
|
||||
角色名称从其features.json中获取。
|
||||
:return: 列表,每个元素为 {"id": "role_id", "name": "角色名称"}
|
||||
"""
|
||||
roles = []
|
||||
features_dir = get_features_dir()
|
||||
if not os.path.exists(features_dir):
|
||||
os.makedirs(features_dir) # 确保目录存在
|
||||
return []
|
||||
|
||||
for role_id in os.listdir(features_dir):
|
||||
role_path = os.path.join(features_dir, role_id)
|
||||
features_file = get_role_features_file_path(role_id)
|
||||
if os.path.isdir(role_path) and os.path.exists(features_file):
|
||||
features = _load_json_file(features_file)
|
||||
if features:
|
||||
roles.append({"id": role_id, "name": features.get("角色名称", role_id)})
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法加载角色 {role_id} 的 features.json,可能文件损坏。"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"目录 {role_path} 或文件 {features_file} 不完整,跳过。")
|
||||
return roles
|
||||
|
||||
|
||||
def create_role(role_id, role_name=None):
|
||||
"""
|
||||
创建一个新的角色。
|
||||
:param role_id: 新角色的唯一ID。
|
||||
:param role_name: 新角色的名称,如果未提供,则默认为role_id。
|
||||
:return: True如果成功,False如果失败或ID已存在。
|
||||
"""
|
||||
role_path = os.path.join(get_features_dir(), role_id)
|
||||
if os.path.exists(role_path):
|
||||
logger.warning(f"角色ID '{role_id}' 已存在,无法创建。")
|
||||
return False
|
||||
|
||||
os.makedirs(role_path)
|
||||
features_file = get_role_features_file_path(role_id)
|
||||
|
||||
default_content = DEFAULT_FEATURE_CONTENT.copy()
|
||||
default_content["角色名称"] = role_name if role_name else role_id
|
||||
|
||||
success = _save_json_file(features_file, default_content)
|
||||
if success:
|
||||
logger.info(f"角色 '{role_id}' 创建成功。")
|
||||
return success
|
||||
|
||||
|
||||
def delete_role(role_id):
|
||||
"""
|
||||
删除一个角色及其所有相关文件。
|
||||
:param role_id: 要删除的角色ID。
|
||||
:return: True如果成功,False如果失败。
|
||||
"""
|
||||
role_path = os.path.join(get_features_dir(), role_id)
|
||||
if not os.path.exists(role_path):
|
||||
logger.warning(f"角色ID '{role_id}' 不存在,无法删除。")
|
||||
return False
|
||||
|
||||
try:
|
||||
shutil.rmtree(role_path)
|
||||
logger.info(f"角色 '{role_id}' 已成功删除。")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"删除角色 '{role_id}' 失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def role_exists(role_id):
|
||||
"""
|
||||
检查指定角色ID的角色是否存在。
|
||||
:param role_id: 角色ID。
|
||||
:return: True如果存在,False如果不存在。
|
||||
"""
|
||||
return os.path.exists(get_role_features_file_path(role_id))
|
||||
|
||||
|
||||
def list_memories():
|
||||
"""
|
||||
列出所有可用的记忆ID和名称。
|
||||
记忆名称从其memory.json中获取(如果模型能生成name字段,否则使用ID)。
|
||||
:return: 列表,每个元素为 {"id": "memory_id", "name": "记忆名称"}
|
||||
"""
|
||||
memories = []
|
||||
memories_dir = get_memories_dir()
|
||||
if not os.path.exists(memories_dir):
|
||||
os.makedirs(memories_dir) # 确保目录存在
|
||||
return []
|
||||
|
||||
for memory_id in os.listdir(memories_dir):
|
||||
memory_path = os.path.join(memories_dir, memory_id)
|
||||
memory_file = get_memory_data_file_path(memory_id)
|
||||
if os.path.isdir(memory_path) and os.path.exists(memory_file):
|
||||
memory_data = _load_json_file(memory_file)
|
||||
# 确保 memory_data 不是 None 且不是空字典
|
||||
if memory_data is not None and memory_data != {}:
|
||||
# 假设 memory.json 中可能有一个 'name' 字段来标识记忆集
|
||||
memories.append(
|
||||
{"id": memory_id, "name": memory_data.get("name", memory_id)}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"无法加载记忆 {memory_id} 的 memory.json,可能文件损或为空,或内容为空。"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"目录 {memory_path} 或文件 {memory_file} 不完整,跳过。")
|
||||
return memories
|
||||
|
||||
|
||||
def create_memory(memory_id, memory_name=None):
|
||||
"""
|
||||
创建一个新的记忆集。
|
||||
:param memory_id: 新记忆集的唯一ID。
|
||||
:param memory_name: 新记忆集的名称,如果未提供,则默认为memory_id。
|
||||
:return: True如果成功,False如果失败或ID已存在。
|
||||
"""
|
||||
memory_path = os.path.join(get_memories_dir(), memory_id)
|
||||
if os.path.exists(memory_path):
|
||||
logger.warning(f"记忆ID '{memory_id}' 已存在,无法创建。")
|
||||
return False
|
||||
|
||||
os.makedirs(memory_path)
|
||||
memory_file = get_memory_data_file_path(memory_id)
|
||||
|
||||
default_content = DEFAULT_MEMORY_CONTENT.copy()
|
||||
if memory_name:
|
||||
default_content["name"] = memory_name # 记忆集本身可以有一个名称
|
||||
# 更新 long_term_facts 中的默认描述,包含记忆名称
|
||||
if default_content["long_term_facts"]:
|
||||
default_content["long_term_facts"][
|
||||
0
|
||||
] = f"此记忆集名为 '{memory_name}',目前未包含特定用户或AI的长期事实。"
|
||||
else:
|
||||
default_content["long_term_facts"].append(
|
||||
f"此记忆集名为 '{memory_name}',目前未包含特定用户或AI的长期事实。"
|
||||
)
|
||||
|
||||
success = _save_json_file(memory_file, default_content)
|
||||
if success:
|
||||
logger.info(f"记忆集 '{memory_id}' 创建成功。")
|
||||
return success
|
||||
|
||||
|
||||
def delete_memory(memory_id):
|
||||
"""
|
||||
删除一个记忆集及其所有相关文件。
|
||||
:param memory_id: 要删除的记忆ID。
|
||||
:return: True如果成功,False如果失败。
|
||||
"""
|
||||
memory_path = os.path.join(get_memories_dir(), memory_id)
|
||||
if not os.path.exists(memory_path):
|
||||
logger.warning(f"记忆ID '{memory_id}' 不存在,无法删除。")
|
||||
return False
|
||||
|
||||
try:
|
||||
shutil.rmtree(memory_path)
|
||||
logger.info(f"记忆集 '{memory_id}' 已成功删除。")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"删除记忆集 '{memory_id}' 失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def memory_exists(memory_id):
|
||||
"""
|
||||
检查指定记忆ID的记忆集是否存在。
|
||||
:param memory_id: 记忆ID。
|
||||
:return: True如果存在,False如果不存在。
|
||||
"""
|
||||
return os.path.exists(get_memory_data_file_path(memory_id))
|
||||
|
||||
|
||||
# --- 加载/保存当前活动会话的文件 ---
|
||||
|
||||
|
||||
def load_active_features(role_id):
|
||||
"""加载当前激活角色的特征数据"""
|
||||
file_path = get_role_features_file_path(role_id)
|
||||
return _load_json_file(file_path, DEFAULT_FEATURE_CONTENT.copy())
|
||||
|
||||
|
||||
def save_active_features(role_id, features_data):
|
||||
"""保存当前激活角色的特征数据"""
|
||||
file_path = get_role_features_file_path(role_id)
|
||||
return _save_json_file(file_path, features_data)
|
||||
|
||||
|
||||
def load_active_memory(memory_id):
|
||||
"""加载当前激活记忆集的记忆数据"""
|
||||
file_path = get_memory_data_file_path(memory_id)
|
||||
return _load_json_file(file_path, DEFAULT_MEMORY_CONTENT.copy())
|
||||
|
||||
|
||||
def save_active_memory(memory_id, memory_data):
|
||||
"""保存当前激活记忆集的记忆数据"""
|
||||
file_path = get_memory_data_file_path(memory_id)
|
||||
return _save_json_file(file_path, memory_data)
|
||||
|
||||
|
||||
def append_chat_log(chat_log_id, entry):
|
||||
"""向当前激活聊天记录文件追加条目"""
|
||||
file_path = get_chat_log_file_path(chat_log_id)
|
||||
return _append_jsonl_file(file_path, entry)
|
||||
|
||||
|
||||
def read_chat_log(chat_log_id, limit=None):
|
||||
"""读取当前激活聊天记录的所有条目"""
|
||||
file_path = get_chat_log_file_path(chat_log_id)
|
||||
return _read_jsonl_file(file_path, limit)
|
||||
|
||||
|
||||
def delete_chat_log(chat_log_id):
|
||||
"""删除指定聊天记录文件"""
|
||||
file_path = get_chat_log_file_path(chat_log_id)
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.info(f"聊天记录文件 '{file_path}' 已删除。")
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.error(f"删除聊天记录文件 '{file_path}' 失败: {e}")
|
||||
return False
|
||||
logger.warning(f"聊天记录文件 '{file_path}' 不存在,无需删除。")
|
||||
return False
|
||||
|
||||
|
||||
# --- 确保默认角色和记忆存在 ---
|
||||
|
||||
|
||||
def ensure_initial_data():
|
||||
"""
|
||||
确保 'default_role' 和 'default_memory' 存在,如果它们不存在。
|
||||
这个函数会在应用启动时调用。
|
||||
"""
|
||||
config_data = get_config()["Session"]
|
||||
default_role_id = config_data["CURRENT_ROLE_ID"]
|
||||
default_memory_id = config_data["CURRENT_MEMORY_ID"]
|
||||
|
||||
if not os.path.exists(get_role_features_file_path(default_role_id)):
|
||||
logger.info(f"默认角色 '{default_role_id}' 不存在,正在创建。")
|
||||
create_role(default_role_id, config_data.get("DEFAULT_ROLE_NAME", "通用AI助手"))
|
||||
|
||||
if not os.path.exists(get_memory_data_file_path(default_memory_id)):
|
||||
logger.info(f"默认记忆集 '{default_memory_id}' 不存在,正在创建。")
|
||||
create_memory(
|
||||
default_memory_id, config_data.get("DEFAULT_MEMORY_NAME", "通用记忆集")
|
||||
)
|
494
gemini_client.py
Normal file
494
gemini_client.py
Normal file
@ -0,0 +1,494 @@
|
||||
# gemini_client.py
|
||||
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
from config import get_config # 导入config模块以获取API配置
|
||||
from typing import Union, Tuple, Any, Dict # 新增导入
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""
|
||||
封装与Google Gemini API的交互。
|
||||
支持直接HTTP请求,以便于使用第三方代理或自定义端点。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
base_url_from_config = self.config["API"]["GEMINI_API_BASE_URL"].rstrip("/")
|
||||
# 尝试从配置的URL中提取正确的base_url,确保它以 /v1beta 结尾且不包含 /models
|
||||
if base_url_from_config.endswith("/v1beta"):
|
||||
self.base_url = base_url_from_config
|
||||
elif base_url_from_config.endswith("/models"):
|
||||
# 如果配置的URL以 /models 结尾,则移除 /models 并添加 /v1beta
|
||||
self.base_url = base_url_from_config.rsplit("/models", 1)[0] + "/v1beta"
|
||||
elif base_url_from_config.endswith("/v1"): # 兼容旧的 /v1 路径
|
||||
self.base_url = base_url_from_config.rsplit("/v1", 1)[0] + "/v1beta"
|
||||
else:
|
||||
# 假设是根URL,添加 /v1beta
|
||||
self.base_url = f"{base_url_from_config}/v1beta"
|
||||
|
||||
self.api_key = self.config["API"].get("GEMINI_API_KEY")
|
||||
# 确保API Key是字符串并清理可能的空白字符
|
||||
if self.api_key:
|
||||
self.api_key = str(self.api_key).strip()
|
||||
|
||||
# 设置请求头,包含Content-Type和x-goog-api-key
|
||||
if self.api_key:
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": self.api_key,
|
||||
}
|
||||
else:
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("GEMINI_API_KEY 未配置或为空,Gemini客户端无法初始化。")
|
||||
|
||||
logger.info(f"GeminiClient 初始化完成,基础URL: {self.base_url}")
|
||||
|
||||
def _make_request(
|
||||
self, endpoint: str, model_name: str, payload: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
内部方法:执行实际的HTTP POST请求到Gemini API。
|
||||
:param endpoint: API端点,例如 "generateContent" 或 "countTokens"。
|
||||
:param model_name: 要使用的模型名称,例如 "gemini-pro"。
|
||||
:param payload: 请求体(Python字典),将转换为JSON发送。
|
||||
:return: 响应JSON(Python字典)。
|
||||
:raises requests.exceptions.RequestException: 如果请求失败。
|
||||
"""
|
||||
# 检查模型名称是否已包含"models/"前缀,避免重复
|
||||
if model_name.startswith("models/"):
|
||||
# 如果已包含前缀,直接使用
|
||||
url = f"{self.base_url}/{model_name}:{endpoint}"
|
||||
else:
|
||||
# 如果不包含前缀,添加"models/"
|
||||
url = f"{self.base_url}/models/{model_name}:{endpoint}"
|
||||
|
||||
# 记录请求信息,但隐藏API Key
|
||||
safe_url = url
|
||||
logger.debug(f"发送请求到 Gemini API: {safe_url}")
|
||||
logger.debug(
|
||||
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers, # 使用初始化时设置的headers
|
||||
json=payload,
|
||||
timeout=90,
|
||||
) # 增加超时时间,应对大模型响应慢
|
||||
|
||||
# 先获取响应文本,避免多次访问响应体
|
||||
response_text = response.text
|
||||
|
||||
# 然后检查状态码
|
||||
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
|
||||
|
||||
# 使用已获取的文本解析JSON,而不是调用response.json()
|
||||
response_data = json.loads(response_text)
|
||||
return response_data
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else 500
|
||||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": status_code}
|
||||
except requests.exceptions.ConnectionError as conn_err:
|
||||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 503} # Service Unavailable
|
||||
except requests.exceptions.Timeout as timeout_err:
|
||||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 408} # Request Timeout
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500} # Internal Server Error
|
||||
except Exception as e:
|
||||
error_message = f"非请求库异常: {e}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
|
||||
def _make_stream_request(
|
||||
self, endpoint: str, model_name: str, payload: Dict[str, Any]
|
||||
) -> Union[Any, Dict[str, Any]]: # 返回类型调整为Dict[str, Any]以统一错误返回
|
||||
"""
|
||||
内部方法:执行实际的HTTP POST流式请求到Gemini API。
|
||||
:param endpoint: API端点,例如 "generateContent"。
|
||||
:param model_name: 要使用的模型名称,例如 "gemini-pro"。
|
||||
:param payload: 请求体(Python字典),将转换为JSON发送。
|
||||
:return: 响应迭代器或错误信息字典。
|
||||
"""
|
||||
# 检查模型名称是否已包含"models/"前缀,避免重复
|
||||
if model_name.startswith("models/"):
|
||||
# 如果已包含前缀,直接使用
|
||||
url = f"{self.base_url}/{model_name}:{endpoint}?stream=true"
|
||||
else:
|
||||
# 如果不包含前缀,添加"models/"
|
||||
url = f"{self.base_url}/models/{model_name}:{endpoint}?stream=true"
|
||||
|
||||
# 记录请求信息,但隐藏API Key
|
||||
safe_url = url
|
||||
logger.debug(f"发送流式请求到 Gemini API: {safe_url}")
|
||||
logger.debug(
|
||||
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers, # 使用初始化时设置的headers
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=90,
|
||||
)
|
||||
|
||||
# 检查状态码但不消费响应体
|
||||
response.raise_for_status()
|
||||
|
||||
# 对于流式请求,直接返回解析后的JSON对象迭代器
|
||||
# 这不会立即消费响应体,而是在迭代时逐步消费
|
||||
def parse_stream():
|
||||
logger.info("开始处理流式响应...")
|
||||
full_response = ""
|
||||
accumulated_json = ""
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 将二进制数据解码为UTF-8字符串
|
||||
decoded_line = line.decode("utf-8")
|
||||
logger.debug(f"原始行数据: {decoded_line[:50]}...")
|
||||
|
||||
# 处理SSE格式 (data: 开头)
|
||||
if decoded_line.startswith("data: "):
|
||||
json_str = decoded_line[6:].strip()
|
||||
if not json_str: # 跳过空数据行
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(json_str)
|
||||
logger.debug(
|
||||
f"SSE格式JSON数据,键: {list(chunk_data.keys())}"
|
||||
)
|
||||
|
||||
# 检查是否包含文本内容
|
||||
if (
|
||||
"candidates" in chunk_data
|
||||
and chunk_data["candidates"]
|
||||
and "content" in chunk_data["candidates"][0]
|
||||
):
|
||||
content = chunk_data["candidates"][0]["content"]
|
||||
|
||||
if "parts" in content and content["parts"]:
|
||||
text_part = content["parts"][0].get("text", "")
|
||||
if text_part:
|
||||
logger.debug(
|
||||
f"提取到文本: {text_part[:30]}..."
|
||||
)
|
||||
full_response += text_part
|
||||
yield text_part
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"SSE格式数据解析JSON失败: {e}, 数据: {json_str[:50]}..."
|
||||
)
|
||||
else:
|
||||
# 处理非SSE格式
|
||||
# 尝试积累JSON直到有效
|
||||
accumulated_json += decoded_line
|
||||
try:
|
||||
# 尝试解析完整JSON
|
||||
chunk_data = json.loads(accumulated_json)
|
||||
logger.debug(
|
||||
f"解析完整JSON成功,键: {list(chunk_data.keys())}"
|
||||
)
|
||||
|
||||
# 检查是否包含文本内容
|
||||
if (
|
||||
"candidates" in chunk_data
|
||||
and chunk_data["candidates"]
|
||||
and "content" in chunk_data["candidates"][0]
|
||||
):
|
||||
content = chunk_data["candidates"][0]["content"]
|
||||
|
||||
if "parts" in content and content["parts"]:
|
||||
text_part = content["parts"][0].get("text", "")
|
||||
if text_part:
|
||||
logger.debug(
|
||||
f"从完整JSON提取到文本: {text_part[:30]}..."
|
||||
)
|
||||
full_response += text_part
|
||||
yield text_part
|
||||
|
||||
# 成功解析后重置累积的JSON
|
||||
accumulated_json = ""
|
||||
except json.JSONDecodeError:
|
||||
# 继续积累直到有效JSON
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理流式响应行时出错: {str(e)}, line: {str(line)[:100]}"
|
||||
)
|
||||
|
||||
# 如果有积累的响应但还没有产生任何输出,尝试直接返回
|
||||
if full_response:
|
||||
logger.info(f"流式响应完成,累积了 {len(full_response)} 字符的响应")
|
||||
else:
|
||||
logger.warning("流式响应完成,但没有提取到任何有效文本内容")
|
||||
# 尝试返回一个默认响应,避免空响应
|
||||
yield "抱歉,处理响应时遇到了问题,请重试。"
|
||||
|
||||
return parse_stream()
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else 500
|
||||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": status_code}
|
||||
except requests.exceptions.ConnectionError as conn_err:
|
||||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 503}
|
||||
except requests.exceptions.Timeout as timeout_err:
|
||||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 408}
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
except Exception as e:
|
||||
error_message = f"非请求库异常: {e}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
|
||||
def _make_get_request(self, endpoint: str) -> Dict[str, Any]:
|
||||
"""
|
||||
内部方法:执行实际的HTTP GET请求到Gemini API。
|
||||
:param endpoint: API端点,例如 "models"。
|
||||
:return: 响应JSON(Python字典)。
|
||||
:raises requests.exceptions.RequestException: 如果请求失败。
|
||||
"""
|
||||
if not self.api_key:
|
||||
return {
|
||||
"error": "API Key 未配置,无法发送请求。",
|
||||
"status_code": 401,
|
||||
} # Unauthorized
|
||||
|
||||
# API Key 通过请求头传递,URL中不再包含
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
|
||||
# 记录请求信息,但隐藏API Key
|
||||
safe_url = url
|
||||
logger.debug(f"发送GET请求到 Gemini API: {safe_url}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url, headers=self.headers, timeout=30 # 使用初始化时设置的headers
|
||||
)
|
||||
|
||||
# 先获取响应文本,避免多次访问响应体
|
||||
response_text = response.text
|
||||
|
||||
# 然后检查状态码
|
||||
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
|
||||
|
||||
# 使用已获取的文本解析JSON,而不是调用response.json()
|
||||
response_data = json.loads(response_text)
|
||||
return response_data
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else 500
|
||||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": status_code}
|
||||
except requests.exceptions.ConnectionError as conn_err:
|
||||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 503} # Service Unavailable
|
||||
except requests.exceptions.Timeout as timeout_err:
|
||||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 408} # Request Timeout
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500} # Internal Server Error
|
||||
except Exception as e:
|
||||
error_message = f"非请求库异常: {e}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
model_name: str,
|
||||
contents: list[Dict[str, Any]],
|
||||
system_instruction: str | None = None,
|
||||
stream: bool = False, # 新增 stream 参数
|
||||
) -> Union[str, Tuple[str, int], Any]: # 返回类型调整以适应流式响应
|
||||
"""
|
||||
与Gemini模型进行内容生成对话。
|
||||
:param model_name: 要使用的模型名称。
|
||||
:param contents: 聊天历史和当前用户消息组成的列表,符合Gemini API的"contents"格式。
|
||||
例如:[{"role": "user", "parts": [{"text": "你好"}]}, ...]
|
||||
:param system_instruction: 系统的指示语,作为单独的参数传入。
|
||||
在Gemini API中,它是请求体的一个顶级字段。
|
||||
:param stream: 是否以流式方式获取响应。
|
||||
:return: 如果 stream 为 False,返回模型回复文本或错误信息。
|
||||
如果 stream 为 True,返回一个生成器,每次 yield 一个文本片段。
|
||||
"""
|
||||
# 根据Gemini API的预期格式构造payload
|
||||
# 如果contents是空的,我们构造一个带有role的内容数组
|
||||
if not contents:
|
||||
if system_instruction:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": system_instruction}]}
|
||||
]
|
||||
}
|
||||
else:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [{"role": "user", "parts": [{"text": ""}]}]
|
||||
}
|
||||
else:
|
||||
# 使用提供的contents
|
||||
payload: Dict[str, Any] = {"contents": contents}
|
||||
|
||||
# 如果有系统指令,我们添加到请求中
|
||||
if system_instruction:
|
||||
# Gemini API的systemInstruction是一个顶级字段,值是一个Content对象
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||||
|
||||
if stream:
|
||||
response_iterator = self._make_stream_request(
|
||||
"generateContent", model_name, payload
|
||||
)
|
||||
if isinstance(response_iterator, dict) and "error" in response_iterator:
|
||||
# 如果流式请求本身返回错误,则直接返回错误信息
|
||||
return str(response_iterator["error"]), int(
|
||||
response_iterator.get("status_code", 500)
|
||||
)
|
||||
|
||||
def stream_generator():
|
||||
for text_chunk in response_iterator:
|
||||
# _make_stream_request 已经处理了JSON解析并提取了文本
|
||||
# 所以这里直接yield文本块即可
|
||||
if text_chunk:
|
||||
yield text_chunk
|
||||
logger.debug("流式响应结束。")
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response_json = self._make_request("generateContent", model_name, payload)
|
||||
|
||||
if "error" in response_json:
|
||||
# 如果 _make_request 返回了错误,直接返回错误信息和状态码
|
||||
return str(response_json["error"]), int(
|
||||
response_json.get("status_code", 500)
|
||||
)
|
||||
|
||||
if "candidates" in response_json and len(response_json["candidates"]) > 0:
|
||||
first_candidate = response_json["candidates"][0]
|
||||
if (
|
||||
"content" in first_candidate
|
||||
and "parts" in first_candidate["content"]
|
||||
and len(first_candidate["content"]["parts"]) > 0
|
||||
):
|
||||
# 返回第一个part的文本内容
|
||||
return first_candidate["content"]["parts"][0]["text"]
|
||||
|
||||
# 如果响应中包含拦截信息
|
||||
if (
|
||||
"promptFeedback" in response_json
|
||||
and "blockReason" in response_json["promptFeedback"]
|
||||
):
|
||||
reason = response_json["promptFeedback"]["blockReason"]
|
||||
logger.warning(f"Gemini API 阻止了响应,原因: {reason}")
|
||||
return (
|
||||
f"抱歉,你的请求被模型策略阻止了。原因: {reason}",
|
||||
400,
|
||||
) # 400 Bad Request for policy block
|
||||
|
||||
# 如果没有candidates且没有拦截信息,可能存在未知响应结构
|
||||
logger.warning(
|
||||
f"Gemini API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
|
||||
)
|
||||
return "抱歉,AI未能生成有效回复,请稍后再试或检查日志。", 500
|
||||
|
||||
def count_tokens(
|
||||
self,
|
||||
model_name: str,
|
||||
contents: list[Dict[str, Any]],
|
||||
system_instruction: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计给定内容在特定模型中的token数量。
|
||||
:param model_name: 要使用的模型名称。
|
||||
:param contents: 要计算token的内容。
|
||||
:param system_instruction: 系统指令。
|
||||
:return: token数量(整数)或-1(如果发生错误)。
|
||||
"""
|
||||
# 根据Gemini API的预期格式构造payload,确保包含role字段
|
||||
if not contents:
|
||||
if system_instruction:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": system_instruction}]}
|
||||
]
|
||||
}
|
||||
else:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [{"role": "user", "parts": [{"text": ""}]}]
|
||||
}
|
||||
else:
|
||||
payload: Dict[str, Any] = {"contents": contents} # 明确声明payload类型
|
||||
if system_instruction:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||||
|
||||
response_json = self._make_request("countTokens", model_name, payload)
|
||||
|
||||
if "error" in response_json:
|
||||
return -1 # 发生错误时返回-1
|
||||
|
||||
if "totalTokens" in response_json:
|
||||
return response_json["totalTokens"]
|
||||
logger.warning(
|
||||
f"countTokens API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
|
||||
)
|
||||
return -1
|
||||
|
||||
def get_models(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取Gemini API可用的模型列表。
|
||||
:return: 包含模型名称的列表,如果发生错误则为空列表。
|
||||
"""
|
||||
logger.info("从 Gemini API 获取模型列表。")
|
||||
response_json = self._make_get_request("models")
|
||||
|
||||
if "error" in response_json:
|
||||
logger.error(f"获取模型列表失败: {response_json['error']}")
|
||||
return {
|
||||
"models": [],
|
||||
"error": response_json["error"],
|
||||
"status_code": response_json["status_code"],
|
||||
}
|
||||
|
||||
available_models = []
|
||||
for m in response_json.get("models", []):
|
||||
# 过滤出支持 'generateContent' 方法的模型,并去除被阻止的模型
|
||||
# 'SUPPORTED' 状态表示模型可用
|
||||
if (
|
||||
"name" in m
|
||||
and "supportedGenerationMethods" in m
|
||||
and "generateContent" in m["supportedGenerationMethods"]
|
||||
and m.get("lifecycleState", "UNSPECIFIED") == "ACTIVE"
|
||||
): # 确保模型是活跃的
|
||||
model_name = m["name"].split("/")[-1] # 提取 'gemini-pro' 部分
|
||||
available_models.append(model_name)
|
||||
logger.info(f"成功获取到 {len(available_models)} 个可用模型。")
|
||||
return {"models": available_models}
|
Reference in New Issue
Block a user