""" 训练滑块缺口检测 CNN (GapDetectorCNN) 复用 train_regression_utils 的通用回归训练流程。 用法: python -m training.train_slide """ from config import ( SOLVER_CONFIG, SOLVER_TRAIN_CONFIG, SOLVER_REGRESSION_RANGE, SLIDE_DATA_DIR, CHECKPOINTS_DIR, ONNX_DIR, ONNX_CONFIG, RANDOM_SEED, get_device, ) from generators.slide_gen import SlideDataGenerator from models.gap_detector import GapDetectorCNN # 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE # 以便复用 train_regression_utils import config as _cfg def main(): solver_cfg = SOLVER_CONFIG["slide"] train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"] img_h, img_w = solver_cfg["cnn_input_size"] model = GapDetectorCNN(img_h=img_h, img_w=img_w) print("=" * 60) print("训练滑块缺口检测 CNN (GapDetectorCNN)") print(f" 输入尺寸: {img_h}×{img_w}") print(f" 任务: 预测缺口 x 坐标百分比") print("=" * 60) # 直接使用 train_regression_utils 中的逻辑 # 但需要临时注入配置 _cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg _cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w) _cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"] from training.train_regression_utils import train_regression_model # 确保数据目录存在 SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True) real_dir = SLIDE_DATA_DIR / "real" real_dir.mkdir(parents=True, exist_ok=True) train_regression_model( model_name="gap_detector", model=model, synthetic_dir=str(SLIDE_DATA_DIR), real_dir=str(real_dir), generator_cls=SlideDataGenerator, config_key="slide_cnn", ) if __name__ == "__main__": main()