Files
CaptchBreaker/inference/export_onnx.py
Hua 9b5f29083e Add slide and rotate interactive captcha solvers
New solver subsystem with independent models:
- GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection
- RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction
- SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback
- RotateSolver with ONNX sin/cos → atan2 inference
- Generators, training scripts, CLI commands, and slide track utility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 18:07:06 +08:00

148 lines
4.6 KiB
Python

"""
ONNX 导出脚本
从 checkpoints/ 加载训练好的 PyTorch 模型,导出为 ONNX 格式到 onnx_models/。
支持逐个导出或一次导出全部。
"""
import torch
import torch.nn as nn
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
IMAGE_SIZE,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
NUM_CAPTCHA_TYPES,
REGRESSION_RANGE,
SOLVER_CONFIG,
)
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
from models.gap_detector import GapDetectorCNN
from models.rotation_regressor import RotationRegressor
def export_model(
model: nn.Module,
model_name: str,
input_shape: tuple,
onnx_dir: str | None = None,
):
"""
导出单个模型为 ONNX。
Args:
model: 已加载权重的 PyTorch 模型
model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider)
input_shape: 输入形状 (C, H, W)
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
"""
from pathlib import Path
out_dir = Path(onnx_dir) if onnx_dir else ONNX_DIR
out_dir.mkdir(parents=True, exist_ok=True)
onnx_path = out_dir / f"{model_name}.onnx"
model.eval()
model.cpu()
dummy = torch.randn(1, *input_shape)
# 分类器和识别器的 dynamic_axes 不同
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider", "gap_detector", "rotation_regressor"):
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
else:
# CTC 模型: output shape = (T, B, C)
dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}}
torch.onnx.export(
model,
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes if ONNX_CONFIG["dynamic_batch"] else None,
)
size_kb = onnx_path.stat().st_size / 1024
print(f"[ONNX] 导出完成: {onnx_path} ({size_kb:.1f} KB)")
def _load_and_export(model_name: str):
"""从 checkpoint 加载模型并导出 ONNX。"""
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
if not ckpt_path.exists():
print(f"[跳过] {model_name}: checkpoint 不存在 ({ckpt_path})")
return
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
if model_name == "classifier":
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
h, w = IMAGE_SIZE["classifier"]
input_shape = (1, h, w)
elif model_name == "normal":
chars = ckpt.get("chars", NORMAL_CHARS)
h, w = IMAGE_SIZE["normal"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "math":
chars = ckpt.get("chars", MATH_CHARS)
h, w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed_text":
chars = ckpt.get("chars", THREED_CHARS)
h, w = IMAGE_SIZE["3d_text"]
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed_rotate":
h, w = IMAGE_SIZE["3d_rotate"]
model = RegressionCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed_slider":
h, w = IMAGE_SIZE["3d_slider"]
model = RegressionCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "gap_detector":
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
model = GapDetectorCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "rotation_regressor":
h, w = SOLVER_CONFIG["rotate"]["input_size"]
model = RotationRegressor(img_h=h, img_w=w)
input_shape = (3, h, w)
else:
print(f"[错误] 未知模型: {model_name}")
return
model.load_state_dict(ckpt["model_state_dict"])
export_model(model, model_name, input_shape)
def export_all():
"""依次导出全部模型 (含 solver 模型)。"""
print("=" * 50)
print("导出全部 ONNX 模型")
print("=" * 50)
for name in [
"classifier", "normal", "math", "threed_text",
"threed_rotate", "threed_slider",
"gap_detector", "rotation_regressor",
]:
_load_and_export(name)
print("\n全部导出完成。")
if __name__ == "__main__":
export_all()