495 lines
23 KiB
Python
495 lines
23 KiB
Python
# gemini_client.py
|
||
|
||
import requests
|
||
import json
|
||
import logging
|
||
from config import get_config # 导入config模块以获取API配置
|
||
from typing import Union, Tuple, Any, Dict # 新增导入
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class GeminiClient:
|
||
"""
|
||
封装与Google Gemini API的交互。
|
||
支持直接HTTP请求,以便于使用第三方代理或自定义端点。
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.config = get_config()
|
||
base_url_from_config = self.config["API"]["GEMINI_API_BASE_URL"].rstrip("/")
|
||
# 尝试从配置的URL中提取正确的base_url,确保它以 /v1beta 结尾且不包含 /models
|
||
if base_url_from_config.endswith("/v1beta"):
|
||
self.base_url = base_url_from_config
|
||
elif base_url_from_config.endswith("/models"):
|
||
# 如果配置的URL以 /models 结尾,则移除 /models 并添加 /v1beta
|
||
self.base_url = base_url_from_config.rsplit("/models", 1)[0] + "/v1beta"
|
||
elif base_url_from_config.endswith("/v1"): # 兼容旧的 /v1 路径
|
||
self.base_url = base_url_from_config.rsplit("/v1", 1)[0] + "/v1beta"
|
||
else:
|
||
# 假设是根URL,添加 /v1beta
|
||
self.base_url = f"{base_url_from_config}/v1beta"
|
||
|
||
self.api_key = self.config["API"].get("GEMINI_API_KEY")
|
||
# 确保API Key是字符串并清理可能的空白字符
|
||
if self.api_key:
|
||
self.api_key = str(self.api_key).strip()
|
||
|
||
# 设置请求头,包含Content-Type和x-goog-api-key
|
||
if self.api_key:
|
||
self.headers = {
|
||
"Content-Type": "application/json",
|
||
"x-goog-api-key": self.api_key,
|
||
}
|
||
else:
|
||
self.headers = {"Content-Type": "application/json"}
|
||
|
||
if not self.api_key:
|
||
raise ValueError("GEMINI_API_KEY 未配置或为空,Gemini客户端无法初始化。")
|
||
|
||
logger.info(f"GeminiClient 初始化完成,基础URL: {self.base_url}")
|
||
|
||
def _make_request(
|
||
self, endpoint: str, model_name: str, payload: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
内部方法:执行实际的HTTP POST请求到Gemini API。
|
||
:param endpoint: API端点,例如 "generateContent" 或 "countTokens"。
|
||
:param model_name: 要使用的模型名称,例如 "gemini-pro"。
|
||
:param payload: 请求体(Python字典),将转换为JSON发送。
|
||
:return: 响应JSON(Python字典)。
|
||
:raises requests.exceptions.RequestException: 如果请求失败。
|
||
"""
|
||
# 检查模型名称是否已包含"models/"前缀,避免重复
|
||
if model_name.startswith("models/"):
|
||
# 如果已包含前缀,直接使用
|
||
url = f"{self.base_url}/{model_name}:{endpoint}"
|
||
else:
|
||
# 如果不包含前缀,添加"models/"
|
||
url = f"{self.base_url}/models/{model_name}:{endpoint}"
|
||
|
||
# 记录请求信息,但隐藏API Key
|
||
safe_url = url
|
||
logger.debug(f"发送请求到 Gemini API: {safe_url}")
|
||
logger.debug(
|
||
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||
)
|
||
|
||
try:
|
||
response = requests.post(
|
||
url,
|
||
headers=self.headers, # 使用初始化时设置的headers
|
||
json=payload,
|
||
timeout=90,
|
||
) # 增加超时时间,应对大模型响应慢
|
||
|
||
# 先获取响应文本,避免多次访问响应体
|
||
response_text = response.text
|
||
|
||
# 然后检查状态码
|
||
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
|
||
|
||
# 使用已获取的文本解析JSON,而不是调用response.json()
|
||
response_data = json.loads(response_text)
|
||
return response_data
|
||
except requests.exceptions.HTTPError as http_err:
|
||
status_code = http_err.response.status_code if http_err.response else 500
|
||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": status_code}
|
||
except requests.exceptions.ConnectionError as conn_err:
|
||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 503} # Service Unavailable
|
||
except requests.exceptions.Timeout as timeout_err:
|
||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 408} # Request Timeout
|
||
except requests.exceptions.RequestException as req_err:
|
||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 500} # Internal Server Error
|
||
except Exception as e:
|
||
error_message = f"非请求库异常: {e}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 500}
|
||
|
||
def _make_stream_request(
|
||
self, endpoint: str, model_name: str, payload: Dict[str, Any]
|
||
) -> Union[Any, Dict[str, Any]]: # 返回类型调整为Dict[str, Any]以统一错误返回
|
||
"""
|
||
内部方法:执行实际的HTTP POST流式请求到Gemini API。
|
||
:param endpoint: API端点,例如 "generateContent"。
|
||
:param model_name: 要使用的模型名称,例如 "gemini-pro"。
|
||
:param payload: 请求体(Python字典),将转换为JSON发送。
|
||
:return: 响应迭代器或错误信息字典。
|
||
"""
|
||
# 检查模型名称是否已包含"models/"前缀,避免重复
|
||
if model_name.startswith("models/"):
|
||
# 如果已包含前缀,直接使用
|
||
url = f"{self.base_url}/{model_name}:{endpoint}?stream=true"
|
||
else:
|
||
# 如果不包含前缀,添加"models/"
|
||
url = f"{self.base_url}/models/{model_name}:{endpoint}?stream=true"
|
||
|
||
# 记录请求信息,但隐藏API Key
|
||
safe_url = url
|
||
logger.debug(f"发送流式请求到 Gemini API: {safe_url}")
|
||
logger.debug(
|
||
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||
)
|
||
|
||
try:
|
||
response = requests.post(
|
||
url,
|
||
headers=self.headers, # 使用初始化时设置的headers
|
||
json=payload,
|
||
stream=True,
|
||
timeout=90,
|
||
)
|
||
|
||
# 检查状态码但不消费响应体
|
||
response.raise_for_status()
|
||
|
||
# 对于流式请求,直接返回解析后的JSON对象迭代器
|
||
# 这不会立即消费响应体,而是在迭代时逐步消费
|
||
def parse_stream():
|
||
logger.info("开始处理流式响应...")
|
||
full_response = ""
|
||
accumulated_json = ""
|
||
|
||
for line in response.iter_lines():
|
||
if not line:
|
||
continue
|
||
|
||
try:
|
||
# 将二进制数据解码为UTF-8字符串
|
||
decoded_line = line.decode("utf-8")
|
||
logger.debug(f"原始行数据: {decoded_line[:50]}...")
|
||
|
||
# 处理SSE格式 (data: 开头)
|
||
if decoded_line.startswith("data: "):
|
||
json_str = decoded_line[6:].strip()
|
||
if not json_str: # 跳过空数据行
|
||
continue
|
||
|
||
try:
|
||
chunk_data = json.loads(json_str)
|
||
logger.debug(
|
||
f"SSE格式JSON数据,键: {list(chunk_data.keys())}"
|
||
)
|
||
|
||
# 检查是否包含文本内容
|
||
if (
|
||
"candidates" in chunk_data
|
||
and chunk_data["candidates"]
|
||
and "content" in chunk_data["candidates"][0]
|
||
):
|
||
content = chunk_data["candidates"][0]["content"]
|
||
|
||
if "parts" in content and content["parts"]:
|
||
text_part = content["parts"][0].get("text", "")
|
||
if text_part:
|
||
logger.debug(
|
||
f"提取到文本: {text_part[:30]}..."
|
||
)
|
||
full_response += text_part
|
||
yield text_part
|
||
except json.JSONDecodeError as e:
|
||
logger.warning(
|
||
f"SSE格式数据解析JSON失败: {e}, 数据: {json_str[:50]}..."
|
||
)
|
||
else:
|
||
# 处理非SSE格式
|
||
# 尝试积累JSON直到有效
|
||
accumulated_json += decoded_line
|
||
try:
|
||
# 尝试解析完整JSON
|
||
chunk_data = json.loads(accumulated_json)
|
||
logger.debug(
|
||
f"解析完整JSON成功,键: {list(chunk_data.keys())}"
|
||
)
|
||
|
||
# 检查是否包含文本内容
|
||
if (
|
||
"candidates" in chunk_data
|
||
and chunk_data["candidates"]
|
||
and "content" in chunk_data["candidates"][0]
|
||
):
|
||
content = chunk_data["candidates"][0]["content"]
|
||
|
||
if "parts" in content and content["parts"]:
|
||
text_part = content["parts"][0].get("text", "")
|
||
if text_part:
|
||
logger.debug(
|
||
f"从完整JSON提取到文本: {text_part[:30]}..."
|
||
)
|
||
full_response += text_part
|
||
yield text_part
|
||
|
||
# 成功解析后重置累积的JSON
|
||
accumulated_json = ""
|
||
except json.JSONDecodeError:
|
||
# 继续积累直到有效JSON
|
||
pass
|
||
except Exception as e:
|
||
logger.error(
|
||
f"处理流式响应行时出错: {str(e)}, line: {str(line)[:100]}"
|
||
)
|
||
|
||
# 如果有积累的响应但还没有产生任何输出,尝试直接返回
|
||
if full_response:
|
||
logger.info(f"流式响应完成,累积了 {len(full_response)} 字符的响应")
|
||
else:
|
||
logger.warning("流式响应完成,但没有提取到任何有效文本内容")
|
||
# 尝试返回一个默认响应,避免空响应
|
||
yield "抱歉,处理响应时遇到了问题,请重试。"
|
||
|
||
return parse_stream()
|
||
except requests.exceptions.HTTPError as http_err:
|
||
status_code = http_err.response.status_code if http_err.response else 500
|
||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": status_code}
|
||
except requests.exceptions.ConnectionError as conn_err:
|
||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 503}
|
||
except requests.exceptions.Timeout as timeout_err:
|
||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 408}
|
||
except requests.exceptions.RequestException as req_err:
|
||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 500}
|
||
except Exception as e:
|
||
error_message = f"非请求库异常: {e}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 500}
|
||
|
||
def _make_get_request(self, endpoint: str) -> Dict[str, Any]:
|
||
"""
|
||
内部方法:执行实际的HTTP GET请求到Gemini API。
|
||
:param endpoint: API端点,例如 "models"。
|
||
:return: 响应JSON(Python字典)。
|
||
:raises requests.exceptions.RequestException: 如果请求失败。
|
||
"""
|
||
if not self.api_key:
|
||
return {
|
||
"error": "API Key 未配置,无法发送请求。",
|
||
"status_code": 401,
|
||
} # Unauthorized
|
||
|
||
# API Key 通过请求头传递,URL中不再包含
|
||
url = f"{self.base_url}/{endpoint}"
|
||
|
||
# 记录请求信息,但隐藏API Key
|
||
safe_url = url
|
||
logger.debug(f"发送GET请求到 Gemini API: {safe_url}")
|
||
|
||
try:
|
||
response = requests.get(
|
||
url, headers=self.headers, timeout=30 # 使用初始化时设置的headers
|
||
)
|
||
|
||
# 先获取响应文本,避免多次访问响应体
|
||
response_text = response.text
|
||
|
||
# 然后检查状态码
|
||
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
|
||
|
||
# 使用已获取的文本解析JSON,而不是调用response.json()
|
||
response_data = json.loads(response_text)
|
||
return response_data
|
||
except requests.exceptions.HTTPError as http_err:
|
||
status_code = http_err.response.status_code if http_err.response else 500
|
||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": status_code}
|
||
except requests.exceptions.ConnectionError as conn_err:
|
||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 503} # Service Unavailable
|
||
except requests.exceptions.Timeout as timeout_err:
|
||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 408} # Request Timeout
|
||
except requests.exceptions.RequestException as req_err:
|
||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 500} # Internal Server Error
|
||
except Exception as e:
|
||
error_message = f"非请求库异常: {e}"
|
||
logger.error(error_message)
|
||
return {"error": error_message, "status_code": 500}
|
||
|
||
def generate_content(
|
||
self,
|
||
model_name: str,
|
||
contents: list[Dict[str, Any]],
|
||
system_instruction: str | None = None,
|
||
stream: bool = False, # 新增 stream 参数
|
||
) -> Union[str, Tuple[str, int], Any]: # 返回类型调整以适应流式响应
|
||
"""
|
||
与Gemini模型进行内容生成对话。
|
||
:param model_name: 要使用的模型名称。
|
||
:param contents: 聊天历史和当前用户消息组成的列表,符合Gemini API的"contents"格式。
|
||
例如:[{"role": "user", "parts": [{"text": "你好"}]}, ...]
|
||
:param system_instruction: 系统的指示语,作为单独的参数传入。
|
||
在Gemini API中,它是请求体的一个顶级字段。
|
||
:param stream: 是否以流式方式获取响应。
|
||
:return: 如果 stream 为 False,返回模型回复文本或错误信息。
|
||
如果 stream 为 True,返回一个生成器,每次 yield 一个文本片段。
|
||
"""
|
||
# 根据Gemini API的预期格式构造payload
|
||
# 如果contents是空的,我们构造一个带有role的内容数组
|
||
if not contents:
|
||
if system_instruction:
|
||
payload: Dict[str, Any] = {
|
||
"contents": [
|
||
{"role": "user", "parts": [{"text": system_instruction}]}
|
||
]
|
||
}
|
||
else:
|
||
payload: Dict[str, Any] = {
|
||
"contents": [{"role": "user", "parts": [{"text": ""}]}]
|
||
}
|
||
else:
|
||
# 使用提供的contents
|
||
payload: Dict[str, Any] = {"contents": contents}
|
||
|
||
# 如果有系统指令,我们添加到请求中
|
||
if system_instruction:
|
||
# Gemini API的systemInstruction是一个顶级字段,值是一个Content对象
|
||
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||
|
||
if stream:
|
||
response_iterator = self._make_stream_request(
|
||
"generateContent", model_name, payload
|
||
)
|
||
if isinstance(response_iterator, dict) and "error" in response_iterator:
|
||
# 如果流式请求本身返回错误,则直接返回错误信息
|
||
return str(response_iterator["error"]), int(
|
||
response_iterator.get("status_code", 500)
|
||
)
|
||
|
||
def stream_generator():
|
||
for text_chunk in response_iterator:
|
||
# _make_stream_request 已经处理了JSON解析并提取了文本
|
||
# 所以这里直接yield文本块即可
|
||
if text_chunk:
|
||
yield text_chunk
|
||
logger.debug("流式响应结束。")
|
||
|
||
return stream_generator()
|
||
else:
|
||
response_json = self._make_request("generateContent", model_name, payload)
|
||
|
||
if "error" in response_json:
|
||
# 如果 _make_request 返回了错误,直接返回错误信息和状态码
|
||
return str(response_json["error"]), int(
|
||
response_json.get("status_code", 500)
|
||
)
|
||
|
||
if "candidates" in response_json and len(response_json["candidates"]) > 0:
|
||
first_candidate = response_json["candidates"][0]
|
||
if (
|
||
"content" in first_candidate
|
||
and "parts" in first_candidate["content"]
|
||
and len(first_candidate["content"]["parts"]) > 0
|
||
):
|
||
# 返回第一个part的文本内容
|
||
return first_candidate["content"]["parts"][0]["text"]
|
||
|
||
# 如果响应中包含拦截信息
|
||
if (
|
||
"promptFeedback" in response_json
|
||
and "blockReason" in response_json["promptFeedback"]
|
||
):
|
||
reason = response_json["promptFeedback"]["blockReason"]
|
||
logger.warning(f"Gemini API 阻止了响应,原因: {reason}")
|
||
return (
|
||
f"抱歉,你的请求被模型策略阻止了。原因: {reason}",
|
||
400,
|
||
) # 400 Bad Request for policy block
|
||
|
||
# 如果没有candidates且没有拦截信息,可能存在未知响应结构
|
||
logger.warning(
|
||
f"Gemini API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
|
||
)
|
||
return "抱歉,AI未能生成有效回复,请稍后再试或检查日志。", 500
|
||
|
||
def count_tokens(
|
||
self,
|
||
model_name: str,
|
||
contents: list[Dict[str, Any]],
|
||
system_instruction: str | None = None,
|
||
) -> int:
|
||
"""
|
||
统计给定内容在特定模型中的token数量。
|
||
:param model_name: 要使用的模型名称。
|
||
:param contents: 要计算token的内容。
|
||
:param system_instruction: 系统指令。
|
||
:return: token数量(整数)或-1(如果发生错误)。
|
||
"""
|
||
# 根据Gemini API的预期格式构造payload,确保包含role字段
|
||
if not contents:
|
||
if system_instruction:
|
||
payload: Dict[str, Any] = {
|
||
"contents": [
|
||
{"role": "user", "parts": [{"text": system_instruction}]}
|
||
]
|
||
}
|
||
else:
|
||
payload: Dict[str, Any] = {
|
||
"contents": [{"role": "user", "parts": [{"text": ""}]}]
|
||
}
|
||
else:
|
||
payload: Dict[str, Any] = {"contents": contents} # 明确声明payload类型
|
||
if system_instruction:
|
||
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||
|
||
response_json = self._make_request("countTokens", model_name, payload)
|
||
|
||
if "error" in response_json:
|
||
return -1 # 发生错误时返回-1
|
||
|
||
if "totalTokens" in response_json:
|
||
return response_json["totalTokens"]
|
||
logger.warning(
|
||
f"countTokens API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
|
||
)
|
||
return -1
|
||
|
||
def get_models(self) -> Dict[str, Any]:
|
||
"""
|
||
获取Gemini API可用的模型列表。
|
||
:return: 包含模型名称的列表,如果发生错误则为空列表。
|
||
"""
|
||
logger.info("从 Gemini API 获取模型列表。")
|
||
response_json = self._make_get_request("models")
|
||
|
||
if "error" in response_json:
|
||
logger.error(f"获取模型列表失败: {response_json['error']}")
|
||
return {
|
||
"models": [],
|
||
"error": response_json["error"],
|
||
"status_code": response_json["status_code"],
|
||
}
|
||
|
||
available_models = []
|
||
for m in response_json.get("models", []):
|
||
# 过滤出支持 'generateContent' 方法的模型,并去除被阻止的模型
|
||
# 'SUPPORTED' 状态表示模型可用
|
||
if (
|
||
"name" in m
|
||
and "supportedGenerationMethods" in m
|
||
and "generateContent" in m["supportedGenerationMethods"]
|
||
and m.get("lifecycleState", "UNSPECIFIED") == "ACTIVE"
|
||
): # 确保模型是活跃的
|
||
model_name = m["name"].split("/")[-1] # 提取 'gemini-pro' 部分
|
||
available_models.append(model_name)
|
||
logger.info(f"成功获取到 {len(available_models)} 个可用模型。")
|
||
return {"models": available_models}
|