Files
CaptchBreaker/training/train_regression_utils.py

336 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回归训练通用逻辑
提供 train_regression_model() 函数,被 train_3d_rotate / train_3d_slider 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 RegressionDataset / DataLoader含真实数据混合
3. 回归训练循环 + cosine scheduler
4. 输出日志: epoch, loss, MAE, tolerance 准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束导出 ONNX
"""
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
GENERATE_CONFIG,
REGRESSION_RANGE,
SOLVER_CONFIG,
SOLVER_REGRESSION_RANGE,
RANDOM_SEED,
get_device,
)
from inference.model_metadata import write_model_metadata
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _circular_smooth_l1(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
循环距离 SmoothL1 loss用于角度回归 (处理 0°/360° 边界)。
pred 和 target 都在 [0, 1] 范围。
"""
diff = torch.abs(pred - target)
# 循环距离: min(|d|, 1-|d|)
diff = torch.min(diff, 1.0 - diff)
# SmoothL1
return torch.where(
diff < 1.0 / 360.0, # beta ≈ 1° 归一化
0.5 * diff * diff / (1.0 / 360.0),
diff - 0.5 * (1.0 / 360.0),
).mean()
def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
"""循环 MAE (归一化空间)。"""
diff = np.abs(pred - target)
diff = np.minimum(diff, 1.0 - diff)
return float(np.mean(diff))
def _export_onnx(
model: nn.Module,
model_name: str,
img_h: int,
img_w: int,
*,
label_range: tuple[int, int] | tuple[float, float],
):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": model_name,
"task": "regression",
"label_range": list(label_range),
"input_shape": [1, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
def train_regression_model(
model_name: str,
model: nn.Module,
synthetic_dir: str | Path,
real_dir: str | Path,
generator_cls,
config_key: str,
):
"""
通用回归训练流程。
Args:
model_name: 模型名称 (用于保存文件: threed_rotate / threed_slider)
model: PyTorch 模型实例 (RegressionCNN)
synthetic_dir: 合成数据目录
real_dir: 真实数据目录
generator_cls: 生成器类 (用于自动生成数据)
config_key: TRAIN_CONFIG / REGRESSION_RANGE 中的键名 (3d_rotate / 3d_slider)
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key]
label_range = REGRESSION_RANGE[config_key]
lo, hi = label_range
is_circular = config_key == "3d_rotate"
device = get_device()
# 容差配置
if config_key == "3d_rotate":
tolerance = 5.0 # ±5°
else:
tolerance = 3.0 # ±3px
_set_seed()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
config_snapshot = {
"image_size": IMAGE_SIZE[config_key],
"label_range": label_range,
}
if config_key in GENERATE_CONFIG:
config_snapshot["generate_config"] = GENERATE_CONFIG[config_key]
elif config_key == "slide_cnn":
config_snapshot["solver_config"] = SOLVER_CONFIG["slide"]
config_snapshot["solver_regression_range"] = SOLVER_REGRESSION_RANGE["slide"]
else:
config_snapshot["train_config"] = cfg
dataset_spec = build_dataset_spec(
generator_cls,
config_key=config_key,
config_snapshot=config_snapshot,
)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=generator_cls,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_path = Path(real_dir)
if real_path.exists() and list(real_path.glob("*.png")):
data_dirs.append(str(real_path))
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))}")
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = RegressionDataset(
dirs=data_dirs, label_range=label_range, transform=train_transform,
)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集使用无增强 transform
val_ds_clean = RegressionDataset(
dirs=data_dirs, label_range=label_range, transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 优化器 / 调度器 / 损失 ----
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
if is_circular:
loss_fn = _circular_smooth_l1
else:
loss_fn = nn.SmoothL1Loss()
best_mae = float("inf")
best_tol_acc = 0.0
start_epoch = 1
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
if dataset_state["refreshed"]:
print("[续训] 合成数据已刷新,忽略旧 checkpoint从 epoch 1 重新训练")
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
else:
if ckpt_data_spec_hash is None:
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
model.load_state_dict(ckpt["model_state_dict"])
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
best_mae = ckpt.get("best_mae", float("inf"))
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
)
# ---- 4. 训练循环 ----
for epoch in range(start_epoch, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets in pbar:
images = images.to(device)
targets = targets.to(device)
preds = model(images)
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for images, targets in val_loader:
images = images.to(device)
preds = model(images)
all_preds.append(preds.cpu().numpy())
all_targets.append(targets.numpy())
all_preds = np.concatenate(all_preds, axis=0).flatten()
all_targets = np.concatenate(all_targets, axis=0).flatten()
# 缩放回原始范围计算 MAE
preds_real = all_preds * (hi - lo) + lo
targets_real = all_targets * (hi - lo) + lo
if is_circular:
# 循环 MAE
diff = np.abs(preds_real - targets_real)
diff = np.minimum(diff, (hi - lo) - diff)
mae = float(np.mean(diff))
within_tol = diff <= tolerance
else:
mae = float(np.mean(np.abs(preds_real - targets_real)))
within_tol = np.abs(preds_real - targets_real) <= tolerance
tol_acc = float(np.mean(within_tol))
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"MAE={mae:.2f} "
f"tol_acc(±{tolerance:.0f})={tol_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 (以容差准确率为准) ----
if tol_acc >= best_tol_acc:
best_tol_acc = tol_acc
best_mae = mae
torch.save({
"model_state_dict": model.state_dict(),
"label_range": label_range,
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w, label_range=label_range)
return best_tol_acc