Files
CaptchBreaker/inference/fun_captcha.py

191 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FunCaptcha 专项 ONNX 推理。
"""
from __future__ import annotations
import io
import os
import time
from pathlib import Path
import numpy as np
from PIL import Image
from config import FUN_CAPTCHA_TASKS, INFERENCE_CONFIG
from inference.model_metadata import load_model_metadata
from inference.pipeline import _try_import_ort
class FunCaptchaRollballPipeline:
"""
`4_3d_rollball_animals` 专项推理器。
输入整张 challenge 图片,内部自动裁切 reference / candidates
再使用 Siamese ONNX 模型逐个候选打分。
"""
def __init__(self, question: str = "4_3d_rollball_animals", models_dir: str | None = None):
if question not in FUN_CAPTCHA_TASKS:
raise ValueError(f"不支持的 FunCaptcha question: {question}")
ort = _try_import_ort()
self.question = question
self.task_cfg = FUN_CAPTCHA_TASKS[question]
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
self.model_path = self._resolve_model_path()
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 2
self.session = ort.InferenceSession(
str(self.model_path),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
self.metadata = load_model_metadata(self.model_path) or {}
self.preprocess_mode = self._resolve_preprocess_mode(self.metadata)
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
self.std = float(INFERENCE_CONFIG["normalize_std"])
self.answer_index_base = int(
self.metadata.get("answer_index_base", self.task_cfg["answer_index_base"])
)
def solve(self, image) -> dict:
t0 = time.perf_counter()
challenge = self._load_image(image)
candidates, reference = self._split_challenge(challenge)
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
input_defs = self.session.get_inputs()
input_names = [inp.name for inp in input_defs]
if len(input_names) != 2:
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
logits = self._run_model(input_defs, input_names, candidates, ref_batch)
scores = 1.0 / (1.0 + np.exp(-logits))
answer_idx = int(np.argmax(logits))
selected = answer_idx + self.answer_index_base
elapsed = (time.perf_counter() - t0) * 1000
return {
"type": "funcaptcha",
"question": self.question,
"objects": [selected],
"scores": [round(float(score), 6) for score in scores.tolist()],
"raw": str(selected),
"result": str(selected),
"time_ms": round(elapsed, 2),
}
def _split_challenge(self, image: Image.Image) -> tuple[np.ndarray, np.ndarray]:
tile_w, tile_h = self.metadata.get("tile_size", self.task_cfg["tile_size"])
ref_box = tuple(self.metadata.get("reference_box", self.task_cfg["reference_box"]))
num_candidates = int(self.metadata.get("num_candidates", self.task_cfg["num_candidates"]))
input_h, input_w = self.task_cfg["input_size"]
candidates = []
for idx in range(num_candidates):
left = idx * tile_w
candidate = image.crop((left, 0, left + tile_w, tile_h))
candidates.append(self._preprocess(candidate, (input_h, input_w)))
reference = image.crop(ref_box)
return (
np.concatenate(candidates, axis=0),
self._preprocess(reference, (input_h, input_w)),
)
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
img_h, img_w = target_size
image = image.convert("RGB")
if self.preprocess_mode == "rgb_centered":
image = image.resize((img_w, img_h), Image.BILINEAR)
elif self.preprocess_mode == "rgb_255":
# 对齐 funcaptcha-server 现有 ONNX 的预处理行为。
image = image.resize((img_w, img_h))
else:
raise ValueError(f"不支持的 FunCaptcha 预处理模式: {self.preprocess_mode}")
arr = np.asarray(image, dtype=np.float32) / 255.0
if self.preprocess_mode == "rgb_centered":
arr = (arr - self.mean) / self.std
arr = np.transpose(arr, (2, 0, 1))
return arr.reshape(1, 3, img_h, img_w)
@staticmethod
def _load_image(image) -> Image.Image:
if isinstance(image, Image.Image):
return image
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
if isinstance(image, bytes):
return Image.open(io.BytesIO(image)).convert("RGB")
raise TypeError(f"不支持的图片输入类型: {type(image)}")
def _resolve_model_path(self) -> Path:
candidates = [self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"]
env_name = self.task_cfg.get("external_model_env")
env_value = os.getenv(env_name) if env_name else None
if env_value:
candidates.append(Path(env_value).expanduser())
for fallback in self.task_cfg.get("fallback_model_paths", []):
candidates.append(Path(fallback).expanduser())
for candidate in candidates:
if candidate.exists():
return candidate
tried = ", ".join(str(path) for path in candidates)
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型,已尝试: {tried}")
@staticmethod
def _resolve_preprocess_mode(metadata: dict) -> str:
preprocess = metadata.get("preprocess")
if preprocess:
return str(preprocess)
if metadata.get("task") == "funcaptcha_siamese":
return "rgb_centered"
return "rgb_255"
def _run_model(
self,
input_defs,
input_names: list[str],
candidates: np.ndarray,
reference_batch: np.ndarray,
) -> np.ndarray:
batch_size = candidates.shape[0]
batch_axis = None
for input_def in input_defs:
shape = getattr(input_def, "shape", None)
if isinstance(shape, (list, tuple)) and shape:
batch_axis = shape[0]
break
if batch_axis in (None, "batch", "None", -1) or not isinstance(batch_axis, int):
return self.session.run(None, {
input_names[0]: candidates,
input_names[1]: reference_batch,
})[0].reshape(-1)
if batch_axis == batch_size:
return self.session.run(None, {
input_names[0]: candidates,
input_names[1]: reference_batch,
})[0].reshape(-1)
if batch_axis != 1:
raise RuntimeError(f"专项模型不支持当前 batch 维度: expected={batch_axis} actual={batch_size}")
outputs = []
for idx in range(batch_size):
out = self.session.run(None, {
input_names[0]: candidates[idx:idx + 1],
input_names[1]: reference_batch[idx:idx + 1],
})[0]
outputs.append(np.asarray(out, dtype=np.float32).reshape(-1))
return np.concatenate(outputs, axis=0)