Add slide and rotate interactive captcha solvers
New solver subsystem with independent models: - GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection - RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction - SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback - RotateSolver with ONNX sin/cos → atan2 inference - Generators, training scripts, CLI commands, and slide track utility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
65
training/train_slide.py
Normal file
65
training/train_slide.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
训练滑块缺口检测 CNN (GapDetectorCNN)
|
||||
|
||||
复用 train_regression_utils 的通用回归训练流程。
|
||||
|
||||
用法: python -m training.train_slide
|
||||
"""
|
||||
|
||||
from config import (
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_TRAIN_CONFIG,
|
||||
SOLVER_REGRESSION_RANGE,
|
||||
SLIDE_DATA_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from generators.slide_gen import SlideDataGenerator
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
|
||||
# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE
|
||||
# 以便复用 train_regression_utils
|
||||
import config as _cfg
|
||||
|
||||
|
||||
def main():
|
||||
solver_cfg = SOLVER_CONFIG["slide"]
|
||||
train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"]
|
||||
img_h, img_w = solver_cfg["cnn_input_size"]
|
||||
|
||||
model = GapDetectorCNN(img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练滑块缺口检测 CNN (GapDetectorCNN)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print(f" 任务: 预测缺口 x 坐标百分比")
|
||||
print("=" * 60)
|
||||
|
||||
# 直接使用 train_regression_utils 中的逻辑
|
||||
# 但需要临时注入配置
|
||||
_cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg
|
||||
_cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w)
|
||||
_cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"]
|
||||
|
||||
from training.train_regression_utils import train_regression_model
|
||||
|
||||
# 确保数据目录存在
|
||||
SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
real_dir = SLIDE_DATA_DIR / "real"
|
||||
real_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_regression_model(
|
||||
model_name="gap_detector",
|
||||
model=model,
|
||||
synthetic_dir=str(SLIDE_DATA_DIR),
|
||||
real_dir=str(real_dir),
|
||||
generator_cls=SlideDataGenerator,
|
||||
config_key="slide_cnn",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user