commit ffdeafa791f2f543a0c62cde143ae646df7697a9 Author: 用户已注销 Date: Fri Jun 6 16:41:50 2025 +0800 上传文件至 / diff --git a/app.py b/app.py new file mode 100644 index 0000000..a1bc55a --- /dev/null +++ b/app.py @@ -0,0 +1,1032 @@ +# app.py + +from flask import Flask, jsonify, request, send_from_directory, Response +from flask_cors import CORS # 导入 CORS +import os +import logging +import json +import threading +import time +import requests # 导入 requests 库 +from flask import g # 导入g对象,用于存储请求上下文中的数据 + +# 导入自定义模块 +from utils import ( + setup_logging, + get_current_timestamp, + generate_uuid, +) # 工具函数,如日志设置、时间戳、UUID生成 +from config import get_config, set_config # 配置管理,用于读取和写入config.ini +import file_manager # 文件管理模块,处理角色、记忆、聊天记录等文件的读写 +from gemini_client import GeminiClient # 导入Gemini客户端,用于与Gemini API交互 +import memory_manager # 导入记忆管理模块,负责构建Prompt和处理记忆更新逻辑 + +# --- Flask 应用初始化 --- + +# app_config 声明为全局变量,用于存储应用配置 +app_config = {} +# 日志配置将在 initialize_app_data 中进行 +logger = logging.getLogger(__name__) # 获取app.py的日志器 + +app = Flask(__name__, static_folder="static") +CORS(app) # 初始化 CORS,允许所有来源的跨域请求 + +# 全局会话状态变量 +global_current_role_id: str = "default_role" +global_current_memory_id: str = "default_memory" +global_current_chat_log_id: str = "" +global_turn_counter: int = 0 +global_current_chat_history: list = [] +global_memory_update_status: str = "idle" # 新增:记忆更新状态 + +# 初始化Gemini客户端 (全局唯一实例) +gemini_client: GeminiClient | None = None + + +# --- 辅助函数 --- +def error_response(message: str, status_code: int) -> Response: + """ + 标准化错误响应格式。 + :param message: 错误信息。 + :param status_code: HTTP状态码。 + :return: JSON格式的错误响应。 + """ + response = jsonify({"error": message}) + response.status_code = status_code + return response + + +@app.before_request +def load_session_context(): + """ + 在每个请求之前加载会话上下文到 Flask 的 g 对象。 + 这确保了每个请求都有独立的会话状态。 + """ + # 直接使用全局变量来设置 g 对象,确保 g 对象中的会话信息与全局变量保持同步 + g.current_role_id = global_current_role_id + g.current_memory_id = global_current_memory_id + g.current_chat_log_id = global_current_chat_log_id + g.turn_counter = global_turn_counter # g.turn_counter 应该反映当前的全局轮次 + + logger.debug( + f"请求上下文加载:角色ID={g.current_role_id}, 记忆ID={g.current_memory_id}, 聊天记录ID={g.current_chat_log_id}, 轮次={g.turn_counter}" + ) + + +# --- 静态文件路由 --- + + +@app.route("/") +def index(): + """根路由,返回前端的index.html文件""" + # 确保 static_folder 不为 None + static_folder_path = ( + app.static_folder if app.static_folder is not None else "static" + ) + return send_from_directory(static_folder_path, "index.html") + + +@app.route("/favicon.ico") +def favicon(): + """处理 favicon.ico 请求""" + static_folder_path = ( + app.static_folder if app.static_folder is not None else "static" + ) + return send_from_directory( + static_folder_path, "favicon.ico", mimetype="image/vnd.microsoft.icon" + ) + + +# --- API 路由定义 --- + + +@app.route("/api/config", methods=["GET"]) +def get_app_config(): + """ + 获取当前应用程序的所有配置。 + :return: JSON格式的配置数据。 + """ + current_config = get_config() + # 敏感信息不直接暴露给前端,例如API KEY + api_key = current_config["API"].get("GEMINI_API_KEY", "") + # 根据用户要求,完整显示API Key + display_config = { + "API": current_config["API"], + "Application": current_config["Application"], + "Session": current_config["Session"], + } + logger.debug("获取配置:成功") + return jsonify(display_config) + + +@app.route("/api/config", methods=["POST"]) +def update_app_config(): + """ + 更新应用程序配置。 + :return: 更新后的配置信息或错误信息。 + """ + data = request.json + if data is None: + logger.warning("更新配置:请求数据为空或格式不正确。") + return error_response("请求数据为空或格式不正确", 400) + + new_config = get_config() # 获取当前配置的副本 + try: + # 允许更新API和Application部分 + for section_name, section_data in data.items(): + if section_name == "API": + for key, value in section_data.items(): + if key == "GEMINI_API_KEY": + new_config["API"][key] = value + else: + new_config["API"][key] = value + elif section_name == "Application": + for key, value in section_data.items(): + if key in [ + "CONTEXT_WINDOW_SIZE", + "MEMORY_RETENTION_TURNS", + "MAX_SHORT_TERM_EVENTS", + ]: + try: + new_config["Application"][key] = int(value) + except ValueError: + logger.warning( + f"配置项 {key} 的值 '{value}' 不是有效数字。" + ) + return error_response( + f"配置项 {key} 的值 '{value}' 不是有效数字", 400 + ) + else: + new_config["Application"][key] = value + # Session 部分由 /api/active_session 控制,不在此处直接修改 + else: + logger.warning(f"尝试更新未知配置部分: {section_name}") + + set_config(new_config) # 保存到文件并更新内存中的全局配置 + logger.info("配置更新:成功。") + + # 重新初始化Gemini客户端以使用新的API配置 + global gemini_client + api_key = new_config["API"].get("GEMINI_API_KEY") + if api_key and api_key.strip(): + gemini_client = GeminiClient() + logger.info("Gemini客户端已重新初始化。") + else: + logger.warning("GEMINI_API_KEY 已被移除或为空,Gemini客户端未初始化。") + gemini_client = None + + return jsonify({"message": "配置更新成功", "config": new_config["Application"]}) + except json.JSONDecodeError: + logger.error("更新配置失败: 请求JSON格式不正确。", exc_info=True) + return error_response("请求JSON格式不正确。", 400) + except KeyError as e: + logger.error(f"更新配置失败: 缺少必要的配置键 {e}", exc_info=True) + return error_response(f"缺少必要的配置键: {e}", 400) + except Exception as e: + logger.error(f"更新配置失败: {e}", exc_info=True) + return error_response(f"更新配置失败: {str(e)}", 500) + + +@app.route("/api/roles", methods=["GET"]) +def get_roles(): + """ + 获取所有可用的角色列表。 + :return: 角色列表JSON。 + """ + roles = file_manager.list_roles() + logger.debug(f"获取角色列表:{len(roles)} 个角色。") + return jsonify(roles) + + +@app.route("/api/roles", methods=["POST"]) +def create_new_role(): + """ + 创建一个新角色。 + 需要提供 'id' 和 'name'。 + :return: 成功或失败信息。 + """ + data = request.json or {} # 确保 data 始终是字典 + if not data: # 检查 data 否为空字典或 None + return error_response("请求数据为空或格式不正确", 400) + role_id = data.get("id") + role_name = data.get("name") + if not role_id or not role_name: + logger.warning("创建角色失败: 缺少 'id' 或 'name'。") + return error_response("需要提供 'id' 和 'name' 来创建角色。", 400) + + try: + if file_manager.create_role(role_id, role_name): + logger.info(f"创建角色 '{role_name}' ({role_id}) 成功。") + return jsonify({"message": "角色创建成功", "role_id": role_id}) + else: + logger.error(f"创建角色 '{role_name}' ({role_id}) 失败,可能ID已存在。") + return error_response("创建角色失败,可能ID已存在。", 409) # Conflict + except Exception as e: + logger.error(f"创建角色时发生异常: {e}", exc_info=True) + return error_response(f"创建角色失败: {str(e)}", 500) + + +@app.route("/api/roles/", methods=["DELETE"]) +def delete_existing_role(role_id): + """ + 删除指定ID的角色。 + :return: 成功或失败信息。 + """ + # 检查是否是当前活跃角色,不允许删除 + if role_id == g.current_role_id: + logger.warning(f"尝试删除当前活跃角色 '{role_id}'。") + return error_response("不能删除当前活跃的角色。", 403) # Forbidden + + try: + if file_manager.delete_role(role_id): + logger.info(f"删除角色 '{role_id}' 成功。") + return jsonify({"message": "角色删除成功", "role_id": role_id}) + else: + logger.error(f"删除角色 '{role_id}' 失败,可能角色不存在或文件操作失败。") + return error_response("删除角色失败,可能角色不存在或文件操作失败。", 404) + except Exception as e: + logger.error(f"删除角色时发生异常: {e}", exc_info=True) + return error_response(f"删除角色失败: {str(e)}", 500) + + +@app.route("/api/memories", methods=["GET"]) +def get_memories(): + """ + 获取所有可用的记忆集列表。 + :return: 记忆集列表JSON。 + """ + memories = file_manager.list_memories() + logger.debug(f"获取记忆列表:{len(memories)} 个记忆集。") + return jsonify(memories) + + +@app.route("/api/memories", methods=["POST"]) +def create_new_memory(): + """ + 创建一个新记忆集。 + 需要提供 'id' 和 'name'。 + :return: 成功或失败信息。 + """ + data = request.json or {} # 确保 data 始终是字典 + if not data: # 检查 data 是否为空字典或 None + return error_response("请求数据为空或格式不正确", 400) + memory_id = data.get("id") + memory_name = data.get("name") + if not memory_id or not memory_name: + logger.warning("创建记忆集失败: 缺少 'id' 或 'name'。") + return error_response("需要提供 'id' 和 'name' 来创建记忆集。", 400) + + try: + if file_manager.create_memory(memory_id, memory_name): + logger.info(f"创建记忆集 '{memory_name}' ({memory_id}) 成功。") + return jsonify({"message": "记忆集创建成功", "memory_id": memory_id}) + else: + logger.error( + f"创建记忆集 '{memory_name}' ({memory_id}) 失败,可能ID已存在。" + ) + return error_response("创建记忆集失败,可能ID已存在。", 409) + except Exception as e: + logger.error(f"创建记忆集时发生异常: {e}", exc_info=True) + return error_response(f"创建记忆集失败: {str(e)}", 500) + + +@app.route("/api/memories/", methods=["DELETE"]) +def delete_existing_memory(memory_id): + """ + 删除指定ID的记忆集。 + :return: 成功或失败信息。 + """ + # 检查是否是当前活跃记忆集,不允许删除 + if memory_id == g.current_memory_id: + logger.warning(f"尝试删除当前活跃记忆集 '{memory_id}'。") + return error_response("不能删除当前活跃的记忆集。", 403) # Forbidden + + try: + if file_manager.delete_memory(memory_id): + logger.info(f"删除记忆集 '{memory_id}' 成功。") + return jsonify({"message": "记忆集删除成功", "memory_id": memory_id}) + else: + logger.error( + f"删除记忆集 '{memory_id}' 失败,可能记忆集不存在或文件操作失败。" + ) + return error_response("删记忆集失败,可能记忆集不存在或文件操作失败。", 404) + except Exception as e: + logger.error(f"删除记忆集时发生异常: {e}", exc_info=True) + return error_response(f"删除记忆集失败: {str(e)}", 500) + + +@app.route("/api/active_session", methods=["GET"]) +def get_active_session(): + """ + 获取当前活跃的会话信息,包括角色ID、记忆ID和聊天记录ID。 + :return: JSON格式的会话信息。 + """ + global global_current_role_id, global_current_memory_id, global_current_chat_log_id, global_turn_counter + logger.debug("获取当前活跃会话:成功。") + return jsonify( + { + "role_id": global_current_role_id, + "memory_id": global_current_memory_id, + "chat_log_id": global_current_chat_log_id, + "turn_counter": global_turn_counter, + "memory_status": global_memory_update_status, # 新增:返回记忆更新状态 + } + ) + + +@app.route("/api/active_session", methods=["POST"]) +def set_active_session(): + """ + 设置当前活跃的会话(角色和记忆)。 + 如果提供了新的role_id或memory_id,则更新并重置聊天上下文和轮次。 + :return: 更新后的会话信息。 + """ + global global_current_role_id, global_current_memory_id, global_current_chat_log_id, global_turn_counter, global_current_chat_history + data = request.json or {} # 确保 data 始终是字典 + + new_role_id = data.get("role_id", global_current_role_id) + new_memory_id = data.get("memory_id", global_current_memory_id) + + # 检查新ID是否存在 + if not file_manager.role_exists(new_role_id): + logger.warning(f"设置活跃会话失败: 角色ID '{new_role_id}' 不存在。") + return error_response(f"角色ID '{new_role_id}' 不存在。", 404) + if not file_manager.memory_exists(new_memory_id): + logger.warning(f"设置活跃会话失败: 记忆ID '{new_memory_id}' 不存在。") + return error_response(f"记忆ID '{new_memory_id}' 不存在。", 404) + + # 只有当角色或记忆发生变化时才重置会话 + if ( + new_role_id != global_current_role_id + or new_memory_id != global_current_memory_id + ): + old_chat_log_id = global_current_chat_log_id # 保存旧的聊天记录ID + + global_current_role_id = new_role_id + global_current_memory_id = new_memory_id + global_current_chat_log_id = ( + f"{global_current_role_id}_{global_current_memory_id}_{generate_uuid()}" + ) + + global_turn_counter = 0 + global_current_chat_history = [] # 清空聊天历史 + + # 更新config.ini中的当前激活会话 + current_app_config = get_config() + current_app_config["Session"]["CURRENT_ROLE_ID"] = global_current_role_id + current_app_config["Session"]["CURRENT_MEMORY_ID"] = global_current_memory_id + set_config(current_app_config) # 保存到文件 + + # 删除旧的聊天记录文件(如果存在且与新的不同) + if old_chat_log_id and old_chat_log_id != global_current_chat_log_id: + if file_manager.delete_chat_log(old_chat_log_id): + logger.info(f"旧聊天记录 '{old_chat_log_id}' 已成功删除。") + else: + logger.warning(f"删除旧聊天记录 '{old_chat_log_id}' 失败或文件不存在。") + + logger.info( + f"活跃会话已切换至角色 '{global_current_role_id}' 和记忆集 '{global_current_memory_id}'。新聊天记录ID: {global_current_chat_log_id}" + ) + else: + logger.info( + f"活跃会话保持不变 (角色: {global_current_role_id}, 记忆: {global_current_memory_id})。" + ) + + return jsonify( + { + "message": "活跃会话已更新", + "role_id": global_current_role_id, + "memory_id": global_current_memory_id, + "chat_log_id": global_current_chat_log_id, + "turn_counter": global_turn_counter, + } + ) + + +@app.route("/api/features_content", methods=["GET"]) +def get_features_content(): + """ + 获取当前活跃角色的特征文件内容。 + :return: JSON格式的特征内容。 + """ + features_data = file_manager.load_active_features(global_current_role_id) + if features_data: + logger.debug(f"获取角色 '{global_current_role_id}' 的特征内容:成功。") + return jsonify(features_data) + else: + logger.error(f"获取角色 '{global_current_role_id}' 的特征内容失败。") + return error_response("未能加载角色特征内容。", 500) + + +@app.route("/api/features_content", methods=["POST"]) +def update_features_content(): + """ + 更新当前活跃角色的特征文内容。 + :return: 成功或失败信息。 + """ + data = request.json + if data is None: + return error_response("请求数据为空或格式不正确", 400) + + # 确保保存的JSON中包含"角色名称"字段,并与实际角色名称一致 + # 从文件管理器获取所有角色,找到当前角色的名称 + roles = file_manager.list_roles() + current_role_name = next( + (r["name"] for r in roles if r["id"] == global_current_role_id), + global_current_role_id, + ) + data["角色名称"] = current_role_name + + if file_manager.save_active_features(global_current_role_id, data): + logger.info(f"角色 '{global_current_role_id}' 的特征内容更新成功。") + return jsonify({"message": "特征内容更新成功"}) + else: + logger.error(f"角色 '{global_current_role_id}' 的特征内容更新失败。") + return error_response("特征内容更新失败。", 500) + + +@app.route("/api/memory_content", methods=["GET"]) +def get_memory_content(): + """ + 获取当前活跃记忆集的记忆文件内容。 + :return: JSON格式的记忆内容。 + """ + memory_data = file_manager.load_active_memory(global_current_memory_id) + if memory_data: + logger.debug(f"获取记忆集 '{global_current_memory_id}' 的记忆内容:成功。") + return jsonify(memory_data) + else: + logger.error(f"获取记忆集 '{global_current_memory_id}' 的记忆内容失败。") + return error_response("未能加载记忆内容。", 500) + + +@app.route("/api/memory_content", methods=["POST"]) +def update_memory_content(): + """ + 更新当前活跃记忆集的记忆文件内容。 + :return: 成功或失败信息。 + """ + data = request.json + if data is None: + return error_response("请求数据为空或格式不正确", 400) + + # 确保保存的JSON中包含'long_term_facts'和'short_term_events'字段 + if "long_term_facts" not in data or "short_term_events" not in data: + return error_response( + "记忆内容需要包含 'long_term_facts' 和 'short_term_events' 字段。", 400 + ) + + if file_manager.save_active_memory(global_current_memory_id, data): + logger.info(f"记忆集 '{global_current_memory_id}' 的记忆内容更新成功。") + return jsonify({"message": "记忆内容更新成功"}) + else: + logger.error(f"记忆集 '{global_current_memory_id}' 的记忆内容更新失败。") + return error_response("记忆内容更新失败。", 500) + + +@app.route("/api/chat_log", methods=["GET"]) +def get_chat_log(): + """ + 获取当前活跃会话的聊天记录。 + :param limit: 可选,限制返回的记录数量。 + :return: JSON格式聊天记录列表。 + """ + limit = request.args.get("limit", type=int) + logs = file_manager.read_chat_log(global_current_chat_log_id, limit=limit) + logger.debug(f"获取聊天记录:{len(logs)} 条。") + return jsonify(logs) + + +@app.route("/api/chat_log", methods=["DELETE"]) +def delete_chat_log_route(): + """ + 删除当前活跃会话的聊天记录文件。 + :return: 成功或失败信息。 + """ + # 这里我们只删除当前 session 对应的聊天记录文件 + if file_manager.delete_chat_log(global_current_chat_log_id): + # 删除后,清空内存中的聊天历史 + global global_current_chat_history, global_turn_counter + global_current_chat_history = [] + global_turn_counter = 0 + logger.info( + f"聊天记录 '{global_current_chat_log_id}' 删除成功,内存历史已清空。" + ) + return jsonify({"message": "聊天记录删除成功,会话历史已清空"}) + else: + logger.error(f"聊天记录 '{global_current_chat_log_id}' 删除失败。") + return error_response("聊天记录删除失败。", 500) + + +@app.route("/api/proxy_models", methods=["GET"]) +def proxy_get_available_models(): + """ + 通过后端代理获取Gemini API可用的模型列表。 + :return: JSON格式的模型列表。 + """ + current_config = get_config() + gemini_api_base_url = current_config["API"].get("GEMINI_API_BASE_URL") + + if not gemini_api_base_url: + logger.warning("GEMINI_API_BASE_URL 未配置,无法获取模型列表。") + return jsonify({"models": [], "error": "GEMINI_API_BASE_URL 未配置。"}) + + # 确保 GEMINI_API_BASE_URL 以 /v1beta 结尾,并拼接 /models + # 移除末尾的斜杠,然后添加 /v1beta/models + base_url_cleaned = gemini_api_base_url.rstrip("/") + if not base_url_cleaned.endswith("/v1beta"): + base_url_cleaned += "/v1beta" + models_url = f"{base_url_cleaned}/models" + + headers = {} + # 如果有API Key,也传递给上游API + api_key = current_config["API"].get("GEMINI_API_KEY") + if api_key and api_key.strip(): + headers["x-goog-api-key"] = api_key # 根据Gemini API的要求添加API Key头 + + try: + response = requests.get(models_url, headers=headers, timeout=10) # 设置超时 + + response.raise_for_status() # 如果状态码不是2xx,则抛出HTTPError + + models_data = response.json() + logger.debug("通过代理成功获取模型列表。") + return jsonify(models_data) + + except requests.exceptions.Timeout: + logger.error(f"获取模型列表超时: {models_url}") + return error_response("获取模型列表超时。", 504) # Gateway Timeout + except requests.exceptions.RequestException as e: + logger.error(f"通过代理获取模型列表失败: {e}", exc_info=True) + return error_response(f"获取模型列表失败: {str(e)}", 500) + except json.JSONDecodeError: + logger.error(f"无法解析模型列表响应为JSON: {response.text}", exc_info=True) + return error_response("无法解析模型列表响应。", 500) + + +@app.route("/api/logs", methods=["GET"]) +def get_app_logs(): + """ + 获取应用程序的日志内容。 + :return: 日志文件内容或错误信息。 + """ + log_file_path = get_config()["Application"]["APP_LOG_FILE_PATH"] + try: + if not os.path.exists(log_file_path): + logger.warning(f"日志文件不存在: {log_file_path}") + return error_response("日志文件不存在。", 404) + with open(log_file_path, "r", encoding="utf-8") as f: + logs = f.read() + return Response(logs, mimetype="text/plain") + except Exception as e: + logger.error(f"读取日志文件失败: {e}", exc_info=True) + return error_response(f"读取日志文件失败: {str(e)}", 500) + + +# --- 主对话路由 --- +@app.route("/api/chat", methods=["POST"]) +def chat_with_gemini(): + """ + 与Gemini模型进行对话。 + :return: AI回复和当前的会话状态。 + """ + global global_turn_counter, global_current_chat_history, global_current_role_id, global_current_memory_id + + # 检查是否请求流式响应 + use_stream = request.args.get("stream", "false").lower() == "true" + + data = request.json + if data is None: + return error_response("请求数据为空或格式不正确", 400) + user_message = data.get("message") + if not user_message: + return error_response("消息内容不能为空。", 400) + + # 1. 记录用户消息并更新全局聊天历史 + user_entry = { + "timestamp": get_current_timestamp(), + "role": "user", + "content": user_message, + "id": f"user-{time.time()}-{generate_uuid()[:8]}", # 添加消息ID + } + file_manager.append_chat_log(g.current_chat_log_id, user_entry) + global_current_chat_history.append( + {"role": "user", "parts": [{"text": user_message}]} + ) + logger.info(f"用户消息已记录。当前轮次: {g.turn_counter}") + + # 2. 构建系统提示词 (包含特征和记忆) + system_instruction = memory_manager.build_system_prompt( + g.current_role_id, + g.current_memory_id, + g.turn_counter, # 传递当前轮次给Prompt + ) + + # 3. 管理上下文长度:只保留最近 CONTEXT_WINDOW_SIZE * 2 条消息 (用户+AI) + context_window_size = app_config["Application"]["CONTEXT_WINDOW_SIZE"] + # 从全局聊天历史中获取用于API调用的部分 + chat_history_for_api = global_current_chat_history[-context_window_size * 2 :] + + # 4. 调用 Gemini API 进行对话 + if gemini_client is None: + logger.error("Gemini客户端未初始化,无法进行对话。") + return error_response("AI助手未初始化,请检查API配置。", 500) + + model_name = app_config["API"]["DEFAULT_GEMINI_MODEL"] + + # 创建AI回复的消息ID + ai_message_id = f"assistant-{time.time()}-{generate_uuid()[:8]}" + + # 根据请求类型选择处理方式 + if use_stream: + # 流式响应处理 + def generate_stream(): + try: + # 确保gemini_client不为None + if gemini_client is None: + logger.error("Gemini客户端未初始化,无法进行流式对话。") + yield f"data: {json.dumps({'error': 'AI助手未初始化,请检查API配置。'})}\n\n" + return + + # 调用流式API + stream_response = gemini_client.generate_content( + model_name=model_name, + contents=chat_history_for_api, + system_instruction=system_instruction, + stream=True, # 启用流式响应 + ) + + full_response = "" + + if isinstance(stream_response, tuple) and len(stream_response) == 2: + # 如果返回错误信息 + error_message, status_code = stream_response + logger.error( + f"Gemini API 流式响应返回错误: {error_message} (状态码: {status_code})" + ) + yield f"data: {json.dumps({'error': error_message})}\n\n" + return + + # 处理流式响应 + for chunk in stream_response: + if chunk: + full_response += chunk + # 发送数据块 + yield f"data: {json.dumps({'chunk': chunk, 'id': ai_message_id})}\n\n" + + # 流结束后,保存完整回复到数据库 + # 获取当前请求中的聊天记录ID(避免使用g对象) + current_chat_log_id = global_current_chat_log_id + + assistant_entry = { + "timestamp": get_current_timestamp(), + "role": "assistant", + "content": full_response, + "id": ai_message_id, + } + # 使用全局变量而不是g对象 + file_manager.append_chat_log(current_chat_log_id, assistant_entry) + global_current_chat_history.append( + {"role": "assistant", "parts": [{"text": full_response}]} + ) + + # 更新全局对话轮次 + global global_turn_counter + global_turn_counter += 1 + + # 发送结束标记 + yield f"data: {json.dumps({'end': True, 'turn_counter': global_turn_counter})}\n\n" + + # 在另一个线程中检查并触发记忆更新,避免阻塞流式响应 + def delayed_memory_update(): + try: + with app.app_context(): # 创建应用上下文 + logger.info("在新线程中触发记忆更新") + check_and_trigger_memory_update() + except Exception as e: + logger.error(f"延迟记忆更新失败: {e}", exc_info=True) + + # 启动延迟更新线程 + threading.Thread(target=delayed_memory_update).start() + + except Exception as e: + logger.error(f"流式生成内容时发生错误: {e}", exc_info=True) + yield f"data: {json.dumps({'error': str(e)})}\n\n" + + return Response(generate_stream(), mimetype="text/event-stream") + else: + # 标准响应处理 + try: + ai_response = gemini_client.generate_content( + model_name=model_name, + contents=chat_history_for_api, + system_instruction=system_instruction, + ) + + # 检查 ai_response 是否为元组 (错误信息, 状态码) + if isinstance(ai_response, tuple) and len(ai_response) == 2: + error_message: str = ai_response[0] + status_code: int = ai_response[1] + logger.error( + f"Gemini API 返回错误: {error_message} (状态码: {status_code})" + ) + return error_response(error_message, status_code) + + # 如果不是元组,则预期它是一个字符串 + if not isinstance(ai_response, str): + logger.error( + f"Gemini API 返回了非字符串/元组类型结果: {type(ai_response)} - {ai_response}" + ) + return error_response("AI助手返回了非预期的结果类型。", 500) + + ai_response_text: str = ai_response + except Exception as e: + logger.error(f"调用Gemini API失败: {e}", exc_info=True) + return error_response(f"AI助手调用失败: {str(e)}", 500) + + logger.info(f"Gemini API 回复: {ai_response_text[:100]}...") + + # 5. 记录AI回复并更新全局聊天历史 + assistant_entry = { + "timestamp": get_current_timestamp(), + "role": "assistant", + "content": ai_response_text, + "id": ai_message_id, + } + file_manager.append_chat_log(g.current_chat_log_id, assistant_entry) + global_current_chat_history.append( + {"role": "assistant", "parts": [{"text": ai_response_text}]} + ) + + # 6. 更新全局对话轮次 + global global_turn_counter + global_turn_counter += 1 + + # 7. 检查并触发记忆更新 (异步) + check_and_trigger_memory_update() + + # 8. 返回AI回复和当前会话状态 + response_data = { + "success": True, + "response": ai_response_text, + "turn_counter": global_turn_counter, + "active_session": { + "role_id": g.current_role_id, + "memory_id": g.current_memory_id, + "chat_log_id": g.current_chat_log_id, + }, + "id": ai_message_id, # 返回消息ID,便于前端追踪 + } + # 如果触发记忆更新,则在响应中添加状态 + if global_turn_counter > 0 and global_turn_counter % context_window_size == 0: + response_data["memory_status"] = ( + "memory_updating" # 保持与 /api/active_session 一致 + ) + + return jsonify(response_data) + + +def check_and_trigger_memory_update(): + """检查是否需要触发记忆更新,如果需要则在后台线程中执行""" + global global_turn_counter, global_current_chat_history, global_current_role_id, global_current_memory_id + + context_window_size = app_config["Application"]["CONTEXT_WINDOW_SIZE"] + if global_turn_counter > 0 and global_turn_counter % context_window_size == 0: + logger.info(f"对话轮次达到 {global_turn_counter},触发异步记忆更新。") + # 使用全局变量而不是g对象 + try: + recent_chat_for_memory_update = global_current_chat_history[ + -context_window_size * 2 : + ] + + current_memory_data_for_update = file_manager.load_active_memory( + global_current_memory_id # 使用全局变量替代g.current_memory_id + ) + + if current_memory_data_for_update: + update_thread = threading.Thread( + target=_async_memory_update_task, + args=( + current_memory_data_for_update, + recent_chat_for_memory_update, + global_turn_counter, + global_current_role_id, # 使用全局变量替代g.current_role_id + global_current_memory_id, # 使用全局变量替代g.current_memory_id + ), + ) + update_thread.start() + logger.info( + f"记忆更新线程已启动,角色ID: {global_current_role_id}, 记忆ID: {global_current_memory_id}" + ) + else: + logger.error(f"无法加载记忆数据,记忆ID: {global_current_memory_id}") + except Exception as e: + logger.error(f"触发记忆更新时发生错误: {e}", exc_info=True) + + +def _async_memory_update_task( + initial_memory_data, + recent_chat_history, + current_trigger_turn_count, + role_id: str, + memory_id: str, +): + """ + 异步执行记忆更新任务。 + :param initial_memory_data: 触发更新时的原始记忆数据。 + :param recent_chat_history: 最近 N 轮对话历史。 + :param current_trigger_turn_count: 触发本次更新时的全局对话轮次。 + :param role_id: 当前活跃的角色ID。 + :param memory_id: 当前活跃的记忆ID。 + """ + global global_memory_update_status # 声明使用全局变量 + # 在新线程中手动创建应用上下文 + with app.app_context(): + logger.info( + f"异步记忆更新任务开始,角色: {role_id}, 记忆: {memory_id}, 触发轮次: {current_trigger_turn_count}" + ) + global_memory_update_status = "updating" # 设置状态为更新中 + try: + # 1. 构建记忆更新的Prompt + update_prompt = memory_manager.build_memory_update_prompt( + recent_chat_history, initial_memory_data, current_trigger_turn_count + ) + + # 2. 调用Gemini API进行记忆更新 + if gemini_client is None: + logger.error("Gemini客户端未初始化,无法执行记忆更新。请检查API配置。") + return + + # 确保 app_config 在上下文中可用,或者重新获取 + current_app_config = get_config() # 重新获取配置,确保在上下文中 + update_model_name = current_app_config["API"]["MEMORY_UPDATE_MODEL"] + try: + # 构造一个符合Gemini API预期的请求体 + # 使用单条消息,将prompt作为文本内容 + gemini_response = gemini_client.generate_content( + model_name=update_model_name, + contents=[{"role": "user", "parts": [{"text": update_prompt}]}], + system_instruction=None, # 不使用system_instruction,直接在contents中包含prompt + ) + + # 检查 gemini_response 是否为元组 (错误信息, 状态码) + if isinstance(gemini_response, tuple) and len(gemini_response) == 2: + error_message: str = gemini_response[0] + status_code: int = gemini_response[1] + logger.error( + f"Gemini模型在记忆更新时返回错误: {error_message} (状态码: {status_code})" + ) + return # 无法继续,直接返回 + + # 如果不是元组,则预期它是一个字符串 + if not isinstance(gemini_response, str): + logger.error( + f"Gemini模型在记忆更新时返回了非字符串/元组类型结果: {type(gemini_response)} - {gemini_response}" + ) + return # 无法继续,直接返回 + + gemini_response_text: str = gemini_response + except Exception as e: + logger.error(f"调用Gemini API进行记忆更新失败: {e}", exc_info=True) + return + + # 3. 处理模型返回的记忆更新JSON + updated_memory_data = memory_manager.process_memory_update_response( + gemini_response_text, + initial_memory_data, + current_trigger_turn_count, + ) + + # 4. 保存更新后的记忆数据 + if file_manager.save_active_memory(memory_id, updated_memory_data): + logger.info(f"记忆集 '{memory_id}' (角色 '{role_id}') 记忆更新成功。") + global_memory_update_status = "completed" # 更新成功 + else: + logger.error(f"记忆集 '{memory_id}' (角色 '{role_id}') 记忆保存失败。") + global_memory_update_status = "error" # 保存失败 + except Exception as e: + logger.error(f"异步记忆更新任务发生致命错误: {e}", exc_info=True) + global_memory_update_status = "error" # 发生错误 + finally: + # 无论成功或失败,最终都将状态设置为 idle,除非有其他更新正在进行 + # 这里需要更复杂的逻辑来判断是否真的空闲,暂时简化处理 + if global_memory_update_status != "updating": # 避免覆盖正在进行的更新 + global_memory_update_status = "idle" + + +# --- 主运行部分 --- +# --- 应用启动前处理 --- +def initialize_app_data(): + """ + 在Flask应用启动时执行。 + 用于确保必要的默认数据(如默认角色和记忆)存在,并初始化会话状态。 + """ + global app_config, gemini_client + global global_current_role_id, global_current_memory_id, global_current_chat_log_id, global_turn_counter, global_current_chat_history + + # 在所有模块加载和配置读取完成后配置日志 + app_config = get_config() # 确保获取到最新的配置 + setup_logging(app_config["Application"]["APP_LOG_FILE_PATH"]) + logger.info("应用数据初始化开始...") + + file_manager.ensure_initial_data() + + # 从 config.ini 加载初始会话状态 + session_config = app_config["Session"] + + # 检查并设置当前角色ID和记忆ID,并更新config.ini(如果需要) + configured_role_id = session_config.get("CURRENT_ROLE_ID", "default_role") + configured_memory_id = session_config.get("CURRENT_MEMORY_ID", "default_memory") + + # 确保配置中的角色和记忆ID是有效的,如果无效则回退到默认值 + if not file_manager.role_exists(configured_role_id): + logger.warning( + f"配置的角色ID '{configured_role_id}' 不存在,回退到 'default_role'。" + ) + global_current_role_id = "default_role" + else: + global_current_role_id = configured_role_id + + if not file_manager.memory_exists(configured_memory_id): + logger.warning( + f"配置的记忆ID '{configured_memory_id}' 不存在,回退到 'default_memory'。" + ) + global_current_memory_id = "default_memory" + else: + global_current_memory_id = configured_memory_id + + # 如果回退了,需要更新config.ini + if ( + configured_role_id != global_current_role_id + or configured_memory_id != global_current_memory_id + ): + current_app_config = get_config() + current_app_config["Session"]["CURRENT_ROLE_ID"] = global_current_role_id + current_app_config["Session"]["CURRENT_MEMORY_ID"] = global_current_memory_id + set_config(current_app_config) + logger.info("config.ini 中的会话配置已更新为有效值。") + app_config = get_config() # 重新加载配置到 app_config + + # 聊天记录ID:每次应用启动时生成一个新的唯一ID + global_current_chat_log_id = ( + f"{global_current_role_id}_{global_current_memory_id}_{generate_uuid()}" + ) + global_turn_counter = 0 + global_current_chat_history = [] # 清空聊天历史 + + logger.info(f"应用启动,已创建新的聊天记录ID: {global_current_chat_log_id}。") + logger.info("应用数据初始化完成。") + + # 初始化Gemini客户端 + api_key = app_config["API"].get("GEMINI_API_KEY") + if api_key and api_key.strip(): # 确保API Key不为空或只包含空格 + gemini_client = GeminiClient() + logger.info("Gemini客户端已初始化。") + else: + logger.warning("未配置或API KEY为空,Gemini客户端未初始化。") + gemini_client = None + + +@app.route("/api/memory/trigger_update", methods=["POST"]) +def trigger_memory_update_route(): + """ + 手动触发记忆更新的API接口。 + """ + global global_current_role_id, global_current_memory_id, global_turn_counter, global_current_chat_history + logger.info("收到手动触发记忆更新的请求。") + + # 获取当前记忆数据和最近的聊天历史 + current_memory_data_for_update = file_manager.load_active_memory( + global_current_memory_id + ) + # 记忆更新通常需要最近的对话历史,这里可以根据需要调整截取长度 + context_window_size = app_config["Application"]["CONTEXT_WINDOW_SIZE"] + recent_chat_for_memory_update = global_current_chat_history[ + -context_window_size * 2 : + ] + + if not current_memory_data_for_update: + return error_response("无法加载当前记忆数据,无法触发更新。", 500) + + # 异步执行记忆更新任务 + update_thread = threading.Thread( + target=_async_memory_update_task, + args=( + current_memory_data_for_update, + recent_chat_for_memory_update, + global_turn_counter, + global_current_role_id, # 传递当前角色ID + global_current_memory_id, # 传递当前记忆ID + ), + ) + update_thread.start() + + return jsonify({"message": "记忆更新任务已在后台触发。"}) + + +if __name__ == "__main__": + # 应用启动时初始化数据 + initialize_app_data() + # 运行Flask应用 + # debug=True 仅用于开发环境,生产环境请勿使用 + app.run(debug=True, host="0.0.0.0", port=5000) diff --git a/config.ini b/config.ini new file mode 100644 index 0000000..b4888d8 --- /dev/null +++ b/config.ini @@ -0,0 +1,20 @@ +[API] +gemini_api_base_url = https://g.shatang.me/v1beta +gemini_api_key = +default_gemini_model = models/gemini-2.0-flash +memory_update_model = models/gemini-2.0-flash + +[Application] +context_window_size = 5 +memory_retention_turns = 10 +max_short_term_events = 15 +features_dir = data/features +memories_dir = data/memories +chat_logs_dir = data/chat_logs +app_log_file_path = logs/app.log + +[Session] +current_role_id = default_role +current_memory_id = default_memory +current_chat_log_id = default_role_default_memory_ce76c8a8-4072-44f4-8f1b-d942166b62e1 + diff --git a/config.py b/config.py new file mode 100644 index 0000000..5f6c61f --- /dev/null +++ b/config.py @@ -0,0 +1,168 @@ +# config.py + +import configparser +import os +import logging + +# 获取根日志器,确保日志配置已在utils.py中完成 +logger = logging.getLogger(__name__) + +CONFIG_FILE = "config.ini" + +# 定义默认配置,如果config.ini不存在或缺失,将使用这些值 +DEFAULT_CONFIG = { + "API": { + "GEMINI_API_BASE_URL": "https://generativelanguage.googleapis.com/v1beta", + "GEMINI_API_KEY": "YOUR_GEMINI_API_KEY_HERE", # 默认值,强烈建议用户手动修改 + "DEFAULT_GEMINI_MODEL": "gemini-pro", + "MEMORY_UPDATE_MODEL": "gemini-pro", + }, + "Application": { + "CONTEXT_WINDOW_SIZE": "6", + "MEMORY_RETENTION_TURNS": "18", + "MAX_SHORT_TERM_EVENTS": "20", + "FEATURES_DIR": "data/features", + "MEMORIES_DIR": "data/memories", + "CHAT_LOGS_DIR": "data/chat_logs", + "APP_LOG_FILE_PATH": "logs/app.log", + }, + "Session": { + "CURRENT_ROLE_ID": "default_role", + "CURRENT_MEMORY_ID": "default_memory", + }, +} + + +def load_config(): + """ + 从 config.ini 文件加载配置。 + 如果文件不存在或某些配置项缺失,则使用默认值并尝试创建/更新文件。 + :return: 包含所有配置的字典。 + """ + config = configparser.ConfigParser() + + # 检查config.ini是否存在,如果不存在则创建并写入默认配置 + if not os.path.exists(CONFIG_FILE): + logger.info(f"config.ini 文件不存在,正在创建默认配置文件:{CONFIG_FILE}") + for section, options in DEFAULT_CONFIG.items(): + config.add_section(section) + for key, value in options.items(): + config.set(section, key, value) + try: + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + config.write(f) + logger.info("默认 config.ini 文件创建成功。") + except IOError as e: + logger.error(f"无法写入 config.ini 文件: {e}") + # 即使写入失败,也继续使用内存中的默认配置 + else: + logger.info(f"从 {CONFIG_FILE} 加载配置...") + + # 读取配置 + try: + config.read(CONFIG_FILE, encoding="utf-8") + except Exception as e: + logger.error(f"读取 config.ini 文件时发生错误: {e},将使用默认配置。") + # 如果读取失败,清空 configparser 对象,让后续逻辑使用默认值 + config = configparser.ConfigParser() + + # 将配置转换为字典格式,并确保所有默认值都存在 + app_config = {} + for section, options in DEFAULT_CONFIG.items(): + app_config[section] = {} + for key, default_value in options.items(): + # 使用 get() 方法获取值,如果不存在则使用默认值 + # 对于整型或布尔型配置,需要手动转换 + value = config.get(section, key, fallback=default_value) + + # 尝试将特定配置项转换为数字类型 + if section == "Application" and key in [ + "CONTEXT_WINDOW_SIZE", + "MEMORY_RETENTION_TURNS", + "MAX_SHORT_TERM_EVENTS", + ]: + try: + app_config[section][key] = int(value) + except ValueError: + logger.warning( + f"配置项 {section}.{key} 的值 '{value}' 不是有效数字,使用默认值 {default_value}。" + ) + app_config[section][key] = int(default_value) + elif section == "API" and key == "GEMINI_API_KEY": + # 优先从环境变量读取 GEMINI_API_KEY + env_key = os.getenv("GEMINI_API_KEY") + if env_key: + app_config[section][key] = env_key + logger.info("已从环境变量加载 GEMINI_API_KEY。") + else: + app_config[section][key] = value + else: + app_config[section][key] = value + + # 如果存在config.ini但缺少某些默认值,更新它 + try: + updated = False + for section, options in DEFAULT_CONFIG.items(): + if not config.has_section(section): + config.add_section(section) + updated = True + for key, default_value in options.items(): + if not config.has_option(section, key): + config.set(section, key, str(default_value)) # 确保写入的是字符串 + updated = True + if updated: + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + config.write(f) + logger.info("config.ini 文件已更新,补齐缺失配置项。") + except IOError as e: + logger.error(f"无法更新 config.ini 文件以补齐缺失配置项: {e}") + + logger.info("配置加载完成。") + return app_config + + +def save_config(new_config): + """ + 将给定的配置字典保存到 config.ini 文件。 + 它会读取现有配置,然后只更新 new_config 中提供的部分。 + :param new_config: 包含要保存的所有配置的字典。 + """ + config = configparser.ConfigParser() + + # 首先读取现有配置,以保留未在 new_config 中提供的项 + if os.path.exists(CONFIG_FILE): + try: + config.read(CONFIG_FILE, encoding="utf-8") + except Exception as e: + logger.error(f"读取现有 config.ini 文件时发生错误: {e},将只保存新配置。") + + for section, options in new_config.items(): + if not config.has_section(section): + config.add_section(section) + for key, value in options.items(): + # 确保将值转换为字符串,因为 configparser 存储的是字符串 + config.set(section, key, str(value)) + + try: + with open(CONFIG_FILE, "w", encoding="utf-8") as f: + config.write(f) + logger.info(f"配置已成功保存到 {CONFIG_FILE}") + except IOError as e: + logger.error(f"保存配置到 {CONFIG_FILE} 失败: {e}") + + +# 全局变量,存储当前加载的配置 +app_config = load_config() + + +def get_config(): + """获取当前加载的配置""" + return app_config + + +def set_config(new_config): + """更新内存中的配置并保存到文件""" + global app_config + app_config = new_config + save_config(app_config) + logger.info("内存配置已更新并保存。") diff --git a/file_manager.py b/file_manager.py new file mode 100644 index 0000000..7ad0fdf --- /dev/null +++ b/file_manager.py @@ -0,0 +1,432 @@ +# file_manager.py + +import json +import os +import uuid +import shutil # 用于删除目录及其内容 +import logging +from config import get_config # 导入config模块的获取配置函数 + +logger = logging.getLogger(__name__) + +# --- 辅助函数:JSON 和 JSONL 文件操作 --- + + +def _load_json_file(file_path, default_content=None): + """ + 加载JSON文件。如果文件不存在,则创建并写入默认内容(如果提供)。 + :param file_path: JSON文件路径。 + :param default_content: 文件不存在时要写入的默认Python对象。 + :return: 文件内容(Python对象)或None(如果加载失败)。 + """ + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在 + if default_content is not None: + logger.info(f"文件不存在,创建并写入默认内容到:{file_path}") + _save_json_file(file_path, default_content) + else: + logger.warning(f"尝试加载的JSON文件不存在且未提供默认内容:{file_path}") + return {} # 返回空字典,以提高健壮性 + + try: + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError as e: + logger.error(f"JSON文件解码失败:{file_path} - {e}") + return None + except IOError as e: + logger.error(f"读取JSON文件失败:{file_path} - {e}") + return None + + +def _save_json_file(file_path, data): + """ + 保存Python对象到JSON文件。 + :param file_path: JSON文件路径。 + :param data: 要保存的Python对象。 + """ + os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在 + try: + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + logger.debug(f"JSON文件已保存:{file_path}") + return True + except IOError as e: + logger.error(f"保存JSON文件失败:{file_path} - {e}") + return False + + +def _append_jsonl_file(file_path, data): + """ + 追加一条JSON记录到JSONL文件。 + :param file_path: JSONL文件路径。 + :param data: 要追加的Python对象。 + """ + os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在 + try: + with open(file_path, "a", encoding="utf-8") as f: + f.write(json.dumps(data, ensure_ascii=False) + "\n") + logger.debug(f"JSONL记录已追加到:{file_path}") + return True + except IOError as e: + logger.error(f"追加JSONL文件失败:{file_path} - {e}") + return False + + +def _read_jsonl_file(file_path, limit=None): + """ + 读取JSONL文件的所有记录。 + :param file_path: JSONL文件路径。 + :param limit: 可选,限制读取的最后N条记录。 + :return: 包含所有记录的列表。 + """ + if not os.path.exists(file_path): + return [] + + records = [] + try: + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + if limit: + lines = lines[-limit:] # 只读取最后N行 + for line in lines: + try: + records.append(json.loads(line.strip())) + except json.JSONDecodeError as e: + logger.warning( + f"跳过JSONL文件中的损坏行:{file_path} - {line.strip()} - {e}" + ) + return records + except IOError as e: + logger.error(f"读取JSONL文件失败:{file_path} - {e}") + return [] + + +# --- 默认文件内容 --- + +# 默认特征文件内容 +DEFAULT_FEATURE_CONTENT = { + "角色名称": "通用AI助手", + "角色描述": "你是一个有帮助的AI助手,旨在提供清晰、准确、友好的回答。你没有预设的特定角色或个性,能够灵活适应各种对话主题。", + "心理特征": { + "核心哲学与世界观": {"描述": "以提供信息、解决问题为核心,保持中立客观。"}, + "价值观与道德观": {"描述": "遵循伦理原则,避免偏见,确保信息准确。"}, + "决策风格": {"描述": "基于逻辑和可用信息进行决策。"}, + "情绪反应与表达": {"描述": "保持平静、专业、不带个人情感。"}, + "人际互动与关系处理": {"描述": "以帮助用户为目标,保持礼貌和尊重。"}, + "动机与目标": {"描述": "提供有价值的信息和解决方案。"}, + "自我认知": {"描述": "认知自己是AI,不具备人类情感和意识。"}, + "应对机制": {"描述": "遇到未知问题时,会请求更多信息或承认能力限制。"}, + }, + "语言特征": { + "词汇与措辞": {"描述": "清晰、简洁、准确,避免模糊不清的表达。"}, + "句式结构与复杂度": {"描述": "常用清晰直接的句式,逻辑性强。"}, + "语气与风格": {"描述": "专业、友善、乐于助人。"}, + "修辞手法与模式": {"描述": "多使用直接陈述、解释说明。"}, + "互动模式": {"描述": "响应式,根据用户输入提供信息或引导对话。"}, + }, + "重要人际关系": [], + "角色弧线总结": "作为一个持续学习和改进的AI,不断提升服务质量。", +} + +# 默认记忆文件内容 +DEFAULT_MEMORY_CONTENT = { + "long_term_facts": ["此记忆集为默认记忆,未包含特定用户或AI的长期事实。"], + "short_term_events": [], +} + +# --- 路径管理函数 --- + + +def _get_dir_path(base_dir_config_key): + """根据配置获取目录路径""" + config = get_config() + return config["Application"][base_dir_config_key] + + +def get_features_dir(): + """获取特征文件根目录""" + return _get_dir_path("FEATURES_DIR") + + +def get_memories_dir(): + """获取记忆文件根目录""" + return _get_dir_path("MEMORIES_DIR") + + +def get_chat_logs_dir(): + """获取聊天记录文件根目录""" + return _get_dir_path("CHAT_LOGS_DIR") + + +def get_role_features_file_path(role_id): + """获取指定角色ID的特征文件路径""" + return os.path.join(get_features_dir(), role_id, "features.json") + + +def get_memory_data_file_path(memory_id): + """获取指定记忆ID的记忆数据文件路径""" + return os.path.join(get_memories_dir(), memory_id, "memory.json") + + +def get_chat_log_file_path(chat_log_id): + """获取指定聊天记录ID的聊天记录文件路径""" + # 聊天记录ID通常与角色ID和记忆ID关联,或者直接就是session ID + # 这里简化为直接以chat_log_id为文件名 + return os.path.join(get_chat_logs_dir(), f"{chat_log_id}.jsonl") + + +# --- 角色和记忆管理 --- + + +def list_roles(): + """ + 列出所有可用的角色ID和名称。 + 角色名称从其features.json中获取。 + :return: 列表,每个元素为 {"id": "role_id", "name": "角色名称"} + """ + roles = [] + features_dir = get_features_dir() + if not os.path.exists(features_dir): + os.makedirs(features_dir) # 确保目录存在 + return [] + + for role_id in os.listdir(features_dir): + role_path = os.path.join(features_dir, role_id) + features_file = get_role_features_file_path(role_id) + if os.path.isdir(role_path) and os.path.exists(features_file): + features = _load_json_file(features_file) + if features: + roles.append({"id": role_id, "name": features.get("角色名称", role_id)}) + else: + logger.warning( + f"无法加载角色 {role_id} 的 features.json,可能文件损坏。" + ) + else: + logger.warning(f"目录 {role_path} 或文件 {features_file} 不完整,跳过。") + return roles + + +def create_role(role_id, role_name=None): + """ + 创建一个新的角色。 + :param role_id: 新角色的唯一ID。 + :param role_name: 新角色的名称,如果未提供,则默认为role_id。 + :return: True如果成功,False如果失败或ID已存在。 + """ + role_path = os.path.join(get_features_dir(), role_id) + if os.path.exists(role_path): + logger.warning(f"角色ID '{role_id}' 已存在,无法创建。") + return False + + os.makedirs(role_path) + features_file = get_role_features_file_path(role_id) + + default_content = DEFAULT_FEATURE_CONTENT.copy() + default_content["角色名称"] = role_name if role_name else role_id + + success = _save_json_file(features_file, default_content) + if success: + logger.info(f"角色 '{role_id}' 创建成功。") + return success + + +def delete_role(role_id): + """ + 删除一个角色及其所有相关文件。 + :param role_id: 要删除的角色ID。 + :return: True如果成功,False如果失败。 + """ + role_path = os.path.join(get_features_dir(), role_id) + if not os.path.exists(role_path): + logger.warning(f"角色ID '{role_id}' 不存在,无法删除。") + return False + + try: + shutil.rmtree(role_path) + logger.info(f"角色 '{role_id}' 已成功删除。") + return True + except Exception as e: + logger.error(f"删除角色 '{role_id}' 失败: {e}") + return False + + +def role_exists(role_id): + """ + 检查指定角色ID的角色是否存在。 + :param role_id: 角色ID。 + :return: True如果存在,False如果不存在。 + """ + return os.path.exists(get_role_features_file_path(role_id)) + + +def list_memories(): + """ + 列出所有可用的记忆ID和名称。 + 记忆名称从其memory.json中获取(如果模型能生成name字段,否则使用ID)。 + :return: 列表,每个元素为 {"id": "memory_id", "name": "记忆名称"} + """ + memories = [] + memories_dir = get_memories_dir() + if not os.path.exists(memories_dir): + os.makedirs(memories_dir) # 确保目录存在 + return [] + + for memory_id in os.listdir(memories_dir): + memory_path = os.path.join(memories_dir, memory_id) + memory_file = get_memory_data_file_path(memory_id) + if os.path.isdir(memory_path) and os.path.exists(memory_file): + memory_data = _load_json_file(memory_file) + # 确保 memory_data 不是 None 且不是空字典 + if memory_data is not None and memory_data != {}: + # 假设 memory.json 中可能有一个 'name' 字段来标识记忆集 + memories.append( + {"id": memory_id, "name": memory_data.get("name", memory_id)} + ) + else: + logger.warning( + f"无法加载记忆 {memory_id} 的 memory.json,可能文件损或为空,或内容为空。" + ) + else: + logger.warning(f"目录 {memory_path} 或文件 {memory_file} 不完整,跳过。") + return memories + + +def create_memory(memory_id, memory_name=None): + """ + 创建一个新的记忆集。 + :param memory_id: 新记忆集的唯一ID。 + :param memory_name: 新记忆集的名称,如果未提供,则默认为memory_id。 + :return: True如果成功,False如果失败或ID已存在。 + """ + memory_path = os.path.join(get_memories_dir(), memory_id) + if os.path.exists(memory_path): + logger.warning(f"记忆ID '{memory_id}' 已存在,无法创建。") + return False + + os.makedirs(memory_path) + memory_file = get_memory_data_file_path(memory_id) + + default_content = DEFAULT_MEMORY_CONTENT.copy() + if memory_name: + default_content["name"] = memory_name # 记忆集本身可以有一个名称 + # 更新 long_term_facts 中的默认描述,包含记忆名称 + if default_content["long_term_facts"]: + default_content["long_term_facts"][ + 0 + ] = f"此记忆集名为 '{memory_name}',目前未包含特定用户或AI的长期事实。" + else: + default_content["long_term_facts"].append( + f"此记忆集名为 '{memory_name}',目前未包含特定用户或AI的长期事实。" + ) + + success = _save_json_file(memory_file, default_content) + if success: + logger.info(f"记忆集 '{memory_id}' 创建成功。") + return success + + +def delete_memory(memory_id): + """ + 删除一个记忆集及其所有相关文件。 + :param memory_id: 要删除的记忆ID。 + :return: True如果成功,False如果失败。 + """ + memory_path = os.path.join(get_memories_dir(), memory_id) + if not os.path.exists(memory_path): + logger.warning(f"记忆ID '{memory_id}' 不存在,无法删除。") + return False + + try: + shutil.rmtree(memory_path) + logger.info(f"记忆集 '{memory_id}' 已成功删除。") + return True + except Exception as e: + logger.error(f"删除记忆集 '{memory_id}' 失败: {e}") + return False + + +def memory_exists(memory_id): + """ + 检查指定记忆ID的记忆集是否存在。 + :param memory_id: 记忆ID。 + :return: True如果存在,False如果不存在。 + """ + return os.path.exists(get_memory_data_file_path(memory_id)) + + +# --- 加载/保存当前活动会话的文件 --- + + +def load_active_features(role_id): + """加载当前激活角色的特征数据""" + file_path = get_role_features_file_path(role_id) + return _load_json_file(file_path, DEFAULT_FEATURE_CONTENT.copy()) + + +def save_active_features(role_id, features_data): + """保存当前激活角色的特征数据""" + file_path = get_role_features_file_path(role_id) + return _save_json_file(file_path, features_data) + + +def load_active_memory(memory_id): + """加载当前激活记忆集的记忆数据""" + file_path = get_memory_data_file_path(memory_id) + return _load_json_file(file_path, DEFAULT_MEMORY_CONTENT.copy()) + + +def save_active_memory(memory_id, memory_data): + """保存当前激活记忆集的记忆数据""" + file_path = get_memory_data_file_path(memory_id) + return _save_json_file(file_path, memory_data) + + +def append_chat_log(chat_log_id, entry): + """向当前激活聊天记录文件追加条目""" + file_path = get_chat_log_file_path(chat_log_id) + return _append_jsonl_file(file_path, entry) + + +def read_chat_log(chat_log_id, limit=None): + """读取当前激活聊天记录的所有条目""" + file_path = get_chat_log_file_path(chat_log_id) + return _read_jsonl_file(file_path, limit) + + +def delete_chat_log(chat_log_id): + """删除指定聊天记录文件""" + file_path = get_chat_log_file_path(chat_log_id) + if os.path.exists(file_path): + try: + os.remove(file_path) + logger.info(f"聊天记录文件 '{file_path}' 已删除。") + return True + except OSError as e: + logger.error(f"删除聊天记录文件 '{file_path}' 失败: {e}") + return False + logger.warning(f"聊天记录文件 '{file_path}' 不存在,无需删除。") + return False + + +# --- 确保默认角色和记忆存在 --- + + +def ensure_initial_data(): + """ + 确保 'default_role' 和 'default_memory' 存在,如果它们不存在。 + 这个函数会在应用启动时调用。 + """ + config_data = get_config()["Session"] + default_role_id = config_data["CURRENT_ROLE_ID"] + default_memory_id = config_data["CURRENT_MEMORY_ID"] + + if not os.path.exists(get_role_features_file_path(default_role_id)): + logger.info(f"默认角色 '{default_role_id}' 不存在,正在创建。") + create_role(default_role_id, config_data.get("DEFAULT_ROLE_NAME", "通用AI助手")) + + if not os.path.exists(get_memory_data_file_path(default_memory_id)): + logger.info(f"默认记忆集 '{default_memory_id}' 不存在,正在创建。") + create_memory( + default_memory_id, config_data.get("DEFAULT_MEMORY_NAME", "通用记忆集") + ) diff --git a/gemini_client.py b/gemini_client.py new file mode 100644 index 0000000..9428908 --- /dev/null +++ b/gemini_client.py @@ -0,0 +1,494 @@ +# gemini_client.py + +import requests +import json +import logging +from config import get_config # 导入config模块以获取API配置 +from typing import Union, Tuple, Any, Dict # 新增导入 + +logger = logging.getLogger(__name__) + + +class GeminiClient: + """ + 封装与Google Gemini API的交互。 + 支持直接HTTP请求,以便于使用第三方代理或自定义端点。 + """ + + def __init__(self): + self.config = get_config() + base_url_from_config = self.config["API"]["GEMINI_API_BASE_URL"].rstrip("/") + # 尝试从配置的URL中提取正确的base_url,确保它以 /v1beta 结尾且不包含 /models + if base_url_from_config.endswith("/v1beta"): + self.base_url = base_url_from_config + elif base_url_from_config.endswith("/models"): + # 如果配置的URL以 /models 结尾,则移除 /models 并添加 /v1beta + self.base_url = base_url_from_config.rsplit("/models", 1)[0] + "/v1beta" + elif base_url_from_config.endswith("/v1"): # 兼容旧的 /v1 路径 + self.base_url = base_url_from_config.rsplit("/v1", 1)[0] + "/v1beta" + else: + # 假设是根URL,添加 /v1beta + self.base_url = f"{base_url_from_config}/v1beta" + + self.api_key = self.config["API"].get("GEMINI_API_KEY") + # 确保API Key是字符串并清理可能的空白字符 + if self.api_key: + self.api_key = str(self.api_key).strip() + + # 设置请求头,包含Content-Type和x-goog-api-key + if self.api_key: + self.headers = { + "Content-Type": "application/json", + "x-goog-api-key": self.api_key, + } + else: + self.headers = {"Content-Type": "application/json"} + + if not self.api_key: + raise ValueError("GEMINI_API_KEY 未配置或为空,Gemini客户端无法初始化。") + + logger.info(f"GeminiClient 初始化完成,基础URL: {self.base_url}") + + def _make_request( + self, endpoint: str, model_name: str, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """ + 内部方法:执行实际的HTTP POST请求到Gemini API。 + :param endpoint: API端点,例如 "generateContent" 或 "countTokens"。 + :param model_name: 要使用的模型名称,例如 "gemini-pro"。 + :param payload: 请求体(Python字典),将转换为JSON发送。 + :return: 响应JSON(Python字典)。 + :raises requests.exceptions.RequestException: 如果请求失败。 + """ + # 检查模型名称是否已包含"models/"前缀,避免重复 + if model_name.startswith("models/"): + # 如果已包含前缀,直接使用 + url = f"{self.base_url}/{model_name}:{endpoint}" + else: + # 如果不包含前缀,添加"models/" + url = f"{self.base_url}/models/{model_name}:{endpoint}" + + # 记录请求信息,但隐藏API Key + safe_url = url + logger.debug(f"发送请求到 Gemini API: {safe_url}") + logger.debug( + f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}" + ) + + try: + response = requests.post( + url, + headers=self.headers, # 使用初始化时设置的headers + json=payload, + timeout=90, + ) # 增加超时时间,应对大模型响应慢 + + # 先获取响应文本,避免多次访问响应体 + response_text = response.text + + # 然后检查状态码 + response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError + + # 使用已获取的文本解析JSON,而不是调用response.json() + response_data = json.loads(response_text) + return response_data + except requests.exceptions.HTTPError as http_err: + status_code = http_err.response.status_code if http_err.response else 500 + error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}" + logger.error(error_message) + return {"error": error_message, "status_code": status_code} + except requests.exceptions.ConnectionError as conn_err: + error_message = f"连接错误: {conn_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 503} # Service Unavailable + except requests.exceptions.Timeout as timeout_err: + error_message = f"请求超时: {timeout_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 408} # Request Timeout + except requests.exceptions.RequestException as req_err: + error_message = f"未知请求错误: {req_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 500} # Internal Server Error + except Exception as e: + error_message = f"非请求库异常: {e}" + logger.error(error_message) + return {"error": error_message, "status_code": 500} + + def _make_stream_request( + self, endpoint: str, model_name: str, payload: Dict[str, Any] + ) -> Union[Any, Dict[str, Any]]: # 返回类型调整为Dict[str, Any]以统一错误返回 + """ + 内部方法:执行实际的HTTP POST流式请求到Gemini API。 + :param endpoint: API端点,例如 "generateContent"。 + :param model_name: 要使用的模型名称,例如 "gemini-pro"。 + :param payload: 请求体(Python字典),将转换为JSON发送。 + :return: 响应迭代器或错误信息字典。 + """ + # 检查模型名称是否已包含"models/"前缀,避免重复 + if model_name.startswith("models/"): + # 如果已包含前缀,直接使用 + url = f"{self.base_url}/{model_name}:{endpoint}?stream=true" + else: + # 如果不包含前缀,添加"models/" + url = f"{self.base_url}/models/{model_name}:{endpoint}?stream=true" + + # 记录请求信息,但隐藏API Key + safe_url = url + logger.debug(f"发送流式请求到 Gemini API: {safe_url}") + logger.debug( + f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}" + ) + + try: + response = requests.post( + url, + headers=self.headers, # 使用初始化时设置的headers + json=payload, + stream=True, + timeout=90, + ) + + # 检查状态码但不消费响应体 + response.raise_for_status() + + # 对于流式请求,直接返回解析后的JSON对象迭代器 + # 这不会立即消费响应体,而是在迭代时逐步消费 + def parse_stream(): + logger.info("开始处理流式响应...") + full_response = "" + accumulated_json = "" + + for line in response.iter_lines(): + if not line: + continue + + try: + # 将二进制数据解码为UTF-8字符串 + decoded_line = line.decode("utf-8") + logger.debug(f"原始行数据: {decoded_line[:50]}...") + + # 处理SSE格式 (data: 开头) + if decoded_line.startswith("data: "): + json_str = decoded_line[6:].strip() + if not json_str: # 跳过空数据行 + continue + + try: + chunk_data = json.loads(json_str) + logger.debug( + f"SSE格式JSON数据,键: {list(chunk_data.keys())}" + ) + + # 检查是否包含文本内容 + if ( + "candidates" in chunk_data + and chunk_data["candidates"] + and "content" in chunk_data["candidates"][0] + ): + content = chunk_data["candidates"][0]["content"] + + if "parts" in content and content["parts"]: + text_part = content["parts"][0].get("text", "") + if text_part: + logger.debug( + f"提取到文本: {text_part[:30]}..." + ) + full_response += text_part + yield text_part + except json.JSONDecodeError as e: + logger.warning( + f"SSE格式数据解析JSON失败: {e}, 数据: {json_str[:50]}..." + ) + else: + # 处理非SSE格式 + # 尝试积累JSON直到有效 + accumulated_json += decoded_line + try: + # 尝试解析完整JSON + chunk_data = json.loads(accumulated_json) + logger.debug( + f"解析完整JSON成功,键: {list(chunk_data.keys())}" + ) + + # 检查是否包含文本内容 + if ( + "candidates" in chunk_data + and chunk_data["candidates"] + and "content" in chunk_data["candidates"][0] + ): + content = chunk_data["candidates"][0]["content"] + + if "parts" in content and content["parts"]: + text_part = content["parts"][0].get("text", "") + if text_part: + logger.debug( + f"从完整JSON提取到文本: {text_part[:30]}..." + ) + full_response += text_part + yield text_part + + # 成功解析后重置累积的JSON + accumulated_json = "" + except json.JSONDecodeError: + # 继续积累直到有效JSON + pass + except Exception as e: + logger.error( + f"处理流式响应行时出错: {str(e)}, line: {str(line)[:100]}" + ) + + # 如果有积累的响应但还没有产生任何输出,尝试直接返回 + if full_response: + logger.info(f"流式响应完成,累积了 {len(full_response)} 字符的响应") + else: + logger.warning("流式响应完成,但没有提取到任何有效文本内容") + # 尝试返回一个默认响应,避免空响应 + yield "抱歉,处理响应时遇到了问题,请重试。" + + return parse_stream() + except requests.exceptions.HTTPError as http_err: + status_code = http_err.response.status_code if http_err.response else 500 + error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}" + logger.error(error_message) + return {"error": error_message, "status_code": status_code} + except requests.exceptions.ConnectionError as conn_err: + error_message = f"连接错误: {conn_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 503} + except requests.exceptions.Timeout as timeout_err: + error_message = f"请求超时: {timeout_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 408} + except requests.exceptions.RequestException as req_err: + error_message = f"未知请求错误: {req_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 500} + except Exception as e: + error_message = f"非请求库异常: {e}" + logger.error(error_message) + return {"error": error_message, "status_code": 500} + + def _make_get_request(self, endpoint: str) -> Dict[str, Any]: + """ + 内部方法:执行实际的HTTP GET请求到Gemini API。 + :param endpoint: API端点,例如 "models"。 + :return: 响应JSON(Python字典)。 + :raises requests.exceptions.RequestException: 如果请求失败。 + """ + if not self.api_key: + return { + "error": "API Key 未配置,无法发送请求。", + "status_code": 401, + } # Unauthorized + + # API Key 通过请求头传递,URL中不再包含 + url = f"{self.base_url}/{endpoint}" + + # 记录请求信息,但隐藏API Key + safe_url = url + logger.debug(f"发送GET请求到 Gemini API: {safe_url}") + + try: + response = requests.get( + url, headers=self.headers, timeout=30 # 使用初始化时设置的headers + ) + + # 先获取响应文本,避免多次访问响应体 + response_text = response.text + + # 然后检查状态码 + response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError + + # 使用已获取的文本解析JSON,而不是调用response.json() + response_data = json.loads(response_text) + return response_data + except requests.exceptions.HTTPError as http_err: + status_code = http_err.response.status_code if http_err.response else 500 + error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}" + logger.error(error_message) + return {"error": error_message, "status_code": status_code} + except requests.exceptions.ConnectionError as conn_err: + error_message = f"连接错误: {conn_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 503} # Service Unavailable + except requests.exceptions.Timeout as timeout_err: + error_message = f"请求超时: {timeout_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 408} # Request Timeout + except requests.exceptions.RequestException as req_err: + error_message = f"未知请求错误: {req_err} - URL: {url}" + logger.error(error_message) + return {"error": error_message, "status_code": 500} # Internal Server Error + except Exception as e: + error_message = f"非请求库异常: {e}" + logger.error(error_message) + return {"error": error_message, "status_code": 500} + + def generate_content( + self, + model_name: str, + contents: list[Dict[str, Any]], + system_instruction: str | None = None, + stream: bool = False, # 新增 stream 参数 + ) -> Union[str, Tuple[str, int], Any]: # 返回类型调整以适应流式响应 + """ + 与Gemini模型进行内容生成对话。 + :param model_name: 要使用的模型名称。 + :param contents: 聊天历史和当前用户消息组成的列表,符合Gemini API的"contents"格式。 + 例如:[{"role": "user", "parts": [{"text": "你好"}]}, ...] + :param system_instruction: 系统的指示语,作为单独的参数传入。 + 在Gemini API中,它是请求体的一个顶级字段。 + :param stream: 是否以流式方式获取响应。 + :return: 如果 stream 为 False,返回模型回复文本或错误信息。 + 如果 stream 为 True,返回一个生成器,每次 yield 一个文本片段。 + """ + # 根据Gemini API的预期格式构造payload + # 如果contents是空的,我们构造一个带有role的内容数组 + if not contents: + if system_instruction: + payload: Dict[str, Any] = { + "contents": [ + {"role": "user", "parts": [{"text": system_instruction}]} + ] + } + else: + payload: Dict[str, Any] = { + "contents": [{"role": "user", "parts": [{"text": ""}]}] + } + else: + # 使用提供的contents + payload: Dict[str, Any] = {"contents": contents} + + # 如果有系统指令,我们添加到请求中 + if system_instruction: + # Gemini API的systemInstruction是一个顶级字段,值是一个Content对象 + payload["systemInstruction"] = {"parts": [{"text": system_instruction}]} + + if stream: + response_iterator = self._make_stream_request( + "generateContent", model_name, payload + ) + if isinstance(response_iterator, dict) and "error" in response_iterator: + # 如果流式请求本身返回错误,则直接返回错误信息 + return str(response_iterator["error"]), int( + response_iterator.get("status_code", 500) + ) + + def stream_generator(): + for text_chunk in response_iterator: + # _make_stream_request 已经处理了JSON解析并提取了文本 + # 所以这里直接yield文本块即可 + if text_chunk: + yield text_chunk + logger.debug("流式响应结束。") + + return stream_generator() + else: + response_json = self._make_request("generateContent", model_name, payload) + + if "error" in response_json: + # 如果 _make_request 返回了错误,直接返回错误信息和状态码 + return str(response_json["error"]), int( + response_json.get("status_code", 500) + ) + + if "candidates" in response_json and len(response_json["candidates"]) > 0: + first_candidate = response_json["candidates"][0] + if ( + "content" in first_candidate + and "parts" in first_candidate["content"] + and len(first_candidate["content"]["parts"]) > 0 + ): + # 返回第一个part的文本内容 + return first_candidate["content"]["parts"][0]["text"] + + # 如果响应中包含拦截信息 + if ( + "promptFeedback" in response_json + and "blockReason" in response_json["promptFeedback"] + ): + reason = response_json["promptFeedback"]["blockReason"] + logger.warning(f"Gemini API 阻止了响应,原因: {reason}") + return ( + f"抱歉,你的请求被模型策略阻止了。原因: {reason}", + 400, + ) # 400 Bad Request for policy block + + # 如果没有candidates且没有拦截信息,可能存在未知响应结构 + logger.warning( + f"Gemini API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}" + ) + return "抱歉,AI未能生成有效回复,请稍后再试或检查日志。", 500 + + def count_tokens( + self, + model_name: str, + contents: list[Dict[str, Any]], + system_instruction: str | None = None, + ) -> int: + """ + 统计给定内容在特定模型中的token数量。 + :param model_name: 要使用的模型名称。 + :param contents: 要计算token的内容。 + :param system_instruction: 系统指令。 + :return: token数量(整数)或-1(如果发生错误)。 + """ + # 根据Gemini API的预期格式构造payload,确保包含role字段 + if not contents: + if system_instruction: + payload: Dict[str, Any] = { + "contents": [ + {"role": "user", "parts": [{"text": system_instruction}]} + ] + } + else: + payload: Dict[str, Any] = { + "contents": [{"role": "user", "parts": [{"text": ""}]}] + } + else: + payload: Dict[str, Any] = {"contents": contents} # 明确声明payload类型 + if system_instruction: + payload["systemInstruction"] = {"parts": [{"text": system_instruction}]} + + response_json = self._make_request("countTokens", model_name, payload) + + if "error" in response_json: + return -1 # 发生错误时返回-1 + + if "totalTokens" in response_json: + return response_json["totalTokens"] + logger.warning( + f"countTokens API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}" + ) + return -1 + + def get_models(self) -> Dict[str, Any]: + """ + 获取Gemini API可用的模型列表。 + :return: 包含模型名称的列表,如果发生错误则为空列表。 + """ + logger.info("从 Gemini API 获取模型列表。") + response_json = self._make_get_request("models") + + if "error" in response_json: + logger.error(f"获取模型列表失败: {response_json['error']}") + return { + "models": [], + "error": response_json["error"], + "status_code": response_json["status_code"], + } + + available_models = [] + for m in response_json.get("models", []): + # 过滤出支持 'generateContent' 方法的模型,并去除被阻止的模型 + # 'SUPPORTED' 状态表示模型可用 + if ( + "name" in m + and "supportedGenerationMethods" in m + and "generateContent" in m["supportedGenerationMethods"] + and m.get("lifecycleState", "UNSPECIFIED") == "ACTIVE" + ): # 确保模型是活跃的 + model_name = m["name"].split("/")[-1] # 提取 'gemini-pro' 部分 + available_models.append(model_name) + logger.info(f"成功获取到 {len(available_models)} 个可用模型。") + return {"models": available_models}