Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider

Split the single "3d" captcha type into three independent expert models:
- 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN)
- 3d_rotate: rotation angle regression (new RegressionCNN, circular loss)
- 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss)

CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated
to 50000 (10000 per class). New generators, model, dataset, training
utilities, and full pipeline/export/CLI support for all subtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hua
2026-03-11 13:55:53 +08:00
parent 760b80ee5e
commit f5be7671bc
20 changed files with 1109 additions and 142 deletions

View File

@@ -1,7 +1,7 @@
"""
CTC 训练通用逻辑
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 Dataset / DataLoader含真实数据混合
@@ -12,8 +12,10 @@ CTC 训练通用逻辑
"""
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
@@ -25,11 +27,21 @@ from config import (
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
RANDOM_SEED,
get_device,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ============================================================
# 准确率计算
# ============================================================
@@ -104,9 +116,12 @@ def train_ctc_model(
config_key: TRAIN_CONFIG 中的键名
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
img_h, img_w = IMAGE_SIZE[config_key]
device = get_device()
# 设置随机种子
_set_seed()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
@@ -139,11 +154,11 @@ def train_ctc_model(
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
@@ -166,12 +181,11 @@ def train_ctc_model(
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar:
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C)
T, B, C = logits.shape
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
# cuDNN CTC requires targets/lengths on CPU
input_lengths = torch.full((B,), T, dtype=torch.int32)
log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)