Files
CaptchBreaker/training/train_rotate_solver.py

275 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
训练旋转验证码角度回归模型 (RotationRegressor)
自定义训练循环 (sin/cos 编码),不复用 train_regression_utils。
用法: python -m training.train_rotate_solver
"""
import math
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm
from config import (
SOLVER_CONFIG,
SOLVER_TRAIN_CONFIG,
ROTATE_SOLVER_DATA_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
AUGMENT_CONFIG,
RANDOM_SEED,
get_device,
)
from generators.rotate_solver_gen import RotateSolverDataGenerator
from inference.model_metadata import write_model_metadata
from models.rotation_regressor import RotationRegressor
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
from training.dataset import RotateSolverDataset
def _set_seed(seed: int = RANDOM_SEED):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _circular_mae_deg(pred_angles: np.ndarray, gt_angles: np.ndarray) -> float:
"""循环 MAE (度数)。"""
diff = np.abs(pred_angles - gt_angles)
diff = np.minimum(diff, 360.0 - diff)
return float(np.mean(diff))
def _build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 训练增强 (不转灰度)。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def _build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 验证 transform。"""
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def _export_onnx(model: nn.Module, img_h: int, img_w: int):
"""导出 ONNX (RGB 3通道输入)。"""
model.eval()
onnx_path = ONNX_DIR / "rotation_regressor.onnx"
dummy = torch.randn(1, 3, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": "rotation_regressor",
"task": "rotation_solver",
"output_encoding": "sin_cos",
"input_shape": [3, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
def main():
cfg = SOLVER_TRAIN_CONFIG["rotate"]
solver_cfg = SOLVER_CONFIG["rotate"]
img_h, img_w = solver_cfg["input_size"]
tolerance = 5.0 # ±5°
_set_seed()
device = get_device()
print("=" * 60)
print("训练旋转验证码角度回归模型 (RotationRegressor)")
print(f" 输入尺寸: {img_h}×{img_w} RGB")
print(f" 编码: sin/cos")
print(f" 容差: ±{tolerance}°")
print("=" * 60)
# ---- 1. 检查 / 生成合成数据 ----
syn_path = ROTATE_SOLVER_DATA_DIR
syn_path.mkdir(parents=True, exist_ok=True)
dataset_spec = build_dataset_spec(
RotateSolverDataGenerator,
config_key="rotate_solver",
config_snapshot={
"solver_config": SOLVER_CONFIG["rotate"],
"train_config": {
"synthetic_samples": cfg["synthetic_samples"],
},
},
)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=RotateSolverDataGenerator,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_dir = syn_path / "real"
real_dir.mkdir(parents=True, exist_ok=True)
if list(real_dir.glob("*.png")):
data_dirs.append(str(real_dir))
print(f"[数据] 混合真实数据: {len(list(real_dir.glob('*.png')))}")
train_transform = _build_train_transform(img_h, img_w)
val_transform = _build_val_transform(img_h, img_w)
full_dataset = RotateSolverDataset(dirs=data_dirs, transform=train_transform)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
val_ds_clean = RotateSolverDataset(dirs=data_dirs, transform=val_transform)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 模型 / 优化器 / 调度器 ----
model = RotationRegressor(img_h=img_h, img_w=img_w).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
loss_fn = nn.MSELoss()
best_mae = float("inf")
best_tol_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / "rotation_regressor.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets in pbar:
images = images.to(device)
targets = targets.to(device) # (B, 2) → (sin, cos)
preds = model(images) # (B, 2)
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_pred_angles = []
all_gt_angles = []
with torch.no_grad():
for images, targets in val_loader:
images = images.to(device)
preds = model(images).cpu().numpy() # (B, 2)
targets_np = targets.numpy() # (B, 2)
# sin/cos → angle
for i in range(len(preds)):
pred_angle = math.degrees(math.atan2(preds[i][0], preds[i][1]))
if pred_angle < 0:
pred_angle += 360.0
gt_angle = math.degrees(math.atan2(targets_np[i][0], targets_np[i][1]))
if gt_angle < 0:
gt_angle += 360.0
all_pred_angles.append(pred_angle)
all_gt_angles.append(gt_angle)
pred_arr = np.array(all_pred_angles)
gt_arr = np.array(all_gt_angles)
mae = _circular_mae_deg(pred_arr, gt_arr)
diff = np.abs(pred_arr - gt_arr)
diff = np.minimum(diff, 360.0 - diff)
tol_acc = float(np.mean(diff <= tolerance))
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"MAE={mae:.2f}° "
f"tol_acc(±{tolerance:.0f}°)={tol_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if tol_acc >= best_tol_acc:
best_tol_acc = tol_acc
best_mae = mae
torch.save({
"model_state_dict": model.state_dict(),
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}°")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, img_h, img_w)
return best_tol_acc
if __name__ == "__main__":
main()