Add slide and rotate interactive captcha solvers
New solver subsystem with independent models: - GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection - RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction - SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback - RotateSolver with ONNX sin/cos → atan2 inference - Generators, training scripts, CLI commands, and slide track utility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -224,3 +224,55 @@ class RegressionDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([label], dtype=torch.float32)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 旋转求解器用数据集 (sin/cos 编码)
|
||||
# ============================================================
|
||||
class RotateSolverDataset(Dataset):
|
||||
"""
|
||||
旋转求解器数据集。
|
||||
|
||||
从目录中读取 {angle}_{xxx}.png 文件,
|
||||
将角度转换为 (sin θ, cos θ) 目标。
|
||||
RGB 输入,不转灰度。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dirs: 数据目录列表
|
||||
transform: 图片预处理/增强 (RGB)
|
||||
"""
|
||||
import math
|
||||
|
||||
self.transform = transform
|
||||
self.samples: list[tuple[str, float, float]] = [] # (路径, sin, cos)
|
||||
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
for f in sorted(d.glob("*.png")):
|
||||
raw_label = f.stem.rsplit("_", 1)[0]
|
||||
try:
|
||||
angle = float(raw_label)
|
||||
except ValueError:
|
||||
continue
|
||||
rad = math.radians(angle)
|
||||
self.samples.append((str(f), math.sin(rad), math.cos(rad)))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
import torch
|
||||
path, sin_val, cos_val = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
|
||||
|
||||
|
||||
245
training/train_rotate_solver.py
Normal file
245
training/train_rotate_solver.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
训练旋转验证码角度回归模型 (RotationRegressor)
|
||||
|
||||
自定义训练循环 (sin/cos 编码),不复用 train_regression_utils。
|
||||
|
||||
用法: python -m training.train_rotate_solver
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_TRAIN_CONFIG,
|
||||
ROTATE_SOLVER_DATA_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
AUGMENT_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
from training.dataset import RotateSolverDataset
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _circular_mae_deg(pred_angles: np.ndarray, gt_angles: np.ndarray) -> float:
|
||||
"""循环 MAE (度数)。"""
|
||||
diff = np.abs(pred_angles - gt_angles)
|
||||
diff = np.minimum(diff, 360.0 - diff)
|
||||
return float(np.mean(diff))
|
||||
|
||||
|
||||
def _build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 训练增强 (不转灰度)。"""
|
||||
aug = AUGMENT_CONFIG
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
def _build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 验证 transform。"""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
def _export_onnx(model: nn.Module, img_h: int, img_w: int):
|
||||
"""导出 ONNX (RGB 3通道输入)。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / "rotation_regressor.onnx"
|
||||
dummy = torch.randn(1, 3, img_h, img_w)
|
||||
torch.onnx.export(
|
||||
model.cpu(),
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
def main():
|
||||
cfg = SOLVER_TRAIN_CONFIG["rotate"]
|
||||
solver_cfg = SOLVER_CONFIG["rotate"]
|
||||
img_h, img_w = solver_cfg["input_size"]
|
||||
tolerance = 5.0 # ±5°
|
||||
|
||||
_set_seed()
|
||||
device = get_device()
|
||||
|
||||
print("=" * 60)
|
||||
print("训练旋转验证码角度回归模型 (RotationRegressor)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w} RGB")
|
||||
print(f" 编码: sin/cos")
|
||||
print(f" 容差: ±{tolerance}°")
|
||||
print("=" * 60)
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = ROTATE_SOLVER_DATA_DIR
|
||||
syn_path.mkdir(parents=True, exist_ok=True)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = RotateSolverDataGenerator()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
real_dir = syn_path / "real"
|
||||
real_dir.mkdir(parents=True, exist_ok=True)
|
||||
if list(real_dir.glob("*.png")):
|
||||
data_dirs.append(str(real_dir))
|
||||
print(f"[数据] 混合真实数据: {len(list(real_dir.glob('*.png')))} 张")
|
||||
|
||||
train_transform = _build_train_transform(img_h, img_w)
|
||||
val_transform = _build_val_transform(img_h, img_w)
|
||||
|
||||
full_dataset = RotateSolverDataset(dirs=data_dirs, transform=train_transform)
|
||||
total = len(full_dataset)
|
||||
val_size = int(total * cfg["val_split"])
|
||||
train_size = total - val_size
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
|
||||
val_ds_clean = RotateSolverDataset(dirs=data_dirs, transform=val_transform)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
# ---- 3. 模型 / 优化器 / 调度器 ----
|
||||
model = RotationRegressor(img_h=img_h, img_w=img_w).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
best_mae = float("inf")
|
||||
best_tol_acc = 0.0
|
||||
ckpt_path = CHECKPOINTS_DIR / "rotation_regressor.pth"
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device) # (B, 2) → (sin, cos)
|
||||
|
||||
preds = model(images) # (B, 2)
|
||||
loss = loss_fn(preds, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
# ---- 5. 验证 ----
|
||||
model.eval()
|
||||
all_pred_angles = []
|
||||
all_gt_angles = []
|
||||
with torch.no_grad():
|
||||
for images, targets in val_loader:
|
||||
images = images.to(device)
|
||||
preds = model(images).cpu().numpy() # (B, 2)
|
||||
targets_np = targets.numpy() # (B, 2)
|
||||
|
||||
# sin/cos → angle
|
||||
for i in range(len(preds)):
|
||||
pred_angle = math.degrees(math.atan2(preds[i][0], preds[i][1]))
|
||||
if pred_angle < 0:
|
||||
pred_angle += 360.0
|
||||
gt_angle = math.degrees(math.atan2(targets_np[i][0], targets_np[i][1]))
|
||||
if gt_angle < 0:
|
||||
gt_angle += 360.0
|
||||
all_pred_angles.append(pred_angle)
|
||||
all_gt_angles.append(gt_angle)
|
||||
|
||||
pred_arr = np.array(all_pred_angles)
|
||||
gt_arr = np.array(all_gt_angles)
|
||||
|
||||
mae = _circular_mae_deg(pred_arr, gt_arr)
|
||||
diff = np.abs(pred_arr - gt_arr)
|
||||
diff = np.minimum(diff, 360.0 - diff)
|
||||
tol_acc = float(np.mean(diff <= tolerance))
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"MAE={mae:.2f}° "
|
||||
f"tol_acc(±{tolerance:.0f}°)={tol_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
# ---- 6. 保存最佳模型 ----
|
||||
if tol_acc >= best_tol_acc:
|
||||
best_tol_acc = tol_acc
|
||||
best_mae = mae
|
||||
torch.save({
|
||||
"model_state_dict": model.state_dict(),
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")
|
||||
|
||||
# ---- 7. 导出 ONNX ----
|
||||
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}°")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, img_h, img_w)
|
||||
|
||||
return best_tol_acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
65
training/train_slide.py
Normal file
65
training/train_slide.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
训练滑块缺口检测 CNN (GapDetectorCNN)
|
||||
|
||||
复用 train_regression_utils 的通用回归训练流程。
|
||||
|
||||
用法: python -m training.train_slide
|
||||
"""
|
||||
|
||||
from config import (
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_TRAIN_CONFIG,
|
||||
SOLVER_REGRESSION_RANGE,
|
||||
SLIDE_DATA_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from generators.slide_gen import SlideDataGenerator
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
|
||||
# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE
|
||||
# 以便复用 train_regression_utils
|
||||
import config as _cfg
|
||||
|
||||
|
||||
def main():
|
||||
solver_cfg = SOLVER_CONFIG["slide"]
|
||||
train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"]
|
||||
img_h, img_w = solver_cfg["cnn_input_size"]
|
||||
|
||||
model = GapDetectorCNN(img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练滑块缺口检测 CNN (GapDetectorCNN)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print(f" 任务: 预测缺口 x 坐标百分比")
|
||||
print("=" * 60)
|
||||
|
||||
# 直接使用 train_regression_utils 中的逻辑
|
||||
# 但需要临时注入配置
|
||||
_cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg
|
||||
_cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w)
|
||||
_cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"]
|
||||
|
||||
from training.train_regression_utils import train_regression_model
|
||||
|
||||
# 确保数据目录存在
|
||||
SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
real_dir = SLIDE_DATA_DIR / "real"
|
||||
real_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_regression_model(
|
||||
model_name="gap_detector",
|
||||
model=model,
|
||||
synthetic_dir=str(SLIDE_DATA_DIR),
|
||||
real_dir=str(real_dir),
|
||||
generator_cls=SlideDataGenerator,
|
||||
config_key="slide_cnn",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user