""" CTC 训练通用逻辑 提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。 职责: 1. 检查合成数据,不存在则自动调用生成器 2. 构建 Dataset / DataLoader(含真实数据混合) 3. CTC 训练循环 + cosine scheduler 4. 输出日志: epoch, loss, 整体准确率, 字符级准确率 5. 保存最佳模型到 checkpoints/ 6. 训练结束导出 ONNX """ import os import random from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from tqdm import tqdm from config import ( CHECKPOINTS_DIR, ONNX_DIR, ONNX_CONFIG, TRAIN_CONFIG, IMAGE_SIZE, GENERATE_CONFIG, RANDOM_SEED, get_device, ) from inference.model_metadata import write_model_metadata from training.data_fingerprint import ( build_dataset_spec, ensure_synthetic_dataset, labels_cover_tokens, ) from training.dataset import CRNNDataset, build_train_transform, build_val_transform def _set_seed(seed: int = RANDOM_SEED): """设置全局随机种子,保证训练可复现。""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # ============================================================ # 准确率计算 # ============================================================ def _calc_accuracy(preds: list[str], labels: list[str]): """返回 (整体准确率, 字符级准确率)。""" total_samples = len(preds) correct_samples = 0 total_chars = 0 correct_chars = 0 for pred, label in zip(preds, labels): if pred == label: correct_samples += 1 # 字符级: 逐位比较 (取较短长度) max_len = max(len(pred), len(label)) if max_len == 0: continue for i in range(max_len): total_chars += 1 if i < len(pred) and i < len(label) and pred[i] == label[i]: correct_chars += 1 sample_acc = correct_samples / max(total_samples, 1) char_acc = correct_chars / max(total_chars, 1) return sample_acc, char_acc # ============================================================ # ONNX 导出 # ============================================================ def _export_onnx( model: nn.Module, model_name: str, img_h: int, img_w: int, *, chars: str, ): """导出模型为 ONNX 格式。""" model.eval() onnx_path = ONNX_DIR / f"{model_name}.onnx" dummy = torch.randn(1, 1, img_h, img_w) torch.onnx.export( model.cpu(), dummy, str(onnx_path), opset_version=ONNX_CONFIG["opset_version"], input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {1: "batch"}} if ONNX_CONFIG["dynamic_batch"] else None, ) write_model_metadata( onnx_path, { "model_name": model_name, "task": "ctc", "chars": chars, "input_shape": [1, img_h, img_w], }, ) print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)") # ============================================================ # 核心训练函数 # ============================================================ def train_ctc_model( model_name: str, model: nn.Module, chars: str, synthetic_dir: str | Path, real_dir: str | Path, generator_cls, config_key: str, ): """ 通用 CTC 训练流程。 Args: model_name: 模型名称 (用于保存文件: normal / math / threed) model: PyTorch 模型实例 (LiteCRNN 或 ThreeDCNN) chars: 字符集字符串 synthetic_dir: 合成数据目录 real_dir: 真实数据目录 generator_cls: 生成器类 (用于自动生成数据) config_key: TRAIN_CONFIG 中的键名 """ cfg = TRAIN_CONFIG[config_key] img_h, img_w = IMAGE_SIZE[config_key] device = get_device() # 设置随机种子 _set_seed() # ---- 1. 检查 / 生成合成数据 ---- syn_path = Path(synthetic_dir) dataset_spec = build_dataset_spec( generator_cls, config_key=config_key, config_snapshot={ "generate_config": GENERATE_CONFIG[config_key], "chars": chars, "image_size": IMAGE_SIZE[config_key], }, ) validator = None if config_key == "math": required_ops = tuple(GENERATE_CONFIG["math"]["operators"]) validator = lambda files: labels_cover_tokens(files, required_ops) dataset_state = ensure_synthetic_dataset( syn_path, generator_cls=generator_cls, spec=dataset_spec, gen_count=cfg["synthetic_samples"], exact_count=cfg["synthetic_samples"], validator=validator, adopt_if_missing=config_key in {"normal", "math"}, ) if dataset_state["refreshed"]: print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张") elif dataset_state["adopted"]: print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张") else: print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张") current_data_spec_hash = dataset_state["manifest"]["spec_hash"] # ---- 2. 构建数据集 ---- data_dirs = [str(syn_path)] real_path = Path(real_dir) if real_path.exists() and list(real_path.glob("*.png")): data_dirs.append(str(real_path)) print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))} 张") train_transform = build_train_transform(img_h, img_w) val_transform = build_val_transform(img_h, img_w) full_dataset = CRNNDataset(dirs=data_dirs, chars=chars, transform=train_transform) total = len(full_dataset) val_size = int(total * cfg["val_split"]) train_size = total - val_size train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) # 验证集使用无增强 transform val_ds_clean = CRNNDataset(dirs=data_dirs, chars=chars, transform=val_transform) val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices] train_loader = DataLoader( train_ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True, ) val_loader = DataLoader( val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True, ) print(f"[数据] 训练: {train_size} 验证: {val_size}") # ---- 3. 优化器 / 调度器 / 损失 ---- model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"]) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"]) ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True) best_acc = 0.0 start_epoch = 1 ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth" # ---- 3.5 断点续训 ---- if ckpt_path.exists(): ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash") if dataset_state["refreshed"]: print("[续训] 合成数据已刷新,忽略旧 checkpoint,从 epoch 1 重新训练") elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash: print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练") else: if ckpt_data_spec_hash is None: print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练") model.load_state_dict(ckpt["model_state_dict"]) best_acc = ckpt.get("best_acc", 0.0) start_epoch = ckpt.get("epoch", 0) + 1 # 快进 scheduler 到对应 epoch for _ in range(start_epoch - 1): scheduler.step() print( f"[续训] 从 epoch {start_epoch} 继续, " f"best_acc={best_acc:.4f}" ) # ---- 4. 训练循环 ---- for epoch in range(start_epoch, cfg["epochs"] + 1): model.train() total_loss = 0.0 num_batches = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) for images, targets, target_lengths, _ in pbar: images = images.to(device) targets = targets.to(device) target_lengths = target_lengths.to(device) logits = model(images) # (T, B, C) T, B, C = logits.shape input_lengths = torch.full((B,), T, dtype=torch.int32, device=device) log_probs = logits.log_softmax(2) loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() total_loss += loss.item() num_batches += 1 pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() avg_loss = total_loss / max(num_batches, 1) # ---- 5. 验证 ---- model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for images, _, _, labels in val_loader: images = images.to(device) logits = model(images) preds = model.greedy_decode(logits) all_preds.extend(preds) all_labels.extend(labels) sample_acc, char_acc = _calc_accuracy(all_preds, all_labels) lr = scheduler.get_last_lr()[0] print( f"Epoch {epoch:3d}/{cfg['epochs']} " f"loss={avg_loss:.4f} " f"acc={sample_acc:.4f} " f"char_acc={char_acc:.4f} " f"lr={lr:.6f}" ) # ---- 6. 保存最佳模型 ---- if sample_acc >= best_acc: best_acc = sample_acc torch.save({ "model_state_dict": model.state_dict(), "chars": chars, "best_acc": best_acc, "epoch": epoch, "synthetic_data_spec_hash": current_data_spec_hash, }, ckpt_path) print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}") # ---- 7. 导出 ONNX ---- print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}") # 加载最佳权重再导出 ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state_dict"]) _export_onnx(model, model_name, img_h, img_w, chars=chars) return best_acc