From 9b5f29083ee9d642c0c5cbe455a92c2238b1dc2b Mon Sep 17 00:00:00 2001 From: Hua Date: Wed, 11 Mar 2026 18:07:06 +0800 Subject: [PATCH] Add slide and rotate interactive captcha solvers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New solver subsystem with independent models: - GapDetectorCNN (1x128x256 grayscale → sigmoid) for slide gap detection - RotationRegressor (3x128x128 RGB → sin/cos via tanh) for rotation angle prediction - SlideSolver with 3-tier strategy: template match → edge detect → CNN fallback - RotateSolver with ONNX sin/cos → atan2 inference - Generators, training scripts, CLI commands, and slide track utility Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 90 +++++++++++- cli.py | 108 ++++++++++++++ config.py | 43 ++++++ generators/__init__.py | 8 +- generators/rotate_solver_gen.py | 156 ++++++++++++++++++++ generators/slide_gen.py | 112 +++++++++++++++ inference/export_onnx.py | 21 ++- models/__init__.py | 8 +- models/gap_detector.py | 82 +++++++++++ models/rotation_regressor.py | 82 +++++++++++ pyproject.toml | 3 + solvers/__init__.py | 17 +++ solvers/base.py | 21 +++ solvers/rotate_solver.py | 80 +++++++++++ solvers/slide_solver.py | 179 +++++++++++++++++++++++ training/dataset.py | 52 +++++++ training/train_rotate_solver.py | 245 ++++++++++++++++++++++++++++++++ training/train_slide.py | 65 +++++++++ utils/__init__.py | 3 + utils/slide_utils.py | 75 ++++++++++ 20 files changed, 1440 insertions(+), 10 deletions(-) create mode 100644 generators/rotate_solver_gen.py create mode 100644 generators/slide_gen.py create mode 100644 models/gap_detector.py create mode 100644 models/rotation_regressor.py create mode 100644 solvers/__init__.py create mode 100644 solvers/base.py create mode 100644 solvers/rotate_solver.py create mode 100644 solvers/slide_solver.py create mode 100644 training/train_rotate_solver.py create mode 100644 training/train_slide.py create mode 100644 utils/__init__.py create mode 100644 utils/slide_utils.py diff --git a/CLAUDE.md b/CLAUDE.md index bcf06bd..5e3fa49 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -33,7 +33,10 @@ captcha-breaker/ │ │ ├── 3d_text/ │ │ ├── 3d_rotate/ │ │ └── 3d_slider/ -│ └── classifier/ # 调度分类器训练数据 (混合各类型) +│ ├── classifier/ # 调度分类器训练数据 (混合各类型) +│ └── solver/ # Solver 训练数据 +│ ├── slide/ # 滑块缺口检测训练数据 +│ └── rotate/ # 旋转角度回归训练数据 ├── generators/ │ ├── __init__.py │ ├── base.py # 生成器基类 @@ -41,13 +44,17 @@ captcha-breaker/ │ ├── math_gen.py # 算式验证码生成器 (如 3+8=?) │ ├── threed_gen.py # 3D立体文字验证码生成器 │ ├── threed_rotate_gen.py # 3D旋转验证码生成器 -│ └── threed_slider_gen.py # 3D滑块验证码生成器 +│ ├── threed_slider_gen.py # 3D滑块验证码生成器 +│ ├── slide_gen.py # 滑块缺口训练数据生成器 +│ └── rotate_solver_gen.py # 旋转求解器训练数据生成器 ├── models/ │ ├── __init__.py │ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式) │ ├── classifier.py # 调度分类模型 │ ├── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN) -│ └── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB) +│ ├── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB) +│ ├── gap_detector.py # 滑块缺口检测CNN (~1MB) +│ └── rotation_regressor.py # 旋转角度回归 sin/cos (~2MB) ├── training/ │ ├── __init__.py │ ├── train_classifier.py # 训练调度模型 @@ -56,6 +63,8 @@ captcha-breaker/ │ ├── train_3d_text.py # 训练3D文字识别 │ ├── train_3d_rotate.py # 训练3D旋转回归 │ ├── train_3d_slider.py # 训练3D滑块回归 +│ ├── train_slide.py # 训练滑块缺口检测 +│ ├── train_rotate_solver.py # 训练旋转角度回归 │ ├── train_utils.py # CTC 训练通用逻辑 │ ├── train_regression_utils.py # 回归训练通用逻辑 │ └── dataset.py # 通用 Dataset 类 @@ -64,20 +73,32 @@ captcha-breaker/ │ ├── pipeline.py # 核心推理流水线 (调度+识别) │ ├── export_onnx.py # PyTorch → ONNX 导出脚本 │ └── math_eval.py # 算式计算模块 +├── solvers/ # 交互式验证码求解器 +│ ├── __init__.py +│ ├── base.py # 求解器基类 +│ ├── slide_solver.py # 滑块求解 (OpenCV + CNN) +│ └── rotate_solver.py # 旋转求解 (ONNX sin/cos) +├── utils/ +│ ├── __init__.py +│ └── slide_utils.py # 滑块轨迹生成工具 ├── checkpoints/ # 训练产出的模型文件 │ ├── classifier.pth │ ├── normal.pth │ ├── math.pth │ ├── threed_text.pth │ ├── threed_rotate.pth -│ └── threed_slider.pth +│ ├── threed_slider.pth +│ ├── gap_detector.pth +│ └── rotation_regressor.pth ├── onnx_models/ # 导出的 ONNX 模型 │ ├── classifier.onnx │ ├── normal.onnx │ ├── math.onnx │ ├── threed_text.onnx │ ├── threed_rotate.onnx -│ └── threed_slider.onnx +│ ├── threed_slider.onnx +│ ├── gap_detector.onnx +│ └── rotation_regressor.onnx ├── server.py # FastAPI 推理服务 (可选) ├── cli.py # 命令行入口 └── tests/ @@ -462,3 +483,62 @@ uv run python cli.py serve --port 8080 6. 实现 cli.py 统一入口 7. 可选: server.py HTTP 服务 8. 编写 tests/ + +## 交互式 Solver 扩展 + +### 概述 + +在现有验证码识别架构之上,新增滑块 (slide) 和旋转 (rotate) 两种交互式验证码求解能力。与现有 3d_rotate/3d_slider 的区别: + +- **3d_slider** (合成拼图回归) → **slide solver**: 真实滑块验证码,OpenCV 优先 + CNN 兜底 +- **3d_rotate** (合成圆盘 sigmoid 回归) → **rotate solver**: 真实旋转验证码,sin/cos 编码 + 自然图 + +每个 solver 模型独立训练、独立导出 ONNX、独立替换,互不依赖。 + +### 滑块求解器 (SlideSolver) + +- 三种方法按优先级: 模板匹配 → 边缘检测 → CNN 兜底 +- 模型: `GapDetectorCNN` (1x128x256 灰度 → sigmoid [0,1]) +- OpenCV 延迟导入,未安装时退化到 CNN only +- 输出: `{"gap_x", "gap_x_percent", "confidence", "method"}` + +### 旋转求解器 (RotateSolver) + +- ONNX 推理 → (sin, cos) → atan2 → 角度 +- 模型: `RotationRegressor` (3x128x128 RGB → tanh (sin θ, cos θ)) +- 输出: `{"angle", "confidence"}` + +### Solver CLI 用法 + +```bash +# 生成训练数据 +uv run python cli.py generate-solver slide --num 30000 +uv run python cli.py generate-solver rotate --num 50000 + +# 训练 (各模型独立) +uv run python cli.py train-solver slide +uv run python cli.py train-solver rotate + +# 求解 +uv run python cli.py solve slide --bg bg.png [--tpl tpl.png] +uv run python cli.py solve rotate --image img.png + +# 导出 (已集成到 export --all) +uv run python cli.py export --model gap_detector +uv run python cli.py export --model rotation_regressor +``` + +### 滑块轨迹生成 + +`utils/slide_utils.py` 提供 `generate_slide_track(distance)`: +- 贝塞尔曲线 ease-out 加速减速 +- y 轴 ±1~3px 随机抖动 +- 时间间隔不均匀 +- 末尾微小过冲回退 + +### Solver 目标指标 + +| 模型 | 准确率目标 | 推理延迟 | 模型体积 | +|------|-----------|---------|---------| +| 滑块 CNN (±5px) | > 85% | < 30ms | ~1MB | +| 旋转回归 (±5°) | > 85% | < 30ms | ~2MB | diff --git a/cli.py b/cli.py index faa2fe5..488ff0e 100644 --- a/cli.py +++ b/cli.py @@ -13,6 +13,11 @@ CaptchaBreaker 命令行入口 python cli.py predict image.png --type normal python cli.py predict-dir ./test_images/ python cli.py serve --port 8080 + python cli.py generate-solver slide --num 30000 + python cli.py train-solver slide + python cli.py train-solver rotate + python cli.py solve slide --bg bg.png [--tpl tpl.png] + python cli.py solve rotate --image img.png """ import argparse @@ -195,6 +200,90 @@ def cmd_serve(args): uvicorn.run(app, host=args.host, port=args.port) +def cmd_generate_solver(args): + """生成 solver 训练数据。""" + from config import SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR + from generators.slide_gen import SlideDataGenerator + from generators.rotate_solver_gen import RotateSolverDataGenerator + + solver_type = args.type + num = args.num + + gen_map = { + "slide": (SlideDataGenerator, SLIDE_DATA_DIR), + "rotate": (RotateSolverDataGenerator, ROTATE_SOLVER_DATA_DIR), + } + + if solver_type not in gen_map: + print(f"未知 solver 类型: {solver_type} 可选: {', '.join(gen_map.keys())}") + sys.exit(1) + + gen_cls, out_dir = gen_map[solver_type] + out_dir.mkdir(parents=True, exist_ok=True) + print(f"生成 solver/{solver_type} 数据: {num} 张 → {out_dir}") + gen = gen_cls() + gen.generate_dataset(num, str(out_dir)) + + +def cmd_train_solver(args): + """训练 solver 模型。""" + solver_type = args.type + + if solver_type == "slide": + from training.train_slide import main as train_fn + elif solver_type == "rotate": + from training.train_rotate_solver import main as train_fn + else: + print(f"未知 solver 类型: {solver_type} 可选: slide, rotate") + sys.exit(1) + + train_fn() + + +def cmd_solve(args): + """求解验证码。""" + solver_type = args.type + + if solver_type == "slide": + from solvers.slide_solver import SlideSolver + + bg_path = args.bg + tpl_path = getattr(args, "tpl", None) + if not Path(bg_path).exists(): + print(f"文件不存在: {bg_path}") + sys.exit(1) + + solver = SlideSolver() + result = solver.solve(bg_path, template_image=tpl_path) + + print(f"背景图: {bg_path}") + if tpl_path: + print(f"模板图: {tpl_path}") + print(f"缺口 x: {result['gap_x']} px") + print(f"缺口 x%: {result['gap_x_percent']:.4f}") + print(f"置信度: {result['confidence']:.4f}") + print(f"方法: {result['method']}") + + elif solver_type == "rotate": + from solvers.rotate_solver import RotateSolver + + image_path = args.image + if not Path(image_path).exists(): + print(f"文件不存在: {image_path}") + sys.exit(1) + + solver = RotateSolver() + result = solver.solve(image_path) + + print(f"图片: {image_path}") + print(f"角度: {result['angle']}°") + print(f"置信度: {result['confidence']}") + + else: + print(f"未知 solver 类型: {solver_type} 可选: slide, rotate") + sys.exit(1) + + def main(): parser = argparse.ArgumentParser( prog="captcha-breaker", @@ -247,6 +336,22 @@ def main(): p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)") p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)") + # ---- generate-solver ---- + p_gen_solver = subparsers.add_parser("generate-solver", help="生成 solver 训练数据") + p_gen_solver.add_argument("type", help="solver 类型: slide, rotate") + p_gen_solver.add_argument("--num", type=int, required=True, help="生成数量") + + # ---- train-solver ---- + p_train_solver = subparsers.add_parser("train-solver", help="训练 solver 模型") + p_train_solver.add_argument("type", help="solver 类型: slide, rotate") + + # ---- solve ---- + p_solve = subparsers.add_parser("solve", help="求解交互式验证码") + p_solve.add_argument("type", help="solver 类型: slide, rotate") + p_solve.add_argument("--bg", help="背景图路径 (slide 必需)") + p_solve.add_argument("--tpl", default=None, help="模板图路径 (slide 可选)") + p_solve.add_argument("--image", help="图片路径 (rotate 必需)") + args = parser.parse_args() if args.command is None: @@ -260,6 +365,9 @@ def main(): "predict": cmd_predict, "predict-dir": cmd_predict_dir, "serve": cmd_serve, + "generate-solver": cmd_generate_solver, + "train-solver": cmd_train_solver, + "solve": cmd_solve, } cmd_map[args.command](args) diff --git a/config.py b/config.py index 4a18afa..f12d9cf 100644 --- a/config.py +++ b/config.py @@ -34,6 +34,11 @@ REAL_3D_TEXT_DIR = REAL_DIR / "3d_text" REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate" REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider" +# Solver 数据目录 +SOLVER_DATA_DIR = DATA_DIR / "solver" +SLIDE_DATA_DIR = SOLVER_DATA_DIR / "slide" +ROTATE_SOLVER_DATA_DIR = SOLVER_DATA_DIR / "rotate" + # ============================================================ # 模型输出目录 # ============================================================ @@ -47,6 +52,7 @@ for _dir in [ REAL_NORMAL_DIR, REAL_MATH_DIR, REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR, CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR, + SLIDE_DATA_DIR, ROTATE_SOLVER_DATA_DIR, ]: _dir.mkdir(parents=True, exist_ok=True) @@ -241,3 +247,40 @@ SERVER_CONFIG = { "host": "0.0.0.0", "port": 8080, } + +# ============================================================ +# Solver 配置 (交互式验证码求解) +# ============================================================ +SOLVER_CONFIG = { + "slide": { + "canny_low": 50, + "canny_high": 150, + "cnn_input_size": (128, 256), # H, W + }, + "rotate": { + "input_size": (128, 128), # H, W + "channels": 3, # RGB + }, +} + +SOLVER_TRAIN_CONFIG = { + "slide_cnn": { + "epochs": 50, + "batch_size": 64, + "lr": 1e-3, + "synthetic_samples": 30000, + "val_split": 0.1, + }, + "rotate": { + "epochs": 80, + "batch_size": 64, + "lr": 5e-4, + "synthetic_samples": 50000, + "val_split": 0.1, + }, +} + +SOLVER_REGRESSION_RANGE = { + "slide": (0, 1), # 归一化百分比 + "rotate": (0, 360), # 角度 +} diff --git a/generators/__init__.py b/generators/__init__.py index dab6a9c..4c31907 100644 --- a/generators/__init__.py +++ b/generators/__init__.py @@ -1,12 +1,14 @@ """ 数据生成器包 -提供五种验证码类型的数据生成器: +提供七种验证码类型的数据生成器: - NormalCaptchaGenerator: 普通字符验证码 - MathCaptchaGenerator: 算式验证码 - ThreeDCaptchaGenerator: 3D 立体文字验证码 - ThreeDRotateGenerator: 3D 旋转验证码 - ThreeDSliderGenerator: 3D 滑块验证码 +- SlideDataGenerator: 滑块验证码求解器训练数据 +- RotateSolverDataGenerator: 旋转验证码求解器训练数据 """ from generators.base import BaseCaptchaGenerator @@ -15,6 +17,8 @@ from generators.math_gen import MathCaptchaGenerator from generators.threed_gen import ThreeDCaptchaGenerator from generators.threed_rotate_gen import ThreeDRotateGenerator from generators.threed_slider_gen import ThreeDSliderGenerator +from generators.slide_gen import SlideDataGenerator +from generators.rotate_solver_gen import RotateSolverDataGenerator __all__ = [ "BaseCaptchaGenerator", @@ -23,4 +27,6 @@ __all__ = [ "ThreeDCaptchaGenerator", "ThreeDRotateGenerator", "ThreeDSliderGenerator", + "SlideDataGenerator", + "RotateSolverDataGenerator", ] diff --git a/generators/rotate_solver_gen.py b/generators/rotate_solver_gen.py new file mode 100644 index 0000000..e414b2f --- /dev/null +++ b/generators/rotate_solver_gen.py @@ -0,0 +1,156 @@ +""" +旋转验证码求解器数据生成器 + +生成旋转验证码训练数据:随机图案 (色块/渐变/几何图形),随机旋转 0-359°。 +裁剪为圆形 (黑色背景填充圆外区域)。 + +标签 = 旋转角度 (整数) +文件名格式: {angle}_{index:06d}.png +""" + +import math +import random + +from PIL import Image, ImageDraw, ImageFilter, ImageFont + +from config import SOLVER_CONFIG +from generators.base import BaseCaptchaGenerator + +_FONT_PATHS = [ + "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf", + "/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf", + "/usr/share/fonts/liberation/LiberationSans-Bold.ttf", + "/usr/share/fonts/liberation/LiberationSerif-Bold.ttf", + "/usr/share/fonts/gnu-free/FreeSansBold.otf", +] + + +class RotateSolverDataGenerator(BaseCaptchaGenerator): + """旋转验证码求解器数据生成器。""" + + def __init__(self, seed: int | None = None): + from config import RANDOM_SEED + super().__init__(seed=seed if seed is not None else RANDOM_SEED) + + self.cfg = SOLVER_CONFIG["rotate"] + self.height, self.width = self.cfg["input_size"] # (H, W) + + self._fonts: list[str] = [] + for p in _FONT_PATHS: + try: + ImageFont.truetype(p, 20) + self._fonts.append(p) + except OSError: + continue + + def generate(self, text: str | None = None) -> tuple[Image.Image, str]: + rng = self.rng + + # 随机旋转角度 0-359 + angle = rng.randint(0, 359) + if text is None: + text = str(angle) + + size = self.width # 正方形 + radius = size // 2 + + # 1. 生成正向图案 (未旋转) + content = self._random_pattern(rng, size) + + # 2. 旋转图案 + rotated = content.rotate(-angle, resample=Image.BICUBIC, expand=False) + + # 3. 裁剪为圆形 (黑色背景) + result = Image.new("RGB", (size, size), (0, 0, 0)) + mask = Image.new("L", (size, size), 0) + mask_draw = ImageDraw.Draw(mask) + mask_draw.ellipse([0, 0, size - 1, size - 1], fill=255) + + result.paste(rotated, (0, 0), mask) + + # 4. 轻微模糊 + result = result.filter(ImageFilter.GaussianBlur(radius=0.5)) + + return result, text + + def _random_pattern(self, rng: random.Random, size: int) -> Image.Image: + """生成随机图案 (带明显方向性,便于模型学习旋转)。""" + img = Image.new("RGB", (size, size)) + draw = ImageDraw.Draw(img) + + # 渐变背景 + base_r = rng.randint(100, 220) + base_g = rng.randint(100, 220) + base_b = rng.randint(100, 220) + for y in range(size): + ratio = y / max(size - 1, 1) + r = int(base_r * (1 - ratio) + rng.randint(40, 120) * ratio) + g = int(base_g * (1 - ratio) + rng.randint(40, 120) * ratio) + b = int(base_b * (1 - ratio) + rng.randint(40, 120) * ratio) + draw.line([(0, y), (size, y)], fill=(r, g, b)) + + cx, cy = size // 2, size // 2 + + # 添加不对称几何图形 (让模型能感知方向) + pattern_type = rng.choice(["triangle", "arrow", "text", "shapes"]) + + if pattern_type == "triangle": + # 顶部三角形标记 + color = tuple(rng.randint(180, 255) for _ in range(3)) + ts = size // 4 + draw.polygon( + [(cx, cy - ts), (cx - ts // 2, cy), (cx + ts // 2, cy)], + fill=color, + ) + # 底部小圆 + draw.ellipse( + [cx - 8, cy + ts // 2, cx + 8, cy + ts // 2 + 16], + fill=tuple(rng.randint(50, 150) for _ in range(3)), + ) + + elif pattern_type == "arrow": + # 向上的箭头 + color = tuple(rng.randint(180, 255) for _ in range(3)) + arrow_len = size // 3 + draw.line([(cx, cy - arrow_len), (cx, cy + arrow_len // 2)], fill=color, width=4) + draw.polygon( + [(cx, cy - arrow_len - 5), (cx - 10, cy - arrow_len + 10), (cx + 10, cy - arrow_len + 10)], + fill=color, + ) + + elif pattern_type == "text" and self._fonts: + # 文字 (有天然方向性) + font_path = rng.choice(self._fonts) + font_size = size // 3 + try: + font = ImageFont.truetype(font_path, font_size) + ch = rng.choice("ABCDEFGHJKLMNPRSTUVWXYZ23456789") + bbox = font.getbbox(ch) + tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1] + draw.text( + (cx - tw // 2 - bbox[0], cy - th // 2 - bbox[1]), + ch, + fill=tuple(rng.randint(0, 80) for _ in range(3)), + font=font, + ) + except OSError: + pass + + else: + # 混合不对称形状 + # 上方矩形 + w, h = rng.randint(15, 30), rng.randint(10, 20) + color = tuple(rng.randint(150, 255) for _ in range(3)) + draw.rectangle([cx - w, cy - size // 3, cx + w, cy - size // 3 + h], fill=color) + # 右下小圆 + r = rng.randint(5, 12) + color2 = tuple(rng.randint(50, 150) for _ in range(3)) + draw.ellipse([cx + size // 5, cy + size // 5, cx + size // 5 + r * 2, cy + size // 5 + r * 2], fill=color2) + + # 添加纹理噪声 + for _ in range(rng.randint(20, 60)): + nx, ny = rng.randint(0, size - 1), rng.randint(0, size - 1) + nc = tuple(rng.randint(80, 220) for _ in range(3)) + draw.point((nx, ny), fill=nc) + + return img diff --git a/generators/slide_gen.py b/generators/slide_gen.py new file mode 100644 index 0000000..8260a37 --- /dev/null +++ b/generators/slide_gen.py @@ -0,0 +1,112 @@ +""" +滑块验证码数据生成器 + +生成滑块验证码训练数据:随机纹理/色块背景 + 方形缺口 + 阴影效果。 +标签 = 缺口中心 x 坐标 (整数) +文件名格式: {gap_x}_{index:06d}.png +""" + +import random + +from PIL import Image, ImageDraw, ImageFilter + +from config import SOLVER_CONFIG +from generators.base import BaseCaptchaGenerator + + +class SlideDataGenerator(BaseCaptchaGenerator): + """滑块验证码数据生成器。""" + + def __init__(self, seed: int | None = None): + from config import RANDOM_SEED + super().__init__(seed=seed if seed is not None else RANDOM_SEED) + + self.cfg = SOLVER_CONFIG["slide"] + self.height, self.width = self.cfg["cnn_input_size"] # (H, W) + self.gap_size = 40 # 缺口大小 + + def generate(self, text: str | None = None) -> tuple[Image.Image, str]: + rng = self.rng + gs = self.gap_size + + # 缺口 x 范围: 留出边距 + margin = gs + 10 + gap_x = rng.randint(margin, self.width - margin) + gap_y = rng.randint(10, self.height - gs - 10) + + if text is None: + text = str(gap_x) + + # 1. 生成纹理背景 + img = self._textured_background(rng) + + # 2. 绘制缺口 (半透明灰色区域 + 阴影) + overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) + overlay_draw = ImageDraw.Draw(overlay) + + # 阴影 (稍大一圈) + overlay_draw.rectangle( + [gap_x + 2, gap_y + 2, gap_x + gs + 2, gap_y + gs + 2], + fill=(0, 0, 0, 60), + ) + # 缺口本体 + overlay_draw.rectangle( + [gap_x, gap_y, gap_x + gs, gap_y + gs], + fill=(80, 80, 80, 160), + outline=(60, 60, 60, 200), + width=2, + ) + + img = img.convert("RGBA") + img = Image.alpha_composite(img, overlay) + img = img.convert("RGB") + + # 3. 轻微模糊 + img = img.filter(ImageFilter.GaussianBlur(radius=0.3)) + + return img, text + + def _textured_background(self, rng: random.Random) -> Image.Image: + """生成带纹理的彩色背景。""" + img = Image.new("RGB", (self.width, self.height)) + draw = ImageDraw.Draw(img) + + # 渐变底色 + base_r = rng.randint(80, 200) + base_g = rng.randint(80, 200) + base_b = rng.randint(80, 200) + for y in range(self.height): + ratio = y / max(self.height - 1, 1) + r = int(base_r + 40 * ratio) + g = int(base_g - 20 * ratio) + b = int(base_b + 20 * ratio) + r, g, b = max(0, min(255, r)), max(0, min(255, g)), max(0, min(255, b)) + draw.line([(0, y), (self.width, y)], fill=(r, g, b)) + + # 纹理噪声 + for _ in range(self.width * self.height // 6): + x = rng.randint(0, self.width - 1) + y = rng.randint(0, self.height - 1) + pixel = img.getpixel((x, y)) + noise = tuple( + max(0, min(255, c + rng.randint(-30, 30))) + for c in pixel + ) + draw.point((x, y), fill=noise) + + # 随机色块 (模拟图案) + for _ in range(rng.randint(4, 8)): + x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20) + x2, y2 = x1 + rng.randint(15, 50), y1 + rng.randint(10, 30) + color = tuple(rng.randint(50, 230) for _ in range(3)) + draw.rectangle([x1, y1, x2, y2], fill=color) + + # 随机圆形 + for _ in range(rng.randint(2, 5)): + cx = rng.randint(10, self.width - 10) + cy = rng.randint(10, self.height - 10) + cr = rng.randint(5, 20) + color = tuple(rng.randint(50, 230) for _ in range(3)) + draw.ellipse([cx - cr, cy - cr, cx + cr, cy + cr], fill=color) + + return img diff --git a/inference/export_onnx.py b/inference/export_onnx.py index 7701444..be77cde 100644 --- a/inference/export_onnx.py +++ b/inference/export_onnx.py @@ -18,11 +18,14 @@ from config import ( THREED_CHARS, NUM_CAPTCHA_TYPES, REGRESSION_RANGE, + SOLVER_CONFIG, ) from models.classifier import CaptchaClassifier from models.lite_crnn import LiteCRNN from models.threed_cnn import ThreeDCNN from models.regression_cnn import RegressionCNN +from models.gap_detector import GapDetectorCNN +from models.rotation_regressor import RotationRegressor def export_model( @@ -52,7 +55,7 @@ def export_model( dummy = torch.randn(1, *input_shape) # 分类器和识别器的 dynamic_axes 不同 - if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"): + if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider", "gap_detector", "rotation_regressor"): dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} else: # CTC 模型: output shape = (T, B, C) @@ -110,6 +113,14 @@ def _load_and_export(model_name: str): h, w = IMAGE_SIZE["3d_slider"] model = RegressionCNN(img_h=h, img_w=w) input_shape = (1, h, w) + elif model_name == "gap_detector": + h, w = SOLVER_CONFIG["slide"]["cnn_input_size"] + model = GapDetectorCNN(img_h=h, img_w=w) + input_shape = (1, h, w) + elif model_name == "rotation_regressor": + h, w = SOLVER_CONFIG["rotate"]["input_size"] + model = RotationRegressor(img_h=h, img_w=w) + input_shape = (3, h, w) else: print(f"[错误] 未知模型: {model_name}") return @@ -119,11 +130,15 @@ def _load_and_export(model_name: str): def export_all(): - """依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。""" + """依次导出全部模型 (含 solver 模型)。""" print("=" * 50) print("导出全部 ONNX 模型") print("=" * 50) - for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]: + for name in [ + "classifier", "normal", "math", "threed_text", + "threed_rotate", "threed_slider", + "gap_detector", "rotation_regressor", + ]: _load_and_export(name) print("\n全部导出完成。") diff --git a/models/__init__.py b/models/__init__.py index 9a8efc8..5fa558b 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,21 +1,27 @@ """ 模型定义包 -提供四种模型: +提供六种模型: - CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB) - LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB) - ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB) - RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB) +- GapDetectorCNN: 滑块缺口检测 CNN (~1MB) +- RotationRegressor: 旋转角度回归 sin/cos 编码 (~2MB) """ from models.classifier import CaptchaClassifier from models.lite_crnn import LiteCRNN from models.threed_cnn import ThreeDCNN from models.regression_cnn import RegressionCNN +from models.gap_detector import GapDetectorCNN +from models.rotation_regressor import RotationRegressor __all__ = [ "CaptchaClassifier", "LiteCRNN", "ThreeDCNN", "RegressionCNN", + "GapDetectorCNN", + "RotationRegressor", ] diff --git a/models/gap_detector.py b/models/gap_detector.py new file mode 100644 index 0000000..5db8dc4 --- /dev/null +++ b/models/gap_detector.py @@ -0,0 +1,82 @@ +""" +滑块缺口检测 CNN (GapDetectorCNN) + +用于检测滑块验证码中缺口的 x 坐标位置。 +输出 sigmoid 归一化到 [0,1],推理时按图片宽度缩放回像素坐标。 + +架构: + Conv(1→32) + BN + ReLU + Pool + Conv(32→64) + BN + ReLU + Pool + Conv(64→128) + BN + ReLU + Pool + Conv(128→128) + BN + ReLU + Pool + AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid + +约 250K 参数,~1MB。 +""" + +import torch +import torch.nn as nn + + +class GapDetectorCNN(nn.Module): + """ + 滑块缺口检测 CNN,输出缺口 x 坐标的归一化百分比 [0,1]。 + + 与 RegressionCNN 架构相同,但语义上专用于滑块缺口检测, + 默认输入尺寸 1x128x256 (灰度)。 + """ + + def __init__(self, img_h: int = 128, img_w: int = 256): + super().__init__() + self.img_h = img_h + self.img_w = img_w + + self.features = nn.Sequential( + # block 1: 1 → 32, H/2, W/2 + nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 2: 32 → 64, H/4, W/4 + nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 3: 64 → 128, H/8, W/8 + nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 4: 128 → 128, H/16, W/16 + nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + ) + + self.pool = nn.AdaptiveAvgPool2d(1) + + self.regressor = nn.Sequential( + nn.Linear(128, 64), + nn.ReLU(inplace=True), + nn.Dropout(0.2), + nn.Linear(64, 1), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (batch, 1, H, W) 灰度图 + + Returns: + output: (batch, 1) sigmoid 输出 [0, 1],表示缺口 x 坐标百分比 + """ + feat = self.features(x) + feat = self.pool(feat) # (B, 128, 1, 1) + feat = feat.flatten(1) # (B, 128) + out = self.regressor(feat) # (B, 1) + return out diff --git a/models/rotation_regressor.py b/models/rotation_regressor.py new file mode 100644 index 0000000..a38b02f --- /dev/null +++ b/models/rotation_regressor.py @@ -0,0 +1,82 @@ +""" +旋转角度回归模型 (RotationRegressor) + +用于预测旋转验证码的正确旋转角度。 +使用 sin/cos 编码避免 0°/360° 边界问题。 +RGB 输入,输出 (sin θ, cos θ) ∈ [-1,1]。 + +架构: + Conv(3→32) + BN + ReLU + Pool + Conv(32→64) + BN + ReLU + Pool + Conv(64→128) + BN + ReLU + Pool + Conv(128→256) + BN + ReLU + Pool + AdaptiveAvgPool2d(1) → FC(256→128) → ReLU → FC(128→2) → Tanh + +约 400K 参数,~2MB。 +""" + +import torch +import torch.nn as nn + + +class RotationRegressor(nn.Module): + """ + 旋转角度回归模型。 + + RGB 输入 3x128x128,输出 (sin θ, cos θ)。 + 推理时用 atan2(sin, cos) 转换为角度。 + """ + + def __init__(self, img_h: int = 128, img_w: int = 128): + super().__init__() + self.img_h = img_h + self.img_w = img_w + + self.features = nn.Sequential( + # block 1: 3 → 32, H/2, W/2 + nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 2: 32 → 64, H/4, W/4 + nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 3: 64 → 128, H/8, W/8 + nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 4: 128 → 256, H/16, W/16 + nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + ) + + self.pool = nn.AdaptiveAvgPool2d(1) + + self.regressor = nn.Sequential( + nn.Linear(256, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 2), + nn.Tanh(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (batch, 3, H, W) RGB 图 + + Returns: + output: (batch, 2) → (sin θ, cos θ) ∈ [-1, 1] + """ + feat = self.features(x) + feat = self.pool(feat) # (B, 256, 1, 1) + feat = feat.flatten(1) # (B, 256) + out = self.regressor(feat) # (B, 2) + return out diff --git a/pyproject.toml b/pyproject.toml index c671713..f770f7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ server = [ "uvicorn>=0.23.0", "python-multipart>=0.0.6", ] +cv = [ + "opencv-python>=4.8.0", +] [project.scripts] captcha = "cli:main" diff --git a/solvers/__init__.py b/solvers/__init__.py new file mode 100644 index 0000000..c485fb5 --- /dev/null +++ b/solvers/__init__.py @@ -0,0 +1,17 @@ +""" +验证码求解器包 + +提供两种交互式验证码求解器: +- SlideSolver: 滑块验证码求解 (OpenCV 优先 + CNN 兜底) +- RotateSolver: 旋转验证码求解 (ONNX sin/cos 回归) +""" + +from solvers.base import BaseSolver +from solvers.slide_solver import SlideSolver +from solvers.rotate_solver import RotateSolver + +__all__ = [ + "BaseSolver", + "SlideSolver", + "RotateSolver", +] diff --git a/solvers/base.py b/solvers/base.py new file mode 100644 index 0000000..7783b62 --- /dev/null +++ b/solvers/base.py @@ -0,0 +1,21 @@ +""" +求解器基类 +""" + +from PIL import Image + + +class BaseSolver: + """验证码求解器基类。""" + + def solve(self, image: Image.Image, **kwargs) -> dict: + """ + 求解验证码。 + + Args: + image: 输入图片 + + Returns: + 包含求解结果的字典 + """ + raise NotImplementedError diff --git a/solvers/rotate_solver.py b/solvers/rotate_solver.py new file mode 100644 index 0000000..9f1496f --- /dev/null +++ b/solvers/rotate_solver.py @@ -0,0 +1,80 @@ +""" +旋转验证码求解器 + +ONNX 推理 → (sin, cos) → atan2 → 角度 +""" + +import math +from pathlib import Path + +import numpy as np +from PIL import Image + +from config import ONNX_DIR, SOLVER_CONFIG +from solvers.base import BaseSolver + + +class RotateSolver(BaseSolver): + """旋转验证码求解器。""" + + def __init__(self, onnx_path: str | Path | None = None): + self.cfg = SOLVER_CONFIG["rotate"] + self._onnx_session = None + self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "rotation_regressor.onnx" + + def _load_onnx(self): + """延迟加载 ONNX 模型。""" + if self._onnx_session is not None: + return + if not self._onnx_path.exists(): + raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}") + import onnxruntime as ort + self._onnx_session = ort.InferenceSession( + str(self._onnx_path), providers=["CPUExecutionProvider"] + ) + + def solve(self, image: Image.Image | str | Path, **kwargs) -> dict: + """ + 求解旋转验证码。 + + Args: + image: 输入图片 (RGB) + + Returns: + {"angle": float, "confidence": float} + """ + if isinstance(image, (str, Path)): + image = Image.open(str(image)).convert("RGB") + else: + image = image.convert("RGB") + + self._load_onnx() + + h, w = self.cfg["input_size"] + + # 预处理: RGB resize + normalize + img = image.resize((w, h)) + arr = np.array(img, dtype=np.float32) / 255.0 + # Normalize per channel: (x - 0.5) / 0.5 + arr = (arr - 0.5) / 0.5 + # HWC → CHW → NCHW + arr = arr.transpose(2, 0, 1)[np.newaxis, :, :, :] + + outputs = self._onnx_session.run(None, {"input": arr}) + sin_val = float(outputs[0][0][0]) + cos_val = float(outputs[0][0][1]) + + # atan2 → 角度 + angle_rad = math.atan2(sin_val, cos_val) + angle_deg = math.degrees(angle_rad) + if angle_deg < 0: + angle_deg += 360.0 + + # 置信度: sin^2 + cos^2 接近 1 表示预测稳定 + magnitude = math.sqrt(sin_val ** 2 + cos_val ** 2) + confidence = min(magnitude, 1.0) + + return { + "angle": round(angle_deg, 1), + "confidence": round(confidence, 3), + } diff --git a/solvers/slide_solver.py b/solvers/slide_solver.py new file mode 100644 index 0000000..0045159 --- /dev/null +++ b/solvers/slide_solver.py @@ -0,0 +1,179 @@ +""" +滑块验证码求解器 + +三种求解方法 (按优先级): +1. 模板匹配: 背景图 + 模板图 → Canny → matchTemplate +2. 边缘检测: 单图 Canny → findContours → 筛选方形轮廓 +3. CNN 兜底: ONNX 推理 → sigmoid → x 百分比 → 像素 + +OpenCV 延迟导入,未安装时退化到 CNN only。 +""" + +from pathlib import Path + +import numpy as np +from PIL import Image + +from config import ONNX_DIR, SOLVER_CONFIG +from solvers.base import BaseSolver + + +class SlideSolver(BaseSolver): + """滑块验证码求解器。""" + + def __init__(self, onnx_path: str | Path | None = None): + self.cfg = SOLVER_CONFIG["slide"] + self._onnx_session = None + self._onnx_path = Path(onnx_path) if onnx_path else ONNX_DIR / "gap_detector.onnx" + + # 检测 OpenCV 可用性 + self._cv2_available = False + try: + import cv2 + self._cv2_available = True + except ImportError: + pass + + def _load_onnx(self): + """延迟加载 ONNX 模型。""" + if self._onnx_session is not None: + return + if not self._onnx_path.exists(): + raise FileNotFoundError(f"ONNX 模型不存在: {self._onnx_path}") + import onnxruntime as ort + self._onnx_session = ort.InferenceSession( + str(self._onnx_path), providers=["CPUExecutionProvider"] + ) + + def solve( + self, + bg_image: Image.Image | str | Path, + template_image: Image.Image | str | Path | None = None, + **kwargs, + ) -> dict: + """ + 求解滑块验证码。 + + Args: + bg_image: 背景图 (必需) + template_image: 模板/拼图块图 (可选,有则优先模板匹配) + + Returns: + {"gap_x": int, "gap_x_percent": float, "confidence": float, "method": str} + """ + bg = self._load_image(bg_image) + + # 方法 1: 模板匹配 + if template_image is not None and self._cv2_available: + tpl = self._load_image(template_image) + result = self._template_match(bg, tpl) + if result is not None: + return result + + # 方法 2: 边缘检测 + if self._cv2_available: + result = self._edge_detect(bg) + if result is not None: + return result + + # 方法 3: CNN 兜底 + return self._cnn_predict(bg) + + def _load_image(self, img: Image.Image | str | Path) -> Image.Image: + if isinstance(img, (str, Path)): + return Image.open(str(img)).convert("RGB") + return img.convert("RGB") + + def _template_match(self, bg: Image.Image, tpl: Image.Image) -> dict | None: + """模板匹配法。""" + import cv2 + + bg_gray = np.array(bg.convert("L")) + tpl_gray = np.array(tpl.convert("L")) + + # Canny 边缘 + bg_edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"]) + tpl_edges = cv2.Canny(tpl_gray, self.cfg["canny_low"], self.cfg["canny_high"]) + + if tpl_edges.sum() == 0: + return None + + result = cv2.matchTemplate(bg_edges, tpl_edges, cv2.TM_CCOEFF_NORMED) + _, max_val, _, max_loc = cv2.minMaxLoc(result) + + if max_val < 0.3: + return None + + gap_x = max_loc[0] + tpl_gray.shape[1] // 2 + return { + "gap_x": int(gap_x), + "gap_x_percent": gap_x / bg_gray.shape[1], + "confidence": float(max_val), + "method": "template_match", + } + + def _edge_detect(self, bg: Image.Image) -> dict | None: + """边缘检测法:找方形轮廓。""" + import cv2 + + bg_gray = np.array(bg.convert("L")) + h, w = bg_gray.shape + + edges = cv2.Canny(bg_gray, self.cfg["canny_low"], self.cfg["canny_high"]) + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + best = None + best_score = 0 + + for cnt in contours: + area = cv2.contourArea(cnt) + # 面积筛选: 缺口大小在合理范围 + if area < (h * w * 0.005) or area > (h * w * 0.15): + continue + + x, y, cw, ch = cv2.boundingRect(cnt) + aspect = min(cw, ch) / max(cw, ch) if max(cw, ch) > 0 else 0 + # 近似方形 + if aspect < 0.5: + continue + + # 评分: 面积适中 + 近似方形 + score = aspect * (area / (h * w * 0.05)) + if score > best_score: + best_score = score + best = (x + cw // 2, cw, ch, score) + + if best is None: + return None + + gap_x, _, _, score = best + return { + "gap_x": int(gap_x), + "gap_x_percent": gap_x / w, + "confidence": min(float(score), 1.0), + "method": "edge_detect", + } + + def _cnn_predict(self, bg: Image.Image) -> dict: + """CNN 推理兜底。""" + self._load_onnx() + + h, w = self.cfg["cnn_input_size"] + orig_w = bg.width + + # 预处理: 灰度 + resize + normalize + img = bg.convert("L").resize((w, h)) + arr = np.array(img, dtype=np.float32) / 255.0 + arr = (arr - 0.5) / 0.5 + arr = arr[np.newaxis, np.newaxis, :, :] # (1, 1, H, W) + + outputs = self._onnx_session.run(None, {"input": arr}) + percent = float(outputs[0][0][0]) + + gap_x = int(percent * orig_w) + return { + "gap_x": gap_x, + "gap_x_percent": percent, + "confidence": 0.5, # CNN 无置信度 + "method": "cnn", + } diff --git a/training/dataset.py b/training/dataset.py index becf24e..4e571f2 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -224,3 +224,55 @@ class RegressionDataset(Dataset): img = self.transform(img) return img, torch.tensor([label], dtype=torch.float32) + +# ============================================================ +# 旋转求解器用数据集 (sin/cos 编码) +# ============================================================ +class RotateSolverDataset(Dataset): + """ + 旋转求解器数据集。 + + 从目录中读取 {angle}_{xxx}.png 文件, + 将角度转换为 (sin θ, cos θ) 目标。 + RGB 输入,不转灰度。 + """ + + def __init__( + self, + dirs: list[str | Path], + transform: transforms.Compose | None = None, + ): + """ + Args: + dirs: 数据目录列表 + transform: 图片预处理/增强 (RGB) + """ + import math + + self.transform = transform + self.samples: list[tuple[str, float, float]] = [] # (路径, sin, cos) + + for d in dirs: + d = Path(d) + if not d.exists(): + continue + for f in sorted(d.glob("*.png")): + raw_label = f.stem.rsplit("_", 1)[0] + try: + angle = float(raw_label) + except ValueError: + continue + rad = math.radians(angle) + self.samples.append((str(f), math.sin(rad), math.cos(rad))) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int): + import torch + path, sin_val, cos_val = self.samples[idx] + img = Image.open(path).convert("RGB") + if self.transform: + img = self.transform(img) + return img, torch.tensor([sin_val, cos_val], dtype=torch.float32) + diff --git a/training/train_rotate_solver.py b/training/train_rotate_solver.py new file mode 100644 index 0000000..c77d787 --- /dev/null +++ b/training/train_rotate_solver.py @@ -0,0 +1,245 @@ +""" +训练旋转验证码角度回归模型 (RotationRegressor) + +自定义训练循环 (sin/cos 编码),不复用 train_regression_utils。 + +用法: python -m training.train_rotate_solver +""" + +import math +import random +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, random_split +from torchvision import transforms +from tqdm import tqdm + +from config import ( + SOLVER_CONFIG, + SOLVER_TRAIN_CONFIG, + ROTATE_SOLVER_DATA_DIR, + CHECKPOINTS_DIR, + ONNX_DIR, + ONNX_CONFIG, + AUGMENT_CONFIG, + RANDOM_SEED, + get_device, +) +from generators.rotate_solver_gen import RotateSolverDataGenerator +from models.rotation_regressor import RotationRegressor +from training.dataset import RotateSolverDataset + + +def _set_seed(seed: int = RANDOM_SEED): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _circular_mae_deg(pred_angles: np.ndarray, gt_angles: np.ndarray) -> float: + """循环 MAE (度数)。""" + diff = np.abs(pred_angles - gt_angles) + diff = np.minimum(diff, 360.0 - diff) + return float(np.mean(diff)) + + +def _build_train_transform(img_h: int, img_w: int) -> transforms.Compose: + """RGB 训练增强 (不转灰度)。""" + aug = AUGMENT_CONFIG + return transforms.Compose([ + transforms.Resize((img_h, img_w)), + transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]), + transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + + +def _build_val_transform(img_h: int, img_w: int) -> transforms.Compose: + """RGB 验证 transform。""" + return transforms.Compose([ + transforms.Resize((img_h, img_w)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + + +def _export_onnx(model: nn.Module, img_h: int, img_w: int): + """导出 ONNX (RGB 3通道输入)。""" + model.eval() + onnx_path = ONNX_DIR / "rotation_regressor.onnx" + dummy = torch.randn(1, 3, img_h, img_w) + torch.onnx.export( + model.cpu(), + dummy, + str(onnx_path), + opset_version=ONNX_CONFIG["opset_version"], + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}} + if ONNX_CONFIG["dynamic_batch"] + else None, + ) + print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)") + + +def main(): + cfg = SOLVER_TRAIN_CONFIG["rotate"] + solver_cfg = SOLVER_CONFIG["rotate"] + img_h, img_w = solver_cfg["input_size"] + tolerance = 5.0 # ±5° + + _set_seed() + device = get_device() + + print("=" * 60) + print("训练旋转验证码角度回归模型 (RotationRegressor)") + print(f" 输入尺寸: {img_h}×{img_w} RGB") + print(f" 编码: sin/cos") + print(f" 容差: ±{tolerance}°") + print("=" * 60) + + # ---- 1. 检查 / 生成合成数据 ---- + syn_path = ROTATE_SOLVER_DATA_DIR + syn_path.mkdir(parents=True, exist_ok=True) + existing = list(syn_path.glob("*.png")) + if len(existing) < cfg["synthetic_samples"]: + print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...") + gen = RotateSolverDataGenerator() + gen.generate_dataset(cfg["synthetic_samples"], str(syn_path)) + else: + print(f"[数据] 合成数据已就绪: {len(existing)} 张") + + # ---- 2. 构建数据集 ---- + data_dirs = [str(syn_path)] + real_dir = syn_path / "real" + real_dir.mkdir(parents=True, exist_ok=True) + if list(real_dir.glob("*.png")): + data_dirs.append(str(real_dir)) + print(f"[数据] 混合真实数据: {len(list(real_dir.glob('*.png')))} 张") + + train_transform = _build_train_transform(img_h, img_w) + val_transform = _build_val_transform(img_h, img_w) + + full_dataset = RotateSolverDataset(dirs=data_dirs, transform=train_transform) + total = len(full_dataset) + val_size = int(total * cfg["val_split"]) + train_size = total - val_size + train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) + + val_ds_clean = RotateSolverDataset(dirs=data_dirs, transform=val_transform) + val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices] + + train_loader = DataLoader( + train_ds, batch_size=cfg["batch_size"], shuffle=True, + num_workers=0, pin_memory=True, + ) + val_loader = DataLoader( + val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, + num_workers=0, pin_memory=True, + ) + + print(f"[数据] 训练: {train_size} 验证: {val_size}") + + # ---- 3. 模型 / 优化器 / 调度器 ---- + model = RotationRegressor(img_h=img_h, img_w=img_w).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"]) + loss_fn = nn.MSELoss() + + best_mae = float("inf") + best_tol_acc = 0.0 + ckpt_path = CHECKPOINTS_DIR / "rotation_regressor.pth" + + # ---- 4. 训练循环 ---- + for epoch in range(1, cfg["epochs"] + 1): + model.train() + total_loss = 0.0 + num_batches = 0 + + pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) + for images, targets in pbar: + images = images.to(device) + targets = targets.to(device) # (B, 2) → (sin, cos) + + preds = model(images) # (B, 2) + loss = loss_fn(preds, targets) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) + optimizer.step() + + total_loss += loss.item() + num_batches += 1 + pbar.set_postfix(loss=f"{loss.item():.4f}") + + scheduler.step() + avg_loss = total_loss / max(num_batches, 1) + + # ---- 5. 验证 ---- + model.eval() + all_pred_angles = [] + all_gt_angles = [] + with torch.no_grad(): + for images, targets in val_loader: + images = images.to(device) + preds = model(images).cpu().numpy() # (B, 2) + targets_np = targets.numpy() # (B, 2) + + # sin/cos → angle + for i in range(len(preds)): + pred_angle = math.degrees(math.atan2(preds[i][0], preds[i][1])) + if pred_angle < 0: + pred_angle += 360.0 + gt_angle = math.degrees(math.atan2(targets_np[i][0], targets_np[i][1])) + if gt_angle < 0: + gt_angle += 360.0 + all_pred_angles.append(pred_angle) + all_gt_angles.append(gt_angle) + + pred_arr = np.array(all_pred_angles) + gt_arr = np.array(all_gt_angles) + + mae = _circular_mae_deg(pred_arr, gt_arr) + diff = np.abs(pred_arr - gt_arr) + diff = np.minimum(diff, 360.0 - diff) + tol_acc = float(np.mean(diff <= tolerance)) + lr = scheduler.get_last_lr()[0] + + print( + f"Epoch {epoch:3d}/{cfg['epochs']} " + f"loss={avg_loss:.4f} " + f"MAE={mae:.2f}° " + f"tol_acc(±{tolerance:.0f}°)={tol_acc:.4f} " + f"lr={lr:.6f}" + ) + + # ---- 6. 保存最佳模型 ---- + if tol_acc >= best_tol_acc: + best_tol_acc = tol_acc + best_mae = mae + torch.save({ + "model_state_dict": model.state_dict(), + "best_mae": best_mae, + "best_tol_acc": best_tol_acc, + "epoch": epoch, + }, ckpt_path) + print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}") + + # ---- 7. 导出 ONNX ---- + print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}°") + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) + model.load_state_dict(ckpt["model_state_dict"]) + _export_onnx(model, img_h, img_w) + + return best_tol_acc + + +if __name__ == "__main__": + main() diff --git a/training/train_slide.py b/training/train_slide.py new file mode 100644 index 0000000..1875c86 --- /dev/null +++ b/training/train_slide.py @@ -0,0 +1,65 @@ +""" +训练滑块缺口检测 CNN (GapDetectorCNN) + +复用 train_regression_utils 的通用回归训练流程。 + +用法: python -m training.train_slide +""" + +from config import ( + SOLVER_CONFIG, + SOLVER_TRAIN_CONFIG, + SOLVER_REGRESSION_RANGE, + SLIDE_DATA_DIR, + CHECKPOINTS_DIR, + ONNX_DIR, + ONNX_CONFIG, + RANDOM_SEED, + get_device, +) +from generators.slide_gen import SlideDataGenerator +from models.gap_detector import GapDetectorCNN + +# 注入 solver 配置到 TRAIN_CONFIG / IMAGE_SIZE / REGRESSION_RANGE +# 以便复用 train_regression_utils +import config as _cfg + + +def main(): + solver_cfg = SOLVER_CONFIG["slide"] + train_cfg = SOLVER_TRAIN_CONFIG["slide_cnn"] + img_h, img_w = solver_cfg["cnn_input_size"] + + model = GapDetectorCNN(img_h=img_h, img_w=img_w) + + print("=" * 60) + print("训练滑块缺口检测 CNN (GapDetectorCNN)") + print(f" 输入尺寸: {img_h}×{img_w}") + print(f" 任务: 预测缺口 x 坐标百分比") + print("=" * 60) + + # 直接使用 train_regression_utils 中的逻辑 + # 但需要临时注入配置 + _cfg.TRAIN_CONFIG["slide_cnn"] = train_cfg + _cfg.IMAGE_SIZE["slide_cnn"] = (img_h, img_w) + _cfg.REGRESSION_RANGE["slide_cnn"] = SOLVER_REGRESSION_RANGE["slide"] + + from training.train_regression_utils import train_regression_model + + # 确保数据目录存在 + SLIDE_DATA_DIR.mkdir(parents=True, exist_ok=True) + real_dir = SLIDE_DATA_DIR / "real" + real_dir.mkdir(parents=True, exist_ok=True) + + train_regression_model( + model_name="gap_detector", + model=model, + synthetic_dir=str(SLIDE_DATA_DIR), + real_dir=str(real_dir), + generator_cls=SlideDataGenerator, + config_key="slide_cnn", + ) + + +if __name__ == "__main__": + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..7507b9d --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,3 @@ +""" +工具函数包 +""" diff --git a/utils/slide_utils.py b/utils/slide_utils.py new file mode 100644 index 0000000..1e161d6 --- /dev/null +++ b/utils/slide_utils.py @@ -0,0 +1,75 @@ +""" +滑块轨迹生成工具 + +生成模拟人类操作的滑块拖拽轨迹。 +""" + +import math +import random + + +def generate_slide_track( + distance: int, + duration: float = 1.0, + seed: int | None = None, +) -> list[dict]: + """ + 生成滑块拖拽轨迹。 + + 使用贝塞尔曲线 ease-out 加速减速,末尾带微小过冲回退。 + y 轴 ±1~3px 随机抖动,时间间隔不均匀。 + + Args: + distance: 滑动距离 (像素) + duration: 总时长 (秒) + seed: 随机种子 + + Returns: + [{"x": float, "y": float, "t": float}, ...] + """ + rng = random.Random(seed) + + if distance <= 0: + return [{"x": 0.0, "y": 0.0, "t": 0.0}] + + track = [] + # 采样点数 + num_points = rng.randint(30, 60) + total_ms = duration * 1000 + + # 生成不均匀时间点 + raw_times = sorted([rng.random() for _ in range(num_points - 2)]) + times = [0.0] + raw_times + [1.0] + + # 过冲距离 + overshoot = rng.uniform(2, 6) + overshoot_start = 0.85 # 85% 时到达目标 + 过冲 + + for t_norm in times: + t_ms = round(t_norm * total_ms, 1) + + if t_norm <= overshoot_start: + # ease-out: 快速启动,缓慢减速 + progress = t_norm / overshoot_start + eased = 1 - (1 - progress) ** 3 # cubic ease-out + x = eased * (distance + overshoot) + else: + # 过冲回退段 + retract_progress = (t_norm - overshoot_start) / (1 - overshoot_start) + eased_retract = retract_progress ** 2 # ease-in 回退 + x = (distance + overshoot) - overshoot * eased_retract + + # y 轴随机抖动 + y_jitter = rng.uniform(-3, 3) if t_norm > 0.05 else 0.0 + + track.append({ + "x": round(x, 1), + "y": round(y_jitter, 1), + "t": t_ms, + }) + + # 确保最后一个点精确到达目标 + track[-1]["x"] = float(distance) + track[-1]["y"] = round(rng.uniform(-0.5, 0.5), 1) + + return track