""" 训练旋转验证码角度回归模型 (RotationRegressor) 自定义训练循环 (sin/cos 编码),不复用 train_regression_utils。 用法: python -m training.train_rotate_solver """ import math import random from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from torchvision import transforms from tqdm import tqdm from config import ( SOLVER_CONFIG, SOLVER_TRAIN_CONFIG, ROTATE_SOLVER_DATA_DIR, CHECKPOINTS_DIR, ONNX_DIR, ONNX_CONFIG, AUGMENT_CONFIG, RANDOM_SEED, get_device, ) from generators.rotate_solver_gen import RotateSolverDataGenerator from models.rotation_regressor import RotationRegressor from training.dataset import RotateSolverDataset def _set_seed(seed: int = RANDOM_SEED): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _circular_mae_deg(pred_angles: np.ndarray, gt_angles: np.ndarray) -> float: """循环 MAE (度数)。""" diff = np.abs(pred_angles - gt_angles) diff = np.minimum(diff, 360.0 - diff) return float(np.mean(diff)) def _build_train_transform(img_h: int, img_w: int) -> transforms.Compose: """RGB 训练增强 (不转灰度)。""" aug = AUGMENT_CONFIG return transforms.Compose([ transforms.Resize((img_h, img_w)), transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]), transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) def _build_val_transform(img_h: int, img_w: int) -> transforms.Compose: """RGB 验证 transform。""" return transforms.Compose([ transforms.Resize((img_h, img_w)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) def _export_onnx(model: nn.Module, img_h: int, img_w: int): """导出 ONNX (RGB 3通道输入)。""" model.eval() onnx_path = ONNX_DIR / "rotation_regressor.onnx" dummy = torch.randn(1, 3, img_h, img_w) torch.onnx.export( model.cpu(), dummy, str(onnx_path), opset_version=ONNX_CONFIG["opset_version"], input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}} if ONNX_CONFIG["dynamic_batch"] else None, ) print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)") def main(): cfg = SOLVER_TRAIN_CONFIG["rotate"] solver_cfg = SOLVER_CONFIG["rotate"] img_h, img_w = solver_cfg["input_size"] tolerance = 5.0 # ±5° _set_seed() device = get_device() print("=" * 60) print("训练旋转验证码角度回归模型 (RotationRegressor)") print(f" 输入尺寸: {img_h}×{img_w} RGB") print(f" 编码: sin/cos") print(f" 容差: ±{tolerance}°") print("=" * 60) # ---- 1. 检查 / 生成合成数据 ---- syn_path = ROTATE_SOLVER_DATA_DIR syn_path.mkdir(parents=True, exist_ok=True) existing = list(syn_path.glob("*.png")) if len(existing) < cfg["synthetic_samples"]: print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...") gen = RotateSolverDataGenerator() gen.generate_dataset(cfg["synthetic_samples"], str(syn_path)) else: print(f"[数据] 合成数据已就绪: {len(existing)} 张") # ---- 2. 构建数据集 ---- data_dirs = [str(syn_path)] real_dir = syn_path / "real" real_dir.mkdir(parents=True, exist_ok=True) if list(real_dir.glob("*.png")): data_dirs.append(str(real_dir)) print(f"[数据] 混合真实数据: {len(list(real_dir.glob('*.png')))} 张") train_transform = _build_train_transform(img_h, img_w) val_transform = _build_val_transform(img_h, img_w) full_dataset = RotateSolverDataset(dirs=data_dirs, transform=train_transform) total = len(full_dataset) val_size = int(total * cfg["val_split"]) train_size = total - val_size train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) val_ds_clean = RotateSolverDataset(dirs=data_dirs, transform=val_transform) val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices] train_loader = DataLoader( train_ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, pin_memory=True, ) val_loader = DataLoader( val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, num_workers=0, pin_memory=True, ) print(f"[数据] 训练: {train_size} 验证: {val_size}") # ---- 3. 模型 / 优化器 / 调度器 ---- model = RotationRegressor(img_h=img_h, img_w=img_w).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"]) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"]) loss_fn = nn.MSELoss() best_mae = float("inf") best_tol_acc = 0.0 ckpt_path = CHECKPOINTS_DIR / "rotation_regressor.pth" # ---- 4. 训练循环 ---- for epoch in range(1, cfg["epochs"] + 1): model.train() total_loss = 0.0 num_batches = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) for images, targets in pbar: images = images.to(device) targets = targets.to(device) # (B, 2) → (sin, cos) preds = model(images) # (B, 2) loss = loss_fn(preds, targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() total_loss += loss.item() num_batches += 1 pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() avg_loss = total_loss / max(num_batches, 1) # ---- 5. 验证 ---- model.eval() all_pred_angles = [] all_gt_angles = [] with torch.no_grad(): for images, targets in val_loader: images = images.to(device) preds = model(images).cpu().numpy() # (B, 2) targets_np = targets.numpy() # (B, 2) # sin/cos → angle for i in range(len(preds)): pred_angle = math.degrees(math.atan2(preds[i][0], preds[i][1])) if pred_angle < 0: pred_angle += 360.0 gt_angle = math.degrees(math.atan2(targets_np[i][0], targets_np[i][1])) if gt_angle < 0: gt_angle += 360.0 all_pred_angles.append(pred_angle) all_gt_angles.append(gt_angle) pred_arr = np.array(all_pred_angles) gt_arr = np.array(all_gt_angles) mae = _circular_mae_deg(pred_arr, gt_arr) diff = np.abs(pred_arr - gt_arr) diff = np.minimum(diff, 360.0 - diff) tol_acc = float(np.mean(diff <= tolerance)) lr = scheduler.get_last_lr()[0] print( f"Epoch {epoch:3d}/{cfg['epochs']} " f"loss={avg_loss:.4f} " f"MAE={mae:.2f}° " f"tol_acc(±{tolerance:.0f}°)={tol_acc:.4f} " f"lr={lr:.6f}" ) # ---- 6. 保存最佳模型 ---- if tol_acc >= best_tol_acc: best_tol_acc = tol_acc best_mae = mae torch.save({ "model_state_dict": model.state_dict(), "best_mae": best_mae, "best_tol_acc": best_tol_acc, "epoch": epoch, }, ckpt_path) print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}") # ---- 7. 导出 ONNX ---- print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}°") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state_dict"]) _export_onnx(model, img_h, img_w) return best_tol_acc if __name__ == "__main__": main()