# gemini_client.py

import requests
import json
import logging
from config import get_config  # 导入config模块以获取API配置
from typing import Union, Tuple, Any, Dict  # 新增导入

logger = logging.getLogger(__name__)


class GeminiClient:
    """
    封装与Google Gemini API的交互。
    支持直接HTTP请求,以便于使用第三方代理或自定义端点。
    """

    def __init__(self):
        self.config = get_config()
        base_url_from_config = self.config["API"]["GEMINI_API_BASE_URL"].rstrip("/")
        # 尝试从配置的URL中提取正确的base_url,确保它以 /v1beta 结尾且不包含 /models
        if base_url_from_config.endswith("/v1beta"):
            self.base_url = base_url_from_config
        elif base_url_from_config.endswith("/models"):
            # 如果配置的URL以 /models 结尾,则移除 /models 并添加 /v1beta
            self.base_url = base_url_from_config.rsplit("/models", 1)[0] + "/v1beta"
        elif base_url_from_config.endswith("/v1"):  # 兼容旧的 /v1 路径
            self.base_url = base_url_from_config.rsplit("/v1", 1)[0] + "/v1beta"
        else:
            # 假设是根URL,添加 /v1beta
            self.base_url = f"{base_url_from_config}/v1beta"

        self.api_key = self.config["API"].get("GEMINI_API_KEY")
        # 确保API Key是字符串并清理可能的空白字符
        if self.api_key:
            self.api_key = str(self.api_key).strip()

        # 设置请求头,包含Content-Type和x-goog-api-key
        if self.api_key:
            self.headers = {
                "Content-Type": "application/json",
                "x-goog-api-key": self.api_key,
            }
        else:
            self.headers = {"Content-Type": "application/json"}

        if not self.api_key:
            raise ValueError("GEMINI_API_KEY 未配置或为空,Gemini客户端无法初始化。")

        logger.info(f"GeminiClient 初始化完成,基础URL: {self.base_url}")

    def _make_request(
        self, endpoint: str, model_name: str, payload: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        内部方法:执行实际的HTTP POST请求到Gemini API。
        :param endpoint: API端点,例如 "generateContent" 或 "countTokens"。
        :param model_name: 要使用的模型名称,例如 "gemini-pro"。
        :param payload: 请求体(Python字典),将转换为JSON发送。
        :return: 响应JSON(Python字典)。
        :raises requests.exceptions.RequestException: 如果请求失败。
        """
        # 检查模型名称是否已包含"models/"前缀,避免重复
        if model_name.startswith("models/"):
            # 如果已包含前缀,直接使用
            url = f"{self.base_url}/{model_name}:{endpoint}"
        else:
            # 如果不包含前缀,添加"models/"
            url = f"{self.base_url}/models/{model_name}:{endpoint}"

        # 记录请求信息,但隐藏API Key
        safe_url = url
        logger.debug(f"发送请求到 Gemini API: {safe_url}")
        logger.debug(
            f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
        )

        try:
            response = requests.post(
                url,
                headers=self.headers,  # 使用初始化时设置的headers
                json=payload,
                timeout=90,
            )  # 增加超时时间,应对大模型响应慢

            # 先获取响应文本,避免多次访问响应体
            response_text = response.text

            # 然后检查状态码
            response.raise_for_status()  # 对 4xx 或 5xx 状态码抛出 HTTPError

            # 使用已获取的文本解析JSON,而不是调用response.json()
            response_data = json.loads(response_text)
            return response_data
        except requests.exceptions.HTTPError as http_err:
            status_code = http_err.response.status_code if http_err.response else 500
            error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
            logger.error(error_message)
            return {"error": error_message, "status_code": status_code}
        except requests.exceptions.ConnectionError as conn_err:
            error_message = f"连接错误: {conn_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 503}  # Service Unavailable
        except requests.exceptions.Timeout as timeout_err:
            error_message = f"请求超时: {timeout_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 408}  # Request Timeout
        except requests.exceptions.RequestException as req_err:
            error_message = f"未知请求错误: {req_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 500}  # Internal Server Error
        except Exception as e:
            error_message = f"非请求库异常: {e}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 500}

    def _make_stream_request(
        self, endpoint: str, model_name: str, payload: Dict[str, Any]
    ) -> Union[Any, Dict[str, Any]]:  # 返回类型调整为Dict[str, Any]以统一错误返回
        """
        内部方法:执行实际的HTTP POST流式请求到Gemini API。
        :param endpoint: API端点,例如 "generateContent"。
        :param model_name: 要使用的模型名称,例如 "gemini-pro"。
        :param payload: 请求体(Python字典),将转换为JSON发送。
        :return: 响应迭代器或错误信息字典。
        """
        # 检查模型名称是否已包含"models/"前缀,避免重复
        if model_name.startswith("models/"):
            # 如果已包含前缀,直接使用
            url = f"{self.base_url}/{model_name}:{endpoint}?stream=true"
        else:
            # 如果不包含前缀,添加"models/"
            url = f"{self.base_url}/models/{model_name}:{endpoint}?stream=true"

        # 记录请求信息,但隐藏API Key
        safe_url = url
        logger.debug(f"发送流式请求到 Gemini API: {safe_url}")
        logger.debug(
            f"请求Payload: {json.dumps(payload, ensure_ascii=False, indent=2)}"
        )

        try:
            response = requests.post(
                url,
                headers=self.headers,  # 使用初始化时设置的headers
                json=payload,
                stream=True,
                timeout=90,
            )

            # 检查状态码但不消费响应体
            response.raise_for_status()

            # 对于流式请求,直接返回解析后的JSON对象迭代器
            # 这不会立即消费响应体,而是在迭代时逐步消费
            def parse_stream():
                logger.info("开始处理流式响应...")
                full_response = ""
                accumulated_json = ""

                for line in response.iter_lines():
                    if not line:
                        continue

                    try:
                        # 将二进制数据解码为UTF-8字符串
                        decoded_line = line.decode("utf-8")
                        logger.debug(f"原始行数据: {decoded_line[:50]}...")

                        # 处理SSE格式 (data: 开头)
                        if decoded_line.startswith("data: "):
                            json_str = decoded_line[6:].strip()
                            if not json_str:  # 跳过空数据行
                                continue

                            try:
                                chunk_data = json.loads(json_str)
                                logger.debug(
                                    f"SSE格式JSON数据,键: {list(chunk_data.keys())}"
                                )

                                # 检查是否包含文本内容
                                if (
                                    "candidates" in chunk_data
                                    and chunk_data["candidates"]
                                    and "content" in chunk_data["candidates"][0]
                                ):
                                    content = chunk_data["candidates"][0]["content"]

                                    if "parts" in content and content["parts"]:
                                        text_part = content["parts"][0].get("text", "")
                                        if text_part:
                                            logger.debug(
                                                f"提取到文本: {text_part[:30]}..."
                                            )
                                            full_response += text_part
                                            yield text_part
                            except json.JSONDecodeError as e:
                                logger.warning(
                                    f"SSE格式数据解析JSON失败: {e}, 数据: {json_str[:50]}..."
                                )
                        else:
                            # 处理非SSE格式
                            # 尝试积累JSON直到有效
                            accumulated_json += decoded_line
                            try:
                                # 尝试解析完整JSON
                                chunk_data = json.loads(accumulated_json)
                                logger.debug(
                                    f"解析完整JSON成功,键: {list(chunk_data.keys())}"
                                )

                                # 检查是否包含文本内容
                                if (
                                    "candidates" in chunk_data
                                    and chunk_data["candidates"]
                                    and "content" in chunk_data["candidates"][0]
                                ):
                                    content = chunk_data["candidates"][0]["content"]

                                    if "parts" in content and content["parts"]:
                                        text_part = content["parts"][0].get("text", "")
                                        if text_part:
                                            logger.debug(
                                                f"从完整JSON提取到文本: {text_part[:30]}..."
                                            )
                                            full_response += text_part
                                            yield text_part

                                # 成功解析后重置累积的JSON
                                accumulated_json = ""
                            except json.JSONDecodeError:
                                # 继续积累直到有效JSON
                                pass
                    except Exception as e:
                        logger.error(
                            f"处理流式响应行时出错: {str(e)}, line: {str(line)[:100]}"
                        )

                # 如果有积累的响应但还没有产生任何输出,尝试直接返回
                if full_response:
                    logger.info(f"流式响应完成,累积了 {len(full_response)} 字符的响应")
                else:
                    logger.warning("流式响应完成,但没有提取到任何有效文本内容")
                    # 尝试返回一个默认响应,避免空响应
                    yield "抱歉,处理响应时遇到了问题,请重试。"

            return parse_stream()
        except requests.exceptions.HTTPError as http_err:
            status_code = http_err.response.status_code if http_err.response else 500
            error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
            logger.error(error_message)
            return {"error": error_message, "status_code": status_code}
        except requests.exceptions.ConnectionError as conn_err:
            error_message = f"连接错误: {conn_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 503}
        except requests.exceptions.Timeout as timeout_err:
            error_message = f"请求超时: {timeout_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 408}
        except requests.exceptions.RequestException as req_err:
            error_message = f"未知请求错误: {req_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 500}
        except Exception as e:
            error_message = f"非请求库异常: {e}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 500}

    def _make_get_request(self, endpoint: str) -> Dict[str, Any]:
        """
        内部方法:执行实际的HTTP GET请求到Gemini API。
        :param endpoint: API端点,例如 "models"。
        :return: 响应JSON(Python字典)。
        :raises requests.exceptions.RequestException: 如果请求失败。
        """
        if not self.api_key:
            return {
                "error": "API Key 未配置,无法发送请求。",
                "status_code": 401,
            }  # Unauthorized

        # API Key 通过请求头传递,URL中不再包含
        url = f"{self.base_url}/{endpoint}"

        # 记录请求信息,但隐藏API Key
        safe_url = url
        logger.debug(f"发送GET请求到 Gemini API: {safe_url}")

        try:
            response = requests.get(
                url, headers=self.headers, timeout=30  # 使用初始化时设置的headers
            )

            # 先获取响应文本,避免多次访问响应体
            response_text = response.text

            # 然后检查状态码
            response.raise_for_status()  # 对 4xx 或 5xx 状态码抛出 HTTPError

            # 使用已获取的文本解析JSON,而不是调用response.json()
            response_data = json.loads(response_text)
            return response_data
        except requests.exceptions.HTTPError as http_err:
            status_code = http_err.response.status_code if http_err.response else 500
            error_message = f"HTTP错误: {http_err} - 响应文本: {http_err.response.text if http_err.response else '无响应文本'}"
            logger.error(error_message)
            return {"error": error_message, "status_code": status_code}
        except requests.exceptions.ConnectionError as conn_err:
            error_message = f"连接错误: {conn_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 503}  # Service Unavailable
        except requests.exceptions.Timeout as timeout_err:
            error_message = f"请求超时: {timeout_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 408}  # Request Timeout
        except requests.exceptions.RequestException as req_err:
            error_message = f"未知请求错误: {req_err} - URL: {url}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 500}  # Internal Server Error
        except Exception as e:
            error_message = f"非请求库异常: {e}"
            logger.error(error_message)
            return {"error": error_message, "status_code": 500}

    def generate_content(
        self,
        model_name: str,
        contents: list[Dict[str, Any]],
        system_instruction: str | None = None,
        stream: bool = False,  # 新增 stream 参数
    ) -> Union[str, Tuple[str, int], Any]:  # 返回类型调整以适应流式响应
        """
        与Gemini模型进行内容生成对话。
        :param model_name: 要使用的模型名称。
        :param contents: 聊天历史和当前用户消息组成的列表,符合Gemini API的"contents"格式。
                         例如:[{"role": "user", "parts": [{"text": "你好"}]}, ...]
        :param system_instruction: 系统的指示语,作为单独的参数传入。
                                   在Gemini API中,它是请求体的一个顶级字段。
        :param stream: 是否以流式方式获取响应。
        :return: 如果 stream 为 False,返回模型回复文本或错误信息。
                 如果 stream 为 True,返回一个生成器,每次 yield 一个文本片段。
        """
        # 根据Gemini API的预期格式构造payload
        # 如果contents是空的,我们构造一个带有role的内容数组
        if not contents:
            if system_instruction:
                payload: Dict[str, Any] = {
                    "contents": [
                        {"role": "user", "parts": [{"text": system_instruction}]}
                    ]
                }
            else:
                payload: Dict[str, Any] = {
                    "contents": [{"role": "user", "parts": [{"text": ""}]}]
                }
        else:
            # 使用提供的contents
            payload: Dict[str, Any] = {"contents": contents}

            # 如果有系统指令,我们添加到请求中
            if system_instruction:
                # Gemini API的systemInstruction是一个顶级字段,值是一个Content对象
                payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}

        if stream:
            response_iterator = self._make_stream_request(
                "generateContent", model_name, payload
            )
            if isinstance(response_iterator, dict) and "error" in response_iterator:
                # 如果流式请求本身返回错误,则直接返回错误信息
                return str(response_iterator["error"]), int(
                    response_iterator.get("status_code", 500)
                )

            def stream_generator():
                for text_chunk in response_iterator:
                    # _make_stream_request 已经处理了JSON解析并提取了文本
                    # 所以这里直接yield文本块即可
                    if text_chunk:
                        yield text_chunk
                logger.debug("流式响应结束。")

            return stream_generator()
        else:
            response_json = self._make_request("generateContent", model_name, payload)

            if "error" in response_json:
                # 如果 _make_request 返回了错误,直接返回错误信息和状态码
                return str(response_json["error"]), int(
                    response_json.get("status_code", 500)
                )

            if "candidates" in response_json and len(response_json["candidates"]) > 0:
                first_candidate = response_json["candidates"][0]
                if (
                    "content" in first_candidate
                    and "parts" in first_candidate["content"]
                    and len(first_candidate["content"]["parts"]) > 0
                ):
                    # 返回第一个part的文本内容
                    return first_candidate["content"]["parts"][0]["text"]

            # 如果响应中包含拦截信息
            if (
                "promptFeedback" in response_json
                and "blockReason" in response_json["promptFeedback"]
            ):
                reason = response_json["promptFeedback"]["blockReason"]
                logger.warning(f"Gemini API 阻止了响应,原因: {reason}")
                return (
                    f"抱歉,你的请求被模型策略阻止了。原因: {reason}",
                    400,
                )  # 400 Bad Request for policy block

            # 如果没有candidates且没有拦截信息,可能存在未知响应结构
            logger.warning(
                f"Gemini API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
            )
            return "抱歉,AI未能生成有效回复,请稍后再试或检查日志。", 500

    def count_tokens(
        self,
        model_name: str,
        contents: list[Dict[str, Any]],
        system_instruction: str | None = None,
    ) -> int:
        """
        统计给定内容在特定模型中的token数量。
        :param model_name: 要使用的模型名称。
        :param contents: 要计算token的内容。
        :param system_instruction: 系统指令。
        :return: token数量(整数)或-1(如果发生错误)。
        """
        # 根据Gemini API的预期格式构造payload,确保包含role字段
        if not contents:
            if system_instruction:
                payload: Dict[str, Any] = {
                    "contents": [
                        {"role": "user", "parts": [{"text": system_instruction}]}
                    ]
                }
            else:
                payload: Dict[str, Any] = {
                    "contents": [{"role": "user", "parts": [{"text": ""}]}]
                }
        else:
            payload: Dict[str, Any] = {"contents": contents}  # 明确声明payload类型
            if system_instruction:
                payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}

        response_json = self._make_request("countTokens", model_name, payload)

        if "error" in response_json:
            return -1  # 发生错误时返回-1

        if "totalTokens" in response_json:
            return response_json["totalTokens"]
        logger.warning(
            f"countTokens API 返回了非预期的响应结构: {json.dumps(response_json, ensure_ascii=False)}"
        )
        return -1

    def get_models(self) -> Dict[str, Any]:
        """
        获取Gemini API可用的模型列表。
        :return: 包含模型名称的列表,如果发生错误则为空列表。
        """
        logger.info("从 Gemini API 获取模型列表。")
        response_json = self._make_get_request("models")

        if "error" in response_json:
            logger.error(f"获取模型列表失败: {response_json['error']}")
            return {
                "models": [],
                "error": response_json["error"],
                "status_code": response_json["status_code"],
            }

        available_models = []
        for m in response_json.get("models", []):
            # 过滤出支持 'generateContent' 方法的模型,并去除被阻止的模型
            # 'SUPPORTED' 状态表示模型可用
            if (
                "name" in m
                and "supportedGenerationMethods" in m
                and "generateContent" in m["supportedGenerationMethods"]
                and m.get("lifecycleState", "UNSPECIFIED") == "ACTIVE"
            ):  # 确保模型是活跃的
                model_name = m["name"].split("/")[-1]  # 提取 'gemini-pro' 部分
                available_models.append(model_name)
        logger.info(f"成功获取到 {len(available_models)} 个可用模型。")
        return {"models": available_models}