Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
249 lines
7.8 KiB
Python
249 lines
7.8 KiB
Python
"""
|
||
训练调度分类器 (CaptchaClassifier)
|
||
|
||
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider。
|
||
数据来源: data/classifier/ 目录 (按类型子目录组织)
|
||
|
||
用法: python -m training.train_classifier
|
||
"""
|
||
|
||
import os
|
||
import random
|
||
import shutil
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import DataLoader, random_split
|
||
from tqdm import tqdm
|
||
|
||
from config import (
|
||
CAPTCHA_TYPES,
|
||
NUM_CAPTCHA_TYPES,
|
||
IMAGE_SIZE,
|
||
TRAIN_CONFIG,
|
||
CLASSIFIER_DIR,
|
||
SYNTHETIC_NORMAL_DIR,
|
||
SYNTHETIC_MATH_DIR,
|
||
SYNTHETIC_3D_TEXT_DIR,
|
||
SYNTHETIC_3D_ROTATE_DIR,
|
||
SYNTHETIC_3D_SLIDER_DIR,
|
||
CHECKPOINTS_DIR,
|
||
ONNX_DIR,
|
||
ONNX_CONFIG,
|
||
RANDOM_SEED,
|
||
get_device,
|
||
)
|
||
from generators.normal_gen import NormalCaptchaGenerator
|
||
from generators.math_gen import MathCaptchaGenerator
|
||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||
from models.classifier import CaptchaClassifier
|
||
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
|
||
|
||
|
||
def _prepare_classifier_data():
|
||
"""
|
||
准备分类器训练数据。
|
||
|
||
策略:从各类型的合成数据目录中软链接 / 复制到 data/classifier/{type}/ 下,
|
||
每类取相同数量,保证类别平衡。
|
||
如果各类型合成数据不存在,先自动生成。
|
||
"""
|
||
cfg = TRAIN_CONFIG["classifier"]
|
||
per_class = cfg["synthetic_samples"] // NUM_CAPTCHA_TYPES
|
||
|
||
# 各类型: (类名, 合成目录, 生成器类)
|
||
type_info = [
|
||
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
|
||
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
|
||
("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator),
|
||
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
|
||
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
|
||
]
|
||
|
||
for cls_name, syn_dir, gen_cls in type_info:
|
||
syn_dir = Path(syn_dir)
|
||
existing = sorted(syn_dir.glob("*.png"))
|
||
|
||
# 如果合成数据不够,生成一些
|
||
if len(existing) < per_class:
|
||
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
|
||
gen = gen_cls()
|
||
gen.generate_dataset(per_class, str(syn_dir))
|
||
existing = sorted(syn_dir.glob("*.png"))
|
||
|
||
# 复制到 classifier 目录
|
||
cls_dir = CLASSIFIER_DIR / cls_name
|
||
cls_dir.mkdir(parents=True, exist_ok=True)
|
||
already = len(list(cls_dir.glob("*.png")))
|
||
if already >= per_class:
|
||
print(f"[数据] {cls_name} 分类器数据已就绪: {already} 张")
|
||
continue
|
||
|
||
# 清空后重新链接
|
||
for f in cls_dir.glob("*.png"):
|
||
f.unlink()
|
||
|
||
selected = existing[:per_class]
|
||
for f in tqdm(selected, desc=f"准备 {cls_name}", leave=False):
|
||
dst = cls_dir / f.name
|
||
# 使用符号链接节省空间,失败则复制
|
||
try:
|
||
dst.symlink_to(f.resolve())
|
||
except OSError:
|
||
shutil.copy2(f, dst)
|
||
|
||
print(f"[数据] {cls_name} 分类器数据就绪: {len(selected)} 张")
|
||
|
||
|
||
def main():
|
||
cfg = TRAIN_CONFIG["classifier"]
|
||
img_h, img_w = IMAGE_SIZE["classifier"]
|
||
device = get_device()
|
||
|
||
# 设置随机种子
|
||
random.seed(RANDOM_SEED)
|
||
np.random.seed(RANDOM_SEED)
|
||
torch.manual_seed(RANDOM_SEED)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(RANDOM_SEED)
|
||
|
||
print("=" * 60)
|
||
print("训练调度分类器 (CaptchaClassifier)")
|
||
print(f" 类别: {CAPTCHA_TYPES}")
|
||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||
print("=" * 60)
|
||
|
||
# ---- 1. 准备数据 ----
|
||
_prepare_classifier_data()
|
||
|
||
# ---- 2. 构建数据集 ----
|
||
train_transform = build_train_transform(img_h, img_w)
|
||
val_transform = build_val_transform(img_h, img_w)
|
||
|
||
full_dataset = CaptchaDataset(
|
||
root_dir=CLASSIFIER_DIR,
|
||
class_names=CAPTCHA_TYPES,
|
||
transform=train_transform,
|
||
)
|
||
total = len(full_dataset)
|
||
val_size = int(total * cfg["val_split"])
|
||
train_size = total - val_size
|
||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||
|
||
# 验证集无增强
|
||
val_ds_clean = CaptchaDataset(
|
||
root_dir=CLASSIFIER_DIR,
|
||
class_names=CAPTCHA_TYPES,
|
||
transform=val_transform,
|
||
)
|
||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||
|
||
train_loader = DataLoader(
|
||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||
num_workers=0, pin_memory=True,
|
||
)
|
||
val_loader = DataLoader(
|
||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||
num_workers=0, pin_memory=True,
|
||
)
|
||
|
||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||
|
||
# ---- 3. 模型 / 优化器 / 调度器 ----
|
||
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES).to(device)
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||
criterion = nn.CrossEntropyLoss()
|
||
|
||
best_acc = 0.0
|
||
ckpt_path = CHECKPOINTS_DIR / "classifier.pth"
|
||
|
||
# ---- 4. 训练循环 ----
|
||
for epoch in range(1, cfg["epochs"] + 1):
|
||
model.train()
|
||
total_loss = 0.0
|
||
num_batches = 0
|
||
|
||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||
for images, labels in pbar:
|
||
images = images.to(device)
|
||
labels = labels.to(device)
|
||
|
||
logits = model(images)
|
||
loss = criterion(logits, labels)
|
||
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
num_batches += 1
|
||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||
|
||
scheduler.step()
|
||
avg_loss = total_loss / max(num_batches, 1)
|
||
|
||
# ---- 5. 验证 ----
|
||
model.eval()
|
||
correct = 0
|
||
total_val = 0
|
||
with torch.no_grad():
|
||
for images, labels in val_loader:
|
||
images = images.to(device)
|
||
labels = labels.to(device)
|
||
logits = model(images)
|
||
preds = logits.argmax(dim=1)
|
||
correct += (preds == labels).sum().item()
|
||
total_val += labels.size(0)
|
||
|
||
val_acc = correct / max(total_val, 1)
|
||
lr = scheduler.get_last_lr()[0]
|
||
|
||
print(
|
||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||
f"loss={avg_loss:.4f} "
|
||
f"acc={val_acc:.4f} "
|
||
f"lr={lr:.6f}"
|
||
)
|
||
|
||
# ---- 6. 保存最佳模型 ----
|
||
if val_acc > best_acc:
|
||
best_acc = val_acc
|
||
torch.save({
|
||
"model_state_dict": model.state_dict(),
|
||
"class_names": CAPTCHA_TYPES,
|
||
"best_acc": best_acc,
|
||
"epoch": epoch,
|
||
}, ckpt_path)
|
||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||
|
||
# ---- 7. 导出 ONNX ----
|
||
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
|
||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||
model.load_state_dict(ckpt["model_state_dict"])
|
||
model.eval()
|
||
onnx_path = ONNX_DIR / "classifier.onnx"
|
||
dummy = torch.randn(1, 1, img_h, img_w)
|
||
torch.onnx.export(
|
||
model.cpu(),
|
||
dummy,
|
||
str(onnx_path),
|
||
opset_version=ONNX_CONFIG["opset_version"],
|
||
input_names=["input"],
|
||
output_names=["output"],
|
||
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
|
||
if ONNX_CONFIG["dynamic_batch"]
|
||
else None,
|
||
)
|
||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||
|
||
return best_acc
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|