Add slide and rotate interactive captcha solvers

New solver subsystem with independent models:
- GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection
- RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction
- SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback
- RotateSolver with ONNX sin/cos → atan2 inference
- Generators, training scripts, CLI commands, and slide track utility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hua
2026-03-11 18:07:06 +08:00
parent 90d6423551
commit 9b5f29083e
20 changed files with 1440 additions and 10 deletions

View File

@@ -33,7 +33,10 @@ captcha-breaker/
│ │ ├── 3d_text/
│ │ ├── 3d_rotate/
│ │ └── 3d_slider/
── classifier/ # 调度分类器训练数据 (混合各类型)
── classifier/ # 调度分类器训练数据 (混合各类型)
│ └── solver/ # Solver 训练数据
│ ├── slide/ # 滑块缺口检测训练数据
│ └── rotate/ # 旋转角度回归训练数据
├── generators/
│ ├── __init__.py
│ ├── base.py # 生成器基类
@@ -41,13 +44,17 @@ captcha-breaker/
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
│ ├── threed_gen.py # 3D立体文字验证码生成器
│ ├── threed_rotate_gen.py # 3D旋转验证码生成器
── threed_slider_gen.py # 3D滑块验证码生成器
── threed_slider_gen.py # 3D滑块验证码生成器
│ ├── slide_gen.py # 滑块缺口训练数据生成器
│ └── rotate_solver_gen.py # 旋转求解器训练数据生成器
├── models/
│ ├── __init__.py
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
│ ├── classifier.py # 调度分类模型
│ ├── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN)
── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
│ ├── gap_detector.py # 滑块缺口检测CNN (~1MB)
│ └── rotation_regressor.py # 旋转角度回归 sin/cos (~2MB)
├── training/
│ ├── __init__.py
│ ├── train_classifier.py # 训练调度模型
@@ -56,6 +63,8 @@ captcha-breaker/
│ ├── train_3d_text.py # 训练3D文字识别
│ ├── train_3d_rotate.py # 训练3D旋转回归
│ ├── train_3d_slider.py # 训练3D滑块回归
│ ├── train_slide.py # 训练滑块缺口检测
│ ├── train_rotate_solver.py # 训练旋转角度回归
│ ├── train_utils.py # CTC 训练通用逻辑
│ ├── train_regression_utils.py # 回归训练通用逻辑
│ └── dataset.py # 通用 Dataset 类
@@ -64,20 +73,32 @@ captcha-breaker/
│ ├── pipeline.py # 核心推理流水线 (调度+识别)
│ ├── export_onnx.py # PyTorch → ONNX 导出脚本
│ └── math_eval.py # 算式计算模块
├── solvers/ # 交互式验证码求解器
│ ├── __init__.py
│ ├── base.py # 求解器基类
│ ├── slide_solver.py # 滑块求解 (OpenCV + CNN)
│ └── rotate_solver.py # 旋转求解 (ONNX sin/cos)
├── utils/
│ ├── __init__.py
│ └── slide_utils.py # 滑块轨迹生成工具
├── checkpoints/ # 训练产出的模型文件
│ ├── classifier.pth
│ ├── normal.pth
│ ├── math.pth
│ ├── threed_text.pth
│ ├── threed_rotate.pth
── threed_slider.pth
── threed_slider.pth
│ ├── gap_detector.pth
│ └── rotation_regressor.pth
├── onnx_models/ # 导出的 ONNX 模型
│ ├── classifier.onnx
│ ├── normal.onnx
│ ├── math.onnx
│ ├── threed_text.onnx
│ ├── threed_rotate.onnx
── threed_slider.onnx
── threed_slider.onnx
│ ├── gap_detector.onnx
│ └── rotation_regressor.onnx
├── server.py # FastAPI 推理服务 (可选)
├── cli.py # 命令行入口
└── tests/
@@ -462,3 +483,62 @@ uv run python cli.py serve --port 8080
6. 实现 cli.py 统一入口
7. 可选: server.py HTTP 服务
8. 编写 tests/
## 交互式 Solver 扩展
### 概述
在现有验证码识别架构之上,新增滑块 (slide) 和旋转 (rotate) 两种交互式验证码求解能力。与现有 3d_rotate/3d_slider 的区别:
- **3d_slider** (合成拼图回归) → **slide solver**: 真实滑块验证码OpenCV 优先 + CNN 兜底
- **3d_rotate** (合成圆盘 sigmoid 回归) → **rotate solver**: 真实旋转验证码sin/cos 编码 + 自然图
每个 solver 模型独立训练、独立导出 ONNX、独立替换互不依赖。
### 滑块求解器 (SlideSolver)
- 三种方法按优先级: 模板匹配 → 边缘检测 → CNN 兜底
- 模型: `GapDetectorCNN` (1x128x256 灰度 → sigmoid [0,1])
- OpenCV 延迟导入,未安装时退化到 CNN only
- 输出: `{"gap_x", "gap_x_percent", "confidence", "method"}`
### 旋转求解器 (RotateSolver)
- ONNX 推理 → (sin, cos) → atan2 → 角度
- 模型: `RotationRegressor` (3x128x128 RGB → tanh (sin θ, cos θ))
- 输出: `{"angle", "confidence"}`
### Solver CLI 用法
```bash
# 生成训练数据
uv run python cli.py generate-solver slide --num 30000
uv run python cli.py generate-solver rotate --num 50000
# 训练 (各模型独立)
uv run python cli.py train-solver slide
uv run python cli.py train-solver rotate
# 求解
uv run python cli.py solve slide --bg bg.png [--tpl tpl.png]
uv run python cli.py solve rotate --image img.png
# 导出 (已集成到 export --all)
uv run python cli.py export --model gap_detector
uv run python cli.py export --model rotation_regressor
```
### 滑块轨迹生成
`utils/slide_utils.py` 提供 `generate_slide_track(distance)`:
- 贝塞尔曲线 ease-out 加速减速
- y 轴 ±1~3px 随机抖动
- 时间间隔不均匀
- 末尾微小过冲回退
### Solver 目标指标
| 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|------|-----------|---------|---------|
| 滑块 CNN (±5px) | > 85% | < 30ms | ~1MB |
| 旋转回归 (±5°) | > 85% | < 30ms | ~2MB |

108
cli.py
View File

@@ -13,6 +13,11 @@ CaptchaBreaker 命令行入口
python cli.py predict image.png --type normal
python cli.py predict-dir ./test_images/
python cli.py serve --port 8080
python cli.py generate-solver slide --num 30000
python cli.py train-solver slide
python cli.py train-solver rotate
python cli.py solve slide --bg bg.png [--tpl tpl.png]
python cli.py solve rotate --image img.png
"""
import argparse
@@ -195,6 +200,90 @@ def cmd_serve(args):
uvicorn.run(app, host=args.host, port=args.port)
def cmd_generate_solver(args):
"""生成 solver 训练数据。"""
from config import SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR
from generators.slide_gen import SlideDataGenerator
from generators.rotate_solver_gen import RotateSolverDataGenerator
solver_type = args.type
num = args.num
gen_map = {
"slide": (SlideDataGenerator, SLIDE_DATA_DIR),
"rotate": (RotateSolverDataGenerator, ROTATE_SOLVER_DATA_DIR),
}
if solver_type not in gen_map:
print(f"未知 solver 类型: {solver_type} 可选: {', '.join(gen_map.keys())}")
sys.exit(1)
gen_cls, out_dir = gen_map[solver_type]
out_dir.mkdir(parents=True, exist_ok=True)
print(f"生成 solver/{solver_type} 数据: {num} 张 → {out_dir}")
gen = gen_cls()
gen.generate_dataset(num, str(out_dir))
def cmd_train_solver(args):
"""训练 solver 模型。"""
solver_type = args.type
if solver_type == "slide":
from training.train_slide import main as train_fn
elif solver_type == "rotate":
from training.train_rotate_solver import main as train_fn
else:
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
sys.exit(1)
train_fn()
def cmd_solve(args):
"""求解验证码。"""
solver_type = args.type
if solver_type == "slide":
from solvers.slide_solver import SlideSolver
bg_path = args.bg
tpl_path = getattr(args, "tpl", None)
if not Path(bg_path).exists():
print(f"文件不存在: {bg_path}")
sys.exit(1)
solver = SlideSolver()
result = solver.solve(bg_path, template_image=tpl_path)
print(f"背景图: {bg_path}")
if tpl_path:
print(f"模板图: {tpl_path}")
print(f"缺口 x: {result['gap_x']} px")
print(f"缺口 x%: {result['gap_x_percent']:.4f}")
print(f"置信度: {result['confidence']:.4f}")
print(f"方法: {result['method']}")
elif solver_type == "rotate":
from solvers.rotate_solver import RotateSolver
image_path = args.image
if not Path(image_path).exists():
print(f"文件不存在: {image_path}")
sys.exit(1)
solver = RotateSolver()
result = solver.solve(image_path)
print(f"图片: {image_path}")
print(f"角度: {result['angle']}°")
print(f"置信度: {result['confidence']}")
else:
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
prog="captcha-breaker",
@@ -247,6 +336,22 @@ def main():
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
# ---- generate-solver ----
p_gen_solver = subparsers.add_parser("generate-solver", help="生成 solver 训练数据")
p_gen_solver.add_argument("type", help="solver 类型: slide, rotate")
p_gen_solver.add_argument("--num", type=int, required=True, help="生成数量")
# ---- train-solver ----
p_train_solver = subparsers.add_parser("train-solver", help="训练 solver 模型")
p_train_solver.add_argument("type", help="solver 类型: slide, rotate")
# ---- solve ----
p_solve = subparsers.add_parser("solve", help="求解交互式验证码")
p_solve.add_argument("type", help="solver 类型: slide, rotate")
p_solve.add_argument("--bg", help="背景图路径 (slide 必需)")
p_solve.add_argument("--tpl", default=None, help="模板图路径 (slide 可选)")
p_solve.add_argument("--image", help="图片路径 (rotate 必需)")
args = parser.parse_args()
if args.command is None:
@@ -260,6 +365,9 @@ def main():
"predict": cmd_predict,
"predict-dir": cmd_predict_dir,
"serve": cmd_serve,
"generate-solver": cmd_generate_solver,
"train-solver": cmd_train_solver,
"solve": cmd_solve,
}
cmd_map[args.command](args)

View File

@@ -34,6 +34,11 @@ REAL_3D_TEXT_DIR = REAL_DIR / "3d_text"
REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate"
REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider"
# Solver 数据目录
SOLVER_DATA_DIR = DATA_DIR / "solver"
SLIDE_DATA_DIR = SOLVER_DATA_DIR / "slide"
ROTATE_SOLVER_DATA_DIR = SOLVER_DATA_DIR / "rotate"
# ============================================================
# 模型输出目录
# ============================================================
@@ -47,6 +52,7 @@ for _dir in [
REAL_NORMAL_DIR, REAL_MATH_DIR,
REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR,
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR,
]:
_dir.mkdir(parents=True, exist_ok=True)
@@ -241,3 +247,40 @@ SERVER_CONFIG = {
"host": "0.0.0.0",
"port": 8080,
}
# ============================================================
# Solver 配置 (交互式验证码求解)
# ============================================================
SOLVER_CONFIG = {
"slide": {
"canny_low": 50,
"canny_high": 150,
"cnn_input_size": (128, 256), # H, W
},
"rotate": {
"input_size": (128, 128), # H, W
"channels": 3, # RGB
},
}
SOLVER_TRAIN_CONFIG = {
"slide_cnn": {
"epochs": 50,
"batch_size": 64,
"lr": 1e-3,
"synthetic_samples": 30000,
"val_split": 0.1,
},
"rotate": {
"epochs": 80,
"batch_size": 64,
"lr": 5e-4,
"synthetic_samples": 50000,
"val_split": 0.1,
},
}
SOLVER_REGRESSION_RANGE = {
"slide": (0, 1), # 归一化百分比
"rotate": (0, 360), # 角度
}

View File

@@ -1,12 +1,14 @@
"""
数据生成器包
提供种验证码类型的数据生成器:
提供种验证码类型的数据生成器:
- NormalCaptchaGenerator: 普通字符验证码
- MathCaptchaGenerator: 算式验证码
- ThreeDCaptchaGenerator: 3D 立体文字验证码
- ThreeDRotateGenerator: 3D 旋转验证码
- ThreeDSliderGenerator: 3D 滑块验证码
- SlideDataGenerator: 滑块验证码求解器训练数据
- RotateSolverDataGenerator: 旋转验证码求解器训练数据
"""
from generators.base import BaseCaptchaGenerator
@@ -15,6 +17,8 @@ from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
from generators.threed_rotate_gen import ThreeDRotateGenerator
from generators.threed_slider_gen import ThreeDSliderGenerator
from generators.slide_gen import SlideDataGenerator
from generators.rotate_solver_gen import RotateSolverDataGenerator
__all__ = [
"BaseCaptchaGenerator",
@@ -23,4 +27,6 @@ __all__ = [
"ThreeDCaptchaGenerator",
"ThreeDRotateGenerator",
"ThreeDSliderGenerator",
"SlideDataGenerator",
"RotateSolverDataGenerator",
]

View File

@@ -0,0 +1,156 @@
"""
旋转验证码求解器数据生成器
生成旋转验证码训练数据:随机图案 (色块/渐变/几何图形),随机旋转 0-359°。
裁剪为圆形 (黑色背景填充圆外区域)。
标签 = 旋转角度 (整数)
文件名格式: {angle}_{index:06d}.png
"""
import math
import random
from PIL import Image, ImageDraw, ImageFilter, ImageFont
from config import SOLVER_CONFIG
from generators.base import BaseCaptchaGenerator
_FONT_PATHS = [
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
]
class RotateSolverDataGenerator(BaseCaptchaGenerator):
"""旋转验证码求解器数据生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = SOLVER_CONFIG["rotate"]
self.height, self.width = self.cfg["input_size"] # (H, W)
self._fonts: list[str] = []
for p in _FONT_PATHS:
try:
ImageFont.truetype(p, 20)
self._fonts.append(p)
except OSError:
continue
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
# 随机旋转角度 0-359
angle = rng.randint(0, 359)
if text is None:
text = str(angle)
size = self.width # 正方形
radius = size // 2
# 1. 生成正向图案 (未旋转)
content = self._random_pattern(rng, size)
# 2. 旋转图案
rotated = content.rotate(-angle, resample=Image.BICUBIC, expand=False)
# 3. 裁剪为圆形 (黑色背景)
result = Image.new("RGB", (size, size), (0, 0, 0))
mask = Image.new("L", (size, size), 0)
mask_draw = ImageDraw.Draw(mask)
mask_draw.ellipse([0, 0, size - 1, size - 1], fill=255)
result.paste(rotated, (0, 0), mask)
# 4. 轻微模糊
result = result.filter(ImageFilter.GaussianBlur(radius=0.5))
return result, text
def _random_pattern(self, rng: random.Random, size: int) -> Image.Image:
"""生成随机图案 (带明显方向性,便于模型学习旋转)。"""
img = Image.new("RGB", (size, size))
draw = ImageDraw.Draw(img)
# 渐变背景
base_r = rng.randint(100, 220)
base_g = rng.randint(100, 220)
base_b = rng.randint(100, 220)
for y in range(size):
ratio = y / max(size - 1, 1)
r = int(base_r * (1 - ratio) + rng.randint(40, 120) * ratio)
g = int(base_g * (1 - ratio) + rng.randint(40, 120) * ratio)
b = int(base_b * (1 - ratio) + rng.randint(40, 120) * ratio)
draw.line([(0, y), (size, y)], fill=(r, g, b))
cx, cy = size // 2, size // 2
# 添加不对称几何图形 (让模型能感知方向)
pattern_type = rng.choice(["triangle", "arrow", "text", "shapes"])
if pattern_type == "triangle":
# 顶部三角形标记
color = tuple(rng.randint(180, 255) for _ in range(3))
ts = size // 4
draw.polygon(
[(cx, cy - ts), (cx - ts // 2, cy), (cx + ts // 2, cy)],
fill=color,
)
# 底部小圆
draw.ellipse(
[cx - 8, cy + ts // 2, cx + 8, cy + ts // 2 + 16],
fill=tuple(rng.randint(50, 150) for _ in range(3)),
)
elif pattern_type == "arrow":
# 向上的箭头
color = tuple(rng.randint(180, 255) for _ in range(3))
arrow_len = size // 3
draw.line([(cx, cy - arrow_len), (cx, cy + arrow_len // 2)], fill=color, width=4)
draw.polygon(
[(cx, cy - arrow_len - 5), (cx - 10, cy - arrow_len + 10), (cx + 10, cy - arrow_len + 10)],
fill=color,
)
elif pattern_type == "text" and self._fonts:
# 文字 (有天然方向性)
font_path = rng.choice(self._fonts)
font_size = size // 3
try:
font = ImageFont.truetype(font_path, font_size)
ch = rng.choice("ABCDEFGHJKLMNPRSTUVWXYZ23456789")
bbox = font.getbbox(ch)
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
draw.text(
(cx - tw // 2 - bbox[0], cy - th // 2 - bbox[1]),
ch,
fill=tuple(rng.randint(0, 80) for _ in range(3)),
font=font,
)
except OSError:
pass
else:
# 混合不对称形状
# 上方矩形
w, h = rng.randint(15, 30), rng.randint(10, 20)
color = tuple(rng.randint(150, 255) for _ in range(3))
draw.rectangle([cx - w, cy - size // 3, cx + w, cy - size // 3 + h], fill=color)
# 右下小圆
r = rng.randint(5, 12)
color2 = tuple(rng.randint(50, 150) for _ in range(3))
draw.ellipse([cx + size // 5, cy + size // 5, cx + size // 5 + r * 2, cy + size // 5 + r * 2], fill=color2)
# 添加纹理噪声
for _ in range(rng.randint(20, 60)):
nx, ny = rng.randint(0, size - 1), rng.randint(0, size - 1)
nc = tuple(rng.randint(80, 220) for _ in range(3))
draw.point((nx, ny), fill=nc)
return img

112
generators/slide_gen.py Normal file
View File

@@ -0,0 +1,112 @@
"""
滑块验证码数据生成器
生成滑块验证码训练数据:随机纹理/色块背景 + 方形缺口 + 阴影效果。
标签 = 缺口中心 x 坐标 (整数)
文件名格式: {gap_x}_{index:06d}.png
"""
import random
from PIL import Image, ImageDraw, ImageFilter
from config import SOLVER_CONFIG
from generators.base import BaseCaptchaGenerator
class SlideDataGenerator(BaseCaptchaGenerator):
"""滑块验证码数据生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = SOLVER_CONFIG["slide"]
self.height, self.width = self.cfg["cnn_input_size"] # (H, W)
self.gap_size = 40 # 缺口大小
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
gs = self.gap_size
# 缺口 x 范围: 留出边距
margin = gs + 10
gap_x = rng.randint(margin, self.width - margin)
gap_y = rng.randint(10, self.height - gs - 10)
if text is None:
text = str(gap_x)
# 1. 生成纹理背景
img = self._textured_background(rng)
# 2. 绘制缺口 (半透明灰色区域 + 阴影)
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
# 阴影 (稍大一圈)
overlay_draw.rectangle(
[gap_x + 2, gap_y + 2, gap_x + gs + 2, gap_y + gs + 2],
fill=(0, 0, 0, 60),
)
# 缺口本体
overlay_draw.rectangle(
[gap_x, gap_y, gap_x + gs, gap_y + gs],
fill=(80, 80, 80, 160),
outline=(60, 60, 60, 200),
width=2,
)
img = img.convert("RGBA")
img = Image.alpha_composite(img, overlay)
img = img.convert("RGB")
# 3. 轻微模糊
img = img.filter(ImageFilter.GaussianBlur(radius=0.3))
return img, text
def _textured_background(self, rng: random.Random) -> Image.Image:
"""生成带纹理的彩色背景。"""
img = Image.new("RGB", (self.width, self.height))
draw = ImageDraw.Draw(img)
# 渐变底色
base_r = rng.randint(80, 200)
base_g = rng.randint(80, 200)
base_b = rng.randint(80, 200)
for y in range(self.height):
ratio = y / max(self.height - 1, 1)
r = int(base_r + 40 * ratio)
g = int(base_g - 20 * ratio)
b = int(base_b + 20 * ratio)
r, g, b = max(0, min(255, r)), max(0, min(255, g)), max(0, min(255, b))
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
# 纹理噪声
for _ in range(self.width * self.height // 6):
x = rng.randint(0, self.width - 1)
y = rng.randint(0, self.height - 1)
pixel = img.getpixel((x, y))
noise = tuple(
max(0, min(255, c + rng.randint(-30, 30)))
for c in pixel
)
draw.point((x, y), fill=noise)
# 随机色块 (模拟图案)
for _ in range(rng.randint(4, 8)):
x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20)
x2, y2 = x1 + rng.randint(15, 50), y1 + rng.randint(10, 30)
color = tuple(rng.randint(50, 230) for _ in range(3))
draw.rectangle([x1, y1, x2, y2], fill=color)
# 随机圆形
for _ in range(rng.randint(2, 5)):
cx = rng.randint(10, self.width - 10)
cy = rng.randint(10, self.height - 10)
cr = rng.randint(5, 20)
color = tuple(rng.randint(50, 230) for _ in range(3))
draw.ellipse([cx - cr, cy - cr, cx + cr, cy + cr], fill=color)
return img

View File

@@ -18,11 +18,14 @@ from config import (
THREED_CHARS,
NUM_CAPTCHA_TYPES,
REGRESSION_RANGE,
SOLVER_CONFIG,
)
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
from models.gap_detector import GapDetectorCNN
from models.rotation_regressor import RotationRegressor
def export_model(
@@ -52,7 +55,7 @@ def export_model(
dummy = torch.randn(1, *input_shape)
# 分类器和识别器的 dynamic_axes 不同
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider", "gap_detector", "rotation_regressor"):
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
else:
# CTC 模型: output shape = (T, B, C)
@@ -110,6 +113,14 @@ def _load_and_export(model_name: str):
h, w = IMAGE_SIZE["3d_slider"]
model = RegressionCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "gap_detector":
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
model = GapDetectorCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "rotation_regressor":
h, w = SOLVER_CONFIG["rotate"]["input_size"]
model = RotationRegressor(img_h=h, img_w=w)
input_shape = (3, h, w)
else:
print(f"[错误] 未知模型: {model_name}")
return
@@ -119,11 +130,15 @@ def _load_and_export(model_name: str):
def export_all():
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
"""依次导出全部模型 (含 solver 模型)"""
print("=" * 50)
print("导出全部 ONNX 模型")
print("=" * 50)
for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
for name in [
"classifier", "normal", "math", "threed_text",
"threed_rotate", "threed_slider",
"gap_detector", "rotation_regressor",
]:
_load_and_export(name)
print("\n全部导出完成。")

View File

@@ -1,21 +1,27 @@
"""
模型定义包
提供种模型:
提供种模型:
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
- ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
- GapDetectorCNN: 滑块缺口检测 CNN (~1MB)
- RotationRegressor: 旋转角度回归 sin/cos 编码 (~2MB)
"""
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
from models.gap_detector import GapDetectorCNN
from models.rotation_regressor import RotationRegressor
__all__ = [
"CaptchaClassifier",
"LiteCRNN",
"ThreeDCNN",
"RegressionCNN",
"GapDetectorCNN",
"RotationRegressor",
]

82
models/gap_detector.py Normal file
View File

@@ -0,0 +1,82 @@
"""
滑块缺口检测 CNN (GapDetectorCNN)
用于检测滑块验证码中缺口的 x 坐标位置。
输出 sigmoid 归一化到 [0,1],推理时按图片宽度缩放回像素坐标。
架构:
Conv(1→32) + BN + ReLU + Pool
Conv(32→64) + BN + ReLU + Pool
Conv(64→128) + BN + ReLU + Pool
Conv(128→128) + BN + ReLU + Pool
AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid
约 250K 参数,~1MB。
"""
import torch
import torch.nn as nn
class GapDetectorCNN(nn.Module):
"""
滑块缺口检测 CNN输出缺口 x 坐标的归一化百分比 [0,1]。
与 RegressionCNN 架构相同,但语义上专用于滑块缺口检测,
默认输入尺寸 1x128x256 (灰度)。
"""
def __init__(self, img_h: int = 128, img_w: int = 256):
super().__init__()
self.img_h = img_h
self.img_w = img_w
self.features = nn.Sequential(
# block 1: 1 → 32, H/2, W/2
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 2: 32 → 64, H/4, W/4
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 3: 64 → 128, H/8, W/8
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 4: 128 → 128, H/16, W/16
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.pool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, H, W) 灰度图
Returns:
output: (batch, 1) sigmoid 输出 [0, 1],表示缺口 x 坐标百分比
"""
feat = self.features(x)
feat = self.pool(feat) # (B, 128, 1, 1)
feat = feat.flatten(1) # (B, 128)
out = self.regressor(feat) # (B, 1)
return out

View File

@@ -0,0 +1,82 @@
"""
旋转角度回归模型 (RotationRegressor)
用于预测旋转验证码的正确旋转角度。
使用 sin/cos 编码避免 0°/360° 边界问题。
RGB 输入,输出 (sin θ, cos θ) ∈ [-1,1]。
架构:
Conv(3→32) + BN + ReLU + Pool
Conv(32→64) + BN + ReLU + Pool
Conv(64→128) + BN + ReLU + Pool
Conv(128→256) + BN + ReLU + Pool
AdaptiveAvgPool2d(1) → FC(256→128) → ReLU → FC(128→2) → Tanh
约 400K 参数,~2MB。
"""
import torch
import torch.nn as nn
class RotationRegressor(nn.Module):
"""
旋转角度回归模型。
RGB 输入 3x128x128输出 (sin θ, cos θ)。
推理时用 atan2(sin, cos) 转换为角度。
"""
def __init__(self, img_h: int = 128, img_w: int = 128):
super().__init__()
self.img_h = img_h
self.img_w = img_w
self.features = nn.Sequential(
# block 1: 3 → 32, H/2, W/2
nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 2: 32 → 64, H/4, W/4
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 3: 64 → 128, H/8, W/8
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 4: 128 → 256, H/16, W/16
nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.pool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 2),
nn.Tanh(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 3, H, W) RGB 图
Returns:
output: (batch, 2) → (sin θ, cos θ) ∈ [-1, 1]
"""
feat = self.features(x)
feat = self.pool(feat) # (B, 256, 1, 1)
feat = feat.flatten(1) # (B, 256)
out = self.regressor(feat) # (B, 2)
return out

View File

@@ -20,6 +20,9 @@ server = [
"uvicorn>=0.23.0",
"python-multipart>=0.0.6",
]
cv = [
"opencv-python>=4.8.0",
]
[project.scripts]
captcha = "cli:main"

17
solvers/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
"""
验证码求解器包
提供两种交互式验证码求解器:
- SlideSolver: 滑块验证码求解 (OpenCV 优先 + CNN 兜底)
- RotateSolver: 旋转验证码求解 (ONNX sin/cos 回归)
"""
from solvers.base import BaseSolver
from solvers.slide_solver import SlideSolver
from solvers.rotate_solver import RotateSolver
__all__ = [
"BaseSolver",
"SlideSolver",
"RotateSolver",
]

21
solvers/base.py Normal file
View File

@@ -0,0 +1,21 @@
"""
求解器基类
"""
from PIL import Image
class BaseSolver:
"""验证码求解器基类。"""
def solve(self, image: Image.Image, **kwargs) -> dict:
"""
求解验证码。
Args:
image: 输入图片
Returns:
包含求解结果的字典
"""
raise NotImplementedError

80
solvers/rotate_solver.py Normal file
View File

@@ -0,0 +1,80 @@
"""
旋转验证码求解器
ONNX 推理 → (sin, cos) → atan2 → 角度
"""
import math
from pathlib import Path
import numpy as np
from PIL import Image
from config import ONNX_DIR, SOLVER_CONFIG
from solvers.base import BaseSolver
class RotateSolver(BaseSolver):
"""旋转验证码求解器。"""
def __init__(self, onnx_path: str | Path | None = None):
self.cfg = SOLVER_CONFIG["rotate"]
self._onnx_session = None
self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "rotation_regressor.onnx"
def _load_onnx(self):
"""延迟加载 ONNX 模型。"""
if self._onnx_session is not None:
return
if not self._onnx_path.exists():
raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}")
import onnxruntime as ort
self._onnx_session = ort.InferenceSession(
str(self._onnx_path), providers=["CPUExecutionProvider"]
)
def solve(self, image: Image.Image | str | Path, **kwargs) -> dict:
"""
求解旋转验证码。
Args:
image: 输入图片 (RGB)
Returns:
{"angle": float, "confidence": float}
"""
if isinstance(image, (str, Path)):
image = Image.open(str(image)).convert("RGB")
else:
image = image.convert("RGB")
self._load_onnx()
h, w = self.cfg["input_size"]
# 预处理: RGB resize + normalize
img = image.resize((w, h))
arr = np.array(img, dtype=np.float32) / 255.0
# Normalize per channel: (x - 0.5) / 0.5
arr = (arr - 0.5) / 0.5
# HWC → CHW → NCHW
arr = arr.transpose(2, 0, 1)[np.newaxis, :, :, :]
outputs = self._onnx_session.run(None, {"input": arr})
sin_val = float(outputs[0][0][0])
cos_val = float(outputs[0][0][1])
# atan2 → 角度
angle_rad = math.atan2(sin_val, cos_val)
angle_deg = math.degrees(angle_rad)
if angle_deg < 0:
angle_deg += 360.0
# 置信度: sin^2 + cos^2 接近 1 表示预测稳定
magnitude = math.sqrt(sin_val ** 2 + cos_val ** 2)
confidence = min(magnitude, 1.0)
return {
"angle": round(angle_deg, 1),
"confidence": round(confidence, 3),
}

179
solvers/slide_solver.py Normal file
View File

@@ -0,0 +1,179 @@
"""
滑块验证码求解器
三种求解方法 (按优先级):
1. 模板匹配: 背景图 + 模板图 → Canny → matchTemplate
2. 边缘检测: 单图 Canny → findContours → 筛选方形轮廓
3. CNN 兜底: ONNX 推理 → sigmoid → x 百分比 → 像素
OpenCV 延迟导入,未安装时退化到 CNN only。
"""
from pathlib import Path
import numpy as np
from PIL import Image
from config import ONNX_DIR, SOLVER_CONFIG
from solvers.base import BaseSolver
class SlideSolver(BaseSolver):
"""滑块验证码求解器。"""
def __init__(self, onnx_path: str | Path | None = None):
self.cfg = SOLVER_CONFIG["slide"]
self._onnx_session = None
self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "gap_detector.onnx"
# 检测 OpenCV 可用性
self._cv2_available = False
try:
import cv2
self._cv2_available = True
except ImportError:
pass
def _load_onnx(self):
"""延迟加载 ONNX 模型。"""
if self._onnx_session is not None:
return
if not self._onnx_path.exists():
raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}")
import onnxruntime as ort
self._onnx_session = ort.InferenceSession(
str(self._onnx_path), providers=["CPUExecutionProvider"]
)
def solve(
self,
bg_image: Image.Image | str | Path,
template_image: Image.Image | str | Path | None = None,
**kwargs,
) -> dict:
"""
求解滑块验证码。
Args:
bg_image: 背景图 (必需)
template_image: 模板/拼图块图 (可选,有则优先模板匹配)
Returns:
{"gap_x": int, "gap_x_percent": float, "confidence": float, "method": str}
"""
bg = self._load_image(bg_image)
# 方法 1: 模板匹配
if template_image is not None and self._cv2_available:
tpl = self._load_image(template_image)
result = self._template_match(bg, tpl)
if result is not None:
return result
# 方法 2: 边缘检测
if self._cv2_available:
result = self._edge_detect(bg)
if result is not None:
return result
# 方法 3: CNN 兜底
return self._cnn_predict(bg)
def _load_image(self, img: Image.Image | str | Path) -> Image.Image:
if isinstance(img, (str, Path)):
return Image.open(str(img)).convert("RGB")
return img.convert("RGB")
def _template_match(self, bg: Image.Image, tpl: Image.Image) -> dict | None:
"""模板匹配法。"""
import cv2
bg_gray = np.array(bg.convert("L"))
tpl_gray = np.array(tpl.convert("L"))
# Canny 边缘
bg_edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"])
tpl_edges = cv2.Canny(tpl_gray, self.cfg["canny_low"], self.cfg["canny_high"])
if tpl_edges.sum() == 0:
return None
result = cv2.matchTemplate(bg_edges, tpl_edges, cv2.TM_CCOEFF_NORMED)
_, max_val, _, max_loc = cv2.minMaxLoc(result)
if max_val < 0.3:
return None
gap_x = max_loc[0] + tpl_gray.shape[1] // 2
return {
"gap_x": int(gap_x),
"gap_x_percent": gap_x / bg_gray.shape[1],
"confidence": float(max_val),
"method": "template_match",
}
def _edge_detect(self, bg: Image.Image) -> dict | None:
"""边缘检测法:找方形轮廓。"""
import cv2
bg_gray = np.array(bg.convert("L"))
h, w = bg_gray.shape
edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"])
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
best = None
best_score = 0
for cnt in contours:
area = cv2.contourArea(cnt)
# 面积筛选: 缺口大小在合理范围
if area < (h * w * 0.005) or area > (h * w * 0.15):
continue
x, y, cw, ch = cv2.boundingRect(cnt)
aspect = min(cw, ch) / max(cw, ch) if max(cw, ch) > 0 else 0
# 近似方形
if aspect < 0.5:
continue
# 评分: 面积适中 + 近似方形
score = aspect * (area / (h * w * 0.05))
if score > best_score:
best_score = score
best = (x + cw // 2, cw, ch, score)
if best is None:
return None
gap_x, _, _, score = best
return {
"gap_x": int(gap_x),
"gap_x_percent": gap_x / w,
"confidence": min(float(score), 1.0),
"method": "edge_detect",
}
def _cnn_predict(self, bg: Image.Image) -> dict:
"""CNN 推理兜底。"""
self._load_onnx()
h, w = self.cfg["cnn_input_size"]
orig_w = bg.width
# 预处理: 灰度 + resize + normalize
img = bg.convert("L").resize((w, h))
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - 0.5) / 0.5
arr = arr[np.newaxis, np.newaxis, :, :] # (1, 1, H, W)
outputs = self._onnx_session.run(None, {"input": arr})
percent = float(outputs[0][0][0])
gap_x = int(percent * orig_w)
return {
"gap_x": gap_x,
"gap_x_percent": percent,
"confidence": 0.5, # CNN 无置信度
"method": "cnn",
}

View File

@@ -224,3 +224,55 @@ class RegressionDataset(Dataset):
img = self.transform(img)
return img, torch.tensor([label], dtype=torch.float32)
# ============================================================
# 旋转求解器用数据集 (sin/cos 编码)
# ============================================================
class RotateSolverDataset(Dataset):
"""
旋转求解器数据集。
从目录中读取 {angle}_{xxx}.png 文件,
将角度转换为 (sin θ, cos θ) 目标。
RGB 输入,不转灰度。
"""
def __init__(
self,
dirs: list[str | Path],
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表
transform: 图片预处理/增强 (RGB)
"""
import math
self.transform = transform
self.samples: list[tuple[str, float, float]] = [] # (路径, sin, cos)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
raw_label = f.stem.rsplit("_", 1)[0]
try:
angle = float(raw_label)
except ValueError:
continue
rad = math.radians(angle)
self.samples.append((str(f), math.sin(rad), math.cos(rad)))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, sin_val, cos_val = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)

View File

@@ -0,0 +1,245 @@
"""
训练旋转验证码角度回归模型 (RotationRegressor)
自定义训练循环 (sin/cos 编码),不复用 train_regression_utils。
用法: python -m training.train_rotate_solver
"""
import math
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm
from config import (
SOLVER_CONFIG,
SOLVER_TRAIN_CONFIG,
ROTATE_SOLVER_DATA_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
AUGMENT_CONFIG,
RANDOM_SEED,
get_device,
)
from generators.rotate_solver_gen import RotateSolverDataGenerator
from models.rotation_regressor import RotationRegressor
from training.dataset import RotateSolverDataset
def _set_seed(seed: int = RANDOM_SEED):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _circular_mae_deg(pred_angles: np.ndarray, gt_angles: np.ndarray) -> float:
"""循环 MAE (度数)。"""
diff = np.abs(pred_angles - gt_angles)
diff = np.minimum(diff, 360.0 - diff)
return float(np.mean(diff))
def _build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 训练增强 (不转灰度)。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def _build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 验证 transform。"""
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
def _export_onnx(model: nn.Module, img_h: int, img_w: int):
"""导出 ONNX (RGB 3通道输入)。"""
model.eval()
onnx_path = ONNX_DIR / "rotation_regressor.onnx"
dummy = torch.randn(1, 3, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
def main():
cfg = SOLVER_TRAIN_CONFIG["rotate"]
solver_cfg = SOLVER_CONFIG["rotate"]
img_h, img_w = solver_cfg["input_size"]
tolerance = 5.0 # ±5°
_set_seed()
device = get_device()
print("=" * 60)
print("训练旋转验证码角度回归模型 (RotationRegressor)")
print(f" 输入尺寸: {img_h}×{img_w} RGB")
print(f" 编码: sin/cos")
print(f" 容差: ±{tolerance}°")
print("=" * 60)
# ---- 1. 检查 / 生成合成数据 ----
syn_path = ROTATE_SOLVER_DATA_DIR
syn_path.mkdir(parents=True, exist_ok=True)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = RotateSolverDataGenerator()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_dir = syn_path / "real"
real_dir.mkdir(parents=True, exist_ok=True)
if list(real_dir.glob("*.png")):
data_dirs.append(str(real_dir))
print(f"[数据] 混合真实数据: {len(list(real_dir.glob('*.png')))}")
train_transform = _build_train_transform(img_h, img_w)
val_transform = _build_val_transform(img_h, img_w)
full_dataset = RotateSolverDataset(dirs=data_dirs, transform=train_transform)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
val_ds_clean = RotateSolverDataset(dirs=data_dirs, transform=val_transform)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 模型 / 优化器 / 调度器 ----
model = RotationRegressor(img_h=img_h, img_w=img_w).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
loss_fn = nn.MSELoss()
best_mae = float("inf")
best_tol_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / "rotation_regressor.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets in pbar:
images = images.to(device)
targets = targets.to(device) # (B, 2) → (sin, cos)
preds = model(images) # (B, 2)
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_pred_angles = []
all_gt_angles = []
with torch.no_grad():
for images, targets in val_loader:
images = images.to(device)
preds = model(images).cpu().numpy() # (B, 2)
targets_np = targets.numpy() # (B, 2)
# sin/cos → angle
for i in range(len(preds)):
pred_angle = math.degrees(math.atan2(preds[i][0], preds[i][1]))
if pred_angle < 0:
pred_angle += 360.0
gt_angle = math.degrees(math.atan2(targets_np[i][0], targets_np[i][1]))
if gt_angle < 0:
gt_angle += 360.0
all_pred_angles.append(pred_angle)
all_gt_angles.append(gt_angle)
pred_arr = np.array(all_pred_angles)
gt_arr = np.array(all_gt_angles)
mae = _circular_mae_deg(pred_arr, gt_arr)
diff = np.abs(pred_arr - gt_arr)
diff = np.minimum(diff, 360.0 - diff)
tol_acc = float(np.mean(diff <= tolerance))
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"MAE={mae:.2f}° "
f"tol_acc(±{tolerance:.0f}°)={tol_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if tol_acc >= best_tol_acc:
best_tol_acc = tol_acc
best_mae = mae
torch.save({
"model_state_dict": model.state_dict(),
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}°")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, img_h, img_w)
return best_tol_acc
if __name__ == "__main__":
main()

65
training/train_slide.py Normal file
View File

@@ -0,0 +1,65 @@
"""
训练滑块缺口检测 CNN (GapDetectorCNN)
复用 train_regression_utils 的通用回归训练流程。
用法: python -m training.train_slide
"""
from config import (
SOLVER_CONFIG,
SOLVER_TRAIN_CONFIG,
SOLVER_REGRESSION_RANGE,
SLIDE_DATA_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
RANDOM_SEED,
get_device,
)
from generators.slide_gen import SlideDataGenerator
from models.gap_detector import GapDetectorCNN
# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE
# 以便复用 train_regression_utils
import config as _cfg
def main():
solver_cfg = SOLVER_CONFIG["slide"]
train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"]
img_h, img_w = solver_cfg["cnn_input_size"]
model = GapDetectorCNN(img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练滑块缺口检测 CNN (GapDetectorCNN)")
print(f" 输入尺寸: {img_h}×{img_w}")
print(f" 任务: 预测缺口 x 坐标百分比")
print("=" * 60)
# 直接使用 train_regression_utils 中的逻辑
# 但需要临时注入配置
_cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg
_cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w)
_cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"]
from training.train_regression_utils import train_regression_model
# 确保数据目录存在
SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True)
real_dir = SLIDE_DATA_DIR / "real"
real_dir.mkdir(parents=True, exist_ok=True)
train_regression_model(
model_name="gap_detector",
model=model,
synthetic_dir=str(SLIDE_DATA_DIR),
real_dir=str(real_dir),
generator_cls=SlideDataGenerator,
config_key="slide_cnn",
)
if __name__ == "__main__":
main()

3
utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
工具函数包
"""

75
utils/slide_utils.py Normal file
View File

@@ -0,0 +1,75 @@
"""
滑块轨迹生成工具
生成模拟人类操作的滑块拖拽轨迹。
"""
import math
import random
def generate_slide_track(
distance: int,
duration: float = 1.0,
seed: int | None = None,
) -> list[dict]:
"""
生成滑块拖拽轨迹。
使用贝塞尔曲线 ease-out 加速减速,末尾带微小过冲回退。
y 轴 ±1~3px 随机抖动,时间间隔不均匀。
Args:
distance: 滑动距离 (像素)
duration: 总时长 (秒)
seed: 随机种子
Returns:
[{"x": float, "y": float, "t": float}, ...]
"""
rng = random.Random(seed)
if distance <= 0:
return [{"x": 0.0, "y": 0.0, "t": 0.0}]
track = []
# 采样点数
num_points = rng.randint(30, 60)
total_ms = duration * 1000
# 生成不均匀时间点
raw_times = sorted([rng.random() for _ in range(num_points - 2)])
times = [0.0] + raw_times + [1.0]
# 过冲距离
overshoot = rng.uniform(2, 6)
overshoot_start = 0.85 # 85% 时到达目标 + 过冲
for t_norm in times:
t_ms = round(t_norm * total_ms, 1)
if t_norm <= overshoot_start:
# ease-out: 快速启动,缓慢减速
progress = t_norm / overshoot_start
eased = 1 - (1 - progress) ** 3 # cubic ease-out
x = eased * (distance + overshoot)
else:
# 过冲回退段
retract_progress = (t_norm - overshoot_start) / (1 - overshoot_start)
eased_retract = retract_progress ** 2 # ease-in 回退
x = (distance + overshoot) - overshoot * eased_retract
# y 轴随机抖动
y_jitter = rng.uniform(-3, 3) if t_norm > 0.05 else 0.0
track.append({
"x": round(x, 1),
"y": round(y_jitter, 1),
"t": t_ms,
})
# 确保最后一个点精确到达目标
track[-1]["x"] = float(distance)
track[-1]["y"] = round(rng.uniform(-0.5, 0.5), 1)
return track