Files
CaptchBreaker/inference/fun_captcha.py

118 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FunCaptcha 专项 ONNX 推理。
"""
from __future__ import annotations
import io
import time
from pathlib import Path
import numpy as np
from PIL import Image
from config import FUN_CAPTCHA_TASKS, INFERENCE_CONFIG
from inference.model_metadata import load_model_metadata
from inference.pipeline import _try_import_ort
class FunCaptchaRollballPipeline:
"""
`4_3d_rollball_animals` 专项推理器。
输入整张 challenge 图片,内部自动裁切 reference / candidates
再使用 Siamese ONNX 模型逐个候选打分。
"""
def __init__(self, question: str = "4_3d_rollball_animals", models_dir: str | None = None):
if question not in FUN_CAPTCHA_TASKS:
raise ValueError(f"不支持的 FunCaptcha question: {question}")
ort = _try_import_ort()
self.question = question
self.task_cfg = FUN_CAPTCHA_TASKS[question]
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
self.model_path = self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"
if not self.model_path.exists():
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型: {self.model_path}")
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 2
self.session = ort.InferenceSession(
str(self.model_path),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
self.metadata = load_model_metadata(self.model_path) or {}
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
self.std = float(INFERENCE_CONFIG["normalize_std"])
self.answer_index_base = int(
self.metadata.get("answer_index_base", self.task_cfg["answer_index_base"])
)
def solve(self, image) -> dict:
t0 = time.perf_counter()
challenge = self._load_image(image)
candidates, reference = self._split_challenge(challenge)
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
input_names = [inp.name for inp in self.session.get_inputs()]
if len(input_names) != 2:
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
logits = self.session.run(None, {
input_names[0]: candidates,
input_names[1]: ref_batch,
})[0].reshape(-1)
scores = 1.0 / (1.0 + np.exp(-logits))
answer_idx = int(np.argmax(logits))
selected = answer_idx + self.answer_index_base
elapsed = (time.perf_counter() - t0) * 1000
return {
"type": "funcaptcha",
"question": self.question,
"objects": [selected],
"scores": [round(float(score), 6) for score in scores.tolist()],
"raw": str(selected),
"result": str(selected),
"time_ms": round(elapsed, 2),
}
def _split_challenge(self, image: Image.Image) -> tuple[np.ndarray, np.ndarray]:
tile_w, tile_h = self.metadata.get("tile_size", self.task_cfg["tile_size"])
ref_box = tuple(self.metadata.get("reference_box", self.task_cfg["reference_box"]))
num_candidates = int(self.metadata.get("num_candidates", self.task_cfg["num_candidates"]))
input_h, input_w = self.task_cfg["input_size"]
candidates = []
for idx in range(num_candidates):
left = idx * tile_w
candidate = image.crop((left, 0, left + tile_w, tile_h))
candidates.append(self._preprocess(candidate, (input_h, input_w)))
reference = image.crop(ref_box)
return (
np.concatenate(candidates, axis=0),
self._preprocess(reference, (input_h, input_w)),
)
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
img_h, img_w = target_size
image = image.convert("RGB").resize((img_w, img_h), Image.BILINEAR)
arr = np.asarray(image, dtype=np.float32) / 255.0
arr = (arr - self.mean) / self.std
arr = np.transpose(arr, (2, 0, 1))
return arr.reshape(1, 3, img_h, img_w)
@staticmethod
def _load_image(image) -> Image.Image:
if isinstance(image, Image.Image):
return image
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
if isinstance(image, bytes):
return Image.open(io.BytesIO(image)).convert("RGB")
raise TypeError(f"不支持的图片输入类型: {type(image)}")