319 lines
11 KiB
Python
319 lines
11 KiB
Python
"""
|
||
CTC 训练通用逻辑
|
||
|
||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
|
||
职责:
|
||
1. 检查合成数据,不存在则自动调用生成器
|
||
2. 构建 Dataset / DataLoader(含真实数据混合)
|
||
3. CTC 训练循环 + cosine scheduler
|
||
4. 输出日志: epoch, loss, 整体准确率, 字符级准确率
|
||
5. 保存最佳模型到 checkpoints/
|
||
6. 训练结束导出 ONNX
|
||
"""
|
||
|
||
import os
|
||
import random
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import DataLoader, random_split
|
||
from tqdm import tqdm
|
||
|
||
from config import (
|
||
CHECKPOINTS_DIR,
|
||
ONNX_DIR,
|
||
ONNX_CONFIG,
|
||
TRAIN_CONFIG,
|
||
IMAGE_SIZE,
|
||
GENERATE_CONFIG,
|
||
RANDOM_SEED,
|
||
get_device,
|
||
)
|
||
from inference.model_metadata import write_model_metadata
|
||
from training.data_fingerprint import (
|
||
build_dataset_spec,
|
||
ensure_synthetic_dataset,
|
||
labels_cover_tokens,
|
||
)
|
||
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
|
||
|
||
|
||
def _set_seed(seed: int = RANDOM_SEED):
|
||
"""设置全局随机种子,保证训练可复现。"""
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(seed)
|
||
|
||
|
||
# ============================================================
|
||
# 准确率计算
|
||
# ============================================================
|
||
def _calc_accuracy(preds: list[str], labels: list[str]):
|
||
"""返回 (整体准确率, 字符级准确率)。"""
|
||
total_samples = len(preds)
|
||
correct_samples = 0
|
||
total_chars = 0
|
||
correct_chars = 0
|
||
|
||
for pred, label in zip(preds, labels):
|
||
if pred == label:
|
||
correct_samples += 1
|
||
# 字符级: 逐位比较 (取较短长度)
|
||
max_len = max(len(pred), len(label))
|
||
if max_len == 0:
|
||
continue
|
||
for i in range(max_len):
|
||
total_chars += 1
|
||
if i < len(pred) and i < len(label) and pred[i] == label[i]:
|
||
correct_chars += 1
|
||
|
||
sample_acc = correct_samples / max(total_samples, 1)
|
||
char_acc = correct_chars / max(total_chars, 1)
|
||
return sample_acc, char_acc
|
||
|
||
|
||
# ============================================================
|
||
# ONNX 导出
|
||
# ============================================================
|
||
def _export_onnx(
|
||
model: nn.Module,
|
||
model_name: str,
|
||
img_h: int,
|
||
img_w: int,
|
||
*,
|
||
chars: str,
|
||
):
|
||
"""导出模型为 ONNX 格式。"""
|
||
model.eval()
|
||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||
dummy = torch.randn(1, 1, img_h, img_w)
|
||
torch.onnx.export(
|
||
model.cpu(),
|
||
dummy,
|
||
str(onnx_path),
|
||
opset_version=ONNX_CONFIG["opset_version"],
|
||
input_names=["input"],
|
||
output_names=["output"],
|
||
dynamic_axes={"input": {0: "batch"}, "output": {1: "batch"}}
|
||
if ONNX_CONFIG["dynamic_batch"]
|
||
else None,
|
||
)
|
||
write_model_metadata(
|
||
onnx_path,
|
||
{
|
||
"model_name": model_name,
|
||
"task": "ctc",
|
||
"chars": chars,
|
||
"input_shape": [1, img_h, img_w],
|
||
},
|
||
)
|
||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||
|
||
|
||
# ============================================================
|
||
# 核心训练函数
|
||
# ============================================================
|
||
def train_ctc_model(
|
||
model_name: str,
|
||
model: nn.Module,
|
||
chars: str,
|
||
synthetic_dir: str | Path,
|
||
real_dir: str | Path,
|
||
generator_cls,
|
||
config_key: str,
|
||
):
|
||
"""
|
||
通用 CTC 训练流程。
|
||
|
||
Args:
|
||
model_name: 模型名称 (用于保存文件: normal / math / threed)
|
||
model: PyTorch 模型实例 (LiteCRNN 或 ThreeDCNN)
|
||
chars: 字符集字符串
|
||
synthetic_dir: 合成数据目录
|
||
real_dir: 真实数据目录
|
||
generator_cls: 生成器类 (用于自动生成数据)
|
||
config_key: TRAIN_CONFIG 中的键名
|
||
"""
|
||
cfg = TRAIN_CONFIG[config_key]
|
||
img_h, img_w = IMAGE_SIZE[config_key]
|
||
device = get_device()
|
||
|
||
# 设置随机种子
|
||
_set_seed()
|
||
|
||
# ---- 1. 检查 / 生成合成数据 ----
|
||
syn_path = Path(synthetic_dir)
|
||
dataset_spec = build_dataset_spec(
|
||
generator_cls,
|
||
config_key=config_key,
|
||
config_snapshot={
|
||
"generate_config": GENERATE_CONFIG[config_key],
|
||
"chars": chars,
|
||
"image_size": IMAGE_SIZE[config_key],
|
||
},
|
||
)
|
||
validator = None
|
||
if config_key == "math":
|
||
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
|
||
validator = lambda files: labels_cover_tokens(files, required_ops)
|
||
|
||
dataset_state = ensure_synthetic_dataset(
|
||
syn_path,
|
||
generator_cls=generator_cls,
|
||
spec=dataset_spec,
|
||
gen_count=cfg["synthetic_samples"],
|
||
exact_count=cfg["synthetic_samples"],
|
||
validator=validator,
|
||
adopt_if_missing=config_key in {"normal", "math"},
|
||
)
|
||
if dataset_state["refreshed"]:
|
||
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||
elif dataset_state["adopted"]:
|
||
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||
else:
|
||
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
|
||
|
||
# ---- 2. 构建数据集 ----
|
||
data_dirs = [str(syn_path)]
|
||
real_path = Path(real_dir)
|
||
if real_path.exists() and list(real_path.glob("*.png")):
|
||
data_dirs.append(str(real_path))
|
||
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))} 张")
|
||
|
||
train_transform = build_train_transform(img_h, img_w)
|
||
val_transform = build_val_transform(img_h, img_w)
|
||
|
||
full_dataset = CRNNDataset(dirs=data_dirs, chars=chars, transform=train_transform)
|
||
total = len(full_dataset)
|
||
val_size = int(total * cfg["val_split"])
|
||
train_size = total - val_size
|
||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||
|
||
# 验证集使用无增强 transform
|
||
val_ds_clean = CRNNDataset(dirs=data_dirs, chars=chars, transform=val_transform)
|
||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||
|
||
train_loader = DataLoader(
|
||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||
)
|
||
val_loader = DataLoader(
|
||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||
)
|
||
|
||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||
|
||
# ---- 3. 优化器 / 调度器 / 损失 ----
|
||
model = model.to(device)
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
|
||
|
||
best_acc = 0.0
|
||
start_epoch = 1
|
||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||
|
||
# ---- 3.5 断点续训 ----
|
||
if ckpt_path.exists():
|
||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
|
||
|
||
if dataset_state["refreshed"]:
|
||
print("[续训] 合成数据已刷新,忽略旧 checkpoint,从 epoch 1 重新训练")
|
||
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
|
||
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
|
||
else:
|
||
if ckpt_data_spec_hash is None:
|
||
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
|
||
model.load_state_dict(ckpt["model_state_dict"])
|
||
best_acc = ckpt.get("best_acc", 0.0)
|
||
start_epoch = ckpt.get("epoch", 0) + 1
|
||
# 快进 scheduler 到对应 epoch
|
||
for _ in range(start_epoch - 1):
|
||
scheduler.step()
|
||
print(
|
||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||
f"best_acc={best_acc:.4f}"
|
||
)
|
||
|
||
# ---- 4. 训练循环 ----
|
||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||
model.train()
|
||
total_loss = 0.0
|
||
num_batches = 0
|
||
|
||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||
for images, targets, target_lengths, _ in pbar:
|
||
images = images.to(device)
|
||
targets = targets.to(device)
|
||
target_lengths = target_lengths.to(device)
|
||
|
||
logits = model(images) # (T, B, C)
|
||
T, B, C = logits.shape
|
||
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
|
||
|
||
log_probs = logits.log_softmax(2)
|
||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||
optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
num_batches += 1
|
||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||
|
||
scheduler.step()
|
||
avg_loss = total_loss / max(num_batches, 1)
|
||
|
||
# ---- 5. 验证 ----
|
||
model.eval()
|
||
all_preds = []
|
||
all_labels = []
|
||
with torch.no_grad():
|
||
for images, _, _, labels in val_loader:
|
||
images = images.to(device)
|
||
logits = model(images)
|
||
preds = model.greedy_decode(logits)
|
||
all_preds.extend(preds)
|
||
all_labels.extend(labels)
|
||
|
||
sample_acc, char_acc = _calc_accuracy(all_preds, all_labels)
|
||
lr = scheduler.get_last_lr()[0]
|
||
|
||
print(
|
||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||
f"loss={avg_loss:.4f} "
|
||
f"acc={sample_acc:.4f} "
|
||
f"char_acc={char_acc:.4f} "
|
||
f"lr={lr:.6f}"
|
||
)
|
||
|
||
# ---- 6. 保存最佳模型 ----
|
||
if sample_acc >= best_acc:
|
||
best_acc = sample_acc
|
||
torch.save({
|
||
"model_state_dict": model.state_dict(),
|
||
"chars": chars,
|
||
"best_acc": best_acc,
|
||
"epoch": epoch,
|
||
"synthetic_data_spec_hash": current_data_spec_hash,
|
||
}, ckpt_path)
|
||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||
|
||
# ---- 7. 导出 ONNX ----
|
||
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
|
||
# 加载最佳权重再导出
|
||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||
model.load_state_dict(ckpt["model_state_dict"])
|
||
_export_onnx(model, model_name, img_h, img_w, chars=chars)
|
||
|
||
return best_acc
|