""" FunCaptcha 专项 ONNX 推理。 """ from __future__ import annotations import io import os import time from pathlib import Path import numpy as np from PIL import Image from config import FUN_CAPTCHA_TASKS, INFERENCE_CONFIG from inference.model_metadata import load_model_metadata from inference.pipeline import _try_import_ort class FunCaptchaRollballPipeline: """ `4_3d_rollball_animals` 专项推理器。 输入整张 challenge 图片,内部自动裁切 reference / candidates, 再使用 Siamese ONNX 模型逐个候选打分。 """ def __init__(self, question: str = "4_3d_rollball_animals", models_dir: str | None = None): if question not in FUN_CAPTCHA_TASKS: raise ValueError(f"不支持的 FunCaptcha question: {question}") ort = _try_import_ort() self.question = question self.task_cfg = FUN_CAPTCHA_TASKS[question] self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"]) self.model_path = self._resolve_model_path() opts = ort.SessionOptions() opts.inter_op_num_threads = 1 opts.intra_op_num_threads = 2 self.session = ort.InferenceSession( str(self.model_path), sess_options=opts, providers=["CPUExecutionProvider"], ) self.metadata = load_model_metadata(self.model_path) or {} self.preprocess_mode = self._resolve_preprocess_mode(self.metadata) self.mean = float(INFERENCE_CONFIG["normalize_mean"]) self.std = float(INFERENCE_CONFIG["normalize_std"]) self.answer_index_base = int( self.metadata.get("answer_index_base", self.task_cfg["answer_index_base"]) ) def solve(self, image) -> dict: t0 = time.perf_counter() challenge = self._load_image(image) candidates, reference = self._split_challenge(challenge) ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0) input_defs = self.session.get_inputs() input_names = [inp.name for inp in input_defs] if len(input_names) != 2: raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}") logits = self._run_model(input_defs, input_names, candidates, ref_batch) scores = 1.0 / (1.0 + np.exp(-logits)) answer_idx = int(np.argmax(logits)) selected = answer_idx + self.answer_index_base elapsed = (time.perf_counter() - t0) * 1000 return { "type": "funcaptcha", "question": self.question, "objects": [selected], "scores": [round(float(score), 6) for score in scores.tolist()], "raw": str(selected), "result": str(selected), "time_ms": round(elapsed, 2), } def _split_challenge(self, image: Image.Image) -> tuple[np.ndarray, np.ndarray]: tile_w, tile_h = self.metadata.get("tile_size", self.task_cfg["tile_size"]) ref_box = tuple(self.metadata.get("reference_box", self.task_cfg["reference_box"])) num_candidates = int(self.metadata.get("num_candidates", self.task_cfg["num_candidates"])) input_h, input_w = self.task_cfg["input_size"] candidates = [] for idx in range(num_candidates): left = idx * tile_w candidate = image.crop((left, 0, left + tile_w, tile_h)) candidates.append(self._preprocess(candidate, (input_h, input_w))) reference = image.crop(ref_box) return ( np.concatenate(candidates, axis=0), self._preprocess(reference, (input_h, input_w)), ) def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray: img_h, img_w = target_size image = image.convert("RGB") if self.preprocess_mode == "rgb_centered": image = image.resize((img_w, img_h), Image.BILINEAR) elif self.preprocess_mode == "rgb_255": # 对齐 funcaptcha-server 现有 ONNX 的预处理行为。 image = image.resize((img_w, img_h)) else: raise ValueError(f"不支持的 FunCaptcha 预处理模式: {self.preprocess_mode}") arr = np.asarray(image, dtype=np.float32) / 255.0 if self.preprocess_mode == "rgb_centered": arr = (arr - self.mean) / self.std arr = np.transpose(arr, (2, 0, 1)) return arr.reshape(1, 3, img_h, img_w) @staticmethod def _load_image(image) -> Image.Image: if isinstance(image, Image.Image): return image if isinstance(image, (str, Path)): return Image.open(image).convert("RGB") if isinstance(image, bytes): return Image.open(io.BytesIO(image)).convert("RGB") raise TypeError(f"不支持的图片输入类型: {type(image)}") def _resolve_model_path(self) -> Path: candidates = [self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"] env_name = self.task_cfg.get("external_model_env") env_value = os.getenv(env_name) if env_name else None if env_value: candidates.append(Path(env_value).expanduser()) for fallback in self.task_cfg.get("fallback_model_paths", []): candidates.append(Path(fallback).expanduser()) for candidate in candidates: if candidate.exists(): return candidate tried = ", ".join(str(path) for path in candidates) raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型,已尝试: {tried}") @staticmethod def _resolve_preprocess_mode(metadata: dict) -> str: preprocess = metadata.get("preprocess") if preprocess: return str(preprocess) if metadata.get("task") == "funcaptcha_siamese": return "rgb_centered" return "rgb_255" def _run_model( self, input_defs, input_names: list[str], candidates: np.ndarray, reference_batch: np.ndarray, ) -> np.ndarray: batch_size = candidates.shape[0] batch_axis = None for input_def in input_defs: shape = getattr(input_def, "shape", None) if isinstance(shape, (list, tuple)) and shape: batch_axis = shape[0] break if batch_axis in (None, "batch", "None", -1) or not isinstance(batch_axis, int): return self.session.run(None, { input_names[0]: candidates, input_names[1]: reference_batch, })[0].reshape(-1) if batch_axis == batch_size: return self.session.run(None, { input_names[0]: candidates, input_names[1]: reference_batch, })[0].reshape(-1) if batch_axis != 1: raise RuntimeError(f"专项模型不支持当前 batch 维度: expected={batch_axis} actual={batch_size}") outputs = [] for idx in range(batch_size): out = self.session.run(None, { input_names[0]: candidates[idx:idx + 1], input_names[1]: reference_batch[idx:idx + 1], })[0] outputs.append(np.asarray(out, dtype=np.float32).reshape(-1)) return np.concatenate(outputs, axis=0)