Files
CaptchBreaker/training/train_utils.py

319 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
CTC 训练通用逻辑
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 Dataset / DataLoader含真实数据混合
3. CTC 训练循环 + cosine scheduler
4. 输出日志: epoch, loss, 整体准确率, 字符级准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束导出 ONNX
"""
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
GENERATE_CONFIG,
RANDOM_SEED,
get_device,
)
from inference.model_metadata import write_model_metadata
from training.data_fingerprint import (
build_dataset_spec,
ensure_synthetic_dataset,
labels_cover_tokens,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ============================================================
# 准确率计算
# ============================================================
def _calc_accuracy(preds: list[str], labels: list[str]):
"""返回 (整体准确率, 字符级准确率)。"""
total_samples = len(preds)
correct_samples = 0
total_chars = 0
correct_chars = 0
for pred, label in zip(preds, labels):
if pred == label:
correct_samples += 1
# 字符级: 逐位比较 (取较短长度)
max_len = max(len(pred), len(label))
if max_len == 0:
continue
for i in range(max_len):
total_chars += 1
if i < len(pred) and i < len(label) and pred[i] == label[i]:
correct_chars += 1
sample_acc = correct_samples / max(total_samples, 1)
char_acc = correct_chars / max(total_chars, 1)
return sample_acc, char_acc
# ============================================================
# ONNX 导出
# ============================================================
def _export_onnx(
model: nn.Module,
model_name: str,
img_h: int,
img_w: int,
*,
chars: str,
):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {1: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": model_name,
"task": "ctc",
"chars": chars,
"input_shape": [1, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
# ============================================================
# 核心训练函数
# ============================================================
def train_ctc_model(
model_name: str,
model: nn.Module,
chars: str,
synthetic_dir: str | Path,
real_dir: str | Path,
generator_cls,
config_key: str,
):
"""
通用 CTC 训练流程。
Args:
model_name: 模型名称 (用于保存文件: normal / math / threed)
model: PyTorch 模型实例 (LiteCRNN 或 ThreeDCNN)
chars: 字符集字符串
synthetic_dir: 合成数据目录
real_dir: 真实数据目录
generator_cls: 生成器类 (用于自动生成数据)
config_key: TRAIN_CONFIG 中的键名
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key]
device = get_device()
# 设置随机种子
_set_seed()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
dataset_spec = build_dataset_spec(
generator_cls,
config_key=config_key,
config_snapshot={
"generate_config": GENERATE_CONFIG[config_key],
"chars": chars,
"image_size": IMAGE_SIZE[config_key],
},
)
validator = None
if config_key == "math":
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
validator = lambda files: labels_cover_tokens(files, required_ops)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=generator_cls,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
validator=validator,
adopt_if_missing=config_key in {"normal", "math"},
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_path = Path(real_dir)
if real_path.exists() and list(real_path.glob("*.png")):
data_dirs.append(str(real_path))
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))}")
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = CRNNDataset(dirs=data_dirs, chars=chars, transform=train_transform)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集使用无增强 transform
val_ds_clean = CRNNDataset(dirs=data_dirs, chars=chars, transform=val_transform)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 优化器 / 调度器 / 损失 ----
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
best_acc = 0.0
start_epoch = 1
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
if dataset_state["refreshed"]:
print("[续训] 合成数据已刷新,忽略旧 checkpoint从 epoch 1 重新训练")
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
else:
if ckpt_data_spec_hash is None:
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
model.load_state_dict(ckpt["model_state_dict"])
best_acc = ckpt.get("best_acc", 0.0)
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_acc={best_acc:.4f}"
)
# ---- 4. 训练循环 ----
for epoch in range(start_epoch, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar:
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C)
T, B, C = logits.shape
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, _, _, labels in val_loader:
images = images.to(device)
logits = model(images)
preds = model.greedy_decode(logits)
all_preds.extend(preds)
all_labels.extend(labels)
sample_acc, char_acc = _calc_accuracy(all_preds, all_labels)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={sample_acc:.4f} "
f"char_acc={char_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if sample_acc >= best_acc:
best_acc = sample_acc
torch.save({
"model_state_dict": model.state_dict(),
"chars": chars,
"best_acc": best_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
# 加载最佳权重再导出
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w, chars=chars)
return best_acc