Align task API and add FunCaptcha support
This commit is contained in:
248
training/train_funcaptcha_rollball.py
Normal file
248
training/train_funcaptcha_rollball.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
训练 FunCaptcha `4_3d_rollball_animals` 专项 Siamese 模型。
|
||||
|
||||
数据格式:
|
||||
data/real/funcaptcha/4_3d_rollball_animals/
|
||||
0_xxx.png
|
||||
1_xxx.jpg
|
||||
2_xxx.jpeg
|
||||
|
||||
文件名前缀表示正确候选索引。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
CHECKPOINTS_DIR,
|
||||
FUN_CAPTCHA_TASKS,
|
||||
IMAGE_SIZE,
|
||||
RANDOM_SEED,
|
||||
TRAIN_CONFIG,
|
||||
get_device,
|
||||
)
|
||||
from inference.export_onnx import _load_and_export
|
||||
from models.fun_captcha_siamese import FunCaptchaSiamese
|
||||
from training.dataset import (
|
||||
FunCaptchaChallengeDataset,
|
||||
build_train_rgb_transform,
|
||||
build_val_rgb_transform,
|
||||
)
|
||||
|
||||
|
||||
QUESTION = "4_3d_rollball_animals"
|
||||
|
||||
|
||||
def _set_seed():
|
||||
random.seed(RANDOM_SEED)
|
||||
np.random.seed(RANDOM_SEED)
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(RANDOM_SEED)
|
||||
|
||||
|
||||
def _flatten_pairs(
|
||||
candidates: torch.Tensor,
|
||||
reference: torch.Tensor,
|
||||
answer_idx: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_candidates, channels, img_h, img_w = candidates.shape
|
||||
references = reference.unsqueeze(1).expand(-1, num_candidates, -1, -1, -1)
|
||||
targets = F.one_hot(answer_idx, num_classes=num_candidates).float()
|
||||
return (
|
||||
candidates.reshape(batch_size * num_candidates, channels, img_h, img_w),
|
||||
references.reshape(batch_size * num_candidates, channels, img_h, img_w),
|
||||
targets.reshape(batch_size * num_candidates, 1),
|
||||
)
|
||||
|
||||
|
||||
def _evaluate(
|
||||
model: FunCaptchaSiamese,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
) -> tuple[float, float]:
|
||||
model.eval()
|
||||
challenge_correct = 0
|
||||
challenge_total = 0
|
||||
pair_correct = 0
|
||||
pair_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for candidates, reference, answer_idx in loader:
|
||||
candidates = candidates.to(device)
|
||||
reference = reference.to(device)
|
||||
answer_idx = answer_idx.to(device)
|
||||
|
||||
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
|
||||
candidates, reference, answer_idx
|
||||
)
|
||||
logits = model(pair_candidates, pair_reference).view(candidates.size(0), candidates.size(1))
|
||||
preds = logits.argmax(dim=1)
|
||||
challenge_correct += (preds == answer_idx).sum().item()
|
||||
challenge_total += answer_idx.size(0)
|
||||
|
||||
pair_probs = torch.sigmoid(logits)
|
||||
pair_preds = (pair_probs >= 0.5).float()
|
||||
target_matrix = pair_targets.view(candidates.size(0), candidates.size(1))
|
||||
pair_correct += (pair_preds == target_matrix).sum().item()
|
||||
pair_total += target_matrix.numel()
|
||||
|
||||
return (
|
||||
challenge_correct / max(challenge_total, 1),
|
||||
pair_correct / max(pair_total, 1),
|
||||
)
|
||||
|
||||
|
||||
def main(question: str = QUESTION):
|
||||
task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||
cfg = TRAIN_CONFIG["funcaptcha_rollball_animals"]
|
||||
img_h, img_w = IMAGE_SIZE["funcaptcha_rollball_animals"]
|
||||
device = get_device()
|
||||
data_dir = Path(task_cfg["data_dir"])
|
||||
ckpt_name = task_cfg["checkpoint_name"]
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{ckpt_name}.pth"
|
||||
|
||||
_set_seed()
|
||||
|
||||
print("=" * 60)
|
||||
print(f"训练 FunCaptcha 专项模型 ({question})")
|
||||
print(f" 数据目录: {data_dir}")
|
||||
print(f" 候选数: {task_cfg['num_candidates']}")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_transform = build_train_rgb_transform(img_h, img_w)
|
||||
val_transform = build_val_rgb_transform(img_h, img_w)
|
||||
|
||||
full_dataset = FunCaptchaChallengeDataset(
|
||||
dirs=[data_dir],
|
||||
task_config=task_cfg,
|
||||
transform=train_transform,
|
||||
)
|
||||
total = len(full_dataset)
|
||||
if total == 0:
|
||||
raise FileNotFoundError(
|
||||
f"未找到任何 FunCaptcha 训练样本,请先准备数据: {data_dir}"
|
||||
)
|
||||
|
||||
val_size = max(1, int(total * cfg["val_split"]))
|
||||
train_size = total - val_size
|
||||
if train_size <= 0:
|
||||
raise ValueError(f"FunCaptcha 数据量过少,至少需要 2 张样本: {data_dir}")
|
||||
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
val_ds_clean = FunCaptchaChallengeDataset(
|
||||
dirs=[data_dir],
|
||||
task_config=task_cfg,
|
||||
transform=val_transform,
|
||||
)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=cfg["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean,
|
||||
batch_size=cfg["batch_size"],
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
model = FunCaptchaSiamese(in_channels=task_cfg["channels"]).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
pos_weight = torch.tensor([task_cfg["num_candidates"] - 1], dtype=torch.float32, device=device)
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
|
||||
best_acc = 0.0
|
||||
start_epoch = 1
|
||||
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = float(ckpt.get("best_acc", 0.0))
|
||||
start_epoch = int(ckpt.get("epoch", 0)) + 1
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(f"[续训] 从 epoch {start_epoch} 继续, best_acc={best_acc:.4f}")
|
||||
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for candidates, reference, answer_idx in pbar:
|
||||
candidates = candidates.to(device)
|
||||
reference = reference.to(device)
|
||||
answer_idx = answer_idx.to(device)
|
||||
|
||||
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
|
||||
candidates, reference, answer_idx
|
||||
)
|
||||
logits = model(pair_candidates, pair_reference)
|
||||
loss = criterion(logits, pair_targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
challenge_acc, pair_acc = _evaluate(model, val_loader, device)
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"acc={challenge_acc:.4f} "
|
||||
f"pair_acc={pair_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
if challenge_acc >= best_acc:
|
||||
best_acc = challenge_acc
|
||||
torch.save(
|
||||
{
|
||||
"model_state_dict": model.state_dict(),
|
||||
"best_acc": best_acc,
|
||||
"epoch": epoch,
|
||||
"question": question,
|
||||
"num_candidates": task_cfg["num_candidates"],
|
||||
"tile_size": list(task_cfg["tile_size"]),
|
||||
"reference_box": list(task_cfg["reference_box"]),
|
||||
"answer_index_base": task_cfg["answer_index_base"],
|
||||
"input_shape": [task_cfg["channels"], img_h, img_w],
|
||||
},
|
||||
ckpt_path,
|
||||
)
|
||||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||||
|
||||
print(f"\n[训练完成] 最佳 challenge acc: {best_acc:.4f}")
|
||||
_load_and_export(task_cfg["artifact_name"])
|
||||
return best_acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user