Align task API and add FunCaptcha support
This commit is contained in:
117
inference/fun_captcha.py
Normal file
117
inference/fun_captcha.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
FunCaptcha 专项 ONNX 推理。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import FUN_CAPTCHA_TASKS, INFERENCE_CONFIG
|
||||
from inference.model_metadata import load_model_metadata
|
||||
from inference.pipeline import _try_import_ort
|
||||
|
||||
|
||||
class FunCaptchaRollballPipeline:
|
||||
"""
|
||||
`4_3d_rollball_animals` 专项推理器。
|
||||
|
||||
输入整张 challenge 图片,内部自动裁切 reference / candidates,
|
||||
再使用 Siamese ONNX 模型逐个候选打分。
|
||||
"""
|
||||
|
||||
def __init__(self, question: str = "4_3d_rollball_animals", models_dir: str | None = None):
|
||||
if question not in FUN_CAPTCHA_TASKS:
|
||||
raise ValueError(f"不支持的 FunCaptcha question: {question}")
|
||||
|
||||
ort = _try_import_ort()
|
||||
self.question = question
|
||||
self.task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||||
self.model_path = self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型: {self.model_path}")
|
||||
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 2
|
||||
self.session = ort.InferenceSession(
|
||||
str(self.model_path),
|
||||
sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.metadata = load_model_metadata(self.model_path) or {}
|
||||
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
|
||||
self.std = float(INFERENCE_CONFIG["normalize_std"])
|
||||
self.answer_index_base = int(
|
||||
self.metadata.get("answer_index_base", self.task_cfg["answer_index_base"])
|
||||
)
|
||||
|
||||
def solve(self, image) -> dict:
|
||||
t0 = time.perf_counter()
|
||||
challenge = self._load_image(image)
|
||||
candidates, reference = self._split_challenge(challenge)
|
||||
|
||||
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
|
||||
input_names = [inp.name for inp in self.session.get_inputs()]
|
||||
if len(input_names) != 2:
|
||||
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
|
||||
|
||||
logits = self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: ref_batch,
|
||||
})[0].reshape(-1)
|
||||
scores = 1.0 / (1.0 + np.exp(-logits))
|
||||
answer_idx = int(np.argmax(logits))
|
||||
selected = answer_idx + self.answer_index_base
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
|
||||
return {
|
||||
"type": "funcaptcha",
|
||||
"question": self.question,
|
||||
"objects": [selected],
|
||||
"scores": [round(float(score), 6) for score in scores.tolist()],
|
||||
"raw": str(selected),
|
||||
"result": str(selected),
|
||||
"time_ms": round(elapsed, 2),
|
||||
}
|
||||
|
||||
def _split_challenge(self, image: Image.Image) -> tuple[np.ndarray, np.ndarray]:
|
||||
tile_w, tile_h = self.metadata.get("tile_size", self.task_cfg["tile_size"])
|
||||
ref_box = tuple(self.metadata.get("reference_box", self.task_cfg["reference_box"]))
|
||||
num_candidates = int(self.metadata.get("num_candidates", self.task_cfg["num_candidates"]))
|
||||
input_h, input_w = self.task_cfg["input_size"]
|
||||
|
||||
candidates = []
|
||||
for idx in range(num_candidates):
|
||||
left = idx * tile_w
|
||||
candidate = image.crop((left, 0, left + tile_w, tile_h))
|
||||
candidates.append(self._preprocess(candidate, (input_h, input_w)))
|
||||
|
||||
reference = image.crop(ref_box)
|
||||
return (
|
||||
np.concatenate(candidates, axis=0),
|
||||
self._preprocess(reference, (input_h, input_w)),
|
||||
)
|
||||
|
||||
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||||
img_h, img_w = target_size
|
||||
image = image.convert("RGB").resize((img_w, img_h), Image.BILINEAR)
|
||||
arr = np.asarray(image, dtype=np.float32) / 255.0
|
||||
arr = (arr - self.mean) / self.std
|
||||
arr = np.transpose(arr, (2, 0, 1))
|
||||
return arr.reshape(1, 3, img_h, img_w)
|
||||
|
||||
@staticmethod
|
||||
def _load_image(image) -> Image.Image:
|
||||
if isinstance(image, Image.Image):
|
||||
return image
|
||||
if isinstance(image, (str, Path)):
|
||||
return Image.open(image).convert("RGB")
|
||||
if isinstance(image, bytes):
|
||||
return Image.open(io.BytesIO(image)).convert("RGB")
|
||||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||||
Reference in New Issue
Block a user