Files
CaptchBreaker/training/train_utils.py
Hua 788ddcae1a Add tests, server, resume training, and project cleanup
- Add 57 unit tests covering generators, models, and pipeline components
- Implement FastAPI HTTP service (server.py) with POST /solve and GET /health
- Add checkpoint resume (断点续训) to both CTC and regression training utils
- Fix device mismatch bug in CTC training (targets/input_lengths on GPU)
- Add pytest dev dependency to pyproject.toml
- Update .gitignore with data/solver/, data/real/, *.log
- Remove PyCharm template main.py
- Update training/__init__.py docs for solver training scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 19:05:47 +08:00

263 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
CTC 训练通用逻辑
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 Dataset / DataLoader含真实数据混合
3. CTC 训练循环 + cosine scheduler
4. 输出日志: epoch, loss, 整体准确率, 字符级准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束导出 ONNX
"""
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
RANDOM_SEED,
get_device,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ============================================================
# 准确率计算
# ============================================================
def _calc_accuracy(preds: list[str], labels: list[str]):
"""返回 (整体准确率, 字符级准确率)。"""
total_samples = len(preds)
correct_samples = 0
total_chars = 0
correct_chars = 0
for pred, label in zip(preds, labels):
if pred == label:
correct_samples += 1
# 字符级: 逐位比较 (取较短长度)
max_len = max(len(pred), len(label))
if max_len == 0:
continue
for i in range(max_len):
total_chars += 1
if i < len(pred) and i < len(label) and pred[i] == label[i]:
correct_chars += 1
sample_acc = correct_samples / max(total_samples, 1)
char_acc = correct_chars / max(total_chars, 1)
return sample_acc, char_acc
# ============================================================
# ONNX 导出
# ============================================================
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {1: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
# ============================================================
# 核心训练函数
# ============================================================
def train_ctc_model(
model_name: str,
model: nn.Module,
chars: str,
synthetic_dir: str | Path,
real_dir: str | Path,
generator_cls,
config_key: str,
):
"""
通用 CTC 训练流程。
Args:
model_name: 模型名称 (用于保存文件: normal / math / threed)
model: PyTorch 模型实例 (LiteCRNN 或 ThreeDCNN)
chars: 字符集字符串
synthetic_dir: 合成数据目录
real_dir: 真实数据目录
generator_cls: 生成器类 (用于自动生成数据)
config_key: TRAIN_CONFIG 中的键名
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key]
device = get_device()
# 设置随机种子
_set_seed()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_path = Path(real_dir)
if real_path.exists() and list(real_path.glob("*.png")):
data_dirs.append(str(real_path))
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))}")
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = CRNNDataset(dirs=data_dirs, chars=chars, transform=train_transform)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集使用无增强 transform
val_ds_clean = CRNNDataset(dirs=data_dirs, chars=chars, transform=val_transform)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 优化器 / 调度器 / 损失 ----
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
best_acc = 0.0
start_epoch = 1
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_acc = ckpt.get("best_acc", 0.0)
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_acc={best_acc:.4f}"
)
# ---- 4. 训练循环 ----
for epoch in range(start_epoch, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar:
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C)
T, B, C = logits.shape
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, _, _, labels in val_loader:
images = images.to(device)
logits = model(images)
preds = model.greedy_decode(logits)
all_preds.extend(preds)
all_labels.extend(labels)
sample_acc, char_acc = _calc_accuracy(all_preds, all_labels)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={sample_acc:.4f} "
f"char_acc={char_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if sample_acc >= best_acc:
best_acc = sample_acc
torch.save({
"model_state_dict": model.state_dict(),
"chars": chars,
"best_acc": best_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
# 加载最佳权重再导出
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
return best_acc