上传文件至 /

This commit is contained in:
2025-06-06 16:41:50 +08:00
commit ffdeafa791
5 changed files with 2146 additions and 0 deletions

1032
app.py Normal file

File diff suppressed because it is too large Load Diff

20
config.ini Normal file
View File

@ -0,0 +1,20 @@
[API]
gemini_api_base_url = https://g.shatang.me/v1beta
gemini_api_key =
default_gemini_model = models/gemini-2.0-flash
memory_update_model = models/gemini-2.0-flash
[Application]
context_window_size = 5
memory_retention_turns = 10
max_short_term_events = 15
features_dir = data/features
memories_dir = data/memories
chat_logs_dir = data/chat_logs
app_log_file_path = logs/app.log
[Session]
current_role_id = default_role
current_memory_id = default_memory
current_chat_log_id = default_role_default_memory_ce76c8a8-4072-44f4-8f1b-d942166b62e1

168
config.py Normal file
View File

@ -0,0 +1,168 @@
# config.py
import configparser
import os
import logging
# 获取根日志器确保日志配置已在utils.py中完成
logger = logging.getLogger(__name__)
CONFIG_FILE = "config.ini"
# 定义默认配置如果config.ini不存在或缺失将使用这些值
DEFAULT_CONFIG = {
"API": {
"GEMINI_API_BASE_URL": "https://generativelanguage.googleapis.com/v1beta",
"GEMINI_API_KEY": "YOUR_GEMINI_API_KEY_HERE", # 默认值,强烈建议用户手动修改
"DEFAULT_GEMINI_MODEL": "gemini-pro",
"MEMORY_UPDATE_MODEL": "gemini-pro",
},
"Application": {
"CONTEXT_WINDOW_SIZE": "6",
"MEMORY_RETENTION_TURNS": "18",
"MAX_SHORT_TERM_EVENTS": "20",
"FEATURES_DIR": "data/features",
"MEMORIES_DIR": "data/memories",
"CHAT_LOGS_DIR": "data/chat_logs",
"APP_LOG_FILE_PATH": "logs/app.log",
},
"Session": {
"CURRENT_ROLE_ID": "default_role",
"CURRENT_MEMORY_ID": "default_memory",
},
}
def load_config():
"""
从 config.ini 文件加载配置。
如果文件不存在或某些配置项缺失,则使用默认值并尝试创建/更新文件。
:return: 包含所有配置的字典。
"""
config = configparser.ConfigParser()
# 检查config.ini是否存在如果不存在则创建并写入默认配置
if not os.path.exists(CONFIG_FILE):
logger.info(f"config.ini 文件不存在,正在创建默认配置文件:{CONFIG_FILE}")
for section, options in DEFAULT_CONFIG.items():
config.add_section(section)
for key, value in options.items():
config.set(section, key, value)
try:
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
config.write(f)
logger.info("默认 config.ini 文件创建成功。")
except IOError as e:
logger.error(f"无法写入 config.ini 文件: {e}")
# 即使写入失败,也继续使用内存中的默认配置
else:
logger.info(f"{CONFIG_FILE} 加载配置...")
# 读取配置
try:
config.read(CONFIG_FILE, encoding="utf-8")
except Exception as e:
logger.error(f"读取 config.ini 文件时发生错误: {e},将使用默认配置。")
# 如果读取失败,清空 configparser 对象,让后续逻辑使用默认值
config = configparser.ConfigParser()
# 将配置转换为字典格式,并确保所有默认值都存在
app_config = {}
for section, options in DEFAULT_CONFIG.items():
app_config[section] = {}
for key, default_value in options.items():
# 使用 get() 方法获取值,如果不存在则使用默认值
# 对于整型或布尔型配置,需要手动转换
value = config.get(section, key, fallback=default_value)
# 尝试将特定配置项转换为数字类型
if section == "Application" and key in [
"CONTEXT_WINDOW_SIZE",
"MEMORY_RETENTION_TURNS",
"MAX_SHORT_TERM_EVENTS",
]:
try:
app_config[section][key] = int(value)
except ValueError:
logger.warning(
f"配置项 {section}.{key} 的值 '{value}' 不是有效数字,使用默认值 {default_value}"
)
app_config[section][key] = int(default_value)
elif section == "API" and key == "GEMINI_API_KEY":
# 优先从环境变量读取 GEMINI_API_KEY
env_key = os.getenv("GEMINI_API_KEY")
if env_key:
app_config[section][key] = env_key
logger.info("已从环境变量加载 GEMINI_API_KEY。")
else:
app_config[section][key] = value
else:
app_config[section][key] = value
# 如果存在config.ini但缺少某些默认值更新它
try:
updated = False
for section, options in DEFAULT_CONFIG.items():
if not config.has_section(section):
config.add_section(section)
updated = True
for key, default_value in options.items():
if not config.has_option(section, key):
config.set(section, key, str(default_value)) # 确保写入的是字符串
updated = True
if updated:
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
config.write(f)
logger.info("config.ini 文件已更新,补齐缺失配置项。")
except IOError as e:
logger.error(f"无法更新 config.ini 文件以补齐缺失配置项: {e}")
logger.info("配置加载完成。")
return app_config
def save_config(new_config):
"""
将给定的配置字典保存到 config.ini 文件。
它会读取现有配置,然后只更新 new_config 中提供的部分。
:param new_config: 包含要保存的所有配置的字典。
"""
config = configparser.ConfigParser()
# 首先读取现有配置,以保留未在 new_config 中提供的项
if os.path.exists(CONFIG_FILE):
try:
config.read(CONFIG_FILE, encoding="utf-8")
except Exception as e:
logger.error(f"读取现有 config.ini 文件时发生错误: {e},将只保存新配置。")
for section, options in new_config.items():
if not config.has_section(section):
config.add_section(section)
for key, value in options.items():
# 确保将值转换为字符串,因为 configparser 存储的是字符串
config.set(section, key, str(value))
try:
with open(CONFIG_FILE, "w", encoding="utf-8") as f:
config.write(f)
logger.info(f"配置已成功保存到 {CONFIG_FILE}")
except IOError as e:
logger.error(f"保存配置到 {CONFIG_FILE} 失败: {e}")
# 全局变量,存储当前加载的配置
app_config = load_config()
def get_config():
"""获取当前加载的配置"""
return app_config
def set_config(new_config):
"""更新内存中的配置并保存到文件"""
global app_config
app_config = new_config
save_config(app_config)
logger.info("内存配置已更新并保存。")

432
file_manager.py Normal file
View File

@ -0,0 +1,432 @@
# file_manager.py
import json
import os
import uuid
import shutil # 用于删除目录及其内容
import logging
from config import get_config # 导入config模块的获取配置函数
logger = logging.getLogger(__name__)
# --- 辅助函数JSON 和 JSONL 文件操作 ---
def _load_json_file(file_path, default_content=None):
"""
加载JSON文件。如果文件不存在则创建并写入默认内容如果提供
:param file_path: JSON文件路径。
:param default_content: 文件不存在时要写入的默认Python对象。
:return: 文件内容Python对象或None如果加载失败
"""
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在
if default_content is not None:
logger.info(f"文件不存在,创建并写入默认内容到:{file_path}")
_save_json_file(file_path, default_content)
else:
logger.warning(f"尝试加载的JSON文件不存在且未提供默认内容{file_path}")
return {} # 返回空字典,以提高健壮性
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
logger.error(f"JSON文件解码失败{file_path} - {e}")
return None
except IOError as e:
logger.error(f"读取JSON文件失败{file_path} - {e}")
return None
def _save_json_file(file_path, data):
"""
保存Python对象到JSON文件。
:param file_path: JSON文件路径。
:param data: 要保存的Python对象。
"""
os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在
try:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.debug(f"JSON文件已保存{file_path}")
return True
except IOError as e:
logger.error(f"保存JSON文件失败{file_path} - {e}")
return False
def _append_jsonl_file(file_path, data):
"""
追加一条JSON记录到JSONL文件。
:param file_path: JSONL文件路径。
:param data: 要追加的Python对象。
"""
os.makedirs(os.path.dirname(file_path), exist_ok=True) # 确保目录存在
try:
with open(file_path, "a", encoding="utf-8") as f:
f.write(json.dumps(data, ensure_ascii=False) + "\n")
logger.debug(f"JSONL记录已追加到{file_path}")
return True
except IOError as e:
logger.error(f"追加JSONL文件失败{file_path} - {e}")
return False
def _read_jsonl_file(file_path, limit=None):
"""
读取JSONL文件的所有记录。
:param file_path: JSONL文件路径。
:param limit: 可选限制读取的最后N条记录。
:return: 包含所有记录的列表。
"""
if not os.path.exists(file_path):
return []
records = []
try:
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
if limit:
lines = lines[-limit:] # 只读取最后N行
for line in lines:
try:
records.append(json.loads(line.strip()))
except json.JSONDecodeError as e:
logger.warning(
f"跳过JSONL文件中的损坏行{file_path} - {line.strip()} - {e}"
)
return records
except IOError as e:
logger.error(f"读取JSONL文件失败{file_path} - {e}")
return []
# --- 默认文件内容 ---
# 默认特征文件内容
DEFAULT_FEATURE_CONTENT = {
"角色名称": "通用AI助手",
"角色描述": "你是一个有帮助的AI助手旨在提供清晰、准确、友好的回答。你没有预设的特定角色或个性能够灵活适应各种对话主题。",
"心理特征": {
"核心哲学与世界观": {"描述": "以提供信息、解决问题为核心,保持中立客观。"},
"价值观与道德观": {"描述": "遵循伦理原则,避免偏见,确保信息准确。"},
"决策风格": {"描述": "基于逻辑和可用信息进行决策。"},
"情绪反应与表达": {"描述": "保持平静、专业、不带个人情感。"},
"人际互动与关系处理": {"描述": "以帮助用户为目标,保持礼貌和尊重。"},
"动机与目标": {"描述": "提供有价值的信息和解决方案。"},
"自我认知": {"描述": "认知自己是AI不具备人类情感和意识。"},
"应对机制": {"描述": "遇到未知问题时,会请求更多信息或承认能力限制。"},
},
"语言特征": {
"词汇与措辞": {"描述": "清晰、简洁、准确,避免模糊不清的表达。"},
"句式结构与复杂度": {"描述": "常用清晰直接的句式,逻辑性强。"},
"语气与风格": {"描述": "专业、友善、乐于助人。"},
"修辞手法与模式": {"描述": "多使用直接陈述、解释说明。"},
"互动模式": {"描述": "响应式,根据用户输入提供信息或引导对话。"},
},
"重要人际关系": [],
"角色弧线总结": "作为一个持续学习和改进的AI不断提升服务质量。",
}
# 默认记忆文件内容
DEFAULT_MEMORY_CONTENT = {
"long_term_facts": ["此记忆集为默认记忆未包含特定用户或AI的长期事实。"],
"short_term_events": [],
}
# --- 路径管理函数 ---
def _get_dir_path(base_dir_config_key):
"""根据配置获取目录路径"""
config = get_config()
return config["Application"][base_dir_config_key]
def get_features_dir():
"""获取特征文件根目录"""
return _get_dir_path("FEATURES_DIR")
def get_memories_dir():
"""获取记忆文件根目录"""
return _get_dir_path("MEMORIES_DIR")
def get_chat_logs_dir():
"""获取聊天记录文件根目录"""
return _get_dir_path("CHAT_LOGS_DIR")
def get_role_features_file_path(role_id):
"""获取指定角色ID的特征文件路径"""
return os.path.join(get_features_dir(), role_id, "features.json")
def get_memory_data_file_path(memory_id):
"""获取指定记忆ID的记忆数据文件路径"""
return os.path.join(get_memories_dir(), memory_id, "memory.json")
def get_chat_log_file_path(chat_log_id):
"""获取指定聊天记录ID的聊天记录文件路径"""
# 聊天记录ID通常与角色ID和记忆ID关联或者直接就是session ID
# 这里简化为直接以chat_log_id为文件名
return os.path.join(get_chat_logs_dir(), f"{chat_log_id}.jsonl")
# --- 角色和记忆管理 ---
def list_roles():
"""
列出所有可用的角色ID和名称。
角色名称从其features.json中获取。
:return: 列表,每个元素为 {"id": "role_id", "name": "角色名称"}
"""
roles = []
features_dir = get_features_dir()
if not os.path.exists(features_dir):
os.makedirs(features_dir) # 确保目录存在
return []
for role_id in os.listdir(features_dir):
role_path = os.path.join(features_dir, role_id)
features_file = get_role_features_file_path(role_id)
if os.path.isdir(role_path) and os.path.exists(features_file):
features = _load_json_file(features_file)
if features:
roles.append({"id": role_id, "name": features.get("角色名称", role_id)})
else:
logger.warning(
f"无法加载角色 {role_id} 的 features.json可能文件损坏。"
)
else:
logger.warning(f"目录 {role_path} 或文件 {features_file} 不完整,跳过。")
return roles
def create_role(role_id, role_name=None):
"""
创建一个新的角色。
:param role_id: 新角色的唯一ID。
:param role_name: 新角色的名称如果未提供则默认为role_id。
:return: True如果成功False如果失败或ID已存在。
"""
role_path = os.path.join(get_features_dir(), role_id)
if os.path.exists(role_path):
logger.warning(f"角色ID '{role_id}' 已存在,无法创建。")
return False
os.makedirs(role_path)
features_file = get_role_features_file_path(role_id)
default_content = DEFAULT_FEATURE_CONTENT.copy()
default_content["角色名称"] = role_name if role_name else role_id
success = _save_json_file(features_file, default_content)
if success:
logger.info(f"角色 '{role_id}' 创建成功。")
return success
def delete_role(role_id):
"""
删除一个角色及其所有相关文件。
:param role_id: 要删除的角色ID。
:return: True如果成功False如果失败。
"""
role_path = os.path.join(get_features_dir(), role_id)
if not os.path.exists(role_path):
logger.warning(f"角色ID '{role_id}' 不存在,无法删除。")
return False
try:
shutil.rmtree(role_path)
logger.info(f"角色 '{role_id}' 已成功删除。")
return True
except Exception as e:
logger.error(f"删除角色 '{role_id}' 失败: {e}")
return False
def role_exists(role_id):
"""
检查指定角色ID的角色是否存在。
:param role_id: 角色ID。
:return: True如果存在False如果不存在。
"""
return os.path.exists(get_role_features_file_path(role_id))
def list_memories():
"""
列出所有可用的记忆ID和名称。
记忆名称从其memory.json中获取如果模型能生成name字段否则使用ID
:return: 列表,每个元素为 {"id": "memory_id", "name": "记忆名称"}
"""
memories = []
memories_dir = get_memories_dir()
if not os.path.exists(memories_dir):
os.makedirs(memories_dir) # 确保目录存在
return []
for memory_id in os.listdir(memories_dir):
memory_path = os.path.join(memories_dir, memory_id)
memory_file = get_memory_data_file_path(memory_id)
if os.path.isdir(memory_path) and os.path.exists(memory_file):
memory_data = _load_json_file(memory_file)
# 确保 memory_data 不是 None 且不是空字典
if memory_data is not None and memory_data != {}:
# 假设 memory.json 中可能有一个 'name' 字段来标识记忆集
memories.append(
{"id": memory_id, "name": memory_data.get("name", memory_id)}
)
else:
logger.warning(
f"无法加载记忆 {memory_id} 的 memory.json可能文件损或为空或内容为空。"
)
else:
logger.warning(f"目录 {memory_path} 或文件 {memory_file} 不完整,跳过。")
return memories
def create_memory(memory_id, memory_name=None):
"""
创建一个新的记忆集。
:param memory_id: 新记忆集的唯一ID。
:param memory_name: 新记忆集的名称如果未提供则默认为memory_id。
:return: True如果成功False如果失败或ID已存在。
"""
memory_path = os.path.join(get_memories_dir(), memory_id)
if os.path.exists(memory_path):
logger.warning(f"记忆ID '{memory_id}' 已存在,无法创建。")
return False
os.makedirs(memory_path)
memory_file = get_memory_data_file_path(memory_id)
default_content = DEFAULT_MEMORY_CONTENT.copy()
if memory_name:
default_content["name"] = memory_name # 记忆集本身可以有一个名称
# 更新 long_term_facts 中的默认描述,包含记忆名称
if default_content["long_term_facts"]:
default_content["long_term_facts"][
0
] = f"此记忆集名为 '{memory_name}'目前未包含特定用户或AI的长期事实。"
else:
default_content["long_term_facts"].append(
f"此记忆集名为 '{memory_name}'目前未包含特定用户或AI的长期事实。"
)
success = _save_json_file(memory_file, default_content)
if success:
logger.info(f"记忆集 '{memory_id}' 创建成功。")
return success
def delete_memory(memory_id):
"""
删除一个记忆集及其所有相关文件。
:param memory_id: 要删除的记忆ID。
:return: True如果成功False如果失败。
"""
memory_path = os.path.join(get_memories_dir(), memory_id)
if not os.path.exists(memory_path):
logger.warning(f"记忆ID '{memory_id}' 不存在,无法删除。")
return False
try:
shutil.rmtree(memory_path)
logger.info(f"记忆集 '{memory_id}' 已成功删除。")
return True
except Exception as e:
logger.error(f"删除记忆集 '{memory_id}' 失败: {e}")
return False
def memory_exists(memory_id):
"""
检查指定记忆ID的记忆集是否存在。
:param memory_id: 记忆ID。
:return: True如果存在False如果不存在。
"""
return os.path.exists(get_memory_data_file_path(memory_id))
# --- 加载/保存当前活动会话的文件 ---
def load_active_features(role_id):
"""加载当前激活角色的特征数据"""
file_path = get_role_features_file_path(role_id)
return _load_json_file(file_path, DEFAULT_FEATURE_CONTENT.copy())
def save_active_features(role_id, features_data):
"""保存当前激活角色的特征数据"""
file_path = get_role_features_file_path(role_id)
return _save_json_file(file_path, features_data)
def load_active_memory(memory_id):
"""加载当前激活记忆集的记忆数据"""
file_path = get_memory_data_file_path(memory_id)
return _load_json_file(file_path, DEFAULT_MEMORY_CONTENT.copy())
def save_active_memory(memory_id, memory_data):
"""保存当前激活记忆集的记忆数据"""
file_path = get_memory_data_file_path(memory_id)
return _save_json_file(file_path, memory_data)
def append_chat_log(chat_log_id, entry):
"""向当前激活聊天记录文件追加条目"""
file_path = get_chat_log_file_path(chat_log_id)
return _append_jsonl_file(file_path, entry)
def read_chat_log(chat_log_id, limit=None):
"""读取当前激活聊天记录的所有条目"""
file_path = get_chat_log_file_path(chat_log_id)
return _read_jsonl_file(file_path, limit)
def delete_chat_log(chat_log_id):
"""删除指定聊天记录文件"""
file_path = get_chat_log_file_path(chat_log_id)
if os.path.exists(file_path):
try:
os.remove(file_path)
logger.info(f"聊天记录文件 '{file_path}' 已删除。")
return True
except OSError as e:
logger.error(f"删除聊天记录文件 '{file_path}' 失败: {e}")
return False
logger.warning(f"聊天记录文件 '{file_path}' 不存在,无需删除。")
return False
# --- 确保默认角色和记忆存在 ---
def ensure_initial_data():
"""
确保 'default_role''default_memory' 存在,如果它们不存在。
这个函数会在应用启动时调用。
"""
config_data = get_config()["Session"]
default_role_id = config_data["CURRENT_ROLE_ID"]
default_memory_id = config_data["CURRENT_MEMORY_ID"]
if not os.path.exists(get_role_features_file_path(default_role_id)):
logger.info(f"默认角色 '{default_role_id}' 不存在,正在创建。")
create_role(default_role_id, config_data.get("DEFAULT_ROLE_NAME", "通用AI助手"))
if not os.path.exists(get_memory_data_file_path(default_memory_id)):
logger.info(f"默认记忆集 '{default_memory_id}' 不存在,正在创建。")
create_memory(
default_memory_id, config_data.get("DEFAULT_MEMORY_NAME", "通用记忆集")
)

494
gemini_client.py Normal file
View File

@ -0,0 +1,494 @@
# gemini_client.py
import requests
import json
import logging
from config import get_config # 导入config模块以获取API配置
from typing import Union, Tuple, Any, Dict # 新增导入
logger = logging.getLogger(__name__)
class GeminiClient:
"""
封装与Google Gemini API的交互。
支持直接HTTP请求以便于使用第三方代理或自定义端点。
"""
def __init__(self):
self.config = get_config()
base_url_from_config = self.config["API"]["GEMINI_API_BASE_URL"].rstrip("/")
# 尝试从配置的URL中提取正确的base_url确保它以 /v1beta 结尾且不包含 /models
if base_url_from_config.endswith("/v1beta"):
self.base_url = base_url_from_config
elif base_url_from_config.endswith("/models"):
# 如果配置的URL以 /models 结尾,则移除 /models 并添加 /v1beta
self.base_url = base_url_from_config.rsplit("/models", 1)[0] + "/v1beta"
elif base_url_from_config.endswith("/v1"): # 兼容旧的 /v1 路径
self.base_url = base_url_from_config.rsplit("/v1", 1)[0] + "/v1beta"
else:
# 假设是根URL添加 /v1beta
self.base_url = f"{base_url_from_config}/v1beta"
self.api_key = self.config["API"].get("GEMINI_API_KEY")
# 确保API Key是字符串并清理可能的空白字符
if self.api_key:
self.api_key = str(self.api_key).strip()
# 设置请求头包含Content-Type和x-goog-api-key
if self.api_key:
self.headers = {
"Content-Type": "application/json",
"x-goog-api-key": self.api_key,
}
else:
self.headers = {"Content-Type": "application/json"}
if not self.api_key:
raise ValueError("GEMINI_API_KEY 未配置或为空Gemini客户端无法初始化。")
logger.info(f"GeminiClient 初始化完成基础URL: {self.base_url}")
def _make_request(
self, endpoint: str, model_name: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""
内部方法执行实际的HTTP POST请求到Gemini API。
:param endpoint: API端点例如 "generateContent""countTokens"
:param model_name: 要使用的模型名称,例如 "gemini-pro"
:param payload: 请求体Python字典将转换为JSON发送。
:return: 响应JSONPython字典
:raises requests.exceptions.RequestException: 如果请求失败。
"""
# 检查模型名称是否已包含"models/"前缀,避免重复
if model_name.startswith("models/"):
# 如果已包含前缀,直接使用
url = f"{self.base_url}/{model_name}:{endpoint}"
else:
# 如果不包含前缀,添加"models/"
url = f"{self.base_url}/models/{model_name}:{endpoint}"
# 记录请求信息但隐藏API Key
safe_url = url
logger.debug(f"发送请求到 Gemini API: {safe_url}")
logger.debug(
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
)
try:
response = requests.post(
url,
headers=self.headers, # 使用初始化时设置的headers
json=payload,
timeout=90,
) # 增加超时时间,应对大模型响应慢
# 先获取响应文本,避免多次访问响应体
response_text = response.text
# 然后检查状态码
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
# 使用已获取的文本解析JSON而不是调用response.json()
response_data = json.loads(response_text)
return response_data
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else 500
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
logger.error(error_message)
return {"error": error_message, "status_code": status_code}
except requests.exceptions.ConnectionError as conn_err:
error_message = f"连接错误: {conn_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 503} # Service Unavailable
except requests.exceptions.Timeout as timeout_err:
error_message = f"请求超时: {timeout_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 408} # Request Timeout
except requests.exceptions.RequestException as req_err:
error_message = f"未知请求错误: {req_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 500} # Internal Server Error
except Exception as e:
error_message = f"非请求库异常: {e}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
def _make_stream_request(
self, endpoint: str, model_name: str, payload: Dict[str, Any]
) -> Union[Any, Dict[str, Any]]: # 返回类型调整为Dict[str, Any]以统一错误返回
"""
内部方法执行实际的HTTP POST流式请求到Gemini API。
:param endpoint: API端点例如 "generateContent"
:param model_name: 要使用的模型名称,例如 "gemini-pro"
:param payload: 请求体Python字典将转换为JSON发送。
:return: 响应迭代器或错误信息字典。
"""
# 检查模型名称是否已包含"models/"前缀,避免重复
if model_name.startswith("models/"):
# 如果已包含前缀,直接使用
url = f"{self.base_url}/{model_name}:{endpoint}?stream=true"
else:
# 如果不包含前缀,添加"models/"
url = f"{self.base_url}/models/{model_name}:{endpoint}?stream=true"
# 记录请求信息但隐藏API Key
safe_url = url
logger.debug(f"发送流式请求到 Gemini API: {safe_url}")
logger.debug(
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
)
try:
response = requests.post(
url,
headers=self.headers, # 使用初始化时设置的headers
json=payload,
stream=True,
timeout=90,
)
# 检查状态码但不消费响应体
response.raise_for_status()
# 对于流式请求直接返回解析后的JSON对象迭代器
# 这不会立即消费响应体,而是在迭代时逐步消费
def parse_stream():
logger.info("开始处理流式响应...")
full_response = ""
accumulated_json = ""
for line in response.iter_lines():
if not line:
continue
try:
# 将二进制数据解码为UTF-8字符串
decoded_line = line.decode("utf-8")
logger.debug(f"原始行数据: {decoded_line[:50]}...")
# 处理SSE格式 (data: 开头)
if decoded_line.startswith("data: "):
json_str = decoded_line[6:].strip()
if not json_str: # 跳过空数据行
continue
try:
chunk_data = json.loads(json_str)
logger.debug(
f"SSE格式JSON数据键: {list(chunk_data.keys())}"
)
# 检查是否包含文本内容
if (
"candidates" in chunk_data
and chunk_data["candidates"]
and "content" in chunk_data["candidates"][0]
):
content = chunk_data["candidates"][0]["content"]
if "parts" in content and content["parts"]:
text_part = content["parts"][0].get("text", "")
if text_part:
logger.debug(
f"提取到文本: {text_part[:30]}..."
)
full_response += text_part
yield text_part
except json.JSONDecodeError as e:
logger.warning(
f"SSE格式数据解析JSON失败: {e}, 数据: {json_str[:50]}..."
)
else:
# 处理非SSE格式
# 尝试积累JSON直到有效
accumulated_json += decoded_line
try:
# 尝试解析完整JSON
chunk_data = json.loads(accumulated_json)
logger.debug(
f"解析完整JSON成功键: {list(chunk_data.keys())}"
)
# 检查是否包含文本内容
if (
"candidates" in chunk_data
and chunk_data["candidates"]
and "content" in chunk_data["candidates"][0]
):
content = chunk_data["candidates"][0]["content"]
if "parts" in content and content["parts"]:
text_part = content["parts"][0].get("text", "")
if text_part:
logger.debug(
f"从完整JSON提取到文本: {text_part[:30]}..."
)
full_response += text_part
yield text_part
# 成功解析后重置累积的JSON
accumulated_json = ""
except json.JSONDecodeError:
# 继续积累直到有效JSON
pass
except Exception as e:
logger.error(
f"处理流式响应行时出错: {str(e)}, line: {str(line)[:100]}"
)
# 如果有积累的响应但还没有产生任何输出,尝试直接返回
if full_response:
logger.info(f"流式响应完成,累积了 {len(full_response)} 字符的响应")
else:
logger.warning("流式响应完成,但没有提取到任何有效文本内容")
# 尝试返回一个默认响应,避免空响应
yield "抱歉,处理响应时遇到了问题,请重试。"
return parse_stream()
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else 500
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
logger.error(error_message)
return {"error": error_message, "status_code": status_code}
except requests.exceptions.ConnectionError as conn_err:
error_message = f"连接错误: {conn_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 503}
except requests.exceptions.Timeout as timeout_err:
error_message = f"请求超时: {timeout_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 408}
except requests.exceptions.RequestException as req_err:
error_message = f"未知请求错误: {req_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
except Exception as e:
error_message = f"非请求库异常: {e}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
def _make_get_request(self, endpoint: str) -> Dict[str, Any]:
"""
内部方法执行实际的HTTP GET请求到Gemini API。
:param endpoint: API端点例如 "models"
:return: 响应JSONPython字典
:raises requests.exceptions.RequestException: 如果请求失败。
"""
if not self.api_key:
return {
"error": "API Key 未配置,无法发送请求。",
"status_code": 401,
} # Unauthorized
# API Key 通过请求头传递URL中不再包含
url = f"{self.base_url}/{endpoint}"
# 记录请求信息但隐藏API Key
safe_url = url
logger.debug(f"发送GET请求到 Gemini API: {safe_url}")
try:
response = requests.get(
url, headers=self.headers, timeout=30 # 使用初始化时设置的headers
)
# 先获取响应文本,避免多次访问响应体
response_text = response.text
# 然后检查状态码
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
# 使用已获取的文本解析JSON而不是调用response.json()
response_data = json.loads(response_text)
return response_data
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else 500
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
logger.error(error_message)
return {"error": error_message, "status_code": status_code}
except requests.exceptions.ConnectionError as conn_err:
error_message = f"连接错误: {conn_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 503} # Service Unavailable
except requests.exceptions.Timeout as timeout_err:
error_message = f"请求超时: {timeout_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 408} # Request Timeout
except requests.exceptions.RequestException as req_err:
error_message = f"未知请求错误: {req_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 500} # Internal Server Error
except Exception as e:
error_message = f"非请求库异常: {e}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
def generate_content(
self,
model_name: str,
contents: list[Dict[str, Any]],
system_instruction: str | None = None,
stream: bool = False, # 新增 stream 参数
) -> Union[str, Tuple[str, int], Any]: # 返回类型调整以适应流式响应
"""
与Gemini模型进行内容生成对话。
:param model_name: 要使用的模型名称。
:param contents: 聊天历史和当前用户消息组成的列表符合Gemini API的"contents"格式。
例如:[{"role": "user", "parts": [{"text": "你好"}]}, ...]
:param system_instruction: 系统的指示语,作为单独的参数传入。
在Gemini API中它是请求体的一个顶级字段。
:param stream: 是否以流式方式获取响应。
:return: 如果 stream 为 False返回模型回复文本或错误信息。
如果 stream 为 True返回一个生成器每次 yield 一个文本片段。
"""
# 根据Gemini API的预期格式构造payload
# 如果contents是空的我们构造一个带有role的内容数组
if not contents:
if system_instruction:
payload: Dict[str, Any] = {
"contents": [
{"role": "user", "parts": [{"text": system_instruction}]}
]
}
else:
payload: Dict[str, Any] = {
"contents": [{"role": "user", "parts": [{"text": ""}]}]
}
else:
# 使用提供的contents
payload: Dict[str, Any] = {"contents": contents}
# 如果有系统指令,我们添加到请求中
if system_instruction:
# Gemini API的systemInstruction是一个顶级字段值是一个Content对象
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
if stream:
response_iterator = self._make_stream_request(
"generateContent", model_name, payload
)
if isinstance(response_iterator, dict) and "error" in response_iterator:
# 如果流式请求本身返回错误,则直接返回错误信息
return str(response_iterator["error"]), int(
response_iterator.get("status_code", 500)
)
def stream_generator():
for text_chunk in response_iterator:
# _make_stream_request 已经处理了JSON解析并提取了文本
# 所以这里直接yield文本块即可
if text_chunk:
yield text_chunk
logger.debug("流式响应结束。")
return stream_generator()
else:
response_json = self._make_request("generateContent", model_name, payload)
if "error" in response_json:
# 如果 _make_request 返回了错误,直接返回错误信息和状态码
return str(response_json["error"]), int(
response_json.get("status_code", 500)
)
if "candidates" in response_json and len(response_json["candidates"]) > 0:
first_candidate = response_json["candidates"][0]
if (
"content" in first_candidate
and "parts" in first_candidate["content"]
and len(first_candidate["content"]["parts"]) > 0
):
# 返回第一个part的文本内容
return first_candidate["content"]["parts"][0]["text"]
# 如果响应中包含拦截信息
if (
"promptFeedback" in response_json
and "blockReason" in response_json["promptFeedback"]
):
reason = response_json["promptFeedback"]["blockReason"]
logger.warning(f"Gemini API 阻止了响应,原因: {reason}")
return (
f"抱歉,你的请求被模型策略阻止了。原因: {reason}",
400,
) # 400 Bad Request for policy block
# 如果没有candidates且没有拦截信息可能存在未知响应结构
logger.warning(
f"Gemini API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
)
return "抱歉AI未能生成有效回复请稍后再试或检查日志。", 500
def count_tokens(
self,
model_name: str,
contents: list[Dict[str, Any]],
system_instruction: str | None = None,
) -> int:
"""
统计给定内容在特定模型中的token数量。
:param model_name: 要使用的模型名称。
:param contents: 要计算token的内容。
:param system_instruction: 系统指令。
:return: token数量整数或-1如果发生错误
"""
# 根据Gemini API的预期格式构造payload确保包含role字段
if not contents:
if system_instruction:
payload: Dict[str, Any] = {
"contents": [
{"role": "user", "parts": [{"text": system_instruction}]}
]
}
else:
payload: Dict[str, Any] = {
"contents": [{"role": "user", "parts": [{"text": ""}]}]
}
else:
payload: Dict[str, Any] = {"contents": contents} # 明确声明payload类型
if system_instruction:
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
response_json = self._make_request("countTokens", model_name, payload)
if "error" in response_json:
return -1 # 发生错误时返回-1
if "totalTokens" in response_json:
return response_json["totalTokens"]
logger.warning(
f"countTokens API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
)
return -1
def get_models(self) -> Dict[str, Any]:
"""
获取Gemini API可用的模型列表。
:return: 包含模型名称的列表,如果发生错误则为空列表。
"""
logger.info("从 Gemini API 获取模型列表。")
response_json = self._make_get_request("models")
if "error" in response_json:
logger.error(f"获取模型列表失败: {response_json['error']}")
return {
"models": [],
"error": response_json["error"],
"status_code": response_json["status_code"],
}
available_models = []
for m in response_json.get("models", []):
# 过滤出支持 'generateContent' 方法的模型,并去除被阻止的模型
# 'SUPPORTED' 状态表示模型可用
if (
"name" in m
and "supportedGenerationMethods" in m
and "generateContent" in m["supportedGenerationMethods"]
and m.get("lifecycleState", "UNSPECIFIED") == "ACTIVE"
): # 确保模型是活跃的
model_name = m["name"].split("/")[-1] # 提取 'gemini-pro' 部分
available_models.append(model_name)
logger.info(f"成功获取到 {len(available_models)} 个可用模型。")
return {"models": available_models}