""" 回归训练通用逻辑 提供 train_regression_model() 函数,被 train_3d_rotate / train_3d_slider 共用。 职责: 1. 检查合成数据,不存在则自动调用生成器 2. 构建 RegressionDataset / DataLoader(含真实数据混合) 3. 回归训练循环 + cosine scheduler 4. 输出日志: epoch, loss, MAE, tolerance 准确率 5. 保存最佳模型到 checkpoints/ 6. 训练结束导出 ONNX """ import random from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from tqdm import tqdm from config import ( CHECKPOINTS_DIR, ONNX_DIR, ONNX_CONFIG, TRAIN_CONFIG, IMAGE_SIZE, GENERATE_CONFIG, REGRESSION_RANGE, SOLVER_CONFIG, SOLVER_REGRESSION_RANGE, RANDOM_SEED, get_device, ) from inference.model_metadata import write_model_metadata from training.dataset import RegressionDataset, build_train_transform, build_val_transform from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset def _set_seed(seed: int = RANDOM_SEED): """设置全局随机种子,保证训练可复现。""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _circular_smooth_l1(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ 循环距离 SmoothL1 loss,用于角度回归 (处理 0°/360° 边界)。 pred 和 target 都在 [0, 1] 范围。 """ diff = torch.abs(pred - target) # 循环距离: min(|d|, 1-|d|) diff = torch.min(diff, 1.0 - diff) # SmoothL1 return torch.where( diff < 1.0 / 360.0, # beta ≈ 1° 归一化 0.5 * diff * diff / (1.0 / 360.0), diff - 0.5 * (1.0 / 360.0), ).mean() def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float: """循环 MAE (归一化空间)。""" diff = np.abs(pred - target) diff = np.minimum(diff, 1.0 - diff) return float(np.mean(diff)) def _export_onnx( model: nn.Module, model_name: str, img_h: int, img_w: int, *, label_range: tuple[int, int] | tuple[float, float], ): """导出模型为 ONNX 格式。""" model.eval() onnx_path = ONNX_DIR / f"{model_name}.onnx" dummy = torch.randn(1, 1, img_h, img_w) torch.onnx.export( model.cpu(), dummy, str(onnx_path), opset_version=ONNX_CONFIG["opset_version"], input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}} if ONNX_CONFIG["dynamic_batch"] else None, ) write_model_metadata( onnx_path, { "model_name": model_name, "task": "regression", "label_range": list(label_range), "input_shape": [1, img_h, img_w], }, ) print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)") def train_regression_model( model_name: str, model: nn.Module, synthetic_dir: str | Path, real_dir: str | Path, generator_cls, config_key: str, ): """ 通用回归训练流程。 Args: model_name: 模型名称 (用于保存文件: threed_rotate / threed_slider) model: PyTorch 模型实例 (RegressionCNN) synthetic_dir: 合成数据目录 real_dir: 真实数据目录 generator_cls: 生成器类 (用于自动生成数据) config_key: TRAIN_CONFIG / REGRESSION_RANGE 中的键名 (3d_rotate / 3d_slider) """ cfg = TRAIN_CONFIG[config_key] img_h, img_w = IMAGE_SIZE[config_key] label_range = REGRESSION_RANGE[config_key] lo, hi = label_range is_circular = config_key == "3d_rotate" device = get_device() # 容差配置 if config_key == "3d_rotate": tolerance = 5.0 # ±5° else: tolerance = 3.0 # ±3px _set_seed() # ---- 1. 检查 / 生成合成数据 ---- syn_path = Path(synthetic_dir) config_snapshot = { "image_size": IMAGE_SIZE[config_key], "label_range": label_range, } if config_key in GENERATE_CONFIG: config_snapshot["generate_config"] = GENERATE_CONFIG[config_key] elif config_key == "slide_cnn": config_snapshot["solver_config"] = SOLVER_CONFIG["slide"] config_snapshot["solver_regression_range"] = SOLVER_REGRESSION_RANGE["slide"] else: config_snapshot["train_config"] = cfg dataset_spec = build_dataset_spec( generator_cls, config_key=config_key, config_snapshot=config_snapshot, ) dataset_state = ensure_synthetic_dataset( syn_path, generator_cls=generator_cls, spec=dataset_spec, gen_count=cfg["synthetic_samples"], exact_count=cfg["synthetic_samples"], ) if dataset_state["refreshed"]: print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张") elif dataset_state["adopted"]: print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张") else: print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张") current_data_spec_hash = dataset_state["manifest"]["spec_hash"] # ---- 2. 构建数据集 ---- data_dirs = [str(syn_path)] real_path = Path(real_dir) if real_path.exists() and list(real_path.glob("*.png")): data_dirs.append(str(real_path)) print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))} 张") train_transform = build_train_transform(img_h, img_w) val_transform = build_val_transform(img_h, img_w) full_dataset = RegressionDataset( dirs=data_dirs, label_range=label_range, transform=train_transform, ) total = len(full_dataset) val_size = int(total * cfg["val_split"]) train_size = total - val_size train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) # 验证集使用无增强 transform val_ds_clean = RegressionDataset( dirs=data_dirs, label_range=label_range, transform=val_transform, ) val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices] train_loader = DataLoader( train_ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, pin_memory=True, ) val_loader = DataLoader( val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, num_workers=0, pin_memory=True, ) print(f"[数据] 训练: {train_size} 验证: {val_size}") # ---- 3. 优化器 / 调度器 / 损失 ---- model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"]) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"]) if is_circular: loss_fn = _circular_smooth_l1 else: loss_fn = nn.SmoothL1Loss() best_mae = float("inf") best_tol_acc = 0.0 start_epoch = 1 ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth" # ---- 3.5 断点续训 ---- if ckpt_path.exists(): ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash") if dataset_state["refreshed"]: print("[续训] 合成数据已刷新,忽略旧 checkpoint,从 epoch 1 重新训练") elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash: print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练") else: if ckpt_data_spec_hash is None: print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练") model.load_state_dict(ckpt["model_state_dict"]) best_tol_acc = ckpt.get("best_tol_acc", 0.0) best_mae = ckpt.get("best_mae", float("inf")) start_epoch = ckpt.get("epoch", 0) + 1 # 快进 scheduler 到对应 epoch for _ in range(start_epoch - 1): scheduler.step() print( f"[续训] 从 epoch {start_epoch} 继续, " f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}" ) # ---- 4. 训练循环 ---- for epoch in range(start_epoch, cfg["epochs"] + 1): model.train() total_loss = 0.0 num_batches = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) for images, targets in pbar: images = images.to(device) targets = targets.to(device) preds = model(images) loss = loss_fn(preds, targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() total_loss += loss.item() num_batches += 1 pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() avg_loss = total_loss / max(num_batches, 1) # ---- 5. 验证 ---- model.eval() all_preds = [] all_targets = [] with torch.no_grad(): for images, targets in val_loader: images = images.to(device) preds = model(images) all_preds.append(preds.cpu().numpy()) all_targets.append(targets.numpy()) all_preds = np.concatenate(all_preds, axis=0).flatten() all_targets = np.concatenate(all_targets, axis=0).flatten() # 缩放回原始范围计算 MAE preds_real = all_preds * (hi - lo) + lo targets_real = all_targets * (hi - lo) + lo if is_circular: # 循环 MAE diff = np.abs(preds_real - targets_real) diff = np.minimum(diff, (hi - lo) - diff) mae = float(np.mean(diff)) within_tol = diff <= tolerance else: mae = float(np.mean(np.abs(preds_real - targets_real))) within_tol = np.abs(preds_real - targets_real) <= tolerance tol_acc = float(np.mean(within_tol)) lr = scheduler.get_last_lr()[0] print( f"Epoch {epoch:3d}/{cfg['epochs']} " f"loss={avg_loss:.4f} " f"MAE={mae:.2f} " f"tol_acc(±{tolerance:.0f})={tol_acc:.4f} " f"lr={lr:.6f}" ) # ---- 6. 保存最佳模型 (以容差准确率为准) ---- if tol_acc >= best_tol_acc: best_tol_acc = tol_acc best_mae = mae torch.save({ "model_state_dict": model.state_dict(), "label_range": label_range, "best_mae": best_mae, "best_tol_acc": best_tol_acc, "epoch": epoch, "synthetic_data_spec_hash": current_data_spec_hash, }, ckpt_path) print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}") # ---- 7. 导出 ONNX ---- print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state_dict"]) _export_onnx(model, model_name, img_h, img_w, label_range=label_range) return best_tol_acc