Files
CaptchBreaker/training/train_slide.py
Hua 9b5f29083e Add slide and rotate interactive captcha solvers
New solver subsystem with independent models:
- GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection
- RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction
- SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback
- RotateSolver with ONNX sin/cos → atan2 inference
- Generators, training scripts, CLI commands, and slide track utility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 18:07:06 +08:00

66 lines
1.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
训练滑块缺口检测 CNN (GapDetectorCNN)
复用 train_regression_utils 的通用回归训练流程。
用法: python -m training.train_slide
"""
from config import (
SOLVER_CONFIG,
SOLVER_TRAIN_CONFIG,
SOLVER_REGRESSION_RANGE,
SLIDE_DATA_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
RANDOM_SEED,
get_device,
)
from generators.slide_gen import SlideDataGenerator
from models.gap_detector import GapDetectorCNN
# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE
# 以便复用 train_regression_utils
import config as _cfg
def main():
solver_cfg = SOLVER_CONFIG["slide"]
train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"]
img_h, img_w = solver_cfg["cnn_input_size"]
model = GapDetectorCNN(img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练滑块缺口检测 CNN (GapDetectorCNN)")
print(f" 输入尺寸: {img_h}×{img_w}")
print(f" 任务: 预测缺口 x 坐标百分比")
print("=" * 60)
# 直接使用 train_regression_utils 中的逻辑
# 但需要临时注入配置
_cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg
_cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w)
_cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"]
from training.train_regression_utils import train_regression_model
# 确保数据目录存在
SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True)
real_dir = SLIDE_DATA_DIR / "real"
real_dir.mkdir(parents=True, exist_ok=True)
train_regression_model(
model_name="gap_detector",
model=model,
synthetic_dir=str(SLIDE_DATA_DIR),
real_dir=str(real_dir),
generator_cls=SlideDataGenerator,
config_key="slide_cnn",
)
if __name__ == "__main__":
main()