上传文件至 /
This commit is contained in:
494
gemini_client.py
Normal file
494
gemini_client.py
Normal file
@ -0,0 +1,494 @@
|
||||
# gemini_client.py
|
||||
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
from config import get_config # 导入config模块以获取API配置
|
||||
from typing import Union, Tuple, Any, Dict # 新增导入
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""
|
||||
封装与Google Gemini API的交互。
|
||||
支持直接HTTP请求,以便于使用第三方代理或自定义端点。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
base_url_from_config = self.config["API"]["GEMINI_API_BASE_URL"].rstrip("/")
|
||||
# 尝试从配置的URL中提取正确的base_url,确保它以 /v1beta 结尾且不包含 /models
|
||||
if base_url_from_config.endswith("/v1beta"):
|
||||
self.base_url = base_url_from_config
|
||||
elif base_url_from_config.endswith("/models"):
|
||||
# 如果配置的URL以 /models 结尾,则移除 /models 并添加 /v1beta
|
||||
self.base_url = base_url_from_config.rsplit("/models", 1)[0] + "/v1beta"
|
||||
elif base_url_from_config.endswith("/v1"): # 兼容旧的 /v1 路径
|
||||
self.base_url = base_url_from_config.rsplit("/v1", 1)[0] + "/v1beta"
|
||||
else:
|
||||
# 假设是根URL,添加 /v1beta
|
||||
self.base_url = f"{base_url_from_config}/v1beta"
|
||||
|
||||
self.api_key = self.config["API"].get("GEMINI_API_KEY")
|
||||
# 确保API Key是字符串并清理可能的空白字符
|
||||
if self.api_key:
|
||||
self.api_key = str(self.api_key).strip()
|
||||
|
||||
# 设置请求头,包含Content-Type和x-goog-api-key
|
||||
if self.api_key:
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": self.api_key,
|
||||
}
|
||||
else:
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("GEMINI_API_KEY 未配置或为空,Gemini客户端无法初始化。")
|
||||
|
||||
logger.info(f"GeminiClient 初始化完成,基础URL: {self.base_url}")
|
||||
|
||||
def _make_request(
|
||||
self, endpoint: str, model_name: str, payload: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
内部方法:执行实际的HTTP POST请求到Gemini API。
|
||||
:param endpoint: API端点,例如 "generateContent" 或 "countTokens"。
|
||||
:param model_name: 要使用的模型名称,例如 "gemini-pro"。
|
||||
:param payload: 请求体(Python字典),将转换为JSON发送。
|
||||
:return: 响应JSON(Python字典)。
|
||||
:raises requests.exceptions.RequestException: 如果请求失败。
|
||||
"""
|
||||
# 检查模型名称是否已包含"models/"前缀,避免重复
|
||||
if model_name.startswith("models/"):
|
||||
# 如果已包含前缀,直接使用
|
||||
url = f"{self.base_url}/{model_name}:{endpoint}"
|
||||
else:
|
||||
# 如果不包含前缀,添加"models/"
|
||||
url = f"{self.base_url}/models/{model_name}:{endpoint}"
|
||||
|
||||
# 记录请求信息,但隐藏API Key
|
||||
safe_url = url
|
||||
logger.debug(f"发送请求到 Gemini API: {safe_url}")
|
||||
logger.debug(
|
||||
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers, # 使用初始化时设置的headers
|
||||
json=payload,
|
||||
timeout=90,
|
||||
) # 增加超时时间,应对大模型响应慢
|
||||
|
||||
# 先获取响应文本,避免多次访问响应体
|
||||
response_text = response.text
|
||||
|
||||
# 然后检查状态码
|
||||
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
|
||||
|
||||
# 使用已获取的文本解析JSON,而不是调用response.json()
|
||||
response_data = json.loads(response_text)
|
||||
return response_data
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else 500
|
||||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": status_code}
|
||||
except requests.exceptions.ConnectionError as conn_err:
|
||||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 503} # Service Unavailable
|
||||
except requests.exceptions.Timeout as timeout_err:
|
||||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 408} # Request Timeout
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500} # Internal Server Error
|
||||
except Exception as e:
|
||||
error_message = f"非请求库异常: {e}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
|
||||
def _make_stream_request(
|
||||
self, endpoint: str, model_name: str, payload: Dict[str, Any]
|
||||
) -> Union[Any, Dict[str, Any]]: # 返回类型调整为Dict[str, Any]以统一错误返回
|
||||
"""
|
||||
内部方法:执行实际的HTTP POST流式请求到Gemini API。
|
||||
:param endpoint: API端点,例如 "generateContent"。
|
||||
:param model_name: 要使用的模型名称,例如 "gemini-pro"。
|
||||
:param payload: 请求体(Python字典),将转换为JSON发送。
|
||||
:return: 响应迭代器或错误信息字典。
|
||||
"""
|
||||
# 检查模型名称是否已包含"models/"前缀,避免重复
|
||||
if model_name.startswith("models/"):
|
||||
# 如果已包含前缀,直接使用
|
||||
url = f"{self.base_url}/{model_name}:{endpoint}?stream=true"
|
||||
else:
|
||||
# 如果不包含前缀,添加"models/"
|
||||
url = f"{self.base_url}/models/{model_name}:{endpoint}?stream=true"
|
||||
|
||||
# 记录请求信息,但隐藏API Key
|
||||
safe_url = url
|
||||
logger.debug(f"发送流式请求到 Gemini API: {safe_url}")
|
||||
logger.debug(
|
||||
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers, # 使用初始化时设置的headers
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=90,
|
||||
)
|
||||
|
||||
# 检查状态码但不消费响应体
|
||||
response.raise_for_status()
|
||||
|
||||
# 对于流式请求,直接返回解析后的JSON对象迭代器
|
||||
# 这不会立即消费响应体,而是在迭代时逐步消费
|
||||
def parse_stream():
|
||||
logger.info("开始处理流式响应...")
|
||||
full_response = ""
|
||||
accumulated_json = ""
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 将二进制数据解码为UTF-8字符串
|
||||
decoded_line = line.decode("utf-8")
|
||||
logger.debug(f"原始行数据: {decoded_line[:50]}...")
|
||||
|
||||
# 处理SSE格式 (data: 开头)
|
||||
if decoded_line.startswith("data: "):
|
||||
json_str = decoded_line[6:].strip()
|
||||
if not json_str: # 跳过空数据行
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(json_str)
|
||||
logger.debug(
|
||||
f"SSE格式JSON数据,键: {list(chunk_data.keys())}"
|
||||
)
|
||||
|
||||
# 检查是否包含文本内容
|
||||
if (
|
||||
"candidates" in chunk_data
|
||||
and chunk_data["candidates"]
|
||||
and "content" in chunk_data["candidates"][0]
|
||||
):
|
||||
content = chunk_data["candidates"][0]["content"]
|
||||
|
||||
if "parts" in content and content["parts"]:
|
||||
text_part = content["parts"][0].get("text", "")
|
||||
if text_part:
|
||||
logger.debug(
|
||||
f"提取到文本: {text_part[:30]}..."
|
||||
)
|
||||
full_response += text_part
|
||||
yield text_part
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"SSE格式数据解析JSON失败: {e}, 数据: {json_str[:50]}..."
|
||||
)
|
||||
else:
|
||||
# 处理非SSE格式
|
||||
# 尝试积累JSON直到有效
|
||||
accumulated_json += decoded_line
|
||||
try:
|
||||
# 尝试解析完整JSON
|
||||
chunk_data = json.loads(accumulated_json)
|
||||
logger.debug(
|
||||
f"解析完整JSON成功,键: {list(chunk_data.keys())}"
|
||||
)
|
||||
|
||||
# 检查是否包含文本内容
|
||||
if (
|
||||
"candidates" in chunk_data
|
||||
and chunk_data["candidates"]
|
||||
and "content" in chunk_data["candidates"][0]
|
||||
):
|
||||
content = chunk_data["candidates"][0]["content"]
|
||||
|
||||
if "parts" in content and content["parts"]:
|
||||
text_part = content["parts"][0].get("text", "")
|
||||
if text_part:
|
||||
logger.debug(
|
||||
f"从完整JSON提取到文本: {text_part[:30]}..."
|
||||
)
|
||||
full_response += text_part
|
||||
yield text_part
|
||||
|
||||
# 成功解析后重置累积的JSON
|
||||
accumulated_json = ""
|
||||
except json.JSONDecodeError:
|
||||
# 继续积累直到有效JSON
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理流式响应行时出错: {str(e)}, line: {str(line)[:100]}"
|
||||
)
|
||||
|
||||
# 如果有积累的响应但还没有产生任何输出,尝试直接返回
|
||||
if full_response:
|
||||
logger.info(f"流式响应完成,累积了 {len(full_response)} 字符的响应")
|
||||
else:
|
||||
logger.warning("流式响应完成,但没有提取到任何有效文本内容")
|
||||
# 尝试返回一个默认响应,避免空响应
|
||||
yield "抱歉,处理响应时遇到了问题,请重试。"
|
||||
|
||||
return parse_stream()
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else 500
|
||||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": status_code}
|
||||
except requests.exceptions.ConnectionError as conn_err:
|
||||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 503}
|
||||
except requests.exceptions.Timeout as timeout_err:
|
||||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 408}
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
except Exception as e:
|
||||
error_message = f"非请求库异常: {e}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
|
||||
def _make_get_request(self, endpoint: str) -> Dict[str, Any]:
|
||||
"""
|
||||
内部方法:执行实际的HTTP GET请求到Gemini API。
|
||||
:param endpoint: API端点,例如 "models"。
|
||||
:return: 响应JSON(Python字典)。
|
||||
:raises requests.exceptions.RequestException: 如果请求失败。
|
||||
"""
|
||||
if not self.api_key:
|
||||
return {
|
||||
"error": "API Key 未配置,无法发送请求。",
|
||||
"status_code": 401,
|
||||
} # Unauthorized
|
||||
|
||||
# API Key 通过请求头传递,URL中不再包含
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
|
||||
# 记录请求信息,但隐藏API Key
|
||||
safe_url = url
|
||||
logger.debug(f"发送GET请求到 Gemini API: {safe_url}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url, headers=self.headers, timeout=30 # 使用初始化时设置的headers
|
||||
)
|
||||
|
||||
# 先获取响应文本,避免多次访问响应体
|
||||
response_text = response.text
|
||||
|
||||
# 然后检查状态码
|
||||
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
|
||||
|
||||
# 使用已获取的文本解析JSON,而不是调用response.json()
|
||||
response_data = json.loads(response_text)
|
||||
return response_data
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else 500
|
||||
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": status_code}
|
||||
except requests.exceptions.ConnectionError as conn_err:
|
||||
error_message = f"连接错误: {conn_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 503} # Service Unavailable
|
||||
except requests.exceptions.Timeout as timeout_err:
|
||||
error_message = f"请求超时: {timeout_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 408} # Request Timeout
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
error_message = f"未知请求错误: {req_err} - URL: {url}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500} # Internal Server Error
|
||||
except Exception as e:
|
||||
error_message = f"非请求库异常: {e}"
|
||||
logger.error(error_message)
|
||||
return {"error": error_message, "status_code": 500}
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
model_name: str,
|
||||
contents: list[Dict[str, Any]],
|
||||
system_instruction: str | None = None,
|
||||
stream: bool = False, # 新增 stream 参数
|
||||
) -> Union[str, Tuple[str, int], Any]: # 返回类型调整以适应流式响应
|
||||
"""
|
||||
与Gemini模型进行内容生成对话。
|
||||
:param model_name: 要使用的模型名称。
|
||||
:param contents: 聊天历史和当前用户消息组成的列表,符合Gemini API的"contents"格式。
|
||||
例如:[{"role": "user", "parts": [{"text": "你好"}]}, ...]
|
||||
:param system_instruction: 系统的指示语,作为单独的参数传入。
|
||||
在Gemini API中,它是请求体的一个顶级字段。
|
||||
:param stream: 是否以流式方式获取响应。
|
||||
:return: 如果 stream 为 False,返回模型回复文本或错误信息。
|
||||
如果 stream 为 True,返回一个生成器,每次 yield 一个文本片段。
|
||||
"""
|
||||
# 根据Gemini API的预期格式构造payload
|
||||
# 如果contents是空的,我们构造一个带有role的内容数组
|
||||
if not contents:
|
||||
if system_instruction:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": system_instruction}]}
|
||||
]
|
||||
}
|
||||
else:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [{"role": "user", "parts": [{"text": ""}]}]
|
||||
}
|
||||
else:
|
||||
# 使用提供的contents
|
||||
payload: Dict[str, Any] = {"contents": contents}
|
||||
|
||||
# 如果有系统指令,我们添加到请求中
|
||||
if system_instruction:
|
||||
# Gemini API的systemInstruction是一个顶级字段,值是一个Content对象
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||||
|
||||
if stream:
|
||||
response_iterator = self._make_stream_request(
|
||||
"generateContent", model_name, payload
|
||||
)
|
||||
if isinstance(response_iterator, dict) and "error" in response_iterator:
|
||||
# 如果流式请求本身返回错误,则直接返回错误信息
|
||||
return str(response_iterator["error"]), int(
|
||||
response_iterator.get("status_code", 500)
|
||||
)
|
||||
|
||||
def stream_generator():
|
||||
for text_chunk in response_iterator:
|
||||
# _make_stream_request 已经处理了JSON解析并提取了文本
|
||||
# 所以这里直接yield文本块即可
|
||||
if text_chunk:
|
||||
yield text_chunk
|
||||
logger.debug("流式响应结束。")
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
response_json = self._make_request("generateContent", model_name, payload)
|
||||
|
||||
if "error" in response_json:
|
||||
# 如果 _make_request 返回了错误,直接返回错误信息和状态码
|
||||
return str(response_json["error"]), int(
|
||||
response_json.get("status_code", 500)
|
||||
)
|
||||
|
||||
if "candidates" in response_json and len(response_json["candidates"]) > 0:
|
||||
first_candidate = response_json["candidates"][0]
|
||||
if (
|
||||
"content" in first_candidate
|
||||
and "parts" in first_candidate["content"]
|
||||
and len(first_candidate["content"]["parts"]) > 0
|
||||
):
|
||||
# 返回第一个part的文本内容
|
||||
return first_candidate["content"]["parts"][0]["text"]
|
||||
|
||||
# 如果响应中包含拦截信息
|
||||
if (
|
||||
"promptFeedback" in response_json
|
||||
and "blockReason" in response_json["promptFeedback"]
|
||||
):
|
||||
reason = response_json["promptFeedback"]["blockReason"]
|
||||
logger.warning(f"Gemini API 阻止了响应,原因: {reason}")
|
||||
return (
|
||||
f"抱歉,你的请求被模型策略阻止了。原因: {reason}",
|
||||
400,
|
||||
) # 400 Bad Request for policy block
|
||||
|
||||
# 如果没有candidates且没有拦截信息,可能存在未知响应结构
|
||||
logger.warning(
|
||||
f"Gemini API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
|
||||
)
|
||||
return "抱歉,AI未能生成有效回复,请稍后再试或检查日志。", 500
|
||||
|
||||
def count_tokens(
|
||||
self,
|
||||
model_name: str,
|
||||
contents: list[Dict[str, Any]],
|
||||
system_instruction: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计给定内容在特定模型中的token数量。
|
||||
:param model_name: 要使用的模型名称。
|
||||
:param contents: 要计算token的内容。
|
||||
:param system_instruction: 系统指令。
|
||||
:return: token数量(整数)或-1(如果发生错误)。
|
||||
"""
|
||||
# 根据Gemini API的预期格式构造payload,确保包含role字段
|
||||
if not contents:
|
||||
if system_instruction:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [
|
||||
{"role": "user", "parts": [{"text": system_instruction}]}
|
||||
]
|
||||
}
|
||||
else:
|
||||
payload: Dict[str, Any] = {
|
||||
"contents": [{"role": "user", "parts": [{"text": ""}]}]
|
||||
}
|
||||
else:
|
||||
payload: Dict[str, Any] = {"contents": contents} # 明确声明payload类型
|
||||
if system_instruction:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
|
||||
|
||||
response_json = self._make_request("countTokens", model_name, payload)
|
||||
|
||||
if "error" in response_json:
|
||||
return -1 # 发生错误时返回-1
|
||||
|
||||
if "totalTokens" in response_json:
|
||||
return response_json["totalTokens"]
|
||||
logger.warning(
|
||||
f"countTokens API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
|
||||
)
|
||||
return -1
|
||||
|
||||
def get_models(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取Gemini API可用的模型列表。
|
||||
:return: 包含模型名称的列表,如果发生错误则为空列表。
|
||||
"""
|
||||
logger.info("从 Gemini API 获取模型列表。")
|
||||
response_json = self._make_get_request("models")
|
||||
|
||||
if "error" in response_json:
|
||||
logger.error(f"获取模型列表失败: {response_json['error']}")
|
||||
return {
|
||||
"models": [],
|
||||
"error": response_json["error"],
|
||||
"status_code": response_json["status_code"],
|
||||
}
|
||||
|
||||
available_models = []
|
||||
for m in response_json.get("models", []):
|
||||
# 过滤出支持 'generateContent' 方法的模型,并去除被阻止的模型
|
||||
# 'SUPPORTED' 状态表示模型可用
|
||||
if (
|
||||
"name" in m
|
||||
and "supportedGenerationMethods" in m
|
||||
and "generateContent" in m["supportedGenerationMethods"]
|
||||
and m.get("lifecycleState", "UNSPECIFIED") == "ACTIVE"
|
||||
): # 确保模型是活跃的
|
||||
model_name = m["name"].split("/")[-1] # 提取 'gemini-pro' 部分
|
||||
available_models.append(model_name)
|
||||
logger.info(f"成功获取到 {len(available_models)} 个可用模型。")
|
||||
return {"models": available_models}
|
Reference in New Issue
Block a user