Files
gemini_boy/gemini_client.py
2025-06-06 16:41:50 +08:00

495 lines
23 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# gemini_client.py
import requests
import json
import logging
from config import get_config # 导入config模块以获取API配置
from typing import Union, Tuple, Any, Dict # 新增导入
logger = logging.getLogger(__name__)
class GeminiClient:
"""
封装与Google Gemini API的交互。
支持直接HTTP请求以便于使用第三方代理或自定义端点。
"""
def __init__(self):
self.config = get_config()
base_url_from_config = self.config["API"]["GEMINI_API_BASE_URL"].rstrip("/")
# 尝试从配置的URL中提取正确的base_url确保它以 /v1beta 结尾且不包含 /models
if base_url_from_config.endswith("/v1beta"):
self.base_url = base_url_from_config
elif base_url_from_config.endswith("/models"):
# 如果配置的URL以 /models 结尾,则移除 /models 并添加 /v1beta
self.base_url = base_url_from_config.rsplit("/models", 1)[0] + "/v1beta"
elif base_url_from_config.endswith("/v1"): # 兼容旧的 /v1 路径
self.base_url = base_url_from_config.rsplit("/v1", 1)[0] + "/v1beta"
else:
# 假设是根URL添加 /v1beta
self.base_url = f"{base_url_from_config}/v1beta"
self.api_key = self.config["API"].get("GEMINI_API_KEY")
# 确保API Key是字符串并清理可能的空白字符
if self.api_key:
self.api_key = str(self.api_key).strip()
# 设置请求头包含Content-Type和x-goog-api-key
if self.api_key:
self.headers = {
"Content-Type": "application/json",
"x-goog-api-key": self.api_key,
}
else:
self.headers = {"Content-Type": "application/json"}
if not self.api_key:
raise ValueError("GEMINI_API_KEY 未配置或为空Gemini客户端无法初始化。")
logger.info(f"GeminiClient 初始化完成基础URL: {self.base_url}")
def _make_request(
self, endpoint: str, model_name: str, payload: Dict[str, Any]
) -> Dict[str, Any]:
"""
内部方法执行实际的HTTP POST请求到Gemini API。
:param endpoint: API端点例如 "generateContent""countTokens"
:param model_name: 要使用的模型名称,例如 "gemini-pro"
:param payload: 请求体Python字典将转换为JSON发送。
:return: 响应JSONPython字典
:raises requests.exceptions.RequestException: 如果请求失败。
"""
# 检查模型名称是否已包含"models/"前缀,避免重复
if model_name.startswith("models/"):
# 如果已包含前缀,直接使用
url = f"{self.base_url}/{model_name}:{endpoint}"
else:
# 如果不包含前缀,添加"models/"
url = f"{self.base_url}/models/{model_name}:{endpoint}"
# 记录请求信息但隐藏API Key
safe_url = url
logger.debug(f"发送请求到 Gemini API: {safe_url}")
logger.debug(
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
)
try:
response = requests.post(
url,
headers=self.headers, # 使用初始化时设置的headers
json=payload,
timeout=90,
) # 增加超时时间,应对大模型响应慢
# 先获取响应文本,避免多次访问响应体
response_text = response.text
# 然后检查状态码
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
# 使用已获取的文本解析JSON而不是调用response.json()
response_data = json.loads(response_text)
return response_data
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else 500
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
logger.error(error_message)
return {"error": error_message, "status_code": status_code}
except requests.exceptions.ConnectionError as conn_err:
error_message = f"连接错误: {conn_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 503} # Service Unavailable
except requests.exceptions.Timeout as timeout_err:
error_message = f"请求超时: {timeout_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 408} # Request Timeout
except requests.exceptions.RequestException as req_err:
error_message = f"未知请求错误: {req_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 500} # Internal Server Error
except Exception as e:
error_message = f"非请求库异常: {e}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
def _make_stream_request(
self, endpoint: str, model_name: str, payload: Dict[str, Any]
) -> Union[Any, Dict[str, Any]]: # 返回类型调整为Dict[str, Any]以统一错误返回
"""
内部方法执行实际的HTTP POST流式请求到Gemini API。
:param endpoint: API端点例如 "generateContent"
:param model_name: 要使用的模型名称,例如 "gemini-pro"
:param payload: 请求体Python字典将转换为JSON发送。
:return: 响应迭代器或错误信息字典。
"""
# 检查模型名称是否已包含"models/"前缀,避免重复
if model_name.startswith("models/"):
# 如果已包含前缀,直接使用
url = f"{self.base_url}/{model_name}:{endpoint}?stream=true"
else:
# 如果不包含前缀,添加"models/"
url = f"{self.base_url}/models/{model_name}:{endpoint}?stream=true"
# 记录请求信息但隐藏API Key
safe_url = url
logger.debug(f"发送流式请求到 Gemini API: {safe_url}")
logger.debug(
f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
)
try:
response = requests.post(
url,
headers=self.headers, # 使用初始化时设置的headers
json=payload,
stream=True,
timeout=90,
)
# 检查状态码但不消费响应体
response.raise_for_status()
# 对于流式请求直接返回解析后的JSON对象迭代器
# 这不会立即消费响应体,而是在迭代时逐步消费
def parse_stream():
logger.info("开始处理流式响应...")
full_response = ""
accumulated_json = ""
for line in response.iter_lines():
if not line:
continue
try:
# 将二进制数据解码为UTF-8字符串
decoded_line = line.decode("utf-8")
logger.debug(f"原始行数据: {decoded_line[:50]}...")
# 处理SSE格式 (data: 开头)
if decoded_line.startswith("data: "):
json_str = decoded_line[6:].strip()
if not json_str: # 跳过空数据行
continue
try:
chunk_data = json.loads(json_str)
logger.debug(
f"SSE格式JSON数据键: {list(chunk_data.keys())}"
)
# 检查是否包含文本内容
if (
"candidates" in chunk_data
and chunk_data["candidates"]
and "content" in chunk_data["candidates"][0]
):
content = chunk_data["candidates"][0]["content"]
if "parts" in content and content["parts"]:
text_part = content["parts"][0].get("text", "")
if text_part:
logger.debug(
f"提取到文本: {text_part[:30]}..."
)
full_response += text_part
yield text_part
except json.JSONDecodeError as e:
logger.warning(
f"SSE格式数据解析JSON失败: {e}, 数据: {json_str[:50]}..."
)
else:
# 处理非SSE格式
# 尝试积累JSON直到有效
accumulated_json += decoded_line
try:
# 尝试解析完整JSON
chunk_data = json.loads(accumulated_json)
logger.debug(
f"解析完整JSON成功键: {list(chunk_data.keys())}"
)
# 检查是否包含文本内容
if (
"candidates" in chunk_data
and chunk_data["candidates"]
and "content" in chunk_data["candidates"][0]
):
content = chunk_data["candidates"][0]["content"]
if "parts" in content and content["parts"]:
text_part = content["parts"][0].get("text", "")
if text_part:
logger.debug(
f"从完整JSON提取到文本: {text_part[:30]}..."
)
full_response += text_part
yield text_part
# 成功解析后重置累积的JSON
accumulated_json = ""
except json.JSONDecodeError:
# 继续积累直到有效JSON
pass
except Exception as e:
logger.error(
f"处理流式响应行时出错: {str(e)}, line: {str(line)[:100]}"
)
# 如果有积累的响应但还没有产生任何输出,尝试直接返回
if full_response:
logger.info(f"流式响应完成,累积了 {len(full_response)} 字符的响应")
else:
logger.warning("流式响应完成,但没有提取到任何有效文本内容")
# 尝试返回一个默认响应,避免空响应
yield "抱歉,处理响应时遇到了问题,请重试。"
return parse_stream()
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else 500
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
logger.error(error_message)
return {"error": error_message, "status_code": status_code}
except requests.exceptions.ConnectionError as conn_err:
error_message = f"连接错误: {conn_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 503}
except requests.exceptions.Timeout as timeout_err:
error_message = f"请求超时: {timeout_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 408}
except requests.exceptions.RequestException as req_err:
error_message = f"未知请求错误: {req_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
except Exception as e:
error_message = f"非请求库异常: {e}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
def _make_get_request(self, endpoint: str) -> Dict[str, Any]:
"""
内部方法执行实际的HTTP GET请求到Gemini API。
:param endpoint: API端点例如 "models"
:return: 响应JSONPython字典
:raises requests.exceptions.RequestException: 如果请求失败。
"""
if not self.api_key:
return {
"error": "API Key 未配置,无法发送请求。",
"status_code": 401,
} # Unauthorized
# API Key 通过请求头传递URL中不再包含
url = f"{self.base_url}/{endpoint}"
# 记录请求信息但隐藏API Key
safe_url = url
logger.debug(f"发送GET请求到 Gemini API: {safe_url}")
try:
response = requests.get(
url, headers=self.headers, timeout=30 # 使用初始化时设置的headers
)
# 先获取响应文本,避免多次访问响应体
response_text = response.text
# 然后检查状态码
response.raise_for_status() # 对 4xx 或 5xx 状态码抛出 HTTPError
# 使用已获取的文本解析JSON而不是调用response.json()
response_data = json.loads(response_text)
return response_data
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else 500
error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
logger.error(error_message)
return {"error": error_message, "status_code": status_code}
except requests.exceptions.ConnectionError as conn_err:
error_message = f"连接错误: {conn_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 503} # Service Unavailable
except requests.exceptions.Timeout as timeout_err:
error_message = f"请求超时: {timeout_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 408} # Request Timeout
except requests.exceptions.RequestException as req_err:
error_message = f"未知请求错误: {req_err} - URL: {url}"
logger.error(error_message)
return {"error": error_message, "status_code": 500} # Internal Server Error
except Exception as e:
error_message = f"非请求库异常: {e}"
logger.error(error_message)
return {"error": error_message, "status_code": 500}
def generate_content(
self,
model_name: str,
contents: list[Dict[str, Any]],
system_instruction: str | None = None,
stream: bool = False, # 新增 stream 参数
) -> Union[str, Tuple[str, int], Any]: # 返回类型调整以适应流式响应
"""
与Gemini模型进行内容生成对话。
:param model_name: 要使用的模型名称。
:param contents: 聊天历史和当前用户消息组成的列表符合Gemini API的"contents"格式。
例如:[{"role": "user", "parts": [{"text": "你好"}]}, ...]
:param system_instruction: 系统的指示语,作为单独的参数传入。
在Gemini API中它是请求体的一个顶级字段。
:param stream: 是否以流式方式获取响应。
:return: 如果 stream 为 False返回模型回复文本或错误信息。
如果 stream 为 True返回一个生成器每次 yield 一个文本片段。
"""
# 根据Gemini API的预期格式构造payload
# 如果contents是空的我们构造一个带有role的内容数组
if not contents:
if system_instruction:
payload: Dict[str, Any] = {
"contents": [
{"role": "user", "parts": [{"text": system_instruction}]}
]
}
else:
payload: Dict[str, Any] = {
"contents": [{"role": "user", "parts": [{"text": ""}]}]
}
else:
# 使用提供的contents
payload: Dict[str, Any] = {"contents": contents}
# 如果有系统指令,我们添加到请求中
if system_instruction:
# Gemini API的systemInstruction是一个顶级字段值是一个Content对象
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
if stream:
response_iterator = self._make_stream_request(
"generateContent", model_name, payload
)
if isinstance(response_iterator, dict) and "error" in response_iterator:
# 如果流式请求本身返回错误,则直接返回错误信息
return str(response_iterator["error"]), int(
response_iterator.get("status_code", 500)
)
def stream_generator():
for text_chunk in response_iterator:
# _make_stream_request 已经处理了JSON解析并提取了文本
# 所以这里直接yield文本块即可
if text_chunk:
yield text_chunk
logger.debug("流式响应结束。")
return stream_generator()
else:
response_json = self._make_request("generateContent", model_name, payload)
if "error" in response_json:
# 如果 _make_request 返回了错误,直接返回错误信息和状态码
return str(response_json["error"]), int(
response_json.get("status_code", 500)
)
if "candidates" in response_json and len(response_json["candidates"]) > 0:
first_candidate = response_json["candidates"][0]
if (
"content" in first_candidate
and "parts" in first_candidate["content"]
and len(first_candidate["content"]["parts"]) > 0
):
# 返回第一个part的文本内容
return first_candidate["content"]["parts"][0]["text"]
# 如果响应中包含拦截信息
if (
"promptFeedback" in response_json
and "blockReason" in response_json["promptFeedback"]
):
reason = response_json["promptFeedback"]["blockReason"]
logger.warning(f"Gemini API 阻止了响应,原因: {reason}")
return (
f"抱歉,你的请求被模型策略阻止了。原因: {reason}",
400,
) # 400 Bad Request for policy block
# 如果没有candidates且没有拦截信息可能存在未知响应结构
logger.warning(
f"Gemini API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
)
return "抱歉AI未能生成有效回复请稍后再试或检查日志。", 500
def count_tokens(
self,
model_name: str,
contents: list[Dict[str, Any]],
system_instruction: str | None = None,
) -> int:
"""
统计给定内容在特定模型中的token数量。
:param model_name: 要使用的模型名称。
:param contents: 要计算token的内容。
:param system_instruction: 系统指令。
:return: token数量整数或-1如果发生错误
"""
# 根据Gemini API的预期格式构造payload确保包含role字段
if not contents:
if system_instruction:
payload: Dict[str, Any] = {
"contents": [
{"role": "user", "parts": [{"text": system_instruction}]}
]
}
else:
payload: Dict[str, Any] = {
"contents": [{"role": "user", "parts": [{"text": ""}]}]
}
else:
payload: Dict[str, Any] = {"contents": contents} # 明确声明payload类型
if system_instruction:
payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
response_json = self._make_request("countTokens", model_name, payload)
if "error" in response_json:
return -1 # 发生错误时返回-1
if "totalTokens" in response_json:
return response_json["totalTokens"]
logger.warning(
f"countTokens API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
)
return -1
def get_models(self) -> Dict[str, Any]:
"""
获取Gemini API可用的模型列表。
:return: 包含模型名称的列表,如果发生错误则为空列表。
"""
logger.info("从 Gemini API 获取模型列表。")
response_json = self._make_get_request("models")
if "error" in response_json:
logger.error(f"获取模型列表失败: {response_json['error']}")
return {
"models": [],
"error": response_json["error"],
"status_code": response_json["status_code"],
}
available_models = []
for m in response_json.get("models", []):
# 过滤出支持 'generateContent' 方法的模型,并去除被阻止的模型
# 'SUPPORTED' 状态表示模型可用
if (
"name" in m
and "supportedGenerationMethods" in m
and "generateContent" in m["supportedGenerationMethods"]
and m.get("lifecycleState", "UNSPECIFIED") == "ACTIVE"
): # 确保模型是活跃的
model_name = m["name"].split("/")[-1] # 提取 'gemini-pro' 部分
available_models.append(model_name)
logger.info(f"成功获取到 {len(available_models)} 个可用模型。")
return {"models": available_models}