254 lines
8.6 KiB
Python
254 lines
8.6 KiB
Python
"""
|
|
ONNX 导出脚本
|
|
|
|
从 checkpoints/ 加载训练好的 PyTorch 模型,导出为 ONNX 格式到 onnx_models/。
|
|
支持逐个导出或一次导出全部。
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from config import (
|
|
CAPTCHA_TYPES,
|
|
CHECKPOINTS_DIR,
|
|
FUN_CAPTCHA_TASKS,
|
|
ONNX_DIR,
|
|
ONNX_CONFIG,
|
|
IMAGE_SIZE,
|
|
NORMAL_CHARS,
|
|
MATH_CHARS,
|
|
THREED_CHARS,
|
|
NUM_CAPTCHA_TYPES,
|
|
REGRESSION_RANGE,
|
|
SOLVER_CONFIG,
|
|
SOLVER_REGRESSION_RANGE,
|
|
)
|
|
from inference.model_metadata import write_model_metadata
|
|
from models.classifier import CaptchaClassifier
|
|
from models.lite_crnn import LiteCRNN
|
|
from models.threed_cnn import ThreeDCNN
|
|
from models.regression_cnn import RegressionCNN
|
|
from models.gap_detector import GapDetectorCNN
|
|
from models.rotation_regressor import RotationRegressor
|
|
from models.fun_captcha_siamese import FunCaptchaSiamese
|
|
|
|
|
|
def export_model(
|
|
model: nn.Module,
|
|
model_name: str,
|
|
input_shape: tuple | None = None,
|
|
onnx_dir: str | None = None,
|
|
metadata: dict | None = None,
|
|
dummy_inputs: tuple[torch.Tensor, ...] | None = None,
|
|
input_names: list[str] | None = None,
|
|
output_names: list[str] | None = None,
|
|
dynamic_axes: dict | None = None,
|
|
):
|
|
"""
|
|
导出单个模型为 ONNX。
|
|
|
|
Args:
|
|
model: 已加载权重的 PyTorch 模型
|
|
model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider)
|
|
input_shape: 输入形状 (C, H, W)
|
|
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
|
|
"""
|
|
from pathlib import Path
|
|
|
|
out_dir = Path(onnx_dir) if onnx_dir else ONNX_DIR
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
onnx_path = out_dir / f"{model_name}.onnx"
|
|
|
|
model.eval()
|
|
model.cpu()
|
|
|
|
if dummy_inputs is None:
|
|
if input_shape is None:
|
|
raise ValueError("input_shape 和 dummy_inputs 不能同时为空")
|
|
dummy_inputs = (torch.randn(1, *input_shape),)
|
|
if input_names is None:
|
|
input_names = ["input"] if len(dummy_inputs) == 1 else [f"input_{i}" for i in range(len(dummy_inputs))]
|
|
if output_names is None:
|
|
output_names = ["output"]
|
|
|
|
if dynamic_axes is None:
|
|
if len(dummy_inputs) > 1:
|
|
dynamic_axes = {name: {0: "batch"} for name in input_names}
|
|
dynamic_axes.update({name: {0: "batch"} for name in output_names})
|
|
elif model_name == "classifier" or model_name in (
|
|
"threed_rotate", "threed_slider", "gap_detector", "rotation_regressor",
|
|
"funcaptcha_rollball_animals",
|
|
):
|
|
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
|
else:
|
|
# CTC 模型: output shape = (T, B, C)
|
|
dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}}
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
dummy_inputs[0] if len(dummy_inputs) == 1 else dummy_inputs,
|
|
str(onnx_path),
|
|
opset_version=ONNX_CONFIG["opset_version"],
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes if ONNX_CONFIG["dynamic_batch"] else None,
|
|
)
|
|
|
|
if metadata is not None:
|
|
write_model_metadata(onnx_path, metadata)
|
|
|
|
size_kb = onnx_path.stat().st_size / 1024
|
|
print(f"[ONNX] 导出完成: {onnx_path} ({size_kb:.1f} KB)")
|
|
|
|
|
|
def _load_and_export(model_name: str):
|
|
"""从 checkpoint 加载模型并导出 ONNX。"""
|
|
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
|
if not ckpt_path.exists():
|
|
print(f"[跳过] {model_name}: checkpoint 不存在 ({ckpt_path})")
|
|
return
|
|
|
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
|
acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
|
|
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
|
|
|
|
metadata = None
|
|
|
|
if model_name == "classifier":
|
|
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
|
|
h, w = IMAGE_SIZE["classifier"]
|
|
input_shape = (1, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "classifier",
|
|
"class_names": list(ckpt.get("class_names", CAPTCHA_TYPES)),
|
|
"input_shape": [1, h, w],
|
|
}
|
|
elif model_name == "normal":
|
|
chars = ckpt.get("chars", NORMAL_CHARS)
|
|
h, w = IMAGE_SIZE["normal"]
|
|
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
|
input_shape = (1, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "ctc",
|
|
"chars": chars,
|
|
"input_shape": [1, h, w],
|
|
}
|
|
elif model_name == "math":
|
|
chars = ckpt.get("chars", MATH_CHARS)
|
|
h, w = IMAGE_SIZE["math"]
|
|
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
|
input_shape = (1, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "ctc",
|
|
"chars": chars,
|
|
"input_shape": [1, h, w],
|
|
}
|
|
elif model_name == "threed_text":
|
|
chars = ckpt.get("chars", THREED_CHARS)
|
|
h, w = IMAGE_SIZE["3d_text"]
|
|
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
|
|
input_shape = (1, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "ctc",
|
|
"chars": chars,
|
|
"input_shape": [1, h, w],
|
|
}
|
|
elif model_name == "threed_rotate":
|
|
h, w = IMAGE_SIZE["3d_rotate"]
|
|
model = RegressionCNN(img_h=h, img_w=w)
|
|
input_shape = (1, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "regression",
|
|
"label_range": list(ckpt.get("label_range", REGRESSION_RANGE["3d_rotate"])),
|
|
"input_shape": [1, h, w],
|
|
}
|
|
elif model_name == "threed_slider":
|
|
h, w = IMAGE_SIZE["3d_slider"]
|
|
model = RegressionCNN(img_h=h, img_w=w)
|
|
input_shape = (1, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "regression",
|
|
"label_range": list(ckpt.get("label_range", REGRESSION_RANGE["3d_slider"])),
|
|
"input_shape": [1, h, w],
|
|
}
|
|
elif model_name == "gap_detector":
|
|
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
|
model = GapDetectorCNN(img_h=h, img_w=w)
|
|
input_shape = (1, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "regression",
|
|
"label_range": list(ckpt.get("label_range", SOLVER_REGRESSION_RANGE["slide"])),
|
|
"input_shape": [1, h, w],
|
|
}
|
|
elif model_name == "rotation_regressor":
|
|
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
|
model = RotationRegressor(img_h=h, img_w=w)
|
|
input_shape = (3, h, w)
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "rotation_solver",
|
|
"output_encoding": "sin_cos",
|
|
"input_shape": [3, h, w],
|
|
}
|
|
elif model_name == "funcaptcha_rollball_animals":
|
|
question = "4_3d_rollball_animals"
|
|
task_cfg = FUN_CAPTCHA_TASKS[question]
|
|
h, w = task_cfg["input_size"]
|
|
model = FunCaptchaSiamese(in_channels=task_cfg["channels"])
|
|
metadata = {
|
|
"model_name": model_name,
|
|
"task": "funcaptcha_siamese",
|
|
"question": question,
|
|
"num_candidates": int(ckpt.get("num_candidates", task_cfg["num_candidates"])),
|
|
"tile_size": list(ckpt.get("tile_size", task_cfg["tile_size"])),
|
|
"reference_box": list(ckpt.get("reference_box", task_cfg["reference_box"])),
|
|
"answer_index_base": int(ckpt.get("answer_index_base", task_cfg["answer_index_base"])),
|
|
"input_shape": list(ckpt.get("input_shape", [task_cfg["channels"], h, w])),
|
|
}
|
|
else:
|
|
print(f"[错误] 未知模型: {model_name}")
|
|
return
|
|
|
|
model.load_state_dict(ckpt["model_state_dict"])
|
|
if model_name == "funcaptcha_rollball_animals":
|
|
channels, h, w = metadata["input_shape"]
|
|
export_model(
|
|
model,
|
|
model_name,
|
|
metadata=metadata,
|
|
dummy_inputs=(
|
|
torch.randn(1, channels, h, w),
|
|
torch.randn(1, channels, h, w),
|
|
),
|
|
input_names=["candidate", "reference"],
|
|
output_names=["output"],
|
|
)
|
|
else:
|
|
export_model(model, model_name, input_shape, metadata=metadata)
|
|
|
|
|
|
def export_all():
|
|
"""依次导出全部模型 (含 solver 模型)。"""
|
|
print("=" * 50)
|
|
print("导出全部 ONNX 模型")
|
|
print("=" * 50)
|
|
for name in [
|
|
"classifier", "normal", "math", "threed_text",
|
|
"threed_rotate", "threed_slider",
|
|
"gap_detector", "rotation_regressor",
|
|
"funcaptcha_rollball_animals",
|
|
]:
|
|
_load_and_export(name)
|
|
print("\n全部导出完成。")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
export_all()
|