Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider
Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
"""
|
||||
训练脚本包
|
||||
|
||||
- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类
|
||||
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
|
||||
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
|
||||
- train_math.py: 训练算式识别 (LiteCRNN - math)
|
||||
- train_3d.py: 训练 3D 立体识别 (ThreeDCNN)
|
||||
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
|
||||
- dataset.py: CRNNDataset / CaptchaDataset / RegressionDataset 通用数据集类
|
||||
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
|
||||
- train_regression_utils.py: 回归训练通用逻辑 (train_regression_model)
|
||||
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
|
||||
- train_math.py: 训练算式识别 (LiteCRNN - math)
|
||||
- train_3d_text.py: 训练 3D 立体文字识别 (ThreeDCNN)
|
||||
- train_3d_rotate.py: 训练 3D 旋转回归 (RegressionCNN)
|
||||
- train_3d_slider.py: 训练 3D 滑块回归 (RegressionCNN)
|
||||
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
|
||||
"""
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
通用 Dataset 类
|
||||
|
||||
提供两种数据集:
|
||||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||||
提供三种数据集:
|
||||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||||
- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1])
|
||||
|
||||
文件名格式约定: {label}_{任意}.png
|
||||
- 分类器: label 可为任意字符,所在子目录名即为类别
|
||||
- 识别器: label 即标注内容 (如 "A3B8" 或 "3+8")
|
||||
- 回归器: label 为数值 (如 "135" 或 "87")
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
@@ -98,7 +101,15 @@ class CRNNDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
|
||||
# 编码标签为整数序列
|
||||
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
|
||||
target = []
|
||||
for c in label:
|
||||
if c in self.char_to_idx:
|
||||
target.append(self.char_to_idx[c])
|
||||
else:
|
||||
warnings.warn(
|
||||
f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})",
|
||||
stacklevel=2,
|
||||
)
|
||||
return img, target, label
|
||||
|
||||
@staticmethod
|
||||
@@ -119,7 +130,7 @@ class CaptchaDataset(Dataset):
|
||||
"""
|
||||
分类器训练数据集。
|
||||
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d"),
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d_text"),
|
||||
目录内所有 .png 文件属于该类。
|
||||
"""
|
||||
|
||||
@@ -157,3 +168,59 @@ class CaptchaDataset(Dataset):
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 回归模型用数据集
|
||||
# ============================================================
|
||||
class RegressionDataset(Dataset):
|
||||
"""
|
||||
回归模型数据集 (3d_rotate / 3d_slider)。
|
||||
|
||||
从目录中读取 {value}_{xxx}.png 文件,
|
||||
将 value 解析为浮点数并归一化到 [0, 1]。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
label_range: tuple[float, float],
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dirs: 数据目录列表
|
||||
label_range: (min_val, max_val) 标签原始范围
|
||||
transform: 图片预处理/增强
|
||||
"""
|
||||
self.label_range = label_range
|
||||
self.lo, self.hi = label_range
|
||||
self.transform = transform
|
||||
|
||||
self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签)
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
for f in sorted(d.glob("*.png")):
|
||||
raw_label = f.stem.rsplit("_", 1)[0]
|
||||
try:
|
||||
value = float(raw_label)
|
||||
except ValueError:
|
||||
continue
|
||||
# 归一化到 [0, 1]
|
||||
norm = (value - self.lo) / max(self.hi - self.lo, 1e-6)
|
||||
norm = max(0.0, min(1.0, norm))
|
||||
self.samples.append((str(f), norm))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
import torch
|
||||
path, label = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([label], dtype=torch.float32)
|
||||
|
||||
|
||||
38
training/train_3d_rotate.py
Normal file
38
training/train_3d_rotate.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
训练 3D 旋转验证码回归模型 (RegressionCNN)
|
||||
|
||||
用法: python -m training.train_3d_rotate
|
||||
"""
|
||||
|
||||
from config import (
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_3D_ROTATE_DIR,
|
||||
REAL_3D_ROTATE_DIR,
|
||||
)
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from training.train_regression_utils import train_regression_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练 3D 旋转验证码回归模型 (RegressionCNN)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print(f" 任务: 预测旋转角度 0-359°")
|
||||
print("=" * 60)
|
||||
|
||||
train_regression_model(
|
||||
model_name="threed_rotate",
|
||||
model=model,
|
||||
synthetic_dir=SYNTHETIC_3D_ROTATE_DIR,
|
||||
real_dir=REAL_3D_ROTATE_DIR,
|
||||
generator_cls=ThreeDRotateGenerator,
|
||||
config_key="3d_rotate",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
training/train_3d_slider.py
Normal file
38
training/train_3d_slider.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
训练 3D 滑块验证码回归模型 (RegressionCNN)
|
||||
|
||||
用法: python -m training.train_3d_slider
|
||||
"""
|
||||
|
||||
from config import (
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_3D_SLIDER_DIR,
|
||||
REAL_3D_SLIDER_DIR,
|
||||
)
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from training.train_regression_utils import train_regression_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练 3D 滑块验证码回归模型 (RegressionCNN)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print(f" 任务: 预测滑块偏移 x 坐标")
|
||||
print("=" * 60)
|
||||
|
||||
train_regression_model(
|
||||
model_name="threed_slider",
|
||||
model=model,
|
||||
synthetic_dir=SYNTHETIC_3D_SLIDER_DIR,
|
||||
real_dir=REAL_3D_SLIDER_DIR,
|
||||
generator_cls=ThreeDSliderGenerator,
|
||||
config_key="3d_slider",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,14 +1,14 @@
|
||||
"""
|
||||
训练 3D 立体验证码识别模型 (ThreeDCNN)
|
||||
训练 3D 立体文字验证码识别模型 (ThreeDCNN)
|
||||
|
||||
用法: python -m training.train_3d
|
||||
用法: python -m training.train_3d_text
|
||||
"""
|
||||
|
||||
from config import (
|
||||
THREED_CHARS,
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_3D_DIR,
|
||||
REAL_3D_DIR,
|
||||
SYNTHETIC_3D_TEXT_DIR,
|
||||
REAL_3D_TEXT_DIR,
|
||||
)
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
@@ -16,23 +16,23 @@ from training.train_utils import train_ctc_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["3d"]
|
||||
img_h, img_w = IMAGE_SIZE["3d_text"]
|
||||
model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练 3D 立体验证码识别模型 (ThreeDCNN)")
|
||||
print("训练 3D 立体文字验证码识别模型 (ThreeDCNN)")
|
||||
print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_ctc_model(
|
||||
model_name="threed",
|
||||
model_name="threed_text",
|
||||
model=model,
|
||||
chars=THREED_CHARS,
|
||||
synthetic_dir=SYNTHETIC_3D_DIR,
|
||||
real_dir=REAL_3D_DIR,
|
||||
synthetic_dir=SYNTHETIC_3D_TEXT_DIR,
|
||||
real_dir=REAL_3D_TEXT_DIR,
|
||||
generator_cls=ThreeDCaptchaGenerator,
|
||||
config_key="threed",
|
||||
config_key="3d_text",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""
|
||||
训练调度分类器 (CaptchaClassifier)
|
||||
|
||||
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。
|
||||
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider。
|
||||
数据来源: data/classifier/ 目录 (按类型子目录组织)
|
||||
|
||||
用法: python -m training.train_classifier
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@@ -24,15 +26,20 @@ from config import (
|
||||
CLASSIFIER_DIR,
|
||||
SYNTHETIC_NORMAL_DIR,
|
||||
SYNTHETIC_MATH_DIR,
|
||||
SYNTHETIC_3D_DIR,
|
||||
SYNTHETIC_3D_TEXT_DIR,
|
||||
SYNTHETIC_3D_ROTATE_DIR,
|
||||
SYNTHETIC_3D_SLIDER_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from generators.normal_gen import NormalCaptchaGenerator
|
||||
from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
from models.classifier import CaptchaClassifier
|
||||
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
|
||||
|
||||
@@ -52,7 +59,9 @@ def _prepare_classifier_data():
|
||||
type_info = [
|
||||
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
|
||||
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
|
||||
("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator),
|
||||
("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator),
|
||||
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
|
||||
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
|
||||
]
|
||||
|
||||
for cls_name, syn_dir, gen_cls in type_info:
|
||||
@@ -95,6 +104,13 @@ def main():
|
||||
img_h, img_w = IMAGE_SIZE["classifier"]
|
||||
device = get_device()
|
||||
|
||||
# 设置随机种子
|
||||
random.seed(RANDOM_SEED)
|
||||
np.random.seed(RANDOM_SEED)
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(RANDOM_SEED)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练调度分类器 (CaptchaClassifier)")
|
||||
print(f" 类别: {CAPTCHA_TYPES}")
|
||||
@@ -128,11 +144,11 @@ def main():
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=2, pin_memory=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=2, pin_memory=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
264
training/train_regression_utils.py
Normal file
264
training/train_regression_utils.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
回归训练通用逻辑
|
||||
|
||||
提供 train_regression_model() 函数,被 train_3d_rotate / train_3d_slider 共用。
|
||||
职责:
|
||||
1. 检查合成数据,不存在则自动调用生成器
|
||||
2. 构建 RegressionDataset / DataLoader(含真实数据混合)
|
||||
3. 回归训练循环 + cosine scheduler
|
||||
4. 输出日志: epoch, loss, MAE, tolerance 准确率
|
||||
5. 保存最佳模型到 checkpoints/
|
||||
6. 训练结束导出 ONNX
|
||||
"""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
REGRESSION_RANGE,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
"""设置全局随机种子,保证训练可复现。"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _circular_smooth_l1(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
循环距离 SmoothL1 loss,用于角度回归 (处理 0°/360° 边界)。
|
||||
pred 和 target 都在 [0, 1] 范围。
|
||||
"""
|
||||
diff = torch.abs(pred - target)
|
||||
# 循环距离: min(|d|, 1-|d|)
|
||||
diff = torch.min(diff, 1.0 - diff)
|
||||
# SmoothL1
|
||||
return torch.where(
|
||||
diff < 1.0 / 360.0, # beta ≈ 1° 归一化
|
||||
0.5 * diff * diff / (1.0 / 360.0),
|
||||
diff - 0.5 * (1.0 / 360.0),
|
||||
).mean()
|
||||
|
||||
|
||||
def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
|
||||
"""循环 MAE (归一化空间)。"""
|
||||
diff = np.abs(pred - target)
|
||||
diff = np.minimum(diff, 1.0 - diff)
|
||||
return float(np.mean(diff))
|
||||
|
||||
|
||||
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
"""导出模型为 ONNX 格式。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||||
dummy = torch.randn(1, 1, img_h, img_w)
|
||||
torch.onnx.export(
|
||||
model.cpu(),
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
def train_regression_model(
|
||||
model_name: str,
|
||||
model: nn.Module,
|
||||
synthetic_dir: str | Path,
|
||||
real_dir: str | Path,
|
||||
generator_cls,
|
||||
config_key: str,
|
||||
):
|
||||
"""
|
||||
通用回归训练流程。
|
||||
|
||||
Args:
|
||||
model_name: 模型名称 (用于保存文件: threed_rotate / threed_slider)
|
||||
model: PyTorch 模型实例 (RegressionCNN)
|
||||
synthetic_dir: 合成数据目录
|
||||
real_dir: 真实数据目录
|
||||
generator_cls: 生成器类 (用于自动生成数据)
|
||||
config_key: TRAIN_CONFIG / REGRESSION_RANGE 中的键名 (3d_rotate / 3d_slider)
|
||||
"""
|
||||
cfg = TRAIN_CONFIG[config_key]
|
||||
img_h, img_w = IMAGE_SIZE[config_key]
|
||||
label_range = REGRESSION_RANGE[config_key]
|
||||
lo, hi = label_range
|
||||
is_circular = config_key == "3d_rotate"
|
||||
device = get_device()
|
||||
|
||||
# 容差配置
|
||||
if config_key == "3d_rotate":
|
||||
tolerance = 5.0 # ±5°
|
||||
else:
|
||||
tolerance = 3.0 # ±3px
|
||||
|
||||
_set_seed()
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
real_path = Path(real_dir)
|
||||
if real_path.exists() and list(real_path.glob("*.png")):
|
||||
data_dirs.append(str(real_path))
|
||||
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))} 张")
|
||||
|
||||
train_transform = build_train_transform(img_h, img_w)
|
||||
val_transform = build_val_transform(img_h, img_w)
|
||||
|
||||
full_dataset = RegressionDataset(
|
||||
dirs=data_dirs, label_range=label_range, transform=train_transform,
|
||||
)
|
||||
total = len(full_dataset)
|
||||
val_size = int(total * cfg["val_split"])
|
||||
train_size = total - val_size
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
|
||||
# 验证集使用无增强 transform
|
||||
val_ds_clean = RegressionDataset(
|
||||
dirs=data_dirs, label_range=label_range, transform=val_transform,
|
||||
)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
# ---- 3. 优化器 / 调度器 / 损失 ----
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
|
||||
if is_circular:
|
||||
loss_fn = _circular_smooth_l1
|
||||
else:
|
||||
loss_fn = nn.SmoothL1Loss()
|
||||
|
||||
best_mae = float("inf")
|
||||
best_tol_acc = 0.0
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
preds = model(images)
|
||||
loss = loss_fn(preds, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
# ---- 5. 验证 ----
|
||||
model.eval()
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
with torch.no_grad():
|
||||
for images, targets in val_loader:
|
||||
images = images.to(device)
|
||||
preds = model(images)
|
||||
all_preds.append(preds.cpu().numpy())
|
||||
all_targets.append(targets.numpy())
|
||||
|
||||
all_preds = np.concatenate(all_preds, axis=0).flatten()
|
||||
all_targets = np.concatenate(all_targets, axis=0).flatten()
|
||||
|
||||
# 缩放回原始范围计算 MAE
|
||||
preds_real = all_preds * (hi - lo) + lo
|
||||
targets_real = all_targets * (hi - lo) + lo
|
||||
|
||||
if is_circular:
|
||||
# 循环 MAE
|
||||
diff = np.abs(preds_real - targets_real)
|
||||
diff = np.minimum(diff, (hi - lo) - diff)
|
||||
mae = float(np.mean(diff))
|
||||
within_tol = diff <= tolerance
|
||||
else:
|
||||
mae = float(np.mean(np.abs(preds_real - targets_real)))
|
||||
within_tol = np.abs(preds_real - targets_real) <= tolerance
|
||||
|
||||
tol_acc = float(np.mean(within_tol))
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"MAE={mae:.2f} "
|
||||
f"tol_acc(±{tolerance:.0f})={tol_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
# ---- 6. 保存最佳模型 (以容差准确率为准) ----
|
||||
if tol_acc >= best_tol_acc:
|
||||
best_tol_acc = tol_acc
|
||||
best_mae = mae
|
||||
torch.save({
|
||||
"model_state_dict": model.state_dict(),
|
||||
"label_range": label_range,
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
|
||||
|
||||
# ---- 7. 导出 ONNX ----
|
||||
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, model_name, img_h, img_w)
|
||||
|
||||
return best_tol_acc
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
CTC 训练通用逻辑
|
||||
|
||||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
|
||||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
|
||||
职责:
|
||||
1. 检查合成数据,不存在则自动调用生成器
|
||||
2. 构建 Dataset / DataLoader(含真实数据混合)
|
||||
@@ -12,8 +12,10 @@ CTC 训练通用逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@@ -25,11 +27,21 @@ from config import (
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
"""设置全局随机种子,保证训练可复现。"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 准确率计算
|
||||
# ============================================================
|
||||
@@ -104,9 +116,12 @@ def train_ctc_model(
|
||||
config_key: TRAIN_CONFIG 中的键名
|
||||
"""
|
||||
cfg = TRAIN_CONFIG[config_key]
|
||||
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
|
||||
img_h, img_w = IMAGE_SIZE[config_key]
|
||||
device = get_device()
|
||||
|
||||
# 设置随机种子
|
||||
_set_seed()
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
@@ -139,11 +154,11 @@ def train_ctc_model(
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
@@ -166,12 +181,11 @@ def train_ctc_model(
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets, target_lengths, _ in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
logits = model(images) # (T, B, C)
|
||||
T, B, C = logits.shape
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
|
||||
# cuDNN CTC requires targets/lengths on CPU
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32)
|
||||
|
||||
log_probs = logits.log_softmax(2)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
|
||||
Reference in New Issue
Block a user