Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -0,0 +1,248 @@
"""
训练 FunCaptcha `4_3d_rollball_animals` 专项 Siamese 模型。
数据格式:
data/real/funcaptcha/4_3d_rollball_animals/
0_xxx.png
1_xxx.jpg
2_xxx.jpeg
文件名前缀表示正确候选索引。
"""
from __future__ import annotations
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
FUN_CAPTCHA_TASKS,
IMAGE_SIZE,
RANDOM_SEED,
TRAIN_CONFIG,
get_device,
)
from inference.export_onnx import _load_and_export
from models.fun_captcha_siamese import FunCaptchaSiamese
from training.dataset import (
FunCaptchaChallengeDataset,
build_train_rgb_transform,
build_val_rgb_transform,
)
QUESTION = "4_3d_rollball_animals"
def _set_seed():
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
def _flatten_pairs(
candidates: torch.Tensor,
reference: torch.Tensor,
answer_idx: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, num_candidates, channels, img_h, img_w = candidates.shape
references = reference.unsqueeze(1).expand(-1, num_candidates, -1, -1, -1)
targets = F.one_hot(answer_idx, num_classes=num_candidates).float()
return (
candidates.reshape(batch_size * num_candidates, channels, img_h, img_w),
references.reshape(batch_size * num_candidates, channels, img_h, img_w),
targets.reshape(batch_size * num_candidates, 1),
)
def _evaluate(
model: FunCaptchaSiamese,
loader: DataLoader,
device: torch.device,
) -> tuple[float, float]:
model.eval()
challenge_correct = 0
challenge_total = 0
pair_correct = 0
pair_total = 0
with torch.no_grad():
for candidates, reference, answer_idx in loader:
candidates = candidates.to(device)
reference = reference.to(device)
answer_idx = answer_idx.to(device)
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
candidates, reference, answer_idx
)
logits = model(pair_candidates, pair_reference).view(candidates.size(0), candidates.size(1))
preds = logits.argmax(dim=1)
challenge_correct += (preds == answer_idx).sum().item()
challenge_total += answer_idx.size(0)
pair_probs = torch.sigmoid(logits)
pair_preds = (pair_probs >= 0.5).float()
target_matrix = pair_targets.view(candidates.size(0), candidates.size(1))
pair_correct += (pair_preds == target_matrix).sum().item()
pair_total += target_matrix.numel()
return (
challenge_correct / max(challenge_total, 1),
pair_correct / max(pair_total, 1),
)
def main(question: str = QUESTION):
task_cfg = FUN_CAPTCHA_TASKS[question]
cfg = TRAIN_CONFIG["funcaptcha_rollball_animals"]
img_h, img_w = IMAGE_SIZE["funcaptcha_rollball_animals"]
device = get_device()
data_dir = Path(task_cfg["data_dir"])
ckpt_name = task_cfg["checkpoint_name"]
ckpt_path = CHECKPOINTS_DIR / f"{ckpt_name}.pth"
_set_seed()
print("=" * 60)
print(f"训练 FunCaptcha 专项模型 ({question})")
print(f" 数据目录: {data_dir}")
print(f" 候选数: {task_cfg['num_candidates']}")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_transform = build_train_rgb_transform(img_h, img_w)
val_transform = build_val_rgb_transform(img_h, img_w)
full_dataset = FunCaptchaChallengeDataset(
dirs=[data_dir],
task_config=task_cfg,
transform=train_transform,
)
total = len(full_dataset)
if total == 0:
raise FileNotFoundError(
f"未找到任何 FunCaptcha 训练样本,请先准备数据: {data_dir}"
)
val_size = max(1, int(total * cfg["val_split"]))
train_size = total - val_size
if train_size <= 0:
raise ValueError(f"FunCaptcha 数据量过少,至少需要 2 张样本: {data_dir}")
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
val_ds_clean = FunCaptchaChallengeDataset(
dirs=[data_dir],
task_config=task_cfg,
transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds,
batch_size=cfg["batch_size"],
shuffle=True,
num_workers=0,
pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean,
batch_size=cfg["batch_size"],
shuffle=False,
num_workers=0,
pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
model = FunCaptchaSiamese(in_channels=task_cfg["channels"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
pos_weight = torch.tensor([task_cfg["num_candidates"] - 1], dtype=torch.float32, device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
best_acc = 0.0
start_epoch = 1
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_acc = float(ckpt.get("best_acc", 0.0))
start_epoch = int(ckpt.get("epoch", 0)) + 1
for _ in range(start_epoch - 1):
scheduler.step()
print(f"[续训] 从 epoch {start_epoch} 继续, best_acc={best_acc:.4f}")
for epoch in range(start_epoch, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for candidates, reference, answer_idx in pbar:
candidates = candidates.to(device)
reference = reference.to(device)
answer_idx = answer_idx.to(device)
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
candidates, reference, answer_idx
)
logits = model(pair_candidates, pair_reference)
loss = criterion(logits, pair_targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
challenge_acc, pair_acc = _evaluate(model, val_loader, device)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={challenge_acc:.4f} "
f"pair_acc={pair_acc:.4f} "
f"lr={lr:.6f}"
)
if challenge_acc >= best_acc:
best_acc = challenge_acc
torch.save(
{
"model_state_dict": model.state_dict(),
"best_acc": best_acc,
"epoch": epoch,
"question": question,
"num_candidates": task_cfg["num_candidates"],
"tile_size": list(task_cfg["tile_size"]),
"reference_box": list(task_cfg["reference_box"]),
"answer_index_base": task_cfg["answer_index_base"],
"input_shape": [task_cfg["channels"], img_h, img_w],
},
ckpt_path,
)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
print(f"\n[训练完成] 最佳 challenge acc: {best_acc:.4f}")
_load_and_export(task_cfg["artifact_name"])
return best_acc
if __name__ == "__main__":
main()