Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider
Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
CTC 训练通用逻辑
|
||||
|
||||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
|
||||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
|
||||
职责:
|
||||
1. 检查合成数据,不存在则自动调用生成器
|
||||
2. 构建 Dataset / DataLoader(含真实数据混合)
|
||||
@@ -12,8 +12,10 @@ CTC 训练通用逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@@ -25,11 +27,21 @@ from config import (
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
"""设置全局随机种子,保证训练可复现。"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 准确率计算
|
||||
# ============================================================
|
||||
@@ -104,9 +116,12 @@ def train_ctc_model(
|
||||
config_key: TRAIN_CONFIG 中的键名
|
||||
"""
|
||||
cfg = TRAIN_CONFIG[config_key]
|
||||
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
|
||||
img_h, img_w = IMAGE_SIZE[config_key]
|
||||
device = get_device()
|
||||
|
||||
# 设置随机种子
|
||||
_set_seed()
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
@@ -139,11 +154,11 @@ def train_ctc_model(
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
@@ -166,12 +181,11 @@ def train_ctc_model(
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets, target_lengths, _ in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
logits = model(images) # (T, B, C)
|
||||
T, B, C = logits.shape
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
|
||||
# cuDNN CTC requires targets/lengths on CPU
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32)
|
||||
|
||||
log_probs = logits.log_softmax(2)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
|
||||
Reference in New Issue
Block a user