""" 旋转验证码求解器 ONNX 推理 → (sin, cos) → atan2 → 角度 """ import math from pathlib import Path import numpy as np from PIL import Image from config import ONNX_DIR, SOLVER_CONFIG from solvers.base import BaseSolver class RotateSolver(BaseSolver): """旋转验证码求解器。""" def __init__(self, onnx_path: str | Path | None = None): self.cfg = SOLVER_CONFIG["rotate"] self._onnx_session = None self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "rotation_regressor.onnx" def _load_onnx(self): """延迟加载 ONNX 模型。""" if self._onnx_session is not None: return if not self._onnx_path.exists(): raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}") import onnxruntime as ort self._onnx_session = ort.InferenceSession( str(self._onnx_path), providers=["CPUExecutionProvider"] ) def solve(self, image: Image.Image | str | Path, **kwargs) -> dict: """ 求解旋转验证码。 Args: image: 输入图片 (RGB) Returns: {"angle": float, "confidence": float} """ if isinstance(image, (str, Path)): image = Image.open(str(image)).convert("RGB") else: image = image.convert("RGB") self._load_onnx() h, w = self.cfg["input_size"] # 预处理: RGB resize + normalize img = image.resize((w, h)) arr = np.array(img, dtype=np.float32) / 255.0 # Normalize per channel: (x - 0.5) / 0.5 arr = (arr - 0.5) / 0.5 # HWC → CHW → NCHW arr = arr.transpose(2, 0, 1)[np.newaxis, :, :, :] outputs = self._onnx_session.run(None, {"input": arr}) sin_val = float(outputs[0][0][0]) cos_val = float(outputs[0][0][1]) # atan2 → 角度 angle_rad = math.atan2(sin_val, cos_val) angle_deg = math.degrees(angle_rad) if angle_deg < 0: angle_deg += 360.0 # 置信度: sin^2 + cos^2 接近 1 表示预测稳定 magnitude = math.sqrt(sin_val ** 2 + cos_val ** 2) confidence = min(magnitude, 1.0) return { "angle": round(angle_deg, 1), "confidence": round(confidence, 3), }