Files
CaptchBreaker/training/train_classifier.py
Hua f5be7671bc Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider
Split the single "3d" captcha type into three independent expert models:
- 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN)
- 3d_rotate: rotation angle regression (new RegressionCNN, circular loss)
- 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss)

CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated
to 50000 (10000 per class). New generators, model, dataset, training
utilities, and full pipeline/export/CLI support for all subtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 13:55:53 +08:00

249 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
训练调度分类器 (CaptchaClassifier)
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider。
数据来源: data/classifier/ 目录 (按类型子目录组织)
用法: python -m training.train_classifier
"""
import os
import random
import shutil
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CAPTCHA_TYPES,
NUM_CAPTCHA_TYPES,
IMAGE_SIZE,
TRAIN_CONFIG,
CLASSIFIER_DIR,
SYNTHETIC_NORMAL_DIR,
SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_TEXT_DIR,
SYNTHETIC_3D_ROTATE_DIR,
SYNTHETIC_3D_SLIDER_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
RANDOM_SEED,
get_device,
)
from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
from generators.threed_rotate_gen import ThreeDRotateGenerator
from generators.threed_slider_gen import ThreeDSliderGenerator
from models.classifier import CaptchaClassifier
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
def _prepare_classifier_data():
"""
准备分类器训练数据。
策略:从各类型的合成数据目录中软链接 / 复制到 data/classifier/{type}/ 下,
每类取相同数量,保证类别平衡。
如果各类型合成数据不存在,先自动生成。
"""
cfg = TRAIN_CONFIG["classifier"]
per_class = cfg["synthetic_samples"] // NUM_CAPTCHA_TYPES
# 各类型: (类名, 合成目录, 生成器类)
type_info = [
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator),
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
]
for cls_name, syn_dir, gen_cls in type_info:
syn_dir = Path(syn_dir)
existing = sorted(syn_dir.glob("*.png"))
# 如果合成数据不够,生成一些
if len(existing) < per_class:
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
gen = gen_cls()
gen.generate_dataset(per_class, str(syn_dir))
existing = sorted(syn_dir.glob("*.png"))
# 复制到 classifier 目录
cls_dir = CLASSIFIER_DIR / cls_name
cls_dir.mkdir(parents=True, exist_ok=True)
already = len(list(cls_dir.glob("*.png")))
if already >= per_class:
print(f"[数据] {cls_name} 分类器数据已就绪: {already}")
continue
# 清空后重新链接
for f in cls_dir.glob("*.png"):
f.unlink()
selected = existing[:per_class]
for f in tqdm(selected, desc=f"准备 {cls_name}", leave=False):
dst = cls_dir / f.name
# 使用符号链接节省空间,失败则复制
try:
dst.symlink_to(f.resolve())
except OSError:
shutil.copy2(f, dst)
print(f"[数据] {cls_name} 分类器数据就绪: {len(selected)}")
def main():
cfg = TRAIN_CONFIG["classifier"]
img_h, img_w = IMAGE_SIZE["classifier"]
device = get_device()
# 设置随机种子
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
print("=" * 60)
print("训练调度分类器 (CaptchaClassifier)")
print(f" 类别: {CAPTCHA_TYPES}")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
# ---- 1. 准备数据 ----
_prepare_classifier_data()
# ---- 2. 构建数据集 ----
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = CaptchaDataset(
root_dir=CLASSIFIER_DIR,
class_names=CAPTCHA_TYPES,
transform=train_transform,
)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集无增强
val_ds_clean = CaptchaDataset(
root_dir=CLASSIFIER_DIR,
class_names=CAPTCHA_TYPES,
transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 模型 / 优化器 / 调度器 ----
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
criterion = nn.CrossEntropyLoss()
best_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / "classifier.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
correct = 0
total_val = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total_val += labels.size(0)
val_acc = correct / max(total_val, 1)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={val_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if val_acc > best_acc:
best_acc = val_acc
torch.save({
"model_state_dict": model.state_dict(),
"class_names": CAPTCHA_TYPES,
"best_acc": best_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
onnx_path = ONNX_DIR / "classifier.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
return best_acc
if __name__ == "__main__":
main()