""" 训练 FunCaptcha `4_3d_rollball_animals` 专项 Siamese 模型。 数据格式: data/real/funcaptcha/4_3d_rollball_animals/ 0_xxx.png 1_xxx.jpg 2_xxx.jpeg 文件名前缀表示正确候选索引。 """ from __future__ import annotations import random from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, random_split from tqdm import tqdm from config import ( CHECKPOINTS_DIR, FUN_CAPTCHA_TASKS, IMAGE_SIZE, RANDOM_SEED, TRAIN_CONFIG, get_device, ) from inference.export_onnx import _load_and_export from models.fun_captcha_siamese import FunCaptchaSiamese from training.dataset import ( FunCaptchaChallengeDataset, build_train_rgb_transform, build_val_rgb_transform, ) QUESTION = "4_3d_rollball_animals" def _set_seed(): random.seed(RANDOM_SEED) np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(RANDOM_SEED) def _flatten_pairs( candidates: torch.Tensor, reference: torch.Tensor, answer_idx: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, num_candidates, channels, img_h, img_w = candidates.shape references = reference.unsqueeze(1).expand(-1, num_candidates, -1, -1, -1) targets = F.one_hot(answer_idx, num_classes=num_candidates).float() return ( candidates.reshape(batch_size * num_candidates, channels, img_h, img_w), references.reshape(batch_size * num_candidates, channels, img_h, img_w), targets.reshape(batch_size * num_candidates, 1), ) def _evaluate( model: FunCaptchaSiamese, loader: DataLoader, device: torch.device, ) -> tuple[float, float]: model.eval() challenge_correct = 0 challenge_total = 0 pair_correct = 0 pair_total = 0 with torch.no_grad(): for candidates, reference, answer_idx in loader: candidates = candidates.to(device) reference = reference.to(device) answer_idx = answer_idx.to(device) pair_candidates, pair_reference, pair_targets = _flatten_pairs( candidates, reference, answer_idx ) logits = model(pair_candidates, pair_reference).view(candidates.size(0), candidates.size(1)) preds = logits.argmax(dim=1) challenge_correct += (preds == answer_idx).sum().item() challenge_total += answer_idx.size(0) pair_probs = torch.sigmoid(logits) pair_preds = (pair_probs >= 0.5).float() target_matrix = pair_targets.view(candidates.size(0), candidates.size(1)) pair_correct += (pair_preds == target_matrix).sum().item() pair_total += target_matrix.numel() return ( challenge_correct / max(challenge_total, 1), pair_correct / max(pair_total, 1), ) def main(question: str = QUESTION): task_cfg = FUN_CAPTCHA_TASKS[question] cfg = TRAIN_CONFIG["funcaptcha_rollball_animals"] img_h, img_w = IMAGE_SIZE["funcaptcha_rollball_animals"] device = get_device() data_dir = Path(task_cfg["data_dir"]) ckpt_name = task_cfg["checkpoint_name"] ckpt_path = CHECKPOINTS_DIR / f"{ckpt_name}.pth" _set_seed() print("=" * 60) print(f"训练 FunCaptcha 专项模型 ({question})") print(f" 数据目录: {data_dir}") print(f" 候选数: {task_cfg['num_candidates']}") print(f" 输入尺寸: {img_h}×{img_w}") print("=" * 60) train_transform = build_train_rgb_transform(img_h, img_w) val_transform = build_val_rgb_transform(img_h, img_w) full_dataset = FunCaptchaChallengeDataset( dirs=[data_dir], task_config=task_cfg, transform=train_transform, ) total = len(full_dataset) if total == 0: raise FileNotFoundError( f"未找到任何 FunCaptcha 训练样本,请先准备数据: {data_dir}" ) val_size = max(1, int(total * cfg["val_split"])) train_size = total - val_size if train_size <= 0: raise ValueError(f"FunCaptcha 数据量过少,至少需要 2 张样本: {data_dir}") train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) val_ds_clean = FunCaptchaChallengeDataset( dirs=[data_dir], task_config=task_cfg, transform=val_transform, ) val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices] train_loader = DataLoader( train_ds, batch_size=cfg["batch_size"], shuffle=True, num_workers=0, pin_memory=True, ) val_loader = DataLoader( val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, num_workers=0, pin_memory=True, ) print(f"[数据] 训练: {train_size} 验证: {val_size}") model = FunCaptchaSiamese(in_channels=task_cfg["channels"]).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"]) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"]) pos_weight = torch.tensor([task_cfg["num_candidates"] - 1], dtype=torch.float32, device=device) criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) best_acc = 0.0 start_epoch = 1 if ckpt_path.exists(): ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) model.load_state_dict(ckpt["model_state_dict"]) best_acc = float(ckpt.get("best_acc", 0.0)) start_epoch = int(ckpt.get("epoch", 0)) + 1 for _ in range(start_epoch - 1): scheduler.step() print(f"[续训] 从 epoch {start_epoch} 继续, best_acc={best_acc:.4f}") for epoch in range(start_epoch, cfg["epochs"] + 1): model.train() total_loss = 0.0 num_batches = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) for candidates, reference, answer_idx in pbar: candidates = candidates.to(device) reference = reference.to(device) answer_idx = answer_idx.to(device) pair_candidates, pair_reference, pair_targets = _flatten_pairs( candidates, reference, answer_idx ) logits = model(pair_candidates, pair_reference) loss = criterion(logits, pair_targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() total_loss += loss.item() num_batches += 1 pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() avg_loss = total_loss / max(num_batches, 1) challenge_acc, pair_acc = _evaluate(model, val_loader, device) lr = scheduler.get_last_lr()[0] print( f"Epoch {epoch:3d}/{cfg['epochs']} " f"loss={avg_loss:.4f} " f"acc={challenge_acc:.4f} " f"pair_acc={pair_acc:.4f} " f"lr={lr:.6f}" ) if challenge_acc >= best_acc: best_acc = challenge_acc torch.save( { "model_state_dict": model.state_dict(), "best_acc": best_acc, "epoch": epoch, "question": question, "num_candidates": task_cfg["num_candidates"], "tile_size": list(task_cfg["tile_size"]), "reference_box": list(task_cfg["reference_box"]), "answer_index_base": task_cfg["answer_index_base"], "input_shape": [task_cfg["channels"], img_h, img_w], }, ckpt_path, ) print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}") print(f"\n[训练完成] 最佳 challenge acc: {best_acc:.4f}") _load_and_export(task_cfg["artifact_name"]) return best_acc if __name__ == "__main__": main()