Add slide and rotate interactive captcha solvers
New solver subsystem with independent models: - GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection - RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction - SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback - RotateSolver with ONNX sin/cos → atan2 inference - Generators, training scripts, CLI commands, and slide track utility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
90
CLAUDE.md
90
CLAUDE.md
@@ -33,7 +33,10 @@ captcha-breaker/
|
|||||||
│ │ ├── 3d_text/
|
│ │ ├── 3d_text/
|
||||||
│ │ ├── 3d_rotate/
|
│ │ ├── 3d_rotate/
|
||||||
│ │ └── 3d_slider/
|
│ │ └── 3d_slider/
|
||||||
│ └── classifier/ # 调度分类器训练数据 (混合各类型)
|
│ ├── classifier/ # 调度分类器训练数据 (混合各类型)
|
||||||
|
│ └── solver/ # Solver 训练数据
|
||||||
|
│ ├── slide/ # 滑块缺口检测训练数据
|
||||||
|
│ └── rotate/ # 旋转角度回归训练数据
|
||||||
├── generators/
|
├── generators/
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── base.py # 生成器基类
|
│ ├── base.py # 生成器基类
|
||||||
@@ -41,13 +44,17 @@ captcha-breaker/
|
|||||||
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
|
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
|
||||||
│ ├── threed_gen.py # 3D立体文字验证码生成器
|
│ ├── threed_gen.py # 3D立体文字验证码生成器
|
||||||
│ ├── threed_rotate_gen.py # 3D旋转验证码生成器
|
│ ├── threed_rotate_gen.py # 3D旋转验证码生成器
|
||||||
│ └── threed_slider_gen.py # 3D滑块验证码生成器
|
│ ├── threed_slider_gen.py # 3D滑块验证码生成器
|
||||||
|
│ ├── slide_gen.py # 滑块缺口训练数据生成器
|
||||||
|
│ └── rotate_solver_gen.py # 旋转求解器训练数据生成器
|
||||||
├── models/
|
├── models/
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
|
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
|
||||||
│ ├── classifier.py # 调度分类模型
|
│ ├── classifier.py # 调度分类模型
|
||||||
│ ├── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN)
|
│ ├── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN)
|
||||||
│ └── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
|
│ ├── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
|
||||||
|
│ ├── gap_detector.py # 滑块缺口检测CNN (~1MB)
|
||||||
|
│ └── rotation_regressor.py # 旋转角度回归 sin/cos (~2MB)
|
||||||
├── training/
|
├── training/
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── train_classifier.py # 训练调度模型
|
│ ├── train_classifier.py # 训练调度模型
|
||||||
@@ -56,6 +63,8 @@ captcha-breaker/
|
|||||||
│ ├── train_3d_text.py # 训练3D文字识别
|
│ ├── train_3d_text.py # 训练3D文字识别
|
||||||
│ ├── train_3d_rotate.py # 训练3D旋转回归
|
│ ├── train_3d_rotate.py # 训练3D旋转回归
|
||||||
│ ├── train_3d_slider.py # 训练3D滑块回归
|
│ ├── train_3d_slider.py # 训练3D滑块回归
|
||||||
|
│ ├── train_slide.py # 训练滑块缺口检测
|
||||||
|
│ ├── train_rotate_solver.py # 训练旋转角度回归
|
||||||
│ ├── train_utils.py # CTC 训练通用逻辑
|
│ ├── train_utils.py # CTC 训练通用逻辑
|
||||||
│ ├── train_regression_utils.py # 回归训练通用逻辑
|
│ ├── train_regression_utils.py # 回归训练通用逻辑
|
||||||
│ └── dataset.py # 通用 Dataset 类
|
│ └── dataset.py # 通用 Dataset 类
|
||||||
@@ -64,20 +73,32 @@ captcha-breaker/
|
|||||||
│ ├── pipeline.py # 核心推理流水线 (调度+识别)
|
│ ├── pipeline.py # 核心推理流水线 (调度+识别)
|
||||||
│ ├── export_onnx.py # PyTorch → ONNX 导出脚本
|
│ ├── export_onnx.py # PyTorch → ONNX 导出脚本
|
||||||
│ └── math_eval.py # 算式计算模块
|
│ └── math_eval.py # 算式计算模块
|
||||||
|
├── solvers/ # 交互式验证码求解器
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── base.py # 求解器基类
|
||||||
|
│ ├── slide_solver.py # 滑块求解 (OpenCV + CNN)
|
||||||
|
│ └── rotate_solver.py # 旋转求解 (ONNX sin/cos)
|
||||||
|
├── utils/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ └── slide_utils.py # 滑块轨迹生成工具
|
||||||
├── checkpoints/ # 训练产出的模型文件
|
├── checkpoints/ # 训练产出的模型文件
|
||||||
│ ├── classifier.pth
|
│ ├── classifier.pth
|
||||||
│ ├── normal.pth
|
│ ├── normal.pth
|
||||||
│ ├── math.pth
|
│ ├── math.pth
|
||||||
│ ├── threed_text.pth
|
│ ├── threed_text.pth
|
||||||
│ ├── threed_rotate.pth
|
│ ├── threed_rotate.pth
|
||||||
│ └── threed_slider.pth
|
│ ├── threed_slider.pth
|
||||||
|
│ ├── gap_detector.pth
|
||||||
|
│ └── rotation_regressor.pth
|
||||||
├── onnx_models/ # 导出的 ONNX 模型
|
├── onnx_models/ # 导出的 ONNX 模型
|
||||||
│ ├── classifier.onnx
|
│ ├── classifier.onnx
|
||||||
│ ├── normal.onnx
|
│ ├── normal.onnx
|
||||||
│ ├── math.onnx
|
│ ├── math.onnx
|
||||||
│ ├── threed_text.onnx
|
│ ├── threed_text.onnx
|
||||||
│ ├── threed_rotate.onnx
|
│ ├── threed_rotate.onnx
|
||||||
│ └── threed_slider.onnx
|
│ ├── threed_slider.onnx
|
||||||
|
│ ├── gap_detector.onnx
|
||||||
|
│ └── rotation_regressor.onnx
|
||||||
├── server.py # FastAPI 推理服务 (可选)
|
├── server.py # FastAPI 推理服务 (可选)
|
||||||
├── cli.py # 命令行入口
|
├── cli.py # 命令行入口
|
||||||
└── tests/
|
└── tests/
|
||||||
@@ -462,3 +483,62 @@ uv run python cli.py serve --port 8080
|
|||||||
6. 实现 cli.py 统一入口
|
6. 实现 cli.py 统一入口
|
||||||
7. 可选: server.py HTTP 服务
|
7. 可选: server.py HTTP 服务
|
||||||
8. 编写 tests/
|
8. 编写 tests/
|
||||||
|
|
||||||
|
## 交互式 Solver 扩展
|
||||||
|
|
||||||
|
### 概述
|
||||||
|
|
||||||
|
在现有验证码识别架构之上,新增滑块 (slide) 和旋转 (rotate) 两种交互式验证码求解能力。与现有 3d_rotate/3d_slider 的区别:
|
||||||
|
|
||||||
|
- **3d_slider** (合成拼图回归) → **slide solver**: 真实滑块验证码,OpenCV 优先 + CNN 兜底
|
||||||
|
- **3d_rotate** (合成圆盘 sigmoid 回归) → **rotate solver**: 真实旋转验证码,sin/cos 编码 + 自然图
|
||||||
|
|
||||||
|
每个 solver 模型独立训练、独立导出 ONNX、独立替换,互不依赖。
|
||||||
|
|
||||||
|
### 滑块求解器 (SlideSolver)
|
||||||
|
|
||||||
|
- 三种方法按优先级: 模板匹配 → 边缘检测 → CNN 兜底
|
||||||
|
- 模型: `GapDetectorCNN` (1x128x256 灰度 → sigmoid [0,1])
|
||||||
|
- OpenCV 延迟导入,未安装时退化到 CNN only
|
||||||
|
- 输出: `{"gap_x", "gap_x_percent", "confidence", "method"}`
|
||||||
|
|
||||||
|
### 旋转求解器 (RotateSolver)
|
||||||
|
|
||||||
|
- ONNX 推理 → (sin, cos) → atan2 → 角度
|
||||||
|
- 模型: `RotationRegressor` (3x128x128 RGB → tanh (sin θ, cos θ))
|
||||||
|
- 输出: `{"angle", "confidence"}`
|
||||||
|
|
||||||
|
### Solver CLI 用法
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 生成训练数据
|
||||||
|
uv run python cli.py generate-solver slide --num 30000
|
||||||
|
uv run python cli.py generate-solver rotate --num 50000
|
||||||
|
|
||||||
|
# 训练 (各模型独立)
|
||||||
|
uv run python cli.py train-solver slide
|
||||||
|
uv run python cli.py train-solver rotate
|
||||||
|
|
||||||
|
# 求解
|
||||||
|
uv run python cli.py solve slide --bg bg.png [--tpl tpl.png]
|
||||||
|
uv run python cli.py solve rotate --image img.png
|
||||||
|
|
||||||
|
# 导出 (已集成到 export --all)
|
||||||
|
uv run python cli.py export --model gap_detector
|
||||||
|
uv run python cli.py export --model rotation_regressor
|
||||||
|
```
|
||||||
|
|
||||||
|
### 滑块轨迹生成
|
||||||
|
|
||||||
|
`utils/slide_utils.py` 提供 `generate_slide_track(distance)`:
|
||||||
|
- 贝塞尔曲线 ease-out 加速减速
|
||||||
|
- y 轴 ±1~3px 随机抖动
|
||||||
|
- 时间间隔不均匀
|
||||||
|
- 末尾微小过冲回退
|
||||||
|
|
||||||
|
### Solver 目标指标
|
||||||
|
|
||||||
|
| 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|
||||||
|
|------|-----------|---------|---------|
|
||||||
|
| 滑块 CNN (±5px) | > 85% | < 30ms | ~1MB |
|
||||||
|
| 旋转回归 (±5°) | > 85% | < 30ms | ~2MB |
|
||||||
|
|||||||
108
cli.py
108
cli.py
@@ -13,6 +13,11 @@ CaptchaBreaker 命令行入口
|
|||||||
python cli.py predict image.png --type normal
|
python cli.py predict image.png --type normal
|
||||||
python cli.py predict-dir ./test_images/
|
python cli.py predict-dir ./test_images/
|
||||||
python cli.py serve --port 8080
|
python cli.py serve --port 8080
|
||||||
|
python cli.py generate-solver slide --num 30000
|
||||||
|
python cli.py train-solver slide
|
||||||
|
python cli.py train-solver rotate
|
||||||
|
python cli.py solve slide --bg bg.png [--tpl tpl.png]
|
||||||
|
python cli.py solve rotate --image img.png
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -195,6 +200,90 @@ def cmd_serve(args):
|
|||||||
uvicorn.run(app, host=args.host, port=args.port)
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_generate_solver(args):
|
||||||
|
"""生成 solver 训练数据。"""
|
||||||
|
from config import SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR
|
||||||
|
from generators.slide_gen import SlideDataGenerator
|
||||||
|
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||||
|
|
||||||
|
solver_type = args.type
|
||||||
|
num = args.num
|
||||||
|
|
||||||
|
gen_map = {
|
||||||
|
"slide": (SlideDataGenerator, SLIDE_DATA_DIR),
|
||||||
|
"rotate": (RotateSolverDataGenerator, ROTATE_SOLVER_DATA_DIR),
|
||||||
|
}
|
||||||
|
|
||||||
|
if solver_type not in gen_map:
|
||||||
|
print(f"未知 solver 类型: {solver_type} 可选: {', '.join(gen_map.keys())}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
gen_cls, out_dir = gen_map[solver_type]
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"生成 solver/{solver_type} 数据: {num} 张 → {out_dir}")
|
||||||
|
gen = gen_cls()
|
||||||
|
gen.generate_dataset(num, str(out_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_train_solver(args):
|
||||||
|
"""训练 solver 模型。"""
|
||||||
|
solver_type = args.type
|
||||||
|
|
||||||
|
if solver_type == "slide":
|
||||||
|
from training.train_slide import main as train_fn
|
||||||
|
elif solver_type == "rotate":
|
||||||
|
from training.train_rotate_solver import main as train_fn
|
||||||
|
else:
|
||||||
|
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
train_fn()
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_solve(args):
|
||||||
|
"""求解验证码。"""
|
||||||
|
solver_type = args.type
|
||||||
|
|
||||||
|
if solver_type == "slide":
|
||||||
|
from solvers.slide_solver import SlideSolver
|
||||||
|
|
||||||
|
bg_path = args.bg
|
||||||
|
tpl_path = getattr(args, "tpl", None)
|
||||||
|
if not Path(bg_path).exists():
|
||||||
|
print(f"文件不存在: {bg_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
solver = SlideSolver()
|
||||||
|
result = solver.solve(bg_path, template_image=tpl_path)
|
||||||
|
|
||||||
|
print(f"背景图: {bg_path}")
|
||||||
|
if tpl_path:
|
||||||
|
print(f"模板图: {tpl_path}")
|
||||||
|
print(f"缺口 x: {result['gap_x']} px")
|
||||||
|
print(f"缺口 x%: {result['gap_x_percent']:.4f}")
|
||||||
|
print(f"置信度: {result['confidence']:.4f}")
|
||||||
|
print(f"方法: {result['method']}")
|
||||||
|
|
||||||
|
elif solver_type == "rotate":
|
||||||
|
from solvers.rotate_solver import RotateSolver
|
||||||
|
|
||||||
|
image_path = args.image
|
||||||
|
if not Path(image_path).exists():
|
||||||
|
print(f"文件不存在: {image_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
solver = RotateSolver()
|
||||||
|
result = solver.solve(image_path)
|
||||||
|
|
||||||
|
print(f"图片: {image_path}")
|
||||||
|
print(f"角度: {result['angle']}°")
|
||||||
|
print(f"置信度: {result['confidence']}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="captcha-breaker",
|
prog="captcha-breaker",
|
||||||
@@ -247,6 +336,22 @@ def main():
|
|||||||
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
|
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
|
||||||
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
|
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
|
||||||
|
|
||||||
|
# ---- generate-solver ----
|
||||||
|
p_gen_solver = subparsers.add_parser("generate-solver", help="生成 solver 训练数据")
|
||||||
|
p_gen_solver.add_argument("type", help="solver 类型: slide, rotate")
|
||||||
|
p_gen_solver.add_argument("--num", type=int, required=True, help="生成数量")
|
||||||
|
|
||||||
|
# ---- train-solver ----
|
||||||
|
p_train_solver = subparsers.add_parser("train-solver", help="训练 solver 模型")
|
||||||
|
p_train_solver.add_argument("type", help="solver 类型: slide, rotate")
|
||||||
|
|
||||||
|
# ---- solve ----
|
||||||
|
p_solve = subparsers.add_parser("solve", help="求解交互式验证码")
|
||||||
|
p_solve.add_argument("type", help="solver 类型: slide, rotate")
|
||||||
|
p_solve.add_argument("--bg", help="背景图路径 (slide 必需)")
|
||||||
|
p_solve.add_argument("--tpl", default=None, help="模板图路径 (slide 可选)")
|
||||||
|
p_solve.add_argument("--image", help="图片路径 (rotate 必需)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.command is None:
|
if args.command is None:
|
||||||
@@ -260,6 +365,9 @@ def main():
|
|||||||
"predict": cmd_predict,
|
"predict": cmd_predict,
|
||||||
"predict-dir": cmd_predict_dir,
|
"predict-dir": cmd_predict_dir,
|
||||||
"serve": cmd_serve,
|
"serve": cmd_serve,
|
||||||
|
"generate-solver": cmd_generate_solver,
|
||||||
|
"train-solver": cmd_train_solver,
|
||||||
|
"solve": cmd_solve,
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd_map[args.command](args)
|
cmd_map[args.command](args)
|
||||||
|
|||||||
43
config.py
43
config.py
@@ -34,6 +34,11 @@ REAL_3D_TEXT_DIR = REAL_DIR / "3d_text"
|
|||||||
REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate"
|
REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate"
|
||||||
REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider"
|
REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider"
|
||||||
|
|
||||||
|
# Solver 数据目录
|
||||||
|
SOLVER_DATA_DIR = DATA_DIR / "solver"
|
||||||
|
SLIDE_DATA_DIR = SOLVER_DATA_DIR / "slide"
|
||||||
|
ROTATE_SOLVER_DATA_DIR = SOLVER_DATA_DIR / "rotate"
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# 模型输出目录
|
# 模型输出目录
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -47,6 +52,7 @@ for _dir in [
|
|||||||
REAL_NORMAL_DIR, REAL_MATH_DIR,
|
REAL_NORMAL_DIR, REAL_MATH_DIR,
|
||||||
REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR,
|
REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR,
|
||||||
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
|
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
|
||||||
|
SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR,
|
||||||
]:
|
]:
|
||||||
_dir.mkdir(parents=True, exist_ok=True)
|
_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -241,3 +247,40 @@ SERVER_CONFIG = {
|
|||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": 8080,
|
"port": 8080,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Solver 配置 (交互式验证码求解)
|
||||||
|
# ============================================================
|
||||||
|
SOLVER_CONFIG = {
|
||||||
|
"slide": {
|
||||||
|
"canny_low": 50,
|
||||||
|
"canny_high": 150,
|
||||||
|
"cnn_input_size": (128, 256), # H, W
|
||||||
|
},
|
||||||
|
"rotate": {
|
||||||
|
"input_size": (128, 128), # H, W
|
||||||
|
"channels": 3, # RGB
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
SOLVER_TRAIN_CONFIG = {
|
||||||
|
"slide_cnn": {
|
||||||
|
"epochs": 50,
|
||||||
|
"batch_size": 64,
|
||||||
|
"lr": 1e-3,
|
||||||
|
"synthetic_samples": 30000,
|
||||||
|
"val_split": 0.1,
|
||||||
|
},
|
||||||
|
"rotate": {
|
||||||
|
"epochs": 80,
|
||||||
|
"batch_size": 64,
|
||||||
|
"lr": 5e-4,
|
||||||
|
"synthetic_samples": 50000,
|
||||||
|
"val_split": 0.1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
SOLVER_REGRESSION_RANGE = {
|
||||||
|
"slide": (0, 1), # 归一化百分比
|
||||||
|
"rotate": (0, 360), # 角度
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
数据生成器包
|
数据生成器包
|
||||||
|
|
||||||
提供五种验证码类型的数据生成器:
|
提供七种验证码类型的数据生成器:
|
||||||
- NormalCaptchaGenerator: 普通字符验证码
|
- NormalCaptchaGenerator: 普通字符验证码
|
||||||
- MathCaptchaGenerator: 算式验证码
|
- MathCaptchaGenerator: 算式验证码
|
||||||
- ThreeDCaptchaGenerator: 3D 立体文字验证码
|
- ThreeDCaptchaGenerator: 3D 立体文字验证码
|
||||||
- ThreeDRotateGenerator: 3D 旋转验证码
|
- ThreeDRotateGenerator: 3D 旋转验证码
|
||||||
- ThreeDSliderGenerator: 3D 滑块验证码
|
- ThreeDSliderGenerator: 3D 滑块验证码
|
||||||
|
- SlideDataGenerator: 滑块验证码求解器训练数据
|
||||||
|
- RotateSolverDataGenerator: 旋转验证码求解器训练数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from generators.base import BaseCaptchaGenerator
|
from generators.base import BaseCaptchaGenerator
|
||||||
@@ -15,6 +17,8 @@ from generators.math_gen import MathCaptchaGenerator
|
|||||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||||
|
from generators.slide_gen import SlideDataGenerator
|
||||||
|
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseCaptchaGenerator",
|
"BaseCaptchaGenerator",
|
||||||
@@ -23,4 +27,6 @@ __all__ = [
|
|||||||
"ThreeDCaptchaGenerator",
|
"ThreeDCaptchaGenerator",
|
||||||
"ThreeDRotateGenerator",
|
"ThreeDRotateGenerator",
|
||||||
"ThreeDSliderGenerator",
|
"ThreeDSliderGenerator",
|
||||||
|
"SlideDataGenerator",
|
||||||
|
"RotateSolverDataGenerator",
|
||||||
]
|
]
|
||||||
|
|||||||
156
generators/rotate_solver_gen.py
Normal file
156
generators/rotate_solver_gen.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
旋转验证码求解器数据生成器
|
||||||
|
|
||||||
|
生成旋转验证码训练数据:随机图案 (色块/渐变/几何图形),随机旋转 0-359°。
|
||||||
|
裁剪为圆形 (黑色背景填充圆外区域)。
|
||||||
|
|
||||||
|
标签 = 旋转角度 (整数)
|
||||||
|
文件名格式: {angle}_{index:06d}.png
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw, ImageFilter, ImageFont
|
||||||
|
|
||||||
|
from config import SOLVER_CONFIG
|
||||||
|
from generators.base import BaseCaptchaGenerator
|
||||||
|
|
||||||
|
_FONT_PATHS = [
|
||||||
|
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
|
||||||
|
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
|
||||||
|
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
|
||||||
|
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
|
||||||
|
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RotateSolverDataGenerator(BaseCaptchaGenerator):
|
||||||
|
"""旋转验证码求解器数据生成器。"""
|
||||||
|
|
||||||
|
def __init__(self, seed: int | None = None):
|
||||||
|
from config import RANDOM_SEED
|
||||||
|
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||||
|
|
||||||
|
self.cfg = SOLVER_CONFIG["rotate"]
|
||||||
|
self.height, self.width = self.cfg["input_size"] # (H, W)
|
||||||
|
|
||||||
|
self._fonts: list[str] = []
|
||||||
|
for p in _FONT_PATHS:
|
||||||
|
try:
|
||||||
|
ImageFont.truetype(p, 20)
|
||||||
|
self._fonts.append(p)
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||||
|
rng = self.rng
|
||||||
|
|
||||||
|
# 随机旋转角度 0-359
|
||||||
|
angle = rng.randint(0, 359)
|
||||||
|
if text is None:
|
||||||
|
text = str(angle)
|
||||||
|
|
||||||
|
size = self.width # 正方形
|
||||||
|
radius = size // 2
|
||||||
|
|
||||||
|
# 1. 生成正向图案 (未旋转)
|
||||||
|
content = self._random_pattern(rng, size)
|
||||||
|
|
||||||
|
# 2. 旋转图案
|
||||||
|
rotated = content.rotate(-angle, resample=Image.BICUBIC, expand=False)
|
||||||
|
|
||||||
|
# 3. 裁剪为圆形 (黑色背景)
|
||||||
|
result = Image.new("RGB", (size, size), (0, 0, 0))
|
||||||
|
mask = Image.new("L", (size, size), 0)
|
||||||
|
mask_draw = ImageDraw.Draw(mask)
|
||||||
|
mask_draw.ellipse([0, 0, size - 1, size - 1], fill=255)
|
||||||
|
|
||||||
|
result.paste(rotated, (0, 0), mask)
|
||||||
|
|
||||||
|
# 4. 轻微模糊
|
||||||
|
result = result.filter(ImageFilter.GaussianBlur(radius=0.5))
|
||||||
|
|
||||||
|
return result, text
|
||||||
|
|
||||||
|
def _random_pattern(self, rng: random.Random, size: int) -> Image.Image:
|
||||||
|
"""生成随机图案 (带明显方向性,便于模型学习旋转)。"""
|
||||||
|
img = Image.new("RGB", (size, size))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 渐变背景
|
||||||
|
base_r = rng.randint(100, 220)
|
||||||
|
base_g = rng.randint(100, 220)
|
||||||
|
base_b = rng.randint(100, 220)
|
||||||
|
for y in range(size):
|
||||||
|
ratio = y / max(size - 1, 1)
|
||||||
|
r = int(base_r * (1 - ratio) + rng.randint(40, 120) * ratio)
|
||||||
|
g = int(base_g * (1 - ratio) + rng.randint(40, 120) * ratio)
|
||||||
|
b = int(base_b * (1 - ratio) + rng.randint(40, 120) * ratio)
|
||||||
|
draw.line([(0, y), (size, y)], fill=(r, g, b))
|
||||||
|
|
||||||
|
cx, cy = size // 2, size // 2
|
||||||
|
|
||||||
|
# 添加不对称几何图形 (让模型能感知方向)
|
||||||
|
pattern_type = rng.choice(["triangle", "arrow", "text", "shapes"])
|
||||||
|
|
||||||
|
if pattern_type == "triangle":
|
||||||
|
# 顶部三角形标记
|
||||||
|
color = tuple(rng.randint(180, 255) for _ in range(3))
|
||||||
|
ts = size // 4
|
||||||
|
draw.polygon(
|
||||||
|
[(cx, cy - ts), (cx - ts // 2, cy), (cx + ts // 2, cy)],
|
||||||
|
fill=color,
|
||||||
|
)
|
||||||
|
# 底部小圆
|
||||||
|
draw.ellipse(
|
||||||
|
[cx - 8, cy + ts // 2, cx + 8, cy + ts // 2 + 16],
|
||||||
|
fill=tuple(rng.randint(50, 150) for _ in range(3)),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif pattern_type == "arrow":
|
||||||
|
# 向上的箭头
|
||||||
|
color = tuple(rng.randint(180, 255) for _ in range(3))
|
||||||
|
arrow_len = size // 3
|
||||||
|
draw.line([(cx, cy - arrow_len), (cx, cy + arrow_len // 2)], fill=color, width=4)
|
||||||
|
draw.polygon(
|
||||||
|
[(cx, cy - arrow_len - 5), (cx - 10, cy - arrow_len + 10), (cx + 10, cy - arrow_len + 10)],
|
||||||
|
fill=color,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif pattern_type == "text" and self._fonts:
|
||||||
|
# 文字 (有天然方向性)
|
||||||
|
font_path = rng.choice(self._fonts)
|
||||||
|
font_size = size // 3
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype(font_path, font_size)
|
||||||
|
ch = rng.choice("ABCDEFGHJKLMNPRSTUVWXYZ23456789")
|
||||||
|
bbox = font.getbbox(ch)
|
||||||
|
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||||
|
draw.text(
|
||||||
|
(cx - tw // 2 - bbox[0], cy - th // 2 - bbox[1]),
|
||||||
|
ch,
|
||||||
|
fill=tuple(rng.randint(0, 80) for _ in range(3)),
|
||||||
|
font=font,
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 混合不对称形状
|
||||||
|
# 上方矩形
|
||||||
|
w, h = rng.randint(15, 30), rng.randint(10, 20)
|
||||||
|
color = tuple(rng.randint(150, 255) for _ in range(3))
|
||||||
|
draw.rectangle([cx - w, cy - size // 3, cx + w, cy - size // 3 + h], fill=color)
|
||||||
|
# 右下小圆
|
||||||
|
r = rng.randint(5, 12)
|
||||||
|
color2 = tuple(rng.randint(50, 150) for _ in range(3))
|
||||||
|
draw.ellipse([cx + size // 5, cy + size // 5, cx + size // 5 + r * 2, cy + size // 5 + r * 2], fill=color2)
|
||||||
|
|
||||||
|
# 添加纹理噪声
|
||||||
|
for _ in range(rng.randint(20, 60)):
|
||||||
|
nx, ny = rng.randint(0, size - 1), rng.randint(0, size - 1)
|
||||||
|
nc = tuple(rng.randint(80, 220) for _ in range(3))
|
||||||
|
draw.point((nx, ny), fill=nc)
|
||||||
|
|
||||||
|
return img
|
||||||
112
generators/slide_gen.py
Normal file
112
generators/slide_gen.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
滑块验证码数据生成器
|
||||||
|
|
||||||
|
生成滑块验证码训练数据:随机纹理/色块背景 + 方形缺口 + 阴影效果。
|
||||||
|
标签 = 缺口中心 x 坐标 (整数)
|
||||||
|
文件名格式: {gap_x}_{index:06d}.png
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw, ImageFilter
|
||||||
|
|
||||||
|
from config import SOLVER_CONFIG
|
||||||
|
from generators.base import BaseCaptchaGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class SlideDataGenerator(BaseCaptchaGenerator):
|
||||||
|
"""滑块验证码数据生成器。"""
|
||||||
|
|
||||||
|
def __init__(self, seed: int | None = None):
|
||||||
|
from config import RANDOM_SEED
|
||||||
|
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||||
|
|
||||||
|
self.cfg = SOLVER_CONFIG["slide"]
|
||||||
|
self.height, self.width = self.cfg["cnn_input_size"] # (H, W)
|
||||||
|
self.gap_size = 40 # 缺口大小
|
||||||
|
|
||||||
|
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||||
|
rng = self.rng
|
||||||
|
gs = self.gap_size
|
||||||
|
|
||||||
|
# 缺口 x 范围: 留出边距
|
||||||
|
margin = gs + 10
|
||||||
|
gap_x = rng.randint(margin, self.width - margin)
|
||||||
|
gap_y = rng.randint(10, self.height - gs - 10)
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
text = str(gap_x)
|
||||||
|
|
||||||
|
# 1. 生成纹理背景
|
||||||
|
img = self._textured_background(rng)
|
||||||
|
|
||||||
|
# 2. 绘制缺口 (半透明灰色区域 + 阴影)
|
||||||
|
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||||
|
overlay_draw = ImageDraw.Draw(overlay)
|
||||||
|
|
||||||
|
# 阴影 (稍大一圈)
|
||||||
|
overlay_draw.rectangle(
|
||||||
|
[gap_x + 2, gap_y + 2, gap_x + gs + 2, gap_y + gs + 2],
|
||||||
|
fill=(0, 0, 0, 60),
|
||||||
|
)
|
||||||
|
# 缺口本体
|
||||||
|
overlay_draw.rectangle(
|
||||||
|
[gap_x, gap_y, gap_x + gs, gap_y + gs],
|
||||||
|
fill=(80, 80, 80, 160),
|
||||||
|
outline=(60, 60, 60, 200),
|
||||||
|
width=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = img.convert("RGBA")
|
||||||
|
img = Image.alpha_composite(img, overlay)
|
||||||
|
img = img.convert("RGB")
|
||||||
|
|
||||||
|
# 3. 轻微模糊
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=0.3))
|
||||||
|
|
||||||
|
return img, text
|
||||||
|
|
||||||
|
def _textured_background(self, rng: random.Random) -> Image.Image:
|
||||||
|
"""生成带纹理的彩色背景。"""
|
||||||
|
img = Image.new("RGB", (self.width, self.height))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# 渐变底色
|
||||||
|
base_r = rng.randint(80, 200)
|
||||||
|
base_g = rng.randint(80, 200)
|
||||||
|
base_b = rng.randint(80, 200)
|
||||||
|
for y in range(self.height):
|
||||||
|
ratio = y / max(self.height - 1, 1)
|
||||||
|
r = int(base_r + 40 * ratio)
|
||||||
|
g = int(base_g - 20 * ratio)
|
||||||
|
b = int(base_b + 20 * ratio)
|
||||||
|
r, g, b = max(0, min(255, r)), max(0, min(255, g)), max(0, min(255, b))
|
||||||
|
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
|
||||||
|
|
||||||
|
# 纹理噪声
|
||||||
|
for _ in range(self.width * self.height // 6):
|
||||||
|
x = rng.randint(0, self.width - 1)
|
||||||
|
y = rng.randint(0, self.height - 1)
|
||||||
|
pixel = img.getpixel((x, y))
|
||||||
|
noise = tuple(
|
||||||
|
max(0, min(255, c + rng.randint(-30, 30)))
|
||||||
|
for c in pixel
|
||||||
|
)
|
||||||
|
draw.point((x, y), fill=noise)
|
||||||
|
|
||||||
|
# 随机色块 (模拟图案)
|
||||||
|
for _ in range(rng.randint(4, 8)):
|
||||||
|
x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20)
|
||||||
|
x2, y2 = x1 + rng.randint(15, 50), y1 + rng.randint(10, 30)
|
||||||
|
color = tuple(rng.randint(50, 230) for _ in range(3))
|
||||||
|
draw.rectangle([x1, y1, x2, y2], fill=color)
|
||||||
|
|
||||||
|
# 随机圆形
|
||||||
|
for _ in range(rng.randint(2, 5)):
|
||||||
|
cx = rng.randint(10, self.width - 10)
|
||||||
|
cy = rng.randint(10, self.height - 10)
|
||||||
|
cr = rng.randint(5, 20)
|
||||||
|
color = tuple(rng.randint(50, 230) for _ in range(3))
|
||||||
|
draw.ellipse([cx - cr, cy - cr, cx + cr, cy + cr], fill=color)
|
||||||
|
|
||||||
|
return img
|
||||||
@@ -18,11 +18,14 @@ from config import (
|
|||||||
THREED_CHARS,
|
THREED_CHARS,
|
||||||
NUM_CAPTCHA_TYPES,
|
NUM_CAPTCHA_TYPES,
|
||||||
REGRESSION_RANGE,
|
REGRESSION_RANGE,
|
||||||
|
SOLVER_CONFIG,
|
||||||
)
|
)
|
||||||
from models.classifier import CaptchaClassifier
|
from models.classifier import CaptchaClassifier
|
||||||
from models.lite_crnn import LiteCRNN
|
from models.lite_crnn import LiteCRNN
|
||||||
from models.threed_cnn import ThreeDCNN
|
from models.threed_cnn import ThreeDCNN
|
||||||
from models.regression_cnn import RegressionCNN
|
from models.regression_cnn import RegressionCNN
|
||||||
|
from models.gap_detector import GapDetectorCNN
|
||||||
|
from models.rotation_regressor import RotationRegressor
|
||||||
|
|
||||||
|
|
||||||
def export_model(
|
def export_model(
|
||||||
@@ -52,7 +55,7 @@ def export_model(
|
|||||||
dummy = torch.randn(1, *input_shape)
|
dummy = torch.randn(1, *input_shape)
|
||||||
|
|
||||||
# 分类器和识别器的 dynamic_axes 不同
|
# 分类器和识别器的 dynamic_axes 不同
|
||||||
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
|
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider", "gap_detector", "rotation_regressor"):
|
||||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||||
else:
|
else:
|
||||||
# CTC 模型: output shape = (T, B, C)
|
# CTC 模型: output shape = (T, B, C)
|
||||||
@@ -110,6 +113,14 @@ def _load_and_export(model_name: str):
|
|||||||
h, w = IMAGE_SIZE["3d_slider"]
|
h, w = IMAGE_SIZE["3d_slider"]
|
||||||
model = RegressionCNN(img_h=h, img_w=w)
|
model = RegressionCNN(img_h=h, img_w=w)
|
||||||
input_shape = (1, h, w)
|
input_shape = (1, h, w)
|
||||||
|
elif model_name == "gap_detector":
|
||||||
|
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
||||||
|
model = GapDetectorCNN(img_h=h, img_w=w)
|
||||||
|
input_shape = (1, h, w)
|
||||||
|
elif model_name == "rotation_regressor":
|
||||||
|
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
||||||
|
model = RotationRegressor(img_h=h, img_w=w)
|
||||||
|
input_shape = (3, h, w)
|
||||||
else:
|
else:
|
||||||
print(f"[错误] 未知模型: {model_name}")
|
print(f"[错误] 未知模型: {model_name}")
|
||||||
return
|
return
|
||||||
@@ -119,11 +130,15 @@ def _load_and_export(model_name: str):
|
|||||||
|
|
||||||
|
|
||||||
def export_all():
|
def export_all():
|
||||||
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
|
"""依次导出全部模型 (含 solver 模型)。"""
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
print("导出全部 ONNX 模型")
|
print("导出全部 ONNX 模型")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
|
for name in [
|
||||||
|
"classifier", "normal", "math", "threed_text",
|
||||||
|
"threed_rotate", "threed_slider",
|
||||||
|
"gap_detector", "rotation_regressor",
|
||||||
|
]:
|
||||||
_load_and_export(name)
|
_load_and_export(name)
|
||||||
print("\n全部导出完成。")
|
print("\n全部导出完成。")
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,27 @@
|
|||||||
"""
|
"""
|
||||||
模型定义包
|
模型定义包
|
||||||
|
|
||||||
提供四种模型:
|
提供六种模型:
|
||||||
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
|
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
|
||||||
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
|
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
|
||||||
- ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
|
- ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
|
||||||
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
|
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
|
||||||
|
- GapDetectorCNN: 滑块缺口检测 CNN (~1MB)
|
||||||
|
- RotationRegressor: 旋转角度回归 sin/cos 编码 (~2MB)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from models.classifier import CaptchaClassifier
|
from models.classifier import CaptchaClassifier
|
||||||
from models.lite_crnn import LiteCRNN
|
from models.lite_crnn import LiteCRNN
|
||||||
from models.threed_cnn import ThreeDCNN
|
from models.threed_cnn import ThreeDCNN
|
||||||
from models.regression_cnn import RegressionCNN
|
from models.regression_cnn import RegressionCNN
|
||||||
|
from models.gap_detector import GapDetectorCNN
|
||||||
|
from models.rotation_regressor import RotationRegressor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CaptchaClassifier",
|
"CaptchaClassifier",
|
||||||
"LiteCRNN",
|
"LiteCRNN",
|
||||||
"ThreeDCNN",
|
"ThreeDCNN",
|
||||||
"RegressionCNN",
|
"RegressionCNN",
|
||||||
|
"GapDetectorCNN",
|
||||||
|
"RotationRegressor",
|
||||||
]
|
]
|
||||||
|
|||||||
82
models/gap_detector.py
Normal file
82
models/gap_detector.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""
|
||||||
|
滑块缺口检测 CNN (GapDetectorCNN)
|
||||||
|
|
||||||
|
用于检测滑块验证码中缺口的 x 坐标位置。
|
||||||
|
输出 sigmoid 归一化到 [0,1],推理时按图片宽度缩放回像素坐标。
|
||||||
|
|
||||||
|
架构:
|
||||||
|
Conv(1→32) + BN + ReLU + Pool
|
||||||
|
Conv(32→64) + BN + ReLU + Pool
|
||||||
|
Conv(64→128) + BN + ReLU + Pool
|
||||||
|
Conv(128→128) + BN + ReLU + Pool
|
||||||
|
AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid
|
||||||
|
|
||||||
|
约 250K 参数,~1MB。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class GapDetectorCNN(nn.Module):
|
||||||
|
"""
|
||||||
|
滑块缺口检测 CNN,输出缺口 x 坐标的归一化百分比 [0,1]。
|
||||||
|
|
||||||
|
与 RegressionCNN 架构相同,但语义上专用于滑块缺口检测,
|
||||||
|
默认输入尺寸 1x128x256 (灰度)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_h: int = 128, img_w: int = 256):
|
||||||
|
super().__init__()
|
||||||
|
self.img_h = img_h
|
||||||
|
self.img_w = img_w
|
||||||
|
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
# block 1: 1 → 32, H/2, W/2
|
||||||
|
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
# block 2: 32 → 64, H/4, W/4
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
# block 3: 64 → 128, H/8, W/8
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
# block 4: 128 → 128, H/16, W/16
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
|
||||||
|
self.regressor = nn.Sequential(
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(64, 1),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (batch, 1, H, W) 灰度图
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: (batch, 1) sigmoid 输出 [0, 1],表示缺口 x 坐标百分比
|
||||||
|
"""
|
||||||
|
feat = self.features(x)
|
||||||
|
feat = self.pool(feat) # (B, 128, 1, 1)
|
||||||
|
feat = feat.flatten(1) # (B, 128)
|
||||||
|
out = self.regressor(feat) # (B, 1)
|
||||||
|
return out
|
||||||
82
models/rotation_regressor.py
Normal file
82
models/rotation_regressor.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""
|
||||||
|
旋转角度回归模型 (RotationRegressor)
|
||||||
|
|
||||||
|
用于预测旋转验证码的正确旋转角度。
|
||||||
|
使用 sin/cos 编码避免 0°/360° 边界问题。
|
||||||
|
RGB 输入,输出 (sin θ, cos θ) ∈ [-1,1]。
|
||||||
|
|
||||||
|
架构:
|
||||||
|
Conv(3→32) + BN + ReLU + Pool
|
||||||
|
Conv(32→64) + BN + ReLU + Pool
|
||||||
|
Conv(64→128) + BN + ReLU + Pool
|
||||||
|
Conv(128→256) + BN + ReLU + Pool
|
||||||
|
AdaptiveAvgPool2d(1) → FC(256→128) → ReLU → FC(128→2) → Tanh
|
||||||
|
|
||||||
|
约 400K 参数,~2MB。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class RotationRegressor(nn.Module):
|
||||||
|
"""
|
||||||
|
旋转角度回归模型。
|
||||||
|
|
||||||
|
RGB 输入 3x128x128,输出 (sin θ, cos θ)。
|
||||||
|
推理时用 atan2(sin, cos) 转换为角度。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, img_h: int = 128, img_w: int = 128):
|
||||||
|
super().__init__()
|
||||||
|
self.img_h = img_h
|
||||||
|
self.img_w = img_w
|
||||||
|
|
||||||
|
self.features = nn.Sequential(
|
||||||
|
# block 1: 3 → 32, H/2, W/2
|
||||||
|
nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
# block 2: 32 → 64, H/4, W/4
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
# block 3: 64 → 128, H/8, W/8
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
|
||||||
|
# block 4: 128 → 256, H/16, W/16
|
||||||
|
nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
|
||||||
|
self.regressor = nn.Sequential(
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, 2),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (batch, 3, H, W) RGB 图
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: (batch, 2) → (sin θ, cos θ) ∈ [-1, 1]
|
||||||
|
"""
|
||||||
|
feat = self.features(x)
|
||||||
|
feat = self.pool(feat) # (B, 256, 1, 1)
|
||||||
|
feat = feat.flatten(1) # (B, 256)
|
||||||
|
out = self.regressor(feat) # (B, 2)
|
||||||
|
return out
|
||||||
@@ -20,6 +20,9 @@ server = [
|
|||||||
"uvicorn>=0.23.0",
|
"uvicorn>=0.23.0",
|
||||||
"python-multipart>=0.0.6",
|
"python-multipart>=0.0.6",
|
||||||
]
|
]
|
||||||
|
cv = [
|
||||||
|
"opencv-python>=4.8.0",
|
||||||
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
captcha = "cli:main"
|
captcha = "cli:main"
|
||||||
|
|||||||
17
solvers/__init__.py
Normal file
17
solvers/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
验证码求解器包
|
||||||
|
|
||||||
|
提供两种交互式验证码求解器:
|
||||||
|
- SlideSolver: 滑块验证码求解 (OpenCV 优先 + CNN 兜底)
|
||||||
|
- RotateSolver: 旋转验证码求解 (ONNX sin/cos 回归)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from solvers.base import BaseSolver
|
||||||
|
from solvers.slide_solver import SlideSolver
|
||||||
|
from solvers.rotate_solver import RotateSolver
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseSolver",
|
||||||
|
"SlideSolver",
|
||||||
|
"RotateSolver",
|
||||||
|
]
|
||||||
21
solvers/base.py
Normal file
21
solvers/base.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
求解器基类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSolver:
|
||||||
|
"""验证码求解器基类。"""
|
||||||
|
|
||||||
|
def solve(self, image: Image.Image, **kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
求解验证码。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 输入图片
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含求解结果的字典
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
80
solvers/rotate_solver.py
Normal file
80
solvers/rotate_solver.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
旋转验证码求解器
|
||||||
|
|
||||||
|
ONNX 推理 → (sin, cos) → atan2 → 角度
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from config import ONNX_DIR, SOLVER_CONFIG
|
||||||
|
from solvers.base import BaseSolver
|
||||||
|
|
||||||
|
|
||||||
|
class RotateSolver(BaseSolver):
|
||||||
|
"""旋转验证码求解器。"""
|
||||||
|
|
||||||
|
def __init__(self, onnx_path: str | Path | None = None):
|
||||||
|
self.cfg = SOLVER_CONFIG["rotate"]
|
||||||
|
self._onnx_session = None
|
||||||
|
self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "rotation_regressor.onnx"
|
||||||
|
|
||||||
|
def _load_onnx(self):
|
||||||
|
"""延迟加载 ONNX 模型。"""
|
||||||
|
if self._onnx_session is not None:
|
||||||
|
return
|
||||||
|
if not self._onnx_path.exists():
|
||||||
|
raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}")
|
||||||
|
import onnxruntime as ort
|
||||||
|
self._onnx_session = ort.InferenceSession(
|
||||||
|
str(self._onnx_path), providers=["CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def solve(self, image: Image.Image | str | Path, **kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
求解旋转验证码。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 输入图片 (RGB)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"angle": float, "confidence": float}
|
||||||
|
"""
|
||||||
|
if isinstance(image, (str, Path)):
|
||||||
|
image = Image.open(str(image)).convert("RGB")
|
||||||
|
else:
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
self._load_onnx()
|
||||||
|
|
||||||
|
h, w = self.cfg["input_size"]
|
||||||
|
|
||||||
|
# 预处理: RGB resize + normalize
|
||||||
|
img = image.resize((w, h))
|
||||||
|
arr = np.array(img, dtype=np.float32) / 255.0
|
||||||
|
# Normalize per channel: (x - 0.5) / 0.5
|
||||||
|
arr = (arr - 0.5) / 0.5
|
||||||
|
# HWC → CHW → NCHW
|
||||||
|
arr = arr.transpose(2, 0, 1)[np.newaxis, :, :, :]
|
||||||
|
|
||||||
|
outputs = self._onnx_session.run(None, {"input": arr})
|
||||||
|
sin_val = float(outputs[0][0][0])
|
||||||
|
cos_val = float(outputs[0][0][1])
|
||||||
|
|
||||||
|
# atan2 → 角度
|
||||||
|
angle_rad = math.atan2(sin_val, cos_val)
|
||||||
|
angle_deg = math.degrees(angle_rad)
|
||||||
|
if angle_deg < 0:
|
||||||
|
angle_deg += 360.0
|
||||||
|
|
||||||
|
# 置信度: sin^2 + cos^2 接近 1 表示预测稳定
|
||||||
|
magnitude = math.sqrt(sin_val ** 2 + cos_val ** 2)
|
||||||
|
confidence = min(magnitude, 1.0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"angle": round(angle_deg, 1),
|
||||||
|
"confidence": round(confidence, 3),
|
||||||
|
}
|
||||||
179
solvers/slide_solver.py
Normal file
179
solvers/slide_solver.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
滑块验证码求解器
|
||||||
|
|
||||||
|
三种求解方法 (按优先级):
|
||||||
|
1. 模板匹配: 背景图 + 模板图 → Canny → matchTemplate
|
||||||
|
2. 边缘检测: 单图 Canny → findContours → 筛选方形轮廓
|
||||||
|
3. CNN 兜底: ONNX 推理 → sigmoid → x 百分比 → 像素
|
||||||
|
|
||||||
|
OpenCV 延迟导入,未安装时退化到 CNN only。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from config import ONNX_DIR, SOLVER_CONFIG
|
||||||
|
from solvers.base import BaseSolver
|
||||||
|
|
||||||
|
|
||||||
|
class SlideSolver(BaseSolver):
|
||||||
|
"""滑块验证码求解器。"""
|
||||||
|
|
||||||
|
def __init__(self, onnx_path: str | Path | None = None):
|
||||||
|
self.cfg = SOLVER_CONFIG["slide"]
|
||||||
|
self._onnx_session = None
|
||||||
|
self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "gap_detector.onnx"
|
||||||
|
|
||||||
|
# 检测 OpenCV 可用性
|
||||||
|
self._cv2_available = False
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
self._cv2_available = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _load_onnx(self):
|
||||||
|
"""延迟加载 ONNX 模型。"""
|
||||||
|
if self._onnx_session is not None:
|
||||||
|
return
|
||||||
|
if not self._onnx_path.exists():
|
||||||
|
raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}")
|
||||||
|
import onnxruntime as ort
|
||||||
|
self._onnx_session = ort.InferenceSession(
|
||||||
|
str(self._onnx_path), providers=["CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self,
|
||||||
|
bg_image: Image.Image | str | Path,
|
||||||
|
template_image: Image.Image | str | Path | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
求解滑块验证码。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bg_image: 背景图 (必需)
|
||||||
|
template_image: 模板/拼图块图 (可选,有则优先模板匹配)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"gap_x": int, "gap_x_percent": float, "confidence": float, "method": str}
|
||||||
|
"""
|
||||||
|
bg = self._load_image(bg_image)
|
||||||
|
|
||||||
|
# 方法 1: 模板匹配
|
||||||
|
if template_image is not None and self._cv2_available:
|
||||||
|
tpl = self._load_image(template_image)
|
||||||
|
result = self._template_match(bg, tpl)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 方法 2: 边缘检测
|
||||||
|
if self._cv2_available:
|
||||||
|
result = self._edge_detect(bg)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 方法 3: CNN 兜底
|
||||||
|
return self._cnn_predict(bg)
|
||||||
|
|
||||||
|
def _load_image(self, img: Image.Image | str | Path) -> Image.Image:
|
||||||
|
if isinstance(img, (str, Path)):
|
||||||
|
return Image.open(str(img)).convert("RGB")
|
||||||
|
return img.convert("RGB")
|
||||||
|
|
||||||
|
def _template_match(self, bg: Image.Image, tpl: Image.Image) -> dict | None:
|
||||||
|
"""模板匹配法。"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
bg_gray = np.array(bg.convert("L"))
|
||||||
|
tpl_gray = np.array(tpl.convert("L"))
|
||||||
|
|
||||||
|
# Canny 边缘
|
||||||
|
bg_edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"])
|
||||||
|
tpl_edges = cv2.Canny(tpl_gray, self.cfg["canny_low"], self.cfg["canny_high"])
|
||||||
|
|
||||||
|
if tpl_edges.sum() == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = cv2.matchTemplate(bg_edges, tpl_edges, cv2.TM_CCOEFF_NORMED)
|
||||||
|
_, max_val, _, max_loc = cv2.minMaxLoc(result)
|
||||||
|
|
||||||
|
if max_val < 0.3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
gap_x = max_loc[0] + tpl_gray.shape[1] // 2
|
||||||
|
return {
|
||||||
|
"gap_x": int(gap_x),
|
||||||
|
"gap_x_percent": gap_x / bg_gray.shape[1],
|
||||||
|
"confidence": float(max_val),
|
||||||
|
"method": "template_match",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _edge_detect(self, bg: Image.Image) -> dict | None:
|
||||||
|
"""边缘检测法:找方形轮廓。"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
bg_gray = np.array(bg.convert("L"))
|
||||||
|
h, w = bg_gray.shape
|
||||||
|
|
||||||
|
edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"])
|
||||||
|
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
best = None
|
||||||
|
best_score = 0
|
||||||
|
|
||||||
|
for cnt in contours:
|
||||||
|
area = cv2.contourArea(cnt)
|
||||||
|
# 面积筛选: 缺口大小在合理范围
|
||||||
|
if area < (h * w * 0.005) or area > (h * w * 0.15):
|
||||||
|
continue
|
||||||
|
|
||||||
|
x, y, cw, ch = cv2.boundingRect(cnt)
|
||||||
|
aspect = min(cw, ch) / max(cw, ch) if max(cw, ch) > 0 else 0
|
||||||
|
# 近似方形
|
||||||
|
if aspect < 0.5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 评分: 面积适中 + 近似方形
|
||||||
|
score = aspect * (area / (h * w * 0.05))
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best = (x + cw // 2, cw, ch, score)
|
||||||
|
|
||||||
|
if best is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
gap_x, _, _, score = best
|
||||||
|
return {
|
||||||
|
"gap_x": int(gap_x),
|
||||||
|
"gap_x_percent": gap_x / w,
|
||||||
|
"confidence": min(float(score), 1.0),
|
||||||
|
"method": "edge_detect",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _cnn_predict(self, bg: Image.Image) -> dict:
|
||||||
|
"""CNN 推理兜底。"""
|
||||||
|
self._load_onnx()
|
||||||
|
|
||||||
|
h, w = self.cfg["cnn_input_size"]
|
||||||
|
orig_w = bg.width
|
||||||
|
|
||||||
|
# 预处理: 灰度 + resize + normalize
|
||||||
|
img = bg.convert("L").resize((w, h))
|
||||||
|
arr = np.array(img, dtype=np.float32) / 255.0
|
||||||
|
arr = (arr - 0.5) / 0.5
|
||||||
|
arr = arr[np.newaxis, np.newaxis, :, :] # (1, 1, H, W)
|
||||||
|
|
||||||
|
outputs = self._onnx_session.run(None, {"input": arr})
|
||||||
|
percent = float(outputs[0][0][0])
|
||||||
|
|
||||||
|
gap_x = int(percent * orig_w)
|
||||||
|
return {
|
||||||
|
"gap_x": gap_x,
|
||||||
|
"gap_x_percent": percent,
|
||||||
|
"confidence": 0.5, # CNN 无置信度
|
||||||
|
"method": "cnn",
|
||||||
|
}
|
||||||
@@ -224,3 +224,55 @@ class RegressionDataset(Dataset):
|
|||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
return img, torch.tensor([label], dtype=torch.float32)
|
return img, torch.tensor([label], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 旋转求解器用数据集 (sin/cos 编码)
|
||||||
|
# ============================================================
|
||||||
|
class RotateSolverDataset(Dataset):
|
||||||
|
"""
|
||||||
|
旋转求解器数据集。
|
||||||
|
|
||||||
|
从目录中读取 {angle}_{xxx}.png 文件,
|
||||||
|
将角度转换为 (sin θ, cos θ) 目标。
|
||||||
|
RGB 输入,不转灰度。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dirs: list[str | Path],
|
||||||
|
transform: transforms.Compose | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dirs: 数据目录列表
|
||||||
|
transform: 图片预处理/增强 (RGB)
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
self.transform = transform
|
||||||
|
self.samples: list[tuple[str, float, float]] = [] # (路径, sin, cos)
|
||||||
|
|
||||||
|
for d in dirs:
|
||||||
|
d = Path(d)
|
||||||
|
if not d.exists():
|
||||||
|
continue
|
||||||
|
for f in sorted(d.glob("*.png")):
|
||||||
|
raw_label = f.stem.rsplit("_", 1)[0]
|
||||||
|
try:
|
||||||
|
angle = float(raw_label)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
rad = math.radians(angle)
|
||||||
|
self.samples.append((str(f), math.sin(rad), math.cos(rad)))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
import torch
|
||||||
|
path, sin_val, cos_val = self.samples[idx]
|
||||||
|
img = Image.open(path).convert("RGB")
|
||||||
|
if self.transform:
|
||||||
|
img = self.transform(img)
|
||||||
|
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
|
||||||
|
|
||||||
|
|||||||
245
training/train_rotate_solver.py
Normal file
245
training/train_rotate_solver.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""
|
||||||
|
训练旋转验证码角度回归模型 (RotationRegressor)
|
||||||
|
|
||||||
|
自定义训练循环 (sin/cos 编码),不复用 train_regression_utils。
|
||||||
|
|
||||||
|
用法: python -m training.train_rotate_solver
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
SOLVER_CONFIG,
|
||||||
|
SOLVER_TRAIN_CONFIG,
|
||||||
|
ROTATE_SOLVER_DATA_DIR,
|
||||||
|
CHECKPOINTS_DIR,
|
||||||
|
ONNX_DIR,
|
||||||
|
ONNX_CONFIG,
|
||||||
|
AUGMENT_CONFIG,
|
||||||
|
RANDOM_SEED,
|
||||||
|
get_device,
|
||||||
|
)
|
||||||
|
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||||
|
from models.rotation_regressor import RotationRegressor
|
||||||
|
from training.dataset import RotateSolverDataset
|
||||||
|
|
||||||
|
|
||||||
|
def _set_seed(seed: int = RANDOM_SEED):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _circular_mae_deg(pred_angles: np.ndarray, gt_angles: np.ndarray) -> float:
|
||||||
|
"""循环 MAE (度数)。"""
|
||||||
|
diff = np.abs(pred_angles - gt_angles)
|
||||||
|
diff = np.minimum(diff, 360.0 - diff)
|
||||||
|
return float(np.mean(diff))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||||
|
"""RGB 训练增强 (不转灰度)。"""
|
||||||
|
aug = AUGMENT_CONFIG
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize((img_h, img_w)),
|
||||||
|
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||||||
|
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||||
|
"""RGB 验证 transform。"""
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize((img_h, img_w)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _export_onnx(model: nn.Module, img_h: int, img_w: int):
|
||||||
|
"""导出 ONNX (RGB 3通道输入)。"""
|
||||||
|
model.eval()
|
||||||
|
onnx_path = ONNX_DIR / "rotation_regressor.onnx"
|
||||||
|
dummy = torch.randn(1, 3, img_h, img_w)
|
||||||
|
torch.onnx.export(
|
||||||
|
model.cpu(),
|
||||||
|
dummy,
|
||||||
|
str(onnx_path),
|
||||||
|
opset_version=ONNX_CONFIG["opset_version"],
|
||||||
|
input_names=["input"],
|
||||||
|
output_names=["output"],
|
||||||
|
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
|
||||||
|
if ONNX_CONFIG["dynamic_batch"]
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cfg = SOLVER_TRAIN_CONFIG["rotate"]
|
||||||
|
solver_cfg = SOLVER_CONFIG["rotate"]
|
||||||
|
img_h, img_w = solver_cfg["input_size"]
|
||||||
|
tolerance = 5.0 # ±5°
|
||||||
|
|
||||||
|
_set_seed()
|
||||||
|
device = get_device()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("训练旋转验证码角度回归模型 (RotationRegressor)")
|
||||||
|
print(f" 输入尺寸: {img_h}×{img_w} RGB")
|
||||||
|
print(f" 编码: sin/cos")
|
||||||
|
print(f" 容差: ±{tolerance}°")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# ---- 1. 检查 / 生成合成数据 ----
|
||||||
|
syn_path = ROTATE_SOLVER_DATA_DIR
|
||||||
|
syn_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
existing = list(syn_path.glob("*.png"))
|
||||||
|
if len(existing) < cfg["synthetic_samples"]:
|
||||||
|
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||||
|
gen = RotateSolverDataGenerator()
|
||||||
|
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||||
|
else:
|
||||||
|
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||||
|
|
||||||
|
# ---- 2. 构建数据集 ----
|
||||||
|
data_dirs = [str(syn_path)]
|
||||||
|
real_dir = syn_path / "real"
|
||||||
|
real_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if list(real_dir.glob("*.png")):
|
||||||
|
data_dirs.append(str(real_dir))
|
||||||
|
print(f"[数据] 混合真实数据: {len(list(real_dir.glob('*.png')))} 张")
|
||||||
|
|
||||||
|
train_transform = _build_train_transform(img_h, img_w)
|
||||||
|
val_transform = _build_val_transform(img_h, img_w)
|
||||||
|
|
||||||
|
full_dataset = RotateSolverDataset(dirs=data_dirs, transform=train_transform)
|
||||||
|
total = len(full_dataset)
|
||||||
|
val_size = int(total * cfg["val_split"])
|
||||||
|
train_size = total - val_size
|
||||||
|
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||||
|
|
||||||
|
val_ds_clean = RotateSolverDataset(dirs=data_dirs, transform=val_transform)
|
||||||
|
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||||
|
num_workers=0, pin_memory=True,
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||||
|
num_workers=0, pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||||
|
|
||||||
|
# ---- 3. 模型 / 优化器 / 调度器 ----
|
||||||
|
model = RotationRegressor(img_h=img_h, img_w=img_w).to(device)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||||
|
loss_fn = nn.MSELoss()
|
||||||
|
|
||||||
|
best_mae = float("inf")
|
||||||
|
best_tol_acc = 0.0
|
||||||
|
ckpt_path = CHECKPOINTS_DIR / "rotation_regressor.pth"
|
||||||
|
|
||||||
|
# ---- 4. 训练循环 ----
|
||||||
|
for epoch in range(1, cfg["epochs"] + 1):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||||
|
for images, targets in pbar:
|
||||||
|
images = images.to(device)
|
||||||
|
targets = targets.to(device) # (B, 2) → (sin, cos)
|
||||||
|
|
||||||
|
preds = model(images) # (B, 2)
|
||||||
|
loss = loss_fn(preds, targets)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
|
# ---- 5. 验证 ----
|
||||||
|
model.eval()
|
||||||
|
all_pred_angles = []
|
||||||
|
all_gt_angles = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, targets in val_loader:
|
||||||
|
images = images.to(device)
|
||||||
|
preds = model(images).cpu().numpy() # (B, 2)
|
||||||
|
targets_np = targets.numpy() # (B, 2)
|
||||||
|
|
||||||
|
# sin/cos → angle
|
||||||
|
for i in range(len(preds)):
|
||||||
|
pred_angle = math.degrees(math.atan2(preds[i][0], preds[i][1]))
|
||||||
|
if pred_angle < 0:
|
||||||
|
pred_angle += 360.0
|
||||||
|
gt_angle = math.degrees(math.atan2(targets_np[i][0], targets_np[i][1]))
|
||||||
|
if gt_angle < 0:
|
||||||
|
gt_angle += 360.0
|
||||||
|
all_pred_angles.append(pred_angle)
|
||||||
|
all_gt_angles.append(gt_angle)
|
||||||
|
|
||||||
|
pred_arr = np.array(all_pred_angles)
|
||||||
|
gt_arr = np.array(all_gt_angles)
|
||||||
|
|
||||||
|
mae = _circular_mae_deg(pred_arr, gt_arr)
|
||||||
|
diff = np.abs(pred_arr - gt_arr)
|
||||||
|
diff = np.minimum(diff, 360.0 - diff)
|
||||||
|
tol_acc = float(np.mean(diff <= tolerance))
|
||||||
|
lr = scheduler.get_last_lr()[0]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||||
|
f"loss={avg_loss:.4f} "
|
||||||
|
f"MAE={mae:.2f}° "
|
||||||
|
f"tol_acc(±{tolerance:.0f}°)={tol_acc:.4f} "
|
||||||
|
f"lr={lr:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- 6. 保存最佳模型 ----
|
||||||
|
if tol_acc >= best_tol_acc:
|
||||||
|
best_tol_acc = tol_acc
|
||||||
|
best_mae = mae
|
||||||
|
torch.save({
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"best_mae": best_mae,
|
||||||
|
"best_tol_acc": best_tol_acc,
|
||||||
|
"epoch": epoch,
|
||||||
|
}, ckpt_path)
|
||||||
|
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")
|
||||||
|
|
||||||
|
# ---- 7. 导出 ONNX ----
|
||||||
|
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}°")
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||||
|
model.load_state_dict(ckpt["model_state_dict"])
|
||||||
|
_export_onnx(model, img_h, img_w)
|
||||||
|
|
||||||
|
return best_tol_acc
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
65
training/train_slide.py
Normal file
65
training/train_slide.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
训练滑块缺口检测 CNN (GapDetectorCNN)
|
||||||
|
|
||||||
|
复用 train_regression_utils 的通用回归训练流程。
|
||||||
|
|
||||||
|
用法: python -m training.train_slide
|
||||||
|
"""
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
SOLVER_CONFIG,
|
||||||
|
SOLVER_TRAIN_CONFIG,
|
||||||
|
SOLVER_REGRESSION_RANGE,
|
||||||
|
SLIDE_DATA_DIR,
|
||||||
|
CHECKPOINTS_DIR,
|
||||||
|
ONNX_DIR,
|
||||||
|
ONNX_CONFIG,
|
||||||
|
RANDOM_SEED,
|
||||||
|
get_device,
|
||||||
|
)
|
||||||
|
from generators.slide_gen import SlideDataGenerator
|
||||||
|
from models.gap_detector import GapDetectorCNN
|
||||||
|
|
||||||
|
# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE
|
||||||
|
# 以便复用 train_regression_utils
|
||||||
|
import config as _cfg
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
solver_cfg = SOLVER_CONFIG["slide"]
|
||||||
|
train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"]
|
||||||
|
img_h, img_w = solver_cfg["cnn_input_size"]
|
||||||
|
|
||||||
|
model = GapDetectorCNN(img_h=img_h, img_w=img_w)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("训练滑块缺口检测 CNN (GapDetectorCNN)")
|
||||||
|
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||||
|
print(f" 任务: 预测缺口 x 坐标百分比")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 直接使用 train_regression_utils 中的逻辑
|
||||||
|
# 但需要临时注入配置
|
||||||
|
_cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg
|
||||||
|
_cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w)
|
||||||
|
_cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"]
|
||||||
|
|
||||||
|
from training.train_regression_utils import train_regression_model
|
||||||
|
|
||||||
|
# 确保数据目录存在
|
||||||
|
SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
real_dir = SLIDE_DATA_DIR / "real"
|
||||||
|
real_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
train_regression_model(
|
||||||
|
model_name="gap_detector",
|
||||||
|
model=model,
|
||||||
|
synthetic_dir=str(SLIDE_DATA_DIR),
|
||||||
|
real_dir=str(real_dir),
|
||||||
|
generator_cls=SlideDataGenerator,
|
||||||
|
config_key="slide_cnn",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
3
utils/__init__.py
Normal file
3
utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
工具函数包
|
||||||
|
"""
|
||||||
75
utils/slide_utils.py
Normal file
75
utils/slide_utils.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
滑块轨迹生成工具
|
||||||
|
|
||||||
|
生成模拟人类操作的滑块拖拽轨迹。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def generate_slide_track(
|
||||||
|
distance: int,
|
||||||
|
duration: float = 1.0,
|
||||||
|
seed: int | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
生成滑块拖拽轨迹。
|
||||||
|
|
||||||
|
使用贝塞尔曲线 ease-out 加速减速,末尾带微小过冲回退。
|
||||||
|
y 轴 ±1~3px 随机抖动,时间间隔不均匀。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distance: 滑动距离 (像素)
|
||||||
|
duration: 总时长 (秒)
|
||||||
|
seed: 随机种子
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[{"x": float, "y": float, "t": float}, ...]
|
||||||
|
"""
|
||||||
|
rng = random.Random(seed)
|
||||||
|
|
||||||
|
if distance <= 0:
|
||||||
|
return [{"x": 0.0, "y": 0.0, "t": 0.0}]
|
||||||
|
|
||||||
|
track = []
|
||||||
|
# 采样点数
|
||||||
|
num_points = rng.randint(30, 60)
|
||||||
|
total_ms = duration * 1000
|
||||||
|
|
||||||
|
# 生成不均匀时间点
|
||||||
|
raw_times = sorted([rng.random() for _ in range(num_points - 2)])
|
||||||
|
times = [0.0] + raw_times + [1.0]
|
||||||
|
|
||||||
|
# 过冲距离
|
||||||
|
overshoot = rng.uniform(2, 6)
|
||||||
|
overshoot_start = 0.85 # 85% 时到达目标 + 过冲
|
||||||
|
|
||||||
|
for t_norm in times:
|
||||||
|
t_ms = round(t_norm * total_ms, 1)
|
||||||
|
|
||||||
|
if t_norm <= overshoot_start:
|
||||||
|
# ease-out: 快速启动,缓慢减速
|
||||||
|
progress = t_norm / overshoot_start
|
||||||
|
eased = 1 - (1 - progress) ** 3 # cubic ease-out
|
||||||
|
x = eased * (distance + overshoot)
|
||||||
|
else:
|
||||||
|
# 过冲回退段
|
||||||
|
retract_progress = (t_norm - overshoot_start) / (1 - overshoot_start)
|
||||||
|
eased_retract = retract_progress ** 2 # ease-in 回退
|
||||||
|
x = (distance + overshoot) - overshoot * eased_retract
|
||||||
|
|
||||||
|
# y 轴随机抖动
|
||||||
|
y_jitter = rng.uniform(-3, 3) if t_norm > 0.05 else 0.0
|
||||||
|
|
||||||
|
track.append({
|
||||||
|
"x": round(x, 1),
|
||||||
|
"y": round(y_jitter, 1),
|
||||||
|
"t": t_ms,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 确保最后一个点精确到达目标
|
||||||
|
track[-1]["x"] = float(distance)
|
||||||
|
track[-1]["y"] = round(rng.uniform(-0.5, 0.5), 1)
|
||||||
|
|
||||||
|
return track
|
||||||
Reference in New Issue
Block a user