""" ONNX 导出脚本 从 checkpoints/ 加载训练好的 PyTorch 模型,导出为 ONNX 格式到 onnx_models/。 支持逐个导出或一次导出全部。 """ import torch import torch.nn as nn from config import ( CAPTCHA_TYPES, CHECKPOINTS_DIR, FUN_CAPTCHA_TASKS, ONNX_DIR, ONNX_CONFIG, IMAGE_SIZE, NORMAL_CHARS, MATH_CHARS, THREED_CHARS, NUM_CAPTCHA_TYPES, REGRESSION_RANGE, SOLVER_CONFIG, SOLVER_REGRESSION_RANGE, ) from inference.model_metadata import write_model_metadata from models.classifier import CaptchaClassifier from models.lite_crnn import LiteCRNN from models.threed_cnn import ThreeDCNN from models.regression_cnn import RegressionCNN from models.gap_detector import GapDetectorCNN from models.rotation_regressor import RotationRegressor from models.fun_captcha_siamese import FunCaptchaSiamese def export_model( model: nn.Module, model_name: str, input_shape: tuple | None = None, onnx_dir: str | None = None, metadata: dict | None = None, dummy_inputs: tuple[torch.Tensor, ...] | None = None, input_names: list[str] | None = None, output_names: list[str] | None = None, dynamic_axes: dict | None = None, ): """ 导出单个模型为 ONNX。 Args: model: 已加载权重的 PyTorch 模型 model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider) input_shape: 输入形状 (C, H, W) onnx_dir: 输出目录 (默认使用 config.ONNX_DIR) """ from pathlib import Path out_dir = Path(onnx_dir) if onnx_dir else ONNX_DIR out_dir.mkdir(parents=True, exist_ok=True) onnx_path = out_dir / f"{model_name}.onnx" model.eval() model.cpu() if dummy_inputs is None: if input_shape is None: raise ValueError("input_shape 和 dummy_inputs 不能同时为空") dummy_inputs = (torch.randn(1, *input_shape),) if input_names is None: input_names = ["input"] if len(dummy_inputs) == 1 else [f"input_{i}" for i in range(len(dummy_inputs))] if output_names is None: output_names = ["output"] if dynamic_axes is None: if len(dummy_inputs) > 1: dynamic_axes = {name: {0: "batch"} for name in input_names} dynamic_axes.update({name: {0: "batch"} for name in output_names}) elif model_name == "classifier" or model_name in ( "threed_rotate", "threed_slider", "gap_detector", "rotation_regressor", "funcaptcha_rollball_animals", ): dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} else: # CTC 模型: output shape = (T, B, C) dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}} torch.onnx.export( model, dummy_inputs[0] if len(dummy_inputs) == 1 else dummy_inputs, str(onnx_path), opset_version=ONNX_CONFIG["opset_version"], input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes if ONNX_CONFIG["dynamic_batch"] else None, ) if metadata is not None: write_model_metadata(onnx_path, metadata) size_kb = onnx_path.stat().st_size / 1024 print(f"[ONNX] 导出完成: {onnx_path} ({size_kb:.1f} KB)") def _load_and_export(model_name: str): """从 checkpoint 加载模型并导出 ONNX。""" ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth" if not ckpt_path.exists(): print(f"[跳过] {model_name}: checkpoint 不存在 ({ckpt_path})") return ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?') print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}") metadata = None if model_name == "classifier": model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES) h, w = IMAGE_SIZE["classifier"] input_shape = (1, h, w) metadata = { "model_name": model_name, "task": "classifier", "class_names": list(ckpt.get("class_names", CAPTCHA_TYPES)), "input_shape": [1, h, w], } elif model_name == "normal": chars = ckpt.get("chars", NORMAL_CHARS) h, w = IMAGE_SIZE["normal"] model = LiteCRNN(chars=chars, img_h=h, img_w=w) input_shape = (1, h, w) metadata = { "model_name": model_name, "task": "ctc", "chars": chars, "input_shape": [1, h, w], } elif model_name == "math": chars = ckpt.get("chars", MATH_CHARS) h, w = IMAGE_SIZE["math"] model = LiteCRNN(chars=chars, img_h=h, img_w=w) input_shape = (1, h, w) metadata = { "model_name": model_name, "task": "ctc", "chars": chars, "input_shape": [1, h, w], } elif model_name == "threed_text": chars = ckpt.get("chars", THREED_CHARS) h, w = IMAGE_SIZE["3d_text"] model = ThreeDCNN(chars=chars, img_h=h, img_w=w) input_shape = (1, h, w) metadata = { "model_name": model_name, "task": "ctc", "chars": chars, "input_shape": [1, h, w], } elif model_name == "threed_rotate": h, w = IMAGE_SIZE["3d_rotate"] model = RegressionCNN(img_h=h, img_w=w) input_shape = (1, h, w) metadata = { "model_name": model_name, "task": "regression", "label_range": list(ckpt.get("label_range", REGRESSION_RANGE["3d_rotate"])), "input_shape": [1, h, w], } elif model_name == "threed_slider": h, w = IMAGE_SIZE["3d_slider"] model = RegressionCNN(img_h=h, img_w=w) input_shape = (1, h, w) metadata = { "model_name": model_name, "task": "regression", "label_range": list(ckpt.get("label_range", REGRESSION_RANGE["3d_slider"])), "input_shape": [1, h, w], } elif model_name == "gap_detector": h, w = SOLVER_CONFIG["slide"]["cnn_input_size"] model = GapDetectorCNN(img_h=h, img_w=w) input_shape = (1, h, w) metadata = { "model_name": model_name, "task": "regression", "label_range": list(ckpt.get("label_range", SOLVER_REGRESSION_RANGE["slide"])), "input_shape": [1, h, w], } elif model_name == "rotation_regressor": h, w = SOLVER_CONFIG["rotate"]["input_size"] model = RotationRegressor(img_h=h, img_w=w) input_shape = (3, h, w) metadata = { "model_name": model_name, "task": "rotation_solver", "output_encoding": "sin_cos", "input_shape": [3, h, w], } elif model_name == "funcaptcha_rollball_animals": question = "4_3d_rollball_animals" task_cfg = FUN_CAPTCHA_TASKS[question] h, w = task_cfg["input_size"] model = FunCaptchaSiamese(in_channels=task_cfg["channels"]) metadata = { "model_name": model_name, "task": "funcaptcha_siamese", "question": question, "num_candidates": int(ckpt.get("num_candidates", task_cfg["num_candidates"])), "tile_size": list(ckpt.get("tile_size", task_cfg["tile_size"])), "reference_box": list(ckpt.get("reference_box", task_cfg["reference_box"])), "answer_index_base": int(ckpt.get("answer_index_base", task_cfg["answer_index_base"])), "input_shape": list(ckpt.get("input_shape", [task_cfg["channels"], h, w])), } else: print(f"[错误] 未知模型: {model_name}") return model.load_state_dict(ckpt["model_state_dict"]) if model_name == "funcaptcha_rollball_animals": channels, h, w = metadata["input_shape"] export_model( model, model_name, metadata=metadata, dummy_inputs=( torch.randn(1, channels, h, w), torch.randn(1, channels, h, w), ), input_names=["candidate", "reference"], output_names=["output"], ) else: export_model(model, model_name, input_shape, metadata=metadata) def export_all(): """依次导出全部模型 (含 solver 模型)。""" print("=" * 50) print("导出全部 ONNX 模型") print("=" * 50) for name in [ "classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider", "gap_detector", "rotation_regressor", "funcaptcha_rollball_animals", ]: _load_and_export(name) print("\n全部导出完成。") if __name__ == "__main__": export_all()