""" 核心推理流水线 加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。 推理流程: 输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出 对算式类型,解码后还会调用 math_eval 计算结果。 """ import io import time from pathlib import Path import numpy as np from PIL import Image from config import ( CAPTCHA_TYPES, IMAGE_SIZE, INFERENCE_CONFIG, NORMAL_CHARS, MATH_CHARS, THREED_CHARS, ) from inference.math_eval import eval_captcha_math def _try_import_ort(): """延迟导入 onnxruntime,给出友好错误提示。""" try: import onnxruntime as ort return ort except ImportError: raise ImportError( "推理需要 onnxruntime,请安装: uv pip install onnxruntime" ) class CaptchaPipeline: """ 核心推理流水线。 加载调度模型和所有专家模型 (ONNX 格式)。 提供统一的 solve(image) 接口。 """ def __init__(self, models_dir: str | None = None): """ 初始化加载所有 ONNX 模型。 Args: models_dir: ONNX 模型目录,默认使用 config 中的路径 """ ort = _try_import_ort() self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"]) self.mean = INFERENCE_CONFIG["normalize_mean"] self.std = INFERENCE_CONFIG["normalize_std"] # 字符集映射 self._chars = { "normal": NORMAL_CHARS, "math": MATH_CHARS, "3d": THREED_CHARS, } # 专家模型名 → ONNX 文件名 self._model_files = { "classifier": "classifier.onnx", "normal": "normal.onnx", "math": "math.onnx", "3d": "threed.onnx", } # 加载所有可用模型 opts = ort.SessionOptions() opts.inter_op_num_threads = 1 opts.intra_op_num_threads = 2 self._sessions: dict[str, "ort.InferenceSession"] = {} for name, fname in self._model_files.items(): path = self.models_dir / fname if path.exists(): self._sessions[name] = ort.InferenceSession( str(path), sess_options=opts, providers=["CPUExecutionProvider"], ) loaded = list(self._sessions.keys()) if not loaded: raise FileNotFoundError( f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}" ) # ---------------------------------------------------------- # 公共接口 # ---------------------------------------------------------- def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray: """ 图片预处理: resize, grayscale, normalize, 转 numpy。 Args: image: PIL Image target_size: (H, W) Returns: (1, 1, H, W) float32 ndarray """ h, w = target_size img = image.convert("L").resize((w, h), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 arr = (arr - self.mean) / self.std return arr.reshape(1, 1, h, w) def classify(self, image: Image.Image) -> str: """ 调度分类,返回类型名: 'normal' / 'math' / '3d'。 Raises: RuntimeError: 分类器模型未加载 """ if "classifier" not in self._sessions: raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx") inp = self.preprocess(image, IMAGE_SIZE["classifier"]) session = self._sessions["classifier"] input_name = session.get_inputs()[0].name logits = session.run(None, {input_name: inp})[0] # (1, num_types) idx = int(np.argmax(logits, axis=1)[0]) return CAPTCHA_TYPES[idx] def solve( self, image, captcha_type: str | None = None, ) -> dict: """ 完整识别流程。 Args: image: PIL.Image 或文件路径 (str/Path) 或 bytes captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d') Returns: dict: { "type": str, # 验证码类型 "raw": str, # OCR 原始识别结果 "result": str, # 最终答案 (算式型为计算结果) "time_ms": float, # 推理耗时 (毫秒) } """ t0 = time.perf_counter() # 1. 解析输入 img = self._load_image(image) # 2. 分类 if captcha_type is None: captcha_type = self.classify(img) # 3. 路由到专家模型 if captcha_type not in self._sessions: raise RuntimeError( f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型" ) size_key = captcha_type # "normal"/"math"/"3d" inp = self.preprocess(img, IMAGE_SIZE[size_key]) session = self._sessions[captcha_type] input_name = session.get_inputs()[0].name logits = session.run(None, {input_name: inp})[0] # (T, 1, C) # 4. CTC 贪心解码 chars = self._chars[captcha_type] raw_text = self._ctc_greedy_decode(logits, chars) # 5. 后处理 if captcha_type == "math": try: result = eval_captcha_math(raw_text) except ValueError: result = raw_text # 解析失败则返回原始文本 else: result = raw_text elapsed = (time.perf_counter() - t0) * 1000 return { "type": captcha_type, "raw": raw_text, "result": result, "time_ms": round(elapsed, 2), } # ---------------------------------------------------------- # 私有方法 # ---------------------------------------------------------- @staticmethod def _load_image(image) -> Image.Image: """将多种输入类型统一转为 PIL Image。""" if isinstance(image, Image.Image): return image if isinstance(image, (str, Path)): return Image.open(image).convert("RGB") if isinstance(image, bytes): return Image.open(io.BytesIO(image)).convert("RGB") raise TypeError(f"不支持的图片输入类型: {type(image)}") @staticmethod def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str: """ CTC 贪心解码 (numpy 版本)。 Args: logits: (T, B, C) ONNX 输出 chars: 字符集 (不含 blank, blank=index 0) Returns: 解码后的字符串 """ # 取 batch=0 preds = np.argmax(logits[:, 0, :], axis=1) # (T,) decoded = [] prev = -1 for idx in preds: if idx != 0 and idx != prev: decoded.append(chars[idx - 1]) prev = idx return "".join(decoded)