Add slide and rotate interactive captcha solvers
New solver subsystem with independent models: - GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection - RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction - SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback - RotateSolver with ONNX sin/cos → atan2 inference - Generators, training scripts, CLI commands, and slide track utility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
90
CLAUDE.md
90
CLAUDE.md
@@ -33,7 +33,10 @@ captcha-breaker/
|
||||
│ │ ├── 3d_text/
|
||||
│ │ ├── 3d_rotate/
|
||||
│ │ └── 3d_slider/
|
||||
│ └── classifier/ # 调度分类器训练数据 (混合各类型)
|
||||
│ ├── classifier/ # 调度分类器训练数据 (混合各类型)
|
||||
│ └── solver/ # Solver 训练数据
|
||||
│ ├── slide/ # 滑块缺口检测训练数据
|
||||
│ └── rotate/ # 旋转角度回归训练数据
|
||||
├── generators/
|
||||
│ ├── __init__.py
|
||||
│ ├── base.py # 生成器基类
|
||||
@@ -41,13 +44,17 @@ captcha-breaker/
|
||||
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
|
||||
│ ├── threed_gen.py # 3D立体文字验证码生成器
|
||||
│ ├── threed_rotate_gen.py # 3D旋转验证码生成器
|
||||
│ └── threed_slider_gen.py # 3D滑块验证码生成器
|
||||
│ ├── threed_slider_gen.py # 3D滑块验证码生成器
|
||||
│ ├── slide_gen.py # 滑块缺口训练数据生成器
|
||||
│ └── rotate_solver_gen.py # 旋转求解器训练数据生成器
|
||||
├── models/
|
||||
│ ├── __init__.py
|
||||
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
|
||||
│ ├── classifier.py # 调度分类模型
|
||||
│ ├── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN)
|
||||
│ └── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
|
||||
│ ├── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
|
||||
│ ├── gap_detector.py # 滑块缺口检测CNN (~1MB)
|
||||
│ └── rotation_regressor.py # 旋转角度回归 sin/cos (~2MB)
|
||||
├── training/
|
||||
│ ├── __init__.py
|
||||
│ ├── train_classifier.py # 训练调度模型
|
||||
@@ -56,6 +63,8 @@ captcha-breaker/
|
||||
│ ├── train_3d_text.py # 训练3D文字识别
|
||||
│ ├── train_3d_rotate.py # 训练3D旋转回归
|
||||
│ ├── train_3d_slider.py # 训练3D滑块回归
|
||||
│ ├── train_slide.py # 训练滑块缺口检测
|
||||
│ ├── train_rotate_solver.py # 训练旋转角度回归
|
||||
│ ├── train_utils.py # CTC 训练通用逻辑
|
||||
│ ├── train_regression_utils.py # 回归训练通用逻辑
|
||||
│ └── dataset.py # 通用 Dataset 类
|
||||
@@ -64,20 +73,32 @@ captcha-breaker/
|
||||
│ ├── pipeline.py # 核心推理流水线 (调度+识别)
|
||||
│ ├── export_onnx.py # PyTorch → ONNX 导出脚本
|
||||
│ └── math_eval.py # 算式计算模块
|
||||
├── solvers/ # 交互式验证码求解器
|
||||
│ ├── __init__.py
|
||||
│ ├── base.py # 求解器基类
|
||||
│ ├── slide_solver.py # 滑块求解 (OpenCV + CNN)
|
||||
│ └── rotate_solver.py # 旋转求解 (ONNX sin/cos)
|
||||
├── utils/
|
||||
│ ├── __init__.py
|
||||
│ └── slide_utils.py # 滑块轨迹生成工具
|
||||
├── checkpoints/ # 训练产出的模型文件
|
||||
│ ├── classifier.pth
|
||||
│ ├── normal.pth
|
||||
│ ├── math.pth
|
||||
│ ├── threed_text.pth
|
||||
│ ├── threed_rotate.pth
|
||||
│ └── threed_slider.pth
|
||||
│ ├── threed_slider.pth
|
||||
│ ├── gap_detector.pth
|
||||
│ └── rotation_regressor.pth
|
||||
├── onnx_models/ # 导出的 ONNX 模型
|
||||
│ ├── classifier.onnx
|
||||
│ ├── normal.onnx
|
||||
│ ├── math.onnx
|
||||
│ ├── threed_text.onnx
|
||||
│ ├── threed_rotate.onnx
|
||||
│ └── threed_slider.onnx
|
||||
│ ├── threed_slider.onnx
|
||||
│ ├── gap_detector.onnx
|
||||
│ └── rotation_regressor.onnx
|
||||
├── server.py # FastAPI 推理服务 (可选)
|
||||
├── cli.py # 命令行入口
|
||||
└── tests/
|
||||
@@ -462,3 +483,62 @@ uv run python cli.py serve --port 8080
|
||||
6. 实现 cli.py 统一入口
|
||||
7. 可选: server.py HTTP 服务
|
||||
8. 编写 tests/
|
||||
|
||||
## 交互式 Solver 扩展
|
||||
|
||||
### 概述
|
||||
|
||||
在现有验证码识别架构之上,新增滑块 (slide) 和旋转 (rotate) 两种交互式验证码求解能力。与现有 3d_rotate/3d_slider 的区别:
|
||||
|
||||
- **3d_slider** (合成拼图回归) → **slide solver**: 真实滑块验证码,OpenCV 优先 + CNN 兜底
|
||||
- **3d_rotate** (合成圆盘 sigmoid 回归) → **rotate solver**: 真实旋转验证码,sin/cos 编码 + 自然图
|
||||
|
||||
每个 solver 模型独立训练、独立导出 ONNX、独立替换,互不依赖。
|
||||
|
||||
### 滑块求解器 (SlideSolver)
|
||||
|
||||
- 三种方法按优先级: 模板匹配 → 边缘检测 → CNN 兜底
|
||||
- 模型: `GapDetectorCNN` (1x128x256 灰度 → sigmoid [0,1])
|
||||
- OpenCV 延迟导入,未安装时退化到 CNN only
|
||||
- 输出: `{"gap_x", "gap_x_percent", "confidence", "method"}`
|
||||
|
||||
### 旋转求解器 (RotateSolver)
|
||||
|
||||
- ONNX 推理 → (sin, cos) → atan2 → 角度
|
||||
- 模型: `RotationRegressor` (3x128x128 RGB → tanh (sin θ, cos θ))
|
||||
- 输出: `{"angle", "confidence"}`
|
||||
|
||||
### Solver CLI 用法
|
||||
|
||||
```bash
|
||||
# 生成训练数据
|
||||
uv run python cli.py generate-solver slide --num 30000
|
||||
uv run python cli.py generate-solver rotate --num 50000
|
||||
|
||||
# 训练 (各模型独立)
|
||||
uv run python cli.py train-solver slide
|
||||
uv run python cli.py train-solver rotate
|
||||
|
||||
# 求解
|
||||
uv run python cli.py solve slide --bg bg.png [--tpl tpl.png]
|
||||
uv run python cli.py solve rotate --image img.png
|
||||
|
||||
# 导出 (已集成到 export --all)
|
||||
uv run python cli.py export --model gap_detector
|
||||
uv run python cli.py export --model rotation_regressor
|
||||
```
|
||||
|
||||
### 滑块轨迹生成
|
||||
|
||||
`utils/slide_utils.py` 提供 `generate_slide_track(distance)`:
|
||||
- 贝塞尔曲线 ease-out 加速减速
|
||||
- y 轴 ±1~3px 随机抖动
|
||||
- 时间间隔不均匀
|
||||
- 末尾微小过冲回退
|
||||
|
||||
### Solver 目标指标
|
||||
|
||||
| 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|
||||
|------|-----------|---------|---------|
|
||||
| 滑块 CNN (±5px) | > 85% | < 30ms | ~1MB |
|
||||
| 旋转回归 (±5°) | > 85% | < 30ms | ~2MB |
|
||||
|
||||
108
cli.py
108
cli.py
@@ -13,6 +13,11 @@ CaptchaBreaker 命令行入口
|
||||
python cli.py predict image.png --type normal
|
||||
python cli.py predict-dir ./test_images/
|
||||
python cli.py serve --port 8080
|
||||
python cli.py generate-solver slide --num 30000
|
||||
python cli.py train-solver slide
|
||||
python cli.py train-solver rotate
|
||||
python cli.py solve slide --bg bg.png [--tpl tpl.png]
|
||||
python cli.py solve rotate --image img.png
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -195,6 +200,90 @@ def cmd_serve(args):
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
def cmd_generate_solver(args):
|
||||
"""生成 solver 训练数据。"""
|
||||
from config import SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR
|
||||
from generators.slide_gen import SlideDataGenerator
|
||||
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||
|
||||
solver_type = args.type
|
||||
num = args.num
|
||||
|
||||
gen_map = {
|
||||
"slide": (SlideDataGenerator, SLIDE_DATA_DIR),
|
||||
"rotate": (RotateSolverDataGenerator, ROTATE_SOLVER_DATA_DIR),
|
||||
}
|
||||
|
||||
if solver_type not in gen_map:
|
||||
print(f"未知 solver 类型: {solver_type} 可选: {', '.join(gen_map.keys())}")
|
||||
sys.exit(1)
|
||||
|
||||
gen_cls, out_dir = gen_map[solver_type]
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"生成 solver/{solver_type} 数据: {num} 张 → {out_dir}")
|
||||
gen = gen_cls()
|
||||
gen.generate_dataset(num, str(out_dir))
|
||||
|
||||
|
||||
def cmd_train_solver(args):
|
||||
"""训练 solver 模型。"""
|
||||
solver_type = args.type
|
||||
|
||||
if solver_type == "slide":
|
||||
from training.train_slide import main as train_fn
|
||||
elif solver_type == "rotate":
|
||||
from training.train_rotate_solver import main as train_fn
|
||||
else:
|
||||
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
|
||||
sys.exit(1)
|
||||
|
||||
train_fn()
|
||||
|
||||
|
||||
def cmd_solve(args):
|
||||
"""求解验证码。"""
|
||||
solver_type = args.type
|
||||
|
||||
if solver_type == "slide":
|
||||
from solvers.slide_solver import SlideSolver
|
||||
|
||||
bg_path = args.bg
|
||||
tpl_path = getattr(args, "tpl", None)
|
||||
if not Path(bg_path).exists():
|
||||
print(f"文件不存在: {bg_path}")
|
||||
sys.exit(1)
|
||||
|
||||
solver = SlideSolver()
|
||||
result = solver.solve(bg_path, template_image=tpl_path)
|
||||
|
||||
print(f"背景图: {bg_path}")
|
||||
if tpl_path:
|
||||
print(f"模板图: {tpl_path}")
|
||||
print(f"缺口 x: {result['gap_x']} px")
|
||||
print(f"缺口 x%: {result['gap_x_percent']:.4f}")
|
||||
print(f"置信度: {result['confidence']:.4f}")
|
||||
print(f"方法: {result['method']}")
|
||||
|
||||
elif solver_type == "rotate":
|
||||
from solvers.rotate_solver import RotateSolver
|
||||
|
||||
image_path = args.image
|
||||
if not Path(image_path).exists():
|
||||
print(f"文件不存在: {image_path}")
|
||||
sys.exit(1)
|
||||
|
||||
solver = RotateSolver()
|
||||
result = solver.solve(image_path)
|
||||
|
||||
print(f"图片: {image_path}")
|
||||
print(f"角度: {result['angle']}°")
|
||||
print(f"置信度: {result['confidence']}")
|
||||
|
||||
else:
|
||||
print(f"未知 solver 类型: {solver_type} 可选: slide, rotate")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="captcha-breaker",
|
||||
@@ -247,6 +336,22 @@ def main():
|
||||
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
|
||||
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
|
||||
|
||||
# ---- generate-solver ----
|
||||
p_gen_solver = subparsers.add_parser("generate-solver", help="生成 solver 训练数据")
|
||||
p_gen_solver.add_argument("type", help="solver 类型: slide, rotate")
|
||||
p_gen_solver.add_argument("--num", type=int, required=True, help="生成数量")
|
||||
|
||||
# ---- train-solver ----
|
||||
p_train_solver = subparsers.add_parser("train-solver", help="训练 solver 模型")
|
||||
p_train_solver.add_argument("type", help="solver 类型: slide, rotate")
|
||||
|
||||
# ---- solve ----
|
||||
p_solve = subparsers.add_parser("solve", help="求解交互式验证码")
|
||||
p_solve.add_argument("type", help="solver 类型: slide, rotate")
|
||||
p_solve.add_argument("--bg", help="背景图路径 (slide 必需)")
|
||||
p_solve.add_argument("--tpl", default=None, help="模板图路径 (slide 可选)")
|
||||
p_solve.add_argument("--image", help="图片路径 (rotate 必需)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
@@ -260,6 +365,9 @@ def main():
|
||||
"predict": cmd_predict,
|
||||
"predict-dir": cmd_predict_dir,
|
||||
"serve": cmd_serve,
|
||||
"generate-solver": cmd_generate_solver,
|
||||
"train-solver": cmd_train_solver,
|
||||
"solve": cmd_solve,
|
||||
}
|
||||
|
||||
cmd_map[args.command](args)
|
||||
|
||||
43
config.py
43
config.py
@@ -34,6 +34,11 @@ REAL_3D_TEXT_DIR = REAL_DIR / "3d_text"
|
||||
REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate"
|
||||
REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider"
|
||||
|
||||
# Solver 数据目录
|
||||
SOLVER_DATA_DIR = DATA_DIR / "solver"
|
||||
SLIDE_DATA_DIR = SOLVER_DATA_DIR / "slide"
|
||||
ROTATE_SOLVER_DATA_DIR = SOLVER_DATA_DIR / "rotate"
|
||||
|
||||
# ============================================================
|
||||
# 模型输出目录
|
||||
# ============================================================
|
||||
@@ -47,6 +52,7 @@ for _dir in [
|
||||
REAL_NORMAL_DIR, REAL_MATH_DIR,
|
||||
REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR,
|
||||
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
|
||||
SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR,
|
||||
]:
|
||||
_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -241,3 +247,40 @@ SERVER_CONFIG = {
|
||||
"host": "0.0.0.0",
|
||||
"port": 8080,
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# Solver 配置 (交互式验证码求解)
|
||||
# ============================================================
|
||||
SOLVER_CONFIG = {
|
||||
"slide": {
|
||||
"canny_low": 50,
|
||||
"canny_high": 150,
|
||||
"cnn_input_size": (128, 256), # H, W
|
||||
},
|
||||
"rotate": {
|
||||
"input_size": (128, 128), # H, W
|
||||
"channels": 3, # RGB
|
||||
},
|
||||
}
|
||||
|
||||
SOLVER_TRAIN_CONFIG = {
|
||||
"slide_cnn": {
|
||||
"epochs": 50,
|
||||
"batch_size": 64,
|
||||
"lr": 1e-3,
|
||||
"synthetic_samples": 30000,
|
||||
"val_split": 0.1,
|
||||
},
|
||||
"rotate": {
|
||||
"epochs": 80,
|
||||
"batch_size": 64,
|
||||
"lr": 5e-4,
|
||||
"synthetic_samples": 50000,
|
||||
"val_split": 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
SOLVER_REGRESSION_RANGE = {
|
||||
"slide": (0, 1), # 归一化百分比
|
||||
"rotate": (0, 360), # 角度
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""
|
||||
数据生成器包
|
||||
|
||||
提供五种验证码类型的数据生成器:
|
||||
提供七种验证码类型的数据生成器:
|
||||
- NormalCaptchaGenerator: 普通字符验证码
|
||||
- MathCaptchaGenerator: 算式验证码
|
||||
- ThreeDCaptchaGenerator: 3D 立体文字验证码
|
||||
- ThreeDRotateGenerator: 3D 旋转验证码
|
||||
- ThreeDSliderGenerator: 3D 滑块验证码
|
||||
- SlideDataGenerator: 滑块验证码求解器训练数据
|
||||
- RotateSolverDataGenerator: 旋转验证码求解器训练数据
|
||||
"""
|
||||
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
@@ -15,6 +17,8 @@ from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
from generators.slide_gen import SlideDataGenerator
|
||||
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||
|
||||
__all__ = [
|
||||
"BaseCaptchaGenerator",
|
||||
@@ -23,4 +27,6 @@ __all__ = [
|
||||
"ThreeDCaptchaGenerator",
|
||||
"ThreeDRotateGenerator",
|
||||
"ThreeDSliderGenerator",
|
||||
"SlideDataGenerator",
|
||||
"RotateSolverDataGenerator",
|
||||
]
|
||||
|
||||
156
generators/rotate_solver_gen.py
Normal file
156
generators/rotate_solver_gen.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
旋转验证码求解器数据生成器
|
||||
|
||||
生成旋转验证码训练数据:随机图案 (色块/渐变/几何图形),随机旋转 0-359°。
|
||||
裁剪为圆形 (黑色背景填充圆外区域)。
|
||||
|
||||
标签 = 旋转角度 (整数)
|
||||
文件名格式: {angle}_{index:06d}.png
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont
|
||||
|
||||
from config import SOLVER_CONFIG
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
|
||||
_FONT_PATHS = [
|
||||
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
|
||||
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
|
||||
]
|
||||
|
||||
|
||||
class RotateSolverDataGenerator(BaseCaptchaGenerator):
|
||||
"""旋转验证码求解器数据生成器。"""
|
||||
|
||||
def __init__(self, seed: int | None = None):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = SOLVER_CONFIG["rotate"]
|
||||
self.height, self.width = self.cfg["input_size"] # (H, W)
|
||||
|
||||
self._fonts: list[str] = []
|
||||
for p in _FONT_PATHS:
|
||||
try:
|
||||
ImageFont.truetype(p, 20)
|
||||
self._fonts.append(p)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
rng = self.rng
|
||||
|
||||
# 随机旋转角度 0-359
|
||||
angle = rng.randint(0, 359)
|
||||
if text is None:
|
||||
text = str(angle)
|
||||
|
||||
size = self.width # 正方形
|
||||
radius = size // 2
|
||||
|
||||
# 1. 生成正向图案 (未旋转)
|
||||
content = self._random_pattern(rng, size)
|
||||
|
||||
# 2. 旋转图案
|
||||
rotated = content.rotate(-angle, resample=Image.BICUBIC, expand=False)
|
||||
|
||||
# 3. 裁剪为圆形 (黑色背景)
|
||||
result = Image.new("RGB", (size, size), (0, 0, 0))
|
||||
mask = Image.new("L", (size, size), 0)
|
||||
mask_draw = ImageDraw.Draw(mask)
|
||||
mask_draw.ellipse([0, 0, size - 1, size - 1], fill=255)
|
||||
|
||||
result.paste(rotated, (0, 0), mask)
|
||||
|
||||
# 4. 轻微模糊
|
||||
result = result.filter(ImageFilter.GaussianBlur(radius=0.5))
|
||||
|
||||
return result, text
|
||||
|
||||
def _random_pattern(self, rng: random.Random, size: int) -> Image.Image:
|
||||
"""生成随机图案 (带明显方向性,便于模型学习旋转)。"""
|
||||
img = Image.new("RGB", (size, size))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 渐变背景
|
||||
base_r = rng.randint(100, 220)
|
||||
base_g = rng.randint(100, 220)
|
||||
base_b = rng.randint(100, 220)
|
||||
for y in range(size):
|
||||
ratio = y / max(size - 1, 1)
|
||||
r = int(base_r * (1 - ratio) + rng.randint(40, 120) * ratio)
|
||||
g = int(base_g * (1 - ratio) + rng.randint(40, 120) * ratio)
|
||||
b = int(base_b * (1 - ratio) + rng.randint(40, 120) * ratio)
|
||||
draw.line([(0, y), (size, y)], fill=(r, g, b))
|
||||
|
||||
cx, cy = size // 2, size // 2
|
||||
|
||||
# 添加不对称几何图形 (让模型能感知方向)
|
||||
pattern_type = rng.choice(["triangle", "arrow", "text", "shapes"])
|
||||
|
||||
if pattern_type == "triangle":
|
||||
# 顶部三角形标记
|
||||
color = tuple(rng.randint(180, 255) for _ in range(3))
|
||||
ts = size // 4
|
||||
draw.polygon(
|
||||
[(cx, cy - ts), (cx - ts // 2, cy), (cx + ts // 2, cy)],
|
||||
fill=color,
|
||||
)
|
||||
# 底部小圆
|
||||
draw.ellipse(
|
||||
[cx - 8, cy + ts // 2, cx + 8, cy + ts // 2 + 16],
|
||||
fill=tuple(rng.randint(50, 150) for _ in range(3)),
|
||||
)
|
||||
|
||||
elif pattern_type == "arrow":
|
||||
# 向上的箭头
|
||||
color = tuple(rng.randint(180, 255) for _ in range(3))
|
||||
arrow_len = size // 3
|
||||
draw.line([(cx, cy - arrow_len), (cx, cy + arrow_len // 2)], fill=color, width=4)
|
||||
draw.polygon(
|
||||
[(cx, cy - arrow_len - 5), (cx - 10, cy - arrow_len + 10), (cx + 10, cy - arrow_len + 10)],
|
||||
fill=color,
|
||||
)
|
||||
|
||||
elif pattern_type == "text" and self._fonts:
|
||||
# 文字 (有天然方向性)
|
||||
font_path = rng.choice(self._fonts)
|
||||
font_size = size // 3
|
||||
try:
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
ch = rng.choice("ABCDEFGHJKLMNPRSTUVWXYZ23456789")
|
||||
bbox = font.getbbox(ch)
|
||||
tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
draw.text(
|
||||
(cx - tw // 2 - bbox[0], cy - th // 2 - bbox[1]),
|
||||
ch,
|
||||
fill=tuple(rng.randint(0, 80) for _ in range(3)),
|
||||
font=font,
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
else:
|
||||
# 混合不对称形状
|
||||
# 上方矩形
|
||||
w, h = rng.randint(15, 30), rng.randint(10, 20)
|
||||
color = tuple(rng.randint(150, 255) for _ in range(3))
|
||||
draw.rectangle([cx - w, cy - size // 3, cx + w, cy - size // 3 + h], fill=color)
|
||||
# 右下小圆
|
||||
r = rng.randint(5, 12)
|
||||
color2 = tuple(rng.randint(50, 150) for _ in range(3))
|
||||
draw.ellipse([cx + size // 5, cy + size // 5, cx + size // 5 + r * 2, cy + size // 5 + r * 2], fill=color2)
|
||||
|
||||
# 添加纹理噪声
|
||||
for _ in range(rng.randint(20, 60)):
|
||||
nx, ny = rng.randint(0, size - 1), rng.randint(0, size - 1)
|
||||
nc = tuple(rng.randint(80, 220) for _ in range(3))
|
||||
draw.point((nx, ny), fill=nc)
|
||||
|
||||
return img
|
||||
112
generators/slide_gen.py
Normal file
112
generators/slide_gen.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
滑块验证码数据生成器
|
||||
|
||||
生成滑块验证码训练数据:随机纹理/色块背景 + 方形缺口 + 阴影效果。
|
||||
标签 = 缺口中心 x 坐标 (整数)
|
||||
文件名格式: {gap_x}_{index:06d}.png
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
|
||||
from config import SOLVER_CONFIG
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
|
||||
|
||||
class SlideDataGenerator(BaseCaptchaGenerator):
|
||||
"""滑块验证码数据生成器。"""
|
||||
|
||||
def __init__(self, seed: int | None = None):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = SOLVER_CONFIG["slide"]
|
||||
self.height, self.width = self.cfg["cnn_input_size"] # (H, W)
|
||||
self.gap_size = 40 # 缺口大小
|
||||
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
rng = self.rng
|
||||
gs = self.gap_size
|
||||
|
||||
# 缺口 x 范围: 留出边距
|
||||
margin = gs + 10
|
||||
gap_x = rng.randint(margin, self.width - margin)
|
||||
gap_y = rng.randint(10, self.height - gs - 10)
|
||||
|
||||
if text is None:
|
||||
text = str(gap_x)
|
||||
|
||||
# 1. 生成纹理背景
|
||||
img = self._textured_background(rng)
|
||||
|
||||
# 2. 绘制缺口 (半透明灰色区域 + 阴影)
|
||||
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||
overlay_draw = ImageDraw.Draw(overlay)
|
||||
|
||||
# 阴影 (稍大一圈)
|
||||
overlay_draw.rectangle(
|
||||
[gap_x + 2, gap_y + 2, gap_x + gs + 2, gap_y + gs + 2],
|
||||
fill=(0, 0, 0, 60),
|
||||
)
|
||||
# 缺口本体
|
||||
overlay_draw.rectangle(
|
||||
[gap_x, gap_y, gap_x + gs, gap_y + gs],
|
||||
fill=(80, 80, 80, 160),
|
||||
outline=(60, 60, 60, 200),
|
||||
width=2,
|
||||
)
|
||||
|
||||
img = img.convert("RGBA")
|
||||
img = Image.alpha_composite(img, overlay)
|
||||
img = img.convert("RGB")
|
||||
|
||||
# 3. 轻微模糊
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=0.3))
|
||||
|
||||
return img, text
|
||||
|
||||
def _textured_background(self, rng: random.Random) -> Image.Image:
|
||||
"""生成带纹理的彩色背景。"""
|
||||
img = Image.new("RGB", (self.width, self.height))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 渐变底色
|
||||
base_r = rng.randint(80, 200)
|
||||
base_g = rng.randint(80, 200)
|
||||
base_b = rng.randint(80, 200)
|
||||
for y in range(self.height):
|
||||
ratio = y / max(self.height - 1, 1)
|
||||
r = int(base_r + 40 * ratio)
|
||||
g = int(base_g - 20 * ratio)
|
||||
b = int(base_b + 20 * ratio)
|
||||
r, g, b = max(0, min(255, r)), max(0, min(255, g)), max(0, min(255, b))
|
||||
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
|
||||
|
||||
# 纹理噪声
|
||||
for _ in range(self.width * self.height // 6):
|
||||
x = rng.randint(0, self.width - 1)
|
||||
y = rng.randint(0, self.height - 1)
|
||||
pixel = img.getpixel((x, y))
|
||||
noise = tuple(
|
||||
max(0, min(255, c + rng.randint(-30, 30)))
|
||||
for c in pixel
|
||||
)
|
||||
draw.point((x, y), fill=noise)
|
||||
|
||||
# 随机色块 (模拟图案)
|
||||
for _ in range(rng.randint(4, 8)):
|
||||
x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20)
|
||||
x2, y2 = x1 + rng.randint(15, 50), y1 + rng.randint(10, 30)
|
||||
color = tuple(rng.randint(50, 230) for _ in range(3))
|
||||
draw.rectangle([x1, y1, x2, y2], fill=color)
|
||||
|
||||
# 随机圆形
|
||||
for _ in range(rng.randint(2, 5)):
|
||||
cx = rng.randint(10, self.width - 10)
|
||||
cy = rng.randint(10, self.height - 10)
|
||||
cr = rng.randint(5, 20)
|
||||
color = tuple(rng.randint(50, 230) for _ in range(3))
|
||||
draw.ellipse([cx - cr, cy - cr, cx + cr, cy + cr], fill=color)
|
||||
|
||||
return img
|
||||
@@ -18,11 +18,14 @@ from config import (
|
||||
THREED_CHARS,
|
||||
NUM_CAPTCHA_TYPES,
|
||||
REGRESSION_RANGE,
|
||||
SOLVER_CONFIG,
|
||||
)
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
|
||||
|
||||
def export_model(
|
||||
@@ -52,7 +55,7 @@ def export_model(
|
||||
dummy = torch.randn(1, *input_shape)
|
||||
|
||||
# 分类器和识别器的 dynamic_axes 不同
|
||||
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
|
||||
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider", "gap_detector", "rotation_regressor"):
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
else:
|
||||
# CTC 模型: output shape = (T, B, C)
|
||||
@@ -110,6 +113,14 @@ def _load_and_export(model_name: str):
|
||||
h, w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "gap_detector":
|
||||
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
||||
model = GapDetectorCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "rotation_regressor":
|
||||
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
||||
model = RotationRegressor(img_h=h, img_w=w)
|
||||
input_shape = (3, h, w)
|
||||
else:
|
||||
print(f"[错误] 未知模型: {model_name}")
|
||||
return
|
||||
@@ -119,11 +130,15 @@ def _load_and_export(model_name: str):
|
||||
|
||||
|
||||
def export_all():
|
||||
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
|
||||
"""依次导出全部模型 (含 solver 模型)。"""
|
||||
print("=" * 50)
|
||||
print("导出全部 ONNX 模型")
|
||||
print("=" * 50)
|
||||
for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
|
||||
for name in [
|
||||
"classifier", "normal", "math", "threed_text",
|
||||
"threed_rotate", "threed_slider",
|
||||
"gap_detector", "rotation_regressor",
|
||||
]:
|
||||
_load_and_export(name)
|
||||
print("\n全部导出完成。")
|
||||
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
"""
|
||||
模型定义包
|
||||
|
||||
提供四种模型:
|
||||
提供六种模型:
|
||||
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
|
||||
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
|
||||
- ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
|
||||
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
|
||||
- GapDetectorCNN: 滑块缺口检测 CNN (~1MB)
|
||||
- RotationRegressor: 旋转角度回归 sin/cos 编码 (~2MB)
|
||||
"""
|
||||
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
|
||||
__all__ = [
|
||||
"CaptchaClassifier",
|
||||
"LiteCRNN",
|
||||
"ThreeDCNN",
|
||||
"RegressionCNN",
|
||||
"GapDetectorCNN",
|
||||
"RotationRegressor",
|
||||
]
|
||||
|
||||
82
models/gap_detector.py
Normal file
82
models/gap_detector.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
滑块缺口检测 CNN (GapDetectorCNN)
|
||||
|
||||
用于检测滑块验证码中缺口的 x 坐标位置。
|
||||
输出 sigmoid 归一化到 [0,1],推理时按图片宽度缩放回像素坐标。
|
||||
|
||||
架构:
|
||||
Conv(1→32) + BN + ReLU + Pool
|
||||
Conv(32→64) + BN + ReLU + Pool
|
||||
Conv(64→128) + BN + ReLU + Pool
|
||||
Conv(128→128) + BN + ReLU + Pool
|
||||
AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid
|
||||
|
||||
约 250K 参数,~1MB。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GapDetectorCNN(nn.Module):
|
||||
"""
|
||||
滑块缺口检测 CNN,输出缺口 x 坐标的归一化百分比 [0,1]。
|
||||
|
||||
与 RegressionCNN 架构相同,但语义上专用于滑块缺口检测,
|
||||
默认输入尺寸 1x128x256 (灰度)。
|
||||
"""
|
||||
|
||||
def __init__(self, img_h: int = 128, img_w: int = 256):
|
||||
super().__init__()
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
|
||||
self.features = nn.Sequential(
|
||||
# block 1: 1 → 32, H/2, W/2
|
||||
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 2: 32 → 64, H/4, W/4
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 3: 64 → 128, H/8, W/8
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 4: 128 → 128, H/16, W/16
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
)
|
||||
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.regressor = nn.Sequential(
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(64, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, H, W) 灰度图
|
||||
|
||||
Returns:
|
||||
output: (batch, 1) sigmoid 输出 [0, 1],表示缺口 x 坐标百分比
|
||||
"""
|
||||
feat = self.features(x)
|
||||
feat = self.pool(feat) # (B, 128, 1, 1)
|
||||
feat = feat.flatten(1) # (B, 128)
|
||||
out = self.regressor(feat) # (B, 1)
|
||||
return out
|
||||
82
models/rotation_regressor.py
Normal file
82
models/rotation_regressor.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
旋转角度回归模型 (RotationRegressor)
|
||||
|
||||
用于预测旋转验证码的正确旋转角度。
|
||||
使用 sin/cos 编码避免 0°/360° 边界问题。
|
||||
RGB 输入,输出 (sin θ, cos θ) ∈ [-1,1]。
|
||||
|
||||
架构:
|
||||
Conv(3→32) + BN + ReLU + Pool
|
||||
Conv(32→64) + BN + ReLU + Pool
|
||||
Conv(64→128) + BN + ReLU + Pool
|
||||
Conv(128→256) + BN + ReLU + Pool
|
||||
AdaptiveAvgPool2d(1) → FC(256→128) → ReLU → FC(128→2) → Tanh
|
||||
|
||||
约 400K 参数,~2MB。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RotationRegressor(nn.Module):
|
||||
"""
|
||||
旋转角度回归模型。
|
||||
|
||||
RGB 输入 3x128x128,输出 (sin θ, cos θ)。
|
||||
推理时用 atan2(sin, cos) 转换为角度。
|
||||
"""
|
||||
|
||||
def __init__(self, img_h: int = 128, img_w: int = 128):
|
||||
super().__init__()
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
|
||||
self.features = nn.Sequential(
|
||||
# block 1: 3 → 32, H/2, W/2
|
||||
nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 2: 32 → 64, H/4, W/4
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 3: 64 → 128, H/8, W/8
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 4: 128 → 256, H/16, W/16
|
||||
nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
)
|
||||
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.regressor = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(128, 2),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 3, H, W) RGB 图
|
||||
|
||||
Returns:
|
||||
output: (batch, 2) → (sin θ, cos θ) ∈ [-1, 1]
|
||||
"""
|
||||
feat = self.features(x)
|
||||
feat = self.pool(feat) # (B, 256, 1, 1)
|
||||
feat = feat.flatten(1) # (B, 256)
|
||||
out = self.regressor(feat) # (B, 2)
|
||||
return out
|
||||
@@ -20,6 +20,9 @@ server = [
|
||||
"uvicorn>=0.23.0",
|
||||
"python-multipart>=0.0.6",
|
||||
]
|
||||
cv = [
|
||||
"opencv-python>=4.8.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
captcha = "cli:main"
|
||||
|
||||
17
solvers/__init__.py
Normal file
17
solvers/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
验证码求解器包
|
||||
|
||||
提供两种交互式验证码求解器:
|
||||
- SlideSolver: 滑块验证码求解 (OpenCV 优先 + CNN 兜底)
|
||||
- RotateSolver: 旋转验证码求解 (ONNX sin/cos 回归)
|
||||
"""
|
||||
|
||||
from solvers.base import BaseSolver
|
||||
from solvers.slide_solver import SlideSolver
|
||||
from solvers.rotate_solver import RotateSolver
|
||||
|
||||
__all__ = [
|
||||
"BaseSolver",
|
||||
"SlideSolver",
|
||||
"RotateSolver",
|
||||
]
|
||||
21
solvers/base.py
Normal file
21
solvers/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
求解器基类
|
||||
"""
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class BaseSolver:
|
||||
"""验证码求解器基类。"""
|
||||
|
||||
def solve(self, image: Image.Image, **kwargs) -> dict:
|
||||
"""
|
||||
求解验证码。
|
||||
|
||||
Args:
|
||||
image: 输入图片
|
||||
|
||||
Returns:
|
||||
包含求解结果的字典
|
||||
"""
|
||||
raise NotImplementedError
|
||||
80
solvers/rotate_solver.py
Normal file
80
solvers/rotate_solver.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
旋转验证码求解器
|
||||
|
||||
ONNX 推理 → (sin, cos) → atan2 → 角度
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import ONNX_DIR, SOLVER_CONFIG
|
||||
from solvers.base import BaseSolver
|
||||
|
||||
|
||||
class RotateSolver(BaseSolver):
|
||||
"""旋转验证码求解器。"""
|
||||
|
||||
def __init__(self, onnx_path: str | Path | None = None):
|
||||
self.cfg = SOLVER_CONFIG["rotate"]
|
||||
self._onnx_session = None
|
||||
self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "rotation_regressor.onnx"
|
||||
|
||||
def _load_onnx(self):
|
||||
"""延迟加载 ONNX 模型。"""
|
||||
if self._onnx_session is not None:
|
||||
return
|
||||
if not self._onnx_path.exists():
|
||||
raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}")
|
||||
import onnxruntime as ort
|
||||
self._onnx_session = ort.InferenceSession(
|
||||
str(self._onnx_path), providers=["CPUExecutionProvider"]
|
||||
)
|
||||
|
||||
def solve(self, image: Image.Image | str | Path, **kwargs) -> dict:
|
||||
"""
|
||||
求解旋转验证码。
|
||||
|
||||
Args:
|
||||
image: 输入图片 (RGB)
|
||||
|
||||
Returns:
|
||||
{"angle": float, "confidence": float}
|
||||
"""
|
||||
if isinstance(image, (str, Path)):
|
||||
image = Image.open(str(image)).convert("RGB")
|
||||
else:
|
||||
image = image.convert("RGB")
|
||||
|
||||
self._load_onnx()
|
||||
|
||||
h, w = self.cfg["input_size"]
|
||||
|
||||
# 预处理: RGB resize + normalize
|
||||
img = image.resize((w, h))
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
# Normalize per channel: (x - 0.5) / 0.5
|
||||
arr = (arr - 0.5) / 0.5
|
||||
# HWC → CHW → NCHW
|
||||
arr = arr.transpose(2, 0, 1)[np.newaxis, :, :, :]
|
||||
|
||||
outputs = self._onnx_session.run(None, {"input": arr})
|
||||
sin_val = float(outputs[0][0][0])
|
||||
cos_val = float(outputs[0][0][1])
|
||||
|
||||
# atan2 → 角度
|
||||
angle_rad = math.atan2(sin_val, cos_val)
|
||||
angle_deg = math.degrees(angle_rad)
|
||||
if angle_deg < 0:
|
||||
angle_deg += 360.0
|
||||
|
||||
# 置信度: sin^2 + cos^2 接近 1 表示预测稳定
|
||||
magnitude = math.sqrt(sin_val ** 2 + cos_val ** 2)
|
||||
confidence = min(magnitude, 1.0)
|
||||
|
||||
return {
|
||||
"angle": round(angle_deg, 1),
|
||||
"confidence": round(confidence, 3),
|
||||
}
|
||||
179
solvers/slide_solver.py
Normal file
179
solvers/slide_solver.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
滑块验证码求解器
|
||||
|
||||
三种求解方法 (按优先级):
|
||||
1. 模板匹配: 背景图 + 模板图 → Canny → matchTemplate
|
||||
2. 边缘检测: 单图 Canny → findContours → 筛选方形轮廓
|
||||
3. CNN 兜底: ONNX 推理 → sigmoid → x 百分比 → 像素
|
||||
|
||||
OpenCV 延迟导入,未安装时退化到 CNN only。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import ONNX_DIR, SOLVER_CONFIG
|
||||
from solvers.base import BaseSolver
|
||||
|
||||
|
||||
class SlideSolver(BaseSolver):
|
||||
"""滑块验证码求解器。"""
|
||||
|
||||
def __init__(self, onnx_path: str | Path | None = None):
|
||||
self.cfg = SOLVER_CONFIG["slide"]
|
||||
self._onnx_session = None
|
||||
self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "gap_detector.onnx"
|
||||
|
||||
# 检测 OpenCV 可用性
|
||||
self._cv2_available = False
|
||||
try:
|
||||
import cv2
|
||||
self._cv2_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def _load_onnx(self):
|
||||
"""延迟加载 ONNX 模型。"""
|
||||
if self._onnx_session is not None:
|
||||
return
|
||||
if not self._onnx_path.exists():
|
||||
raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}")
|
||||
import onnxruntime as ort
|
||||
self._onnx_session = ort.InferenceSession(
|
||||
str(self._onnx_path), providers=["CPUExecutionProvider"]
|
||||
)
|
||||
|
||||
def solve(
|
||||
self,
|
||||
bg_image: Image.Image | str | Path,
|
||||
template_image: Image.Image | str | Path | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
求解滑块验证码。
|
||||
|
||||
Args:
|
||||
bg_image: 背景图 (必需)
|
||||
template_image: 模板/拼图块图 (可选,有则优先模板匹配)
|
||||
|
||||
Returns:
|
||||
{"gap_x": int, "gap_x_percent": float, "confidence": float, "method": str}
|
||||
"""
|
||||
bg = self._load_image(bg_image)
|
||||
|
||||
# 方法 1: 模板匹配
|
||||
if template_image is not None and self._cv2_available:
|
||||
tpl = self._load_image(template_image)
|
||||
result = self._template_match(bg, tpl)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# 方法 2: 边缘检测
|
||||
if self._cv2_available:
|
||||
result = self._edge_detect(bg)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# 方法 3: CNN 兜底
|
||||
return self._cnn_predict(bg)
|
||||
|
||||
def _load_image(self, img: Image.Image | str | Path) -> Image.Image:
|
||||
if isinstance(img, (str, Path)):
|
||||
return Image.open(str(img)).convert("RGB")
|
||||
return img.convert("RGB")
|
||||
|
||||
def _template_match(self, bg: Image.Image, tpl: Image.Image) -> dict | None:
|
||||
"""模板匹配法。"""
|
||||
import cv2
|
||||
|
||||
bg_gray = np.array(bg.convert("L"))
|
||||
tpl_gray = np.array(tpl.convert("L"))
|
||||
|
||||
# Canny 边缘
|
||||
bg_edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"])
|
||||
tpl_edges = cv2.Canny(tpl_gray, self.cfg["canny_low"], self.cfg["canny_high"])
|
||||
|
||||
if tpl_edges.sum() == 0:
|
||||
return None
|
||||
|
||||
result = cv2.matchTemplate(bg_edges, tpl_edges, cv2.TM_CCOEFF_NORMED)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(result)
|
||||
|
||||
if max_val < 0.3:
|
||||
return None
|
||||
|
||||
gap_x = max_loc[0] + tpl_gray.shape[1] // 2
|
||||
return {
|
||||
"gap_x": int(gap_x),
|
||||
"gap_x_percent": gap_x / bg_gray.shape[1],
|
||||
"confidence": float(max_val),
|
||||
"method": "template_match",
|
||||
}
|
||||
|
||||
def _edge_detect(self, bg: Image.Image) -> dict | None:
|
||||
"""边缘检测法:找方形轮廓。"""
|
||||
import cv2
|
||||
|
||||
bg_gray = np.array(bg.convert("L"))
|
||||
h, w = bg_gray.shape
|
||||
|
||||
edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"])
|
||||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
best = None
|
||||
best_score = 0
|
||||
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
# 面积筛选: 缺口大小在合理范围
|
||||
if area < (h * w * 0.005) or area > (h * w * 0.15):
|
||||
continue
|
||||
|
||||
x, y, cw, ch = cv2.boundingRect(cnt)
|
||||
aspect = min(cw, ch) / max(cw, ch) if max(cw, ch) > 0 else 0
|
||||
# 近似方形
|
||||
if aspect < 0.5:
|
||||
continue
|
||||
|
||||
# 评分: 面积适中 + 近似方形
|
||||
score = aspect * (area / (h * w * 0.05))
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best = (x + cw // 2, cw, ch, score)
|
||||
|
||||
if best is None:
|
||||
return None
|
||||
|
||||
gap_x, _, _, score = best
|
||||
return {
|
||||
"gap_x": int(gap_x),
|
||||
"gap_x_percent": gap_x / w,
|
||||
"confidence": min(float(score), 1.0),
|
||||
"method": "edge_detect",
|
||||
}
|
||||
|
||||
def _cnn_predict(self, bg: Image.Image) -> dict:
|
||||
"""CNN 推理兜底。"""
|
||||
self._load_onnx()
|
||||
|
||||
h, w = self.cfg["cnn_input_size"]
|
||||
orig_w = bg.width
|
||||
|
||||
# 预处理: 灰度 + resize + normalize
|
||||
img = bg.convert("L").resize((w, h))
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
arr = (arr - 0.5) / 0.5
|
||||
arr = arr[np.newaxis, np.newaxis, :, :] # (1, 1, H, W)
|
||||
|
||||
outputs = self._onnx_session.run(None, {"input": arr})
|
||||
percent = float(outputs[0][0][0])
|
||||
|
||||
gap_x = int(percent * orig_w)
|
||||
return {
|
||||
"gap_x": gap_x,
|
||||
"gap_x_percent": percent,
|
||||
"confidence": 0.5, # CNN 无置信度
|
||||
"method": "cnn",
|
||||
}
|
||||
@@ -224,3 +224,55 @@ class RegressionDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([label], dtype=torch.float32)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 旋转求解器用数据集 (sin/cos 编码)
|
||||
# ============================================================
|
||||
class RotateSolverDataset(Dataset):
|
||||
"""
|
||||
旋转求解器数据集。
|
||||
|
||||
从目录中读取 {angle}_{xxx}.png 文件,
|
||||
将角度转换为 (sin θ, cos θ) 目标。
|
||||
RGB 输入,不转灰度。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dirs: 数据目录列表
|
||||
transform: 图片预处理/增强 (RGB)
|
||||
"""
|
||||
import math
|
||||
|
||||
self.transform = transform
|
||||
self.samples: list[tuple[str, float, float]] = [] # (路径, sin, cos)
|
||||
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
for f in sorted(d.glob("*.png")):
|
||||
raw_label = f.stem.rsplit("_", 1)[0]
|
||||
try:
|
||||
angle = float(raw_label)
|
||||
except ValueError:
|
||||
continue
|
||||
rad = math.radians(angle)
|
||||
self.samples.append((str(f), math.sin(rad), math.cos(rad)))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
import torch
|
||||
path, sin_val, cos_val = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
|
||||
|
||||
|
||||
245
training/train_rotate_solver.py
Normal file
245
training/train_rotate_solver.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
训练旋转验证码角度回归模型 (RotationRegressor)
|
||||
|
||||
自定义训练循环 (sin/cos 编码),不复用 train_regression_utils。
|
||||
|
||||
用法: python -m training.train_rotate_solver
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_TRAIN_CONFIG,
|
||||
ROTATE_SOLVER_DATA_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
AUGMENT_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
from training.dataset import RotateSolverDataset
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _circular_mae_deg(pred_angles: np.ndarray, gt_angles: np.ndarray) -> float:
|
||||
"""循环 MAE (度数)。"""
|
||||
diff = np.abs(pred_angles - gt_angles)
|
||||
diff = np.minimum(diff, 360.0 - diff)
|
||||
return float(np.mean(diff))
|
||||
|
||||
|
||||
def _build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 训练增强 (不转灰度)。"""
|
||||
aug = AUGMENT_CONFIG
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
def _build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 验证 transform。"""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
def _export_onnx(model: nn.Module, img_h: int, img_w: int):
|
||||
"""导出 ONNX (RGB 3通道输入)。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / "rotation_regressor.onnx"
|
||||
dummy = torch.randn(1, 3, img_h, img_w)
|
||||
torch.onnx.export(
|
||||
model.cpu(),
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
def main():
|
||||
cfg = SOLVER_TRAIN_CONFIG["rotate"]
|
||||
solver_cfg = SOLVER_CONFIG["rotate"]
|
||||
img_h, img_w = solver_cfg["input_size"]
|
||||
tolerance = 5.0 # ±5°
|
||||
|
||||
_set_seed()
|
||||
device = get_device()
|
||||
|
||||
print("=" * 60)
|
||||
print("训练旋转验证码角度回归模型 (RotationRegressor)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w} RGB")
|
||||
print(f" 编码: sin/cos")
|
||||
print(f" 容差: ±{tolerance}°")
|
||||
print("=" * 60)
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = ROTATE_SOLVER_DATA_DIR
|
||||
syn_path.mkdir(parents=True, exist_ok=True)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = RotateSolverDataGenerator()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
real_dir = syn_path / "real"
|
||||
real_dir.mkdir(parents=True, exist_ok=True)
|
||||
if list(real_dir.glob("*.png")):
|
||||
data_dirs.append(str(real_dir))
|
||||
print(f"[数据] 混合真实数据: {len(list(real_dir.glob('*.png')))} 张")
|
||||
|
||||
train_transform = _build_train_transform(img_h, img_w)
|
||||
val_transform = _build_val_transform(img_h, img_w)
|
||||
|
||||
full_dataset = RotateSolverDataset(dirs=data_dirs, transform=train_transform)
|
||||
total = len(full_dataset)
|
||||
val_size = int(total * cfg["val_split"])
|
||||
train_size = total - val_size
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
|
||||
val_ds_clean = RotateSolverDataset(dirs=data_dirs, transform=val_transform)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
# ---- 3. 模型 / 优化器 / 调度器 ----
|
||||
model = RotationRegressor(img_h=img_h, img_w=img_w).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
best_mae = float("inf")
|
||||
best_tol_acc = 0.0
|
||||
ckpt_path = CHECKPOINTS_DIR / "rotation_regressor.pth"
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device) # (B, 2) → (sin, cos)
|
||||
|
||||
preds = model(images) # (B, 2)
|
||||
loss = loss_fn(preds, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
# ---- 5. 验证 ----
|
||||
model.eval()
|
||||
all_pred_angles = []
|
||||
all_gt_angles = []
|
||||
with torch.no_grad():
|
||||
for images, targets in val_loader:
|
||||
images = images.to(device)
|
||||
preds = model(images).cpu().numpy() # (B, 2)
|
||||
targets_np = targets.numpy() # (B, 2)
|
||||
|
||||
# sin/cos → angle
|
||||
for i in range(len(preds)):
|
||||
pred_angle = math.degrees(math.atan2(preds[i][0], preds[i][1]))
|
||||
if pred_angle < 0:
|
||||
pred_angle += 360.0
|
||||
gt_angle = math.degrees(math.atan2(targets_np[i][0], targets_np[i][1]))
|
||||
if gt_angle < 0:
|
||||
gt_angle += 360.0
|
||||
all_pred_angles.append(pred_angle)
|
||||
all_gt_angles.append(gt_angle)
|
||||
|
||||
pred_arr = np.array(all_pred_angles)
|
||||
gt_arr = np.array(all_gt_angles)
|
||||
|
||||
mae = _circular_mae_deg(pred_arr, gt_arr)
|
||||
diff = np.abs(pred_arr - gt_arr)
|
||||
diff = np.minimum(diff, 360.0 - diff)
|
||||
tol_acc = float(np.mean(diff <= tolerance))
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"MAE={mae:.2f}° "
|
||||
f"tol_acc(±{tolerance:.0f}°)={tol_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
# ---- 6. 保存最佳模型 ----
|
||||
if tol_acc >= best_tol_acc:
|
||||
best_tol_acc = tol_acc
|
||||
best_mae = mae
|
||||
torch.save({
|
||||
"model_state_dict": model.state_dict(),
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")
|
||||
|
||||
# ---- 7. 导出 ONNX ----
|
||||
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}°")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, img_h, img_w)
|
||||
|
||||
return best_tol_acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
65
training/train_slide.py
Normal file
65
training/train_slide.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
训练滑块缺口检测 CNN (GapDetectorCNN)
|
||||
|
||||
复用 train_regression_utils 的通用回归训练流程。
|
||||
|
||||
用法: python -m training.train_slide
|
||||
"""
|
||||
|
||||
from config import (
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_TRAIN_CONFIG,
|
||||
SOLVER_REGRESSION_RANGE,
|
||||
SLIDE_DATA_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from generators.slide_gen import SlideDataGenerator
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
|
||||
# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE
|
||||
# 以便复用 train_regression_utils
|
||||
import config as _cfg
|
||||
|
||||
|
||||
def main():
|
||||
solver_cfg = SOLVER_CONFIG["slide"]
|
||||
train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"]
|
||||
img_h, img_w = solver_cfg["cnn_input_size"]
|
||||
|
||||
model = GapDetectorCNN(img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练滑块缺口检测 CNN (GapDetectorCNN)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print(f" 任务: 预测缺口 x 坐标百分比")
|
||||
print("=" * 60)
|
||||
|
||||
# 直接使用 train_regression_utils 中的逻辑
|
||||
# 但需要临时注入配置
|
||||
_cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg
|
||||
_cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w)
|
||||
_cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"]
|
||||
|
||||
from training.train_regression_utils import train_regression_model
|
||||
|
||||
# 确保数据目录存在
|
||||
SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
real_dir = SLIDE_DATA_DIR / "real"
|
||||
real_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_regression_model(
|
||||
model_name="gap_detector",
|
||||
model=model,
|
||||
synthetic_dir=str(SLIDE_DATA_DIR),
|
||||
real_dir=str(real_dir),
|
||||
generator_cls=SlideDataGenerator,
|
||||
config_key="slide_cnn",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
utils/__init__.py
Normal file
3
utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
工具函数包
|
||||
"""
|
||||
75
utils/slide_utils.py
Normal file
75
utils/slide_utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
滑块轨迹生成工具
|
||||
|
||||
生成模拟人类操作的滑块拖拽轨迹。
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
|
||||
def generate_slide_track(
|
||||
distance: int,
|
||||
duration: float = 1.0,
|
||||
seed: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
生成滑块拖拽轨迹。
|
||||
|
||||
使用贝塞尔曲线 ease-out 加速减速,末尾带微小过冲回退。
|
||||
y 轴 ±1~3px 随机抖动,时间间隔不均匀。
|
||||
|
||||
Args:
|
||||
distance: 滑动距离 (像素)
|
||||
duration: 总时长 (秒)
|
||||
seed: 随机种子
|
||||
|
||||
Returns:
|
||||
[{"x": float, "y": float, "t": float}, ...]
|
||||
"""
|
||||
rng = random.Random(seed)
|
||||
|
||||
if distance <= 0:
|
||||
return [{"x": 0.0, "y": 0.0, "t": 0.0}]
|
||||
|
||||
track = []
|
||||
# 采样点数
|
||||
num_points = rng.randint(30, 60)
|
||||
total_ms = duration * 1000
|
||||
|
||||
# 生成不均匀时间点
|
||||
raw_times = sorted([rng.random() for _ in range(num_points - 2)])
|
||||
times = [0.0] + raw_times + [1.0]
|
||||
|
||||
# 过冲距离
|
||||
overshoot = rng.uniform(2, 6)
|
||||
overshoot_start = 0.85 # 85% 时到达目标 + 过冲
|
||||
|
||||
for t_norm in times:
|
||||
t_ms = round(t_norm * total_ms, 1)
|
||||
|
||||
if t_norm <= overshoot_start:
|
||||
# ease-out: 快速启动,缓慢减速
|
||||
progress = t_norm / overshoot_start
|
||||
eased = 1 - (1 - progress) ** 3 # cubic ease-out
|
||||
x = eased * (distance + overshoot)
|
||||
else:
|
||||
# 过冲回退段
|
||||
retract_progress = (t_norm - overshoot_start) / (1 - overshoot_start)
|
||||
eased_retract = retract_progress ** 2 # ease-in 回退
|
||||
x = (distance + overshoot) - overshoot * eased_retract
|
||||
|
||||
# y 轴随机抖动
|
||||
y_jitter = rng.uniform(-3, 3) if t_norm > 0.05 else 0.0
|
||||
|
||||
track.append({
|
||||
"x": round(x, 1),
|
||||
"y": round(y_jitter, 1),
|
||||
"t": t_ms,
|
||||
})
|
||||
|
||||
# 确保最后一个点精确到达目标
|
||||
track[-1]["x"] = float(distance)
|
||||
track[-1]["y"] = round(rng.uniform(-0.5, 0.5), 1)
|
||||
|
||||
return track
|
||||
Reference in New Issue
Block a user