""" 训练调度分类器 (CaptchaClassifier) 从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider。 数据来源: data/classifier/ 目录 (按类型子目录组织) 用法: python -m training.train_classifier """ import os import random import shutil from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from tqdm import tqdm from config import ( CAPTCHA_TYPES, NUM_CAPTCHA_TYPES, IMAGE_SIZE, TRAIN_CONFIG, GENERATE_CONFIG, NORMAL_CHARS, MATH_CHARS, THREED_CHARS, REGRESSION_RANGE, CLASSIFIER_DIR, SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR, CHECKPOINTS_DIR, ONNX_DIR, ONNX_CONFIG, RANDOM_SEED, get_device, ) from generators.normal_gen import NormalCaptchaGenerator from generators.math_gen import MathCaptchaGenerator from generators.threed_gen import ThreeDCaptchaGenerator from generators.threed_rotate_gen import ThreeDRotateGenerator from generators.threed_slider_gen import ThreeDSliderGenerator from inference.model_metadata import write_model_metadata from models.classifier import CaptchaClassifier from training.data_fingerprint import ( build_dataset_spec, ensure_synthetic_dataset, labels_cover_tokens, ) from training.dataset import CaptchaDataset, build_train_transform, build_val_transform def _prepare_classifier_data(): """ 准备分类器训练数据。 策略:从各类型的合成数据目录中软链接 / 复制到 data/classifier/{type}/ 下, 每类取相同数量,保证类别平衡。 如果各类型合成数据不存在,先自动生成。 """ cfg = TRAIN_CONFIG["classifier"] per_class = cfg["synthetic_samples"] // NUM_CAPTCHA_TYPES # 各类型: (类名, 合成目录, 生成器类) type_info = [ ("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator), ("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator), ("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator), ("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator), ("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator), ] chars_map = { "normal": NORMAL_CHARS, "math": MATH_CHARS, "3d_text": THREED_CHARS, } for cls_name, syn_dir, gen_cls in type_info: syn_dir = Path(syn_dir) config_snapshot = { "generate_config": GENERATE_CONFIG[cls_name], "image_size": IMAGE_SIZE[cls_name], } if cls_name in chars_map: config_snapshot["chars"] = chars_map[cls_name] if cls_name in REGRESSION_RANGE: config_snapshot["label_range"] = REGRESSION_RANGE[cls_name] validator = None if cls_name == "math": required_ops = tuple(GENERATE_CONFIG["math"]["operators"]) validator = lambda files, tokens=required_ops: labels_cover_tokens(files, tokens) dataset_state = ensure_synthetic_dataset( syn_dir, generator_cls=gen_cls, spec=build_dataset_spec( gen_cls, config_key=cls_name, config_snapshot=config_snapshot, ), gen_count=per_class, min_count=per_class, validator=validator, adopt_if_missing=cls_name in {"normal", "math"}, ) if dataset_state["refreshed"]: print(f"[数据] {cls_name} 合成数据已刷新: {dataset_state['sample_count']} 张") elif dataset_state["adopted"]: print(f"[数据] {cls_name} 合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张") else: print(f"[数据] {cls_name} 合成数据已就绪: {dataset_state['sample_count']} 张") existing = sorted(syn_dir.glob("*.png")) # 复制到 classifier 目录 cls_dir = CLASSIFIER_DIR / cls_name cls_dir.mkdir(parents=True, exist_ok=True) # classifier 数据是派生目录,每次重建以对齐当前源数据与指纹状态。 for f in cls_dir.glob("*.png"): f.unlink() selected = existing[:per_class] for f in tqdm(selected, desc=f"准备 {cls_name}", leave=False): dst = cls_dir / f.name # 使用符号链接节省空间,失败则复制 try: dst.symlink_to(f.resolve()) except OSError: shutil.copy2(f, dst) print(f"[数据] {cls_name} 分类器数据就绪: {len(selected)} 张") def main(): cfg = TRAIN_CONFIG["classifier"] img_h, img_w = IMAGE_SIZE["classifier"] device = get_device() # 设置随机种子 random.seed(RANDOM_SEED) np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(RANDOM_SEED) print("=" * 60) print("训练调度分类器 (CaptchaClassifier)") print(f" 类别: {CAPTCHA_TYPES}") print(f" 输入尺寸: {img_h}×{img_w}") print("=" * 60) # ---- 1. 准备数据 ---- _prepare_classifier_data() # ---- 2. 构建数据集 ---- train_transform = build_train_transform(img_h, img_w) val_transform = build_val_transform(img_h, img_w) full_dataset = CaptchaDataset( root_dir=CLASSIFIER_DIR, class_names=CAPTCHA_TYPES, transform=train_transform, ) total = len(full_dataset) val_size = int(total * cfg["val_split"]) train_size = total - val_size train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) # 验证集无增强 val_ds_clean = CaptchaDataset( root_dir=CLASSIFIER_DIR, class_names=CAPTCHA_TYPES, transform=val_transform, ) val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices] train_loader = DataLoader( train_ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, pin_memory=True, ) val_loader = DataLoader( val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, num_workers=0, pin_memory=True, ) print(f"[数据] 训练: {train_size} 验证: {val_size}") # ---- 3. 模型 / 优化器 / 调度器 ---- model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"]) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"]) criterion = nn.CrossEntropyLoss() best_acc = 0.0 ckpt_path = CHECKPOINTS_DIR / "classifier.pth" # ---- 4. 训练循环 ---- for epoch in range(1, cfg["epochs"] + 1): model.train() total_loss = 0.0 num_batches = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) for images, labels in pbar: images = images.to(device) labels = labels.to(device) logits = model(images) loss = criterion(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() num_batches += 1 pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() avg_loss = total_loss / max(num_batches, 1) # ---- 5. 验证 ---- model.eval() correct = 0 total_val = 0 with torch.no_grad(): for images, labels in val_loader: images = images.to(device) labels = labels.to(device) logits = model(images) preds = logits.argmax(dim=1) correct += (preds == labels).sum().item() total_val += labels.size(0) val_acc = correct / max(total_val, 1) lr = scheduler.get_last_lr()[0] print( f"Epoch {epoch:3d}/{cfg['epochs']} " f"loss={avg_loss:.4f} " f"acc={val_acc:.4f} " f"lr={lr:.6f}" ) # ---- 6. 保存最佳模型 ---- if val_acc > best_acc: best_acc = val_acc torch.save({ "model_state_dict": model.state_dict(), "class_names": CAPTCHA_TYPES, "best_acc": best_acc, "epoch": epoch, }, ckpt_path) print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}") # ---- 7. 导出 ONNX ---- print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state_dict"]) model.eval() onnx_path = ONNX_DIR / "classifier.onnx" dummy = torch.randn(1, 1, img_h, img_w) torch.onnx.export( model.cpu(), dummy, str(onnx_path), opset_version=ONNX_CONFIG["opset_version"], input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}} if ONNX_CONFIG["dynamic_batch"] else None, ) write_model_metadata( onnx_path, { "model_name": "classifier", "task": "classifier", "class_names": list(CAPTCHA_TYPES), "input_shape": [1, img_h, img_w], }, ) print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)") return best_acc if __name__ == "__main__": main()