249 lines
7.9 KiB
Python
249 lines
7.9 KiB
Python
"""
|
||
训练 FunCaptcha `4_3d_rollball_animals` 专项 Siamese 模型。
|
||
|
||
数据格式:
|
||
data/real/funcaptcha/4_3d_rollball_animals/
|
||
0_xxx.png
|
||
1_xxx.jpg
|
||
2_xxx.jpeg
|
||
|
||
文件名前缀表示正确候选索引。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import random
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import DataLoader, random_split
|
||
from tqdm import tqdm
|
||
|
||
from config import (
|
||
CHECKPOINTS_DIR,
|
||
FUN_CAPTCHA_TASKS,
|
||
IMAGE_SIZE,
|
||
RANDOM_SEED,
|
||
TRAIN_CONFIG,
|
||
get_device,
|
||
)
|
||
from inference.export_onnx import _load_and_export
|
||
from models.fun_captcha_siamese import FunCaptchaSiamese
|
||
from training.dataset import (
|
||
FunCaptchaChallengeDataset,
|
||
build_train_rgb_transform,
|
||
build_val_rgb_transform,
|
||
)
|
||
|
||
|
||
QUESTION = "4_3d_rollball_animals"
|
||
|
||
|
||
def _set_seed():
|
||
random.seed(RANDOM_SEED)
|
||
np.random.seed(RANDOM_SEED)
|
||
torch.manual_seed(RANDOM_SEED)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(RANDOM_SEED)
|
||
|
||
|
||
def _flatten_pairs(
|
||
candidates: torch.Tensor,
|
||
reference: torch.Tensor,
|
||
answer_idx: torch.Tensor,
|
||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
batch_size, num_candidates, channels, img_h, img_w = candidates.shape
|
||
references = reference.unsqueeze(1).expand(-1, num_candidates, -1, -1, -1)
|
||
targets = F.one_hot(answer_idx, num_classes=num_candidates).float()
|
||
return (
|
||
candidates.reshape(batch_size * num_candidates, channels, img_h, img_w),
|
||
references.reshape(batch_size * num_candidates, channels, img_h, img_w),
|
||
targets.reshape(batch_size * num_candidates, 1),
|
||
)
|
||
|
||
|
||
def _evaluate(
|
||
model: FunCaptchaSiamese,
|
||
loader: DataLoader,
|
||
device: torch.device,
|
||
) -> tuple[float, float]:
|
||
model.eval()
|
||
challenge_correct = 0
|
||
challenge_total = 0
|
||
pair_correct = 0
|
||
pair_total = 0
|
||
|
||
with torch.no_grad():
|
||
for candidates, reference, answer_idx in loader:
|
||
candidates = candidates.to(device)
|
||
reference = reference.to(device)
|
||
answer_idx = answer_idx.to(device)
|
||
|
||
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
|
||
candidates, reference, answer_idx
|
||
)
|
||
logits = model(pair_candidates, pair_reference).view(candidates.size(0), candidates.size(1))
|
||
preds = logits.argmax(dim=1)
|
||
challenge_correct += (preds == answer_idx).sum().item()
|
||
challenge_total += answer_idx.size(0)
|
||
|
||
pair_probs = torch.sigmoid(logits)
|
||
pair_preds = (pair_probs >= 0.5).float()
|
||
target_matrix = pair_targets.view(candidates.size(0), candidates.size(1))
|
||
pair_correct += (pair_preds == target_matrix).sum().item()
|
||
pair_total += target_matrix.numel()
|
||
|
||
return (
|
||
challenge_correct / max(challenge_total, 1),
|
||
pair_correct / max(pair_total, 1),
|
||
)
|
||
|
||
|
||
def main(question: str = QUESTION):
|
||
task_cfg = FUN_CAPTCHA_TASKS[question]
|
||
cfg = TRAIN_CONFIG["funcaptcha_rollball_animals"]
|
||
img_h, img_w = IMAGE_SIZE["funcaptcha_rollball_animals"]
|
||
device = get_device()
|
||
data_dir = Path(task_cfg["data_dir"])
|
||
ckpt_name = task_cfg["checkpoint_name"]
|
||
ckpt_path = CHECKPOINTS_DIR / f"{ckpt_name}.pth"
|
||
|
||
_set_seed()
|
||
|
||
print("=" * 60)
|
||
print(f"训练 FunCaptcha 专项模型 ({question})")
|
||
print(f" 数据目录: {data_dir}")
|
||
print(f" 候选数: {task_cfg['num_candidates']}")
|
||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||
print("=" * 60)
|
||
|
||
train_transform = build_train_rgb_transform(img_h, img_w)
|
||
val_transform = build_val_rgb_transform(img_h, img_w)
|
||
|
||
full_dataset = FunCaptchaChallengeDataset(
|
||
dirs=[data_dir],
|
||
task_config=task_cfg,
|
||
transform=train_transform,
|
||
)
|
||
total = len(full_dataset)
|
||
if total == 0:
|
||
raise FileNotFoundError(
|
||
f"未找到任何 FunCaptcha 训练样本,请先准备数据: {data_dir}"
|
||
)
|
||
|
||
val_size = max(1, int(total * cfg["val_split"]))
|
||
train_size = total - val_size
|
||
if train_size <= 0:
|
||
raise ValueError(f"FunCaptcha 数据量过少,至少需要 2 张样本: {data_dir}")
|
||
|
||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||
val_ds_clean = FunCaptchaChallengeDataset(
|
||
dirs=[data_dir],
|
||
task_config=task_cfg,
|
||
transform=val_transform,
|
||
)
|
||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||
|
||
train_loader = DataLoader(
|
||
train_ds,
|
||
batch_size=cfg["batch_size"],
|
||
shuffle=True,
|
||
num_workers=0,
|
||
pin_memory=True,
|
||
)
|
||
val_loader = DataLoader(
|
||
val_ds_clean,
|
||
batch_size=cfg["batch_size"],
|
||
shuffle=False,
|
||
num_workers=0,
|
||
pin_memory=True,
|
||
)
|
||
|
||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||
|
||
model = FunCaptchaSiamese(in_channels=task_cfg["channels"]).to(device)
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||
pos_weight = torch.tensor([task_cfg["num_candidates"] - 1], dtype=torch.float32, device=device)
|
||
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||
|
||
best_acc = 0.0
|
||
start_epoch = 1
|
||
|
||
if ckpt_path.exists():
|
||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||
model.load_state_dict(ckpt["model_state_dict"])
|
||
best_acc = float(ckpt.get("best_acc", 0.0))
|
||
start_epoch = int(ckpt.get("epoch", 0)) + 1
|
||
for _ in range(start_epoch - 1):
|
||
scheduler.step()
|
||
print(f"[续训] 从 epoch {start_epoch} 继续, best_acc={best_acc:.4f}")
|
||
|
||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||
model.train()
|
||
total_loss = 0.0
|
||
num_batches = 0
|
||
|
||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||
for candidates, reference, answer_idx in pbar:
|
||
candidates = candidates.to(device)
|
||
reference = reference.to(device)
|
||
answer_idx = answer_idx.to(device)
|
||
|
||
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
|
||
candidates, reference, answer_idx
|
||
)
|
||
logits = model(pair_candidates, pair_reference)
|
||
loss = criterion(logits, pair_targets)
|
||
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||
optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
num_batches += 1
|
||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||
|
||
scheduler.step()
|
||
avg_loss = total_loss / max(num_batches, 1)
|
||
challenge_acc, pair_acc = _evaluate(model, val_loader, device)
|
||
lr = scheduler.get_last_lr()[0]
|
||
|
||
print(
|
||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||
f"loss={avg_loss:.4f} "
|
||
f"acc={challenge_acc:.4f} "
|
||
f"pair_acc={pair_acc:.4f} "
|
||
f"lr={lr:.6f}"
|
||
)
|
||
|
||
if challenge_acc >= best_acc:
|
||
best_acc = challenge_acc
|
||
torch.save(
|
||
{
|
||
"model_state_dict": model.state_dict(),
|
||
"best_acc": best_acc,
|
||
"epoch": epoch,
|
||
"question": question,
|
||
"num_candidates": task_cfg["num_candidates"],
|
||
"tile_size": list(task_cfg["tile_size"]),
|
||
"reference_box": list(task_cfg["reference_box"]),
|
||
"answer_index_base": task_cfg["answer_index_base"],
|
||
"input_shape": [task_cfg["channels"], img_h, img_w],
|
||
},
|
||
ckpt_path,
|
||
)
|
||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||
|
||
print(f"\n[训练完成] 最佳 challenge acc: {best_acc:.4f}")
|
||
_load_and_export(task_cfg["artifact_name"])
|
||
return best_acc
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|