Align task API and add FunCaptcha support
This commit is contained in:
@@ -2,16 +2,19 @@
|
||||
推理包
|
||||
|
||||
- pipeline.py: CaptchaPipeline 核心推理流水线
|
||||
- fun_captcha.py: FunCaptcha 专项推理
|
||||
- export_onnx.py: PyTorch → ONNX 导出
|
||||
- math_eval.py: 算式计算模块
|
||||
"""
|
||||
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
from inference.fun_captcha import FunCaptchaRollballPipeline
|
||||
from inference.math_eval import eval_captcha_math
|
||||
from inference.export_onnx import export_model, export_all
|
||||
|
||||
__all__ = [
|
||||
"CaptchaPipeline",
|
||||
"FunCaptchaRollballPipeline",
|
||||
"eval_captcha_math",
|
||||
"export_model",
|
||||
"export_all",
|
||||
|
||||
@@ -9,7 +9,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from config import (
|
||||
CAPTCHA_TYPES,
|
||||
CHECKPOINTS_DIR,
|
||||
FUN_CAPTCHA_TASKS,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
@@ -19,20 +21,28 @@ from config import (
|
||||
NUM_CAPTCHA_TYPES,
|
||||
REGRESSION_RANGE,
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_REGRESSION_RANGE,
|
||||
)
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
from models.fun_captcha_siamese import FunCaptchaSiamese
|
||||
|
||||
|
||||
def export_model(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
input_shape: tuple,
|
||||
input_shape: tuple | None = None,
|
||||
onnx_dir: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
dummy_inputs: tuple[torch.Tensor, ...] | None = None,
|
||||
input_names: list[str] | None = None,
|
||||
output_names: list[str] | None = None,
|
||||
dynamic_axes: dict | None = None,
|
||||
):
|
||||
"""
|
||||
导出单个模型为 ONNX。
|
||||
@@ -52,25 +62,41 @@ def export_model(
|
||||
model.eval()
|
||||
model.cpu()
|
||||
|
||||
dummy = torch.randn(1, *input_shape)
|
||||
if dummy_inputs is None:
|
||||
if input_shape is None:
|
||||
raise ValueError("input_shape 和 dummy_inputs 不能同时为空")
|
||||
dummy_inputs = (torch.randn(1, *input_shape),)
|
||||
if input_names is None:
|
||||
input_names = ["input"] if len(dummy_inputs) == 1 else [f"input_{i}" for i in range(len(dummy_inputs))]
|
||||
if output_names is None:
|
||||
output_names = ["output"]
|
||||
|
||||
# 分类器和识别器的 dynamic_axes 不同
|
||||
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider", "gap_detector", "rotation_regressor"):
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
else:
|
||||
# CTC 模型: output shape = (T, B, C)
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}}
|
||||
if dynamic_axes is None:
|
||||
if len(dummy_inputs) > 1:
|
||||
dynamic_axes = {name: {0: "batch"} for name in input_names}
|
||||
dynamic_axes.update({name: {0: "batch"} for name in output_names})
|
||||
elif model_name == "classifier" or model_name in (
|
||||
"threed_rotate", "threed_slider", "gap_detector", "rotation_regressor",
|
||||
"funcaptcha_rollball_animals",
|
||||
):
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
else:
|
||||
# CTC 模型: output shape = (T, B, C)
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}}
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy,
|
||||
dummy_inputs[0] if len(dummy_inputs) == 1 else dummy_inputs,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes if ONNX_CONFIG["dynamic_batch"] else None,
|
||||
)
|
||||
|
||||
if metadata is not None:
|
||||
write_model_metadata(onnx_path, metadata)
|
||||
|
||||
size_kb = onnx_path.stat().st_size / 1024
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({size_kb:.1f} KB)")
|
||||
|
||||
@@ -86,47 +112,126 @@ def _load_and_export(model_name: str):
|
||||
acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
|
||||
|
||||
metadata = None
|
||||
|
||||
if model_name == "classifier":
|
||||
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
|
||||
h, w = IMAGE_SIZE["classifier"]
|
||||
input_shape = (1, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "classifier",
|
||||
"class_names": list(ckpt.get("class_names", CAPTCHA_TYPES)),
|
||||
"input_shape": [1, h, w],
|
||||
}
|
||||
elif model_name == "normal":
|
||||
chars = ckpt.get("chars", NORMAL_CHARS)
|
||||
h, w = IMAGE_SIZE["normal"]
|
||||
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "ctc",
|
||||
"chars": chars,
|
||||
"input_shape": [1, h, w],
|
||||
}
|
||||
elif model_name == "math":
|
||||
chars = ckpt.get("chars", MATH_CHARS)
|
||||
h, w = IMAGE_SIZE["math"]
|
||||
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "ctc",
|
||||
"chars": chars,
|
||||
"input_shape": [1, h, w],
|
||||
}
|
||||
elif model_name == "threed_text":
|
||||
chars = ckpt.get("chars", THREED_CHARS)
|
||||
h, w = IMAGE_SIZE["3d_text"]
|
||||
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "ctc",
|
||||
"chars": chars,
|
||||
"input_shape": [1, h, w],
|
||||
}
|
||||
elif model_name == "threed_rotate":
|
||||
h, w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "regression",
|
||||
"label_range": list(ckpt.get("label_range", REGRESSION_RANGE["3d_rotate"])),
|
||||
"input_shape": [1, h, w],
|
||||
}
|
||||
elif model_name == "threed_slider":
|
||||
h, w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "regression",
|
||||
"label_range": list(ckpt.get("label_range", REGRESSION_RANGE["3d_slider"])),
|
||||
"input_shape": [1, h, w],
|
||||
}
|
||||
elif model_name == "gap_detector":
|
||||
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
||||
model = GapDetectorCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "regression",
|
||||
"label_range": list(ckpt.get("label_range", SOLVER_REGRESSION_RANGE["slide"])),
|
||||
"input_shape": [1, h, w],
|
||||
}
|
||||
elif model_name == "rotation_regressor":
|
||||
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
||||
model = RotationRegressor(img_h=h, img_w=w)
|
||||
input_shape = (3, h, w)
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "rotation_solver",
|
||||
"output_encoding": "sin_cos",
|
||||
"input_shape": [3, h, w],
|
||||
}
|
||||
elif model_name == "funcaptcha_rollball_animals":
|
||||
question = "4_3d_rollball_animals"
|
||||
task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||
h, w = task_cfg["input_size"]
|
||||
model = FunCaptchaSiamese(in_channels=task_cfg["channels"])
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"task": "funcaptcha_siamese",
|
||||
"question": question,
|
||||
"num_candidates": int(ckpt.get("num_candidates", task_cfg["num_candidates"])),
|
||||
"tile_size": list(ckpt.get("tile_size", task_cfg["tile_size"])),
|
||||
"reference_box": list(ckpt.get("reference_box", task_cfg["reference_box"])),
|
||||
"answer_index_base": int(ckpt.get("answer_index_base", task_cfg["answer_index_base"])),
|
||||
"input_shape": list(ckpt.get("input_shape", [task_cfg["channels"], h, w])),
|
||||
}
|
||||
else:
|
||||
print(f"[错误] 未知模型: {model_name}")
|
||||
return
|
||||
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
export_model(model, model_name, input_shape)
|
||||
if model_name == "funcaptcha_rollball_animals":
|
||||
channels, h, w = metadata["input_shape"]
|
||||
export_model(
|
||||
model,
|
||||
model_name,
|
||||
metadata=metadata,
|
||||
dummy_inputs=(
|
||||
torch.randn(1, channels, h, w),
|
||||
torch.randn(1, channels, h, w),
|
||||
),
|
||||
input_names=["candidate", "reference"],
|
||||
output_names=["output"],
|
||||
)
|
||||
else:
|
||||
export_model(model, model_name, input_shape, metadata=metadata)
|
||||
|
||||
|
||||
def export_all():
|
||||
@@ -138,6 +243,7 @@ def export_all():
|
||||
"classifier", "normal", "math", "threed_text",
|
||||
"threed_rotate", "threed_slider",
|
||||
"gap_detector", "rotation_regressor",
|
||||
"funcaptcha_rollball_animals",
|
||||
]:
|
||||
_load_and_export(name)
|
||||
print("\n全部导出完成。")
|
||||
|
||||
117
inference/fun_captcha.py
Normal file
117
inference/fun_captcha.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
FunCaptcha 专项 ONNX 推理。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import FUN_CAPTCHA_TASKS, INFERENCE_CONFIG
|
||||
from inference.model_metadata import load_model_metadata
|
||||
from inference.pipeline import _try_import_ort
|
||||
|
||||
|
||||
class FunCaptchaRollballPipeline:
|
||||
"""
|
||||
`4_3d_rollball_animals` 专项推理器。
|
||||
|
||||
输入整张 challenge 图片,内部自动裁切 reference / candidates,
|
||||
再使用 Siamese ONNX 模型逐个候选打分。
|
||||
"""
|
||||
|
||||
def __init__(self, question: str = "4_3d_rollball_animals", models_dir: str | None = None):
|
||||
if question not in FUN_CAPTCHA_TASKS:
|
||||
raise ValueError(f"不支持的 FunCaptcha question: {question}")
|
||||
|
||||
ort = _try_import_ort()
|
||||
self.question = question
|
||||
self.task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||||
self.model_path = self.models_dir / f"{self.task_cfg['artifact_name']}.onnx"
|
||||
if not self.model_path.exists():
|
||||
raise FileNotFoundError(f"未找到 FunCaptcha ONNX 模型: {self.model_path}")
|
||||
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 2
|
||||
self.session = ort.InferenceSession(
|
||||
str(self.model_path),
|
||||
sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.metadata = load_model_metadata(self.model_path) or {}
|
||||
self.mean = float(INFERENCE_CONFIG["normalize_mean"])
|
||||
self.std = float(INFERENCE_CONFIG["normalize_std"])
|
||||
self.answer_index_base = int(
|
||||
self.metadata.get("answer_index_base", self.task_cfg["answer_index_base"])
|
||||
)
|
||||
|
||||
def solve(self, image) -> dict:
|
||||
t0 = time.perf_counter()
|
||||
challenge = self._load_image(image)
|
||||
candidates, reference = self._split_challenge(challenge)
|
||||
|
||||
ref_batch = np.repeat(reference, repeats=candidates.shape[0], axis=0)
|
||||
input_names = [inp.name for inp in self.session.get_inputs()]
|
||||
if len(input_names) != 2:
|
||||
raise RuntimeError(f"专项模型输入数量异常: expected=2 got={len(input_names)}")
|
||||
|
||||
logits = self.session.run(None, {
|
||||
input_names[0]: candidates,
|
||||
input_names[1]: ref_batch,
|
||||
})[0].reshape(-1)
|
||||
scores = 1.0 / (1.0 + np.exp(-logits))
|
||||
answer_idx = int(np.argmax(logits))
|
||||
selected = answer_idx + self.answer_index_base
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
|
||||
return {
|
||||
"type": "funcaptcha",
|
||||
"question": self.question,
|
||||
"objects": [selected],
|
||||
"scores": [round(float(score), 6) for score in scores.tolist()],
|
||||
"raw": str(selected),
|
||||
"result": str(selected),
|
||||
"time_ms": round(elapsed, 2),
|
||||
}
|
||||
|
||||
def _split_challenge(self, image: Image.Image) -> tuple[np.ndarray, np.ndarray]:
|
||||
tile_w, tile_h = self.metadata.get("tile_size", self.task_cfg["tile_size"])
|
||||
ref_box = tuple(self.metadata.get("reference_box", self.task_cfg["reference_box"]))
|
||||
num_candidates = int(self.metadata.get("num_candidates", self.task_cfg["num_candidates"]))
|
||||
input_h, input_w = self.task_cfg["input_size"]
|
||||
|
||||
candidates = []
|
||||
for idx in range(num_candidates):
|
||||
left = idx * tile_w
|
||||
candidate = image.crop((left, 0, left + tile_w, tile_h))
|
||||
candidates.append(self._preprocess(candidate, (input_h, input_w)))
|
||||
|
||||
reference = image.crop(ref_box)
|
||||
return (
|
||||
np.concatenate(candidates, axis=0),
|
||||
self._preprocess(reference, (input_h, input_w)),
|
||||
)
|
||||
|
||||
def _preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||||
img_h, img_w = target_size
|
||||
image = image.convert("RGB").resize((img_w, img_h), Image.BILINEAR)
|
||||
arr = np.asarray(image, dtype=np.float32) / 255.0
|
||||
arr = (arr - self.mean) / self.std
|
||||
arr = np.transpose(arr, (2, 0, 1))
|
||||
return arr.reshape(1, 3, img_h, img_w)
|
||||
|
||||
@staticmethod
|
||||
def _load_image(image) -> Image.Image:
|
||||
if isinstance(image, Image.Image):
|
||||
return image
|
||||
if isinstance(image, (str, Path)):
|
||||
return Image.open(image).convert("RGB")
|
||||
if isinstance(image, bytes):
|
||||
return Image.open(io.BytesIO(image)).convert("RGB")
|
||||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||||
33
inference/model_metadata.py
Normal file
33
inference/model_metadata.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
ONNX 模型 sidecar metadata 辅助工具。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def model_metadata_path(model_path: str | Path) -> Path:
|
||||
return Path(model_path).with_suffix(".meta.json")
|
||||
|
||||
|
||||
def write_model_metadata(model_path: str | Path, metadata: dict) -> Path:
|
||||
path = model_metadata_path(model_path)
|
||||
payload = {
|
||||
"version": 1,
|
||||
**metadata,
|
||||
}
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=True, indent=2, sort_keys=True)
|
||||
f.write("\n")
|
||||
return path
|
||||
|
||||
|
||||
def load_model_metadata(model_path: str | Path) -> dict | None:
|
||||
path = model_metadata_path(model_path)
|
||||
if not path.exists():
|
||||
return None
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
@@ -26,6 +26,7 @@ from config import (
|
||||
REGRESSION_RANGE,
|
||||
)
|
||||
from inference.math_eval import eval_captcha_math
|
||||
from inference.model_metadata import load_model_metadata
|
||||
|
||||
|
||||
def _try_import_ort():
|
||||
@@ -66,6 +67,11 @@ class CaptchaPipeline:
|
||||
"math": MATH_CHARS,
|
||||
"3d_text": THREED_CHARS,
|
||||
}
|
||||
self._classifier_class_names = tuple(CAPTCHA_TYPES)
|
||||
self._regression_ranges = {
|
||||
"3d_rotate": REGRESSION_RANGE["3d_rotate"],
|
||||
"3d_slider": REGRESSION_RANGE["3d_slider"],
|
||||
}
|
||||
|
||||
# 回归模型类型
|
||||
self._regression_types = {"3d_rotate", "3d_slider"}
|
||||
@@ -86,6 +92,7 @@ class CaptchaPipeline:
|
||||
opts.intra_op_num_threads = 2
|
||||
|
||||
self._sessions: dict[str, "ort.InferenceSession"] = {}
|
||||
self._metadata: dict[str, dict] = {}
|
||||
for name, fname in self._model_files.items():
|
||||
path = self.models_dir / fname
|
||||
if path.exists():
|
||||
@@ -93,6 +100,7 @@ class CaptchaPipeline:
|
||||
str(path), sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self._metadata[name] = load_model_metadata(path) or {}
|
||||
|
||||
loaded = list(self._sessions.keys())
|
||||
if not loaded:
|
||||
@@ -135,7 +143,14 @@ class CaptchaPipeline:
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
|
||||
idx = int(np.argmax(logits, axis=1)[0])
|
||||
return CAPTCHA_TYPES[idx]
|
||||
class_names = tuple(
|
||||
self._metadata.get("classifier", {}).get("class_names", self._classifier_class_names)
|
||||
)
|
||||
if idx >= len(class_names):
|
||||
raise RuntimeError(
|
||||
f"分类器输出索引越界: idx={idx}, classes={len(class_names)}"
|
||||
)
|
||||
return class_names[idx]
|
||||
|
||||
def solve(
|
||||
self,
|
||||
@@ -182,14 +197,17 @@ class CaptchaPipeline:
|
||||
# 回归模型: 输出 (batch, 1) sigmoid 值
|
||||
output = session.run(None, {input_name: inp})[0] # (1, 1)
|
||||
sigmoid_val = float(output[0, 0])
|
||||
lo, hi = REGRESSION_RANGE[captcha_type]
|
||||
lo, hi = self._metadata.get(captcha_type, {}).get(
|
||||
"label_range",
|
||||
self._regression_ranges[captcha_type],
|
||||
)
|
||||
real_val = sigmoid_val * (hi - lo) + lo
|
||||
raw_text = f"{real_val:.1f}"
|
||||
result = str(int(round(real_val)))
|
||||
else:
|
||||
# CTC 模型
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
chars = self._chars[captcha_type]
|
||||
chars = self._metadata.get(captcha_type, {}).get("chars", self._chars[captcha_type])
|
||||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||||
|
||||
# 后处理
|
||||
|
||||
Reference in New Issue
Block a user