Files
CaptchBreaker/training/train_funcaptcha_rollball.py

249 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
训练 FunCaptcha `4_3d_rollball_animals` 专项 Siamese 模型。
数据格式:
data/real/funcaptcha/4_3d_rollball_animals/
0_xxx.png
1_xxx.jpg
2_xxx.jpeg
文件名前缀表示正确候选索引。
"""
from __future__ import annotations
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
FUN_CAPTCHA_TASKS,
IMAGE_SIZE,
RANDOM_SEED,
TRAIN_CONFIG,
get_device,
)
from inference.export_onnx import _load_and_export
from models.fun_captcha_siamese import FunCaptchaSiamese
from training.dataset import (
FunCaptchaChallengeDataset,
build_train_rgb_transform,
build_val_rgb_transform,
)
QUESTION = "4_3d_rollball_animals"
def _set_seed():
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
def _flatten_pairs(
candidates: torch.Tensor,
reference: torch.Tensor,
answer_idx: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, num_candidates, channels, img_h, img_w = candidates.shape
references = reference.unsqueeze(1).expand(-1, num_candidates, -1, -1, -1)
targets = F.one_hot(answer_idx, num_classes=num_candidates).float()
return (
candidates.reshape(batch_size * num_candidates, channels, img_h, img_w),
references.reshape(batch_size * num_candidates, channels, img_h, img_w),
targets.reshape(batch_size * num_candidates, 1),
)
def _evaluate(
model: FunCaptchaSiamese,
loader: DataLoader,
device: torch.device,
) -> tuple[float, float]:
model.eval()
challenge_correct = 0
challenge_total = 0
pair_correct = 0
pair_total = 0
with torch.no_grad():
for candidates, reference, answer_idx in loader:
candidates = candidates.to(device)
reference = reference.to(device)
answer_idx = answer_idx.to(device)
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
candidates, reference, answer_idx
)
logits = model(pair_candidates, pair_reference).view(candidates.size(0), candidates.size(1))
preds = logits.argmax(dim=1)
challenge_correct += (preds == answer_idx).sum().item()
challenge_total += answer_idx.size(0)
pair_probs = torch.sigmoid(logits)
pair_preds = (pair_probs >= 0.5).float()
target_matrix = pair_targets.view(candidates.size(0), candidates.size(1))
pair_correct += (pair_preds == target_matrix).sum().item()
pair_total += target_matrix.numel()
return (
challenge_correct / max(challenge_total, 1),
pair_correct / max(pair_total, 1),
)
def main(question: str = QUESTION):
task_cfg = FUN_CAPTCHA_TASKS[question]
cfg = TRAIN_CONFIG["funcaptcha_rollball_animals"]
img_h, img_w = IMAGE_SIZE["funcaptcha_rollball_animals"]
device = get_device()
data_dir = Path(task_cfg["data_dir"])
ckpt_name = task_cfg["checkpoint_name"]
ckpt_path = CHECKPOINTS_DIR / f"{ckpt_name}.pth"
_set_seed()
print("=" * 60)
print(f"训练 FunCaptcha 专项模型 ({question})")
print(f" 数据目录: {data_dir}")
print(f" 候选数: {task_cfg['num_candidates']}")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_transform = build_train_rgb_transform(img_h, img_w)
val_transform = build_val_rgb_transform(img_h, img_w)
full_dataset = FunCaptchaChallengeDataset(
dirs=[data_dir],
task_config=task_cfg,
transform=train_transform,
)
total = len(full_dataset)
if total == 0:
raise FileNotFoundError(
f"未找到任何 FunCaptcha 训练样本,请先准备数据: {data_dir}"
)
val_size = max(1, int(total * cfg["val_split"]))
train_size = total - val_size
if train_size <= 0:
raise ValueError(f"FunCaptcha 数据量过少,至少需要 2 张样本: {data_dir}")
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
val_ds_clean = FunCaptchaChallengeDataset(
dirs=[data_dir],
task_config=task_cfg,
transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds,
batch_size=cfg["batch_size"],
shuffle=True,
num_workers=0,
pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean,
batch_size=cfg["batch_size"],
shuffle=False,
num_workers=0,
pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
model = FunCaptchaSiamese(in_channels=task_cfg["channels"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
pos_weight = torch.tensor([task_cfg["num_candidates"] - 1], dtype=torch.float32, device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
best_acc = 0.0
start_epoch = 1
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_acc = float(ckpt.get("best_acc", 0.0))
start_epoch = int(ckpt.get("epoch", 0)) + 1
for _ in range(start_epoch - 1):
scheduler.step()
print(f"[续训] 从 epoch {start_epoch} 继续, best_acc={best_acc:.4f}")
for epoch in range(start_epoch, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for candidates, reference, answer_idx in pbar:
candidates = candidates.to(device)
reference = reference.to(device)
answer_idx = answer_idx.to(device)
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
candidates, reference, answer_idx
)
logits = model(pair_candidates, pair_reference)
loss = criterion(logits, pair_targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
challenge_acc, pair_acc = _evaluate(model, val_loader, device)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={challenge_acc:.4f} "
f"pair_acc={pair_acc:.4f} "
f"lr={lr:.6f}"
)
if challenge_acc >= best_acc:
best_acc = challenge_acc
torch.save(
{
"model_state_dict": model.state_dict(),
"best_acc": best_acc,
"epoch": epoch,
"question": question,
"num_candidates": task_cfg["num_candidates"],
"tile_size": list(task_cfg["tile_size"]),
"reference_box": list(task_cfg["reference_box"]),
"answer_index_base": task_cfg["answer_index_base"],
"input_shape": [task_cfg["channels"], img_h, img_w],
},
ckpt_path,
)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
print(f"\n[训练完成] 最佳 challenge acc: {best_acc:.4f}")
_load_and_export(task_cfg["artifact_name"])
return best_acc
if __name__ == "__main__":
main()