Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider

Split the single "3d" captcha type into three independent expert models:
- 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN)
- 3d_rotate: rotation angle regression (new RegressionCNN, circular loss)
- 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss)

CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated
to 50000 (10000 per class). New generators, model, dataset, training
utilities, and full pipeline/export/CLI support for all subtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hua
2026-03-11 13:55:53 +08:00
parent 760b80ee5e
commit f5be7671bc
20 changed files with 1109 additions and 142 deletions

View File

@@ -1,25 +1,33 @@
# Repository Guidelines # Repository Guidelines
## Project Structure & Module Organization ## Project Structure & Module Organization
Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas, `models/` contains the classifier and expert OCR models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`. Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas (5 types: normal, math, 3d_text, 3d_rotate, 3d_slider), `models/` contains the classifier, CTC expert models, and regression models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`.
## Build, Test, and Development Commands ## Build, Test, and Development Commands
Use `uv` for environment and dependency management. Use `uv` for environment and dependency management.
- `uv sync` installs the base runtime dependencies from `pyproject.toml`. - `uv sync` installs the base runtime dependencies from `pyproject.toml`.
- `uv sync --extra server` installs HTTP service dependencies. - `uv sync --extra server` installs HTTP service dependencies.
- `uv run captcha generate --type normal --num 1000` generates synthetic training data. - `uv run captcha generate --type normal --num 1000` generates synthetic training data. Types: `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, `classifier`.
- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d -> classifier`. - `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d_text -> 3d_rotate -> 3d_slider -> classifier`.
- `uv run captcha export --all` exports all trained models to ONNX. - `uv run captcha export --all` exports all trained models to ONNX.
- `uv run captcha export --model 3d_text` exports a single model; `3d_text` is automatically mapped to `threed_text`.
- `uv run captcha predict image.png` runs auto-routing inference; add `--type normal` to skip classification. - `uv run captcha predict image.png` runs auto-routing inference; add `--type normal` to skip classification.
- `uv run captcha predict-dir ./test_images` runs batch inference on a directory. - `uv run captcha predict-dir ./test_images` runs batch inference on a directory.
- `uv run captcha serve --port 8080` starts the optional HTTP API when `server.py` is implemented. - `uv run captcha serve --port 8080` starts the optional HTTP API when `server.py` is implemented.
## Coding Style & Naming Conventions ## Coding Style & Naming Conventions
Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d`, and `classifier`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding unless the pipeline is intentionally redesigned. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`. Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Checkpoint/ONNX file names use `threed_text`, `threed_rotate`, `threed_slider` (underscored, no hyphens). Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding for OCR models. Regression models (3d_rotate, 3d_slider) output sigmoid [0,1] scaled by `REGRESSION_RANGE`. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`.
## Training & Data Rules
- All training scripts must set the global random seed (`random`, `numpy`, `torch`) via `config.RANDOM_SEED` before training begins.
- All DataLoaders use `num_workers=0` for cross-platform consistency.
- Generator parameters (rotation, noise, shadow, etc.) must come from `config.GENERATE_CONFIG`, not hardcoded values.
- `CRNNDataset` emits a `warnings.warn` when a label contains characters outside the configured charset, rather than silently dropping them.
- `RegressionDataset` parses numeric labels from filenames and normalizes to [0,1] via `label_range`.
## Data & Testing Guidelines ## Data & Testing Guidelines
Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing. Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. For regression types, label is the numeric value (angle or offset). Sample targets are defined in `config.py`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing.
## Commit & Pull Request Guidelines ## Commit & Pull Request Guidelines
Git history is not available in this workspace snapshot, so use short imperative commit subjects such as `Add classifier export smoke test`. Keep pull requests focused, describe affected modules, list the commands you ran, and attach sample outputs when prediction behavior changes. Git history is not available in this workspace snapshot, so use short imperative commit subjects such as `Add classifier export smoke test`. Keep pull requests focused, describe affected modules, list the commands you ran, and attach sample outputs when prediction behavior changes.

151
CLAUDE.md
View File

@@ -24,29 +24,40 @@ captcha-breaker/
│ ├── synthetic/ # 合成训练数据 (自动生成,不入 git) │ ├── synthetic/ # 合成训练数据 (自动生成,不入 git)
│ │ ├── normal/ # 普通字符型 │ │ ├── normal/ # 普通字符型
│ │ ├── math/ # 算式型 │ │ ├── math/ # 算式型
│ │ ── 3d/ # 3D立体型 │ │ ── 3d_text/ # 3D立体文字
│ │ ├── 3d_rotate/ # 3D旋转型
│ │ └── 3d_slider/ # 3D滑块型
│ ├── real/ # 真实验证码样本 (手动标注) │ ├── real/ # 真实验证码样本 (手动标注)
│ │ ├── normal/ │ │ ├── normal/
│ │ ├── math/ │ │ ├── math/
│ │ ── 3d/ │ │ ── 3d_text/
│ │ ├── 3d_rotate/
│ │ └── 3d_slider/
│ └── classifier/ # 调度分类器训练数据 (混合各类型) │ └── classifier/ # 调度分类器训练数据 (混合各类型)
├── generators/ ├── generators/
│ ├── __init__.py │ ├── __init__.py
│ ├── base.py # 生成器基类 │ ├── base.py # 生成器基类
│ ├── normal_gen.py # 普通字符验证码生成器 │ ├── normal_gen.py # 普通字符验证码生成器
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?) │ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
── threed_gen.py # 3D立体验证码生成器 ── threed_gen.py # 3D立体文字验证码生成器
│ ├── threed_rotate_gen.py # 3D旋转验证码生成器
│ └── threed_slider_gen.py # 3D滑块验证码生成器
├── models/ ├── models/
│ ├── __init__.py │ ├── __init__.py
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式) │ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
│ ├── classifier.py # 调度分类模型 │ ├── classifier.py # 调度分类模型
── threed_cnn.py # 3D验证码专用模型 (更深的CNN) ── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN)
│ └── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
├── training/ ├── training/
│ ├── __init__.py │ ├── __init__.py
│ ├── train_classifier.py # 训练调度模型 │ ├── train_classifier.py # 训练调度模型
│ ├── train_normal.py # 训练普通字符识别 │ ├── train_normal.py # 训练普通字符识别
│ ├── train_math.py # 训练算式识别 │ ├── train_math.py # 训练算式识别
│ ├── train_3d.py # 训练3D识别 │ ├── train_3d_text.py # 训练3D文字识别
│ ├── train_3d_rotate.py # 训练3D旋转回归
│ ├── train_3d_slider.py # 训练3D滑块回归
│ ├── train_utils.py # CTC 训练通用逻辑
│ ├── train_regression_utils.py # 回归训练通用逻辑
│ └── dataset.py # 通用 Dataset 类 │ └── dataset.py # 通用 Dataset 类
├── inference/ ├── inference/
│ ├── __init__.py │ ├── __init__.py
@@ -57,12 +68,16 @@ captcha-breaker/
│ ├── classifier.pth │ ├── classifier.pth
│ ├── normal.pth │ ├── normal.pth
│ ├── math.pth │ ├── math.pth
── threed.pth ── threed_text.pth
│ ├── threed_rotate.pth
│ └── threed_slider.pth
├── onnx_models/ # 导出的 ONNX 模型 ├── onnx_models/ # 导出的 ONNX 模型
│ ├── classifier.onnx │ ├── classifier.onnx
│ ├── normal.onnx │ ├── normal.onnx
│ ├── math.onnx │ ├── math.onnx
── threed.onnx ── threed_text.onnx
│ ├── threed_rotate.onnx
│ └── threed_slider.onnx
├── server.py # FastAPI 推理服务 (可选) ├── server.py # FastAPI 推理服务 (可选)
├── cli.py # 命令行入口 ├── cli.py # 命令行入口
└── tests/ └── tests/
@@ -78,13 +93,13 @@ captcha-breaker/
``` ```
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → 后处理 → 输出结果 输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → 后处理 → 输出结果
┌────────────────┐ ┌────────┬───┼───────┬──────────┐
▼ ▼ ▼
normal math 3d normal math 3d_text 3d_rotate 3d_slider
(CRNN) (CRNN) (CNN) (CRNN) (CRNN) (CNN) (RegCNN) (RegCNN)
│ │ │
▼ ▼ ▼
"A3B8" "3+8=?"→11 "X9K2" "A3B8" "3+8=?"→11 "X9K2" "135" "87"
``` ```
### 调度分类器 (classifier.py) ### 调度分类器 (classifier.py)
@@ -102,7 +117,7 @@ class CaptchaClassifier(nn.Module):
轻量分类器,几层卷积即可区分不同类型验证码。 轻量分类器,几层卷积即可区分不同类型验证码。
不同类型验证码视觉差异大有无运算符、3D效果等分类很容易。 不同类型验证码视觉差异大有无运算符、3D效果等分类很容易。
""" """
def __init__(self, num_types=3): def __init__(self, num_types=5):
# 4层卷积 + GAP + FC # 4层卷积 + GAP + FC
# Conv2d(1,16) -> Conv2d(16,32) -> Conv2d(32,64) -> Conv2d(64,64) # Conv2d(1,16) -> Conv2d(16,32) -> Conv2d(32,64) -> Conv2d(64,64)
# AdaptiveAvgPool2d(1) -> Linear(64, num_types) # AdaptiveAvgPool2d(1) -> Linear(64, num_types)
@@ -142,14 +157,32 @@ def eval_captcha_math(expr: str) -> str:
pass pass
``` ```
### 3D立体识别专家 (threed_cnn.py) ### 3D立体文字识别专家 (threed_cnn.py)
- 任务: 识别带 3D 透视/阴影效果的验证码 - 任务: 识别带 3D 透视/阴影效果的文字验证码
- 架构: 更深的 CNN + CRNN或 ResNet-lite backbone - 架构: 更深的 CNN + CRNN或 ResNet-lite backbone
- 输入: 灰度图 1x60x160 - 输入: 灰度图 1x60x160
- 需要更强的特征提取能力来处理透视变形和阴影 - 需要更强的特征提取能力来处理透视变形和阴影
- 模型体积目标: < 5MB - 模型体积目标: < 5MB
### 3D旋转识别专家 (regression_cnn.py - 3d_rotate 模式)
- 任务: 预测旋转验证码的正确旋转角度
- 架构: 轻量回归 CNN (4层卷积 + GAP + FC + Sigmoid)
- 输入: 灰度图 1x80x80
- 输出: [0,1] sigmoid 值,按 (0, 360) 缩放回角度
- 标签范围: 0-359°
- 模型体积目标: ~1MB
### 3D滑块识别专家 (regression_cnn.py - 3d_slider 模式)
- 任务: 预测滑块拼图缺口的 x 偏移
- 架构: 同上回归 CNN不同输入尺寸
- 输入: 灰度图 1x80x240
- 输出: [0,1] sigmoid 值,按 (10, 200) 缩放回像素偏移
- 标签范围: 10-200px
- 模型体积目标: ~1MB
## 数据生成器规范 ## 数据生成器规范
### 基类 (base.py) ### 基类 (base.py)
@@ -185,13 +218,26 @@ class BaseCaptchaGenerator:
- 标签格式: `3+8` (存储算式本身,不存结果) - 标签格式: `3+8` (存储算式本身,不存结果)
- 视觉风格: 与目标算式验证码一致 - 视觉风格: 与目标算式验证码一致
### 3D生成器 (threed_gen.py) ### 3D文字生成器 (threed_gen.py)
- 使用 Pillow 的仿射变换模拟 3D 透视 - 使用 Pillow 的仿射变换模拟 3D 透视
- 添加阴影效果 - 添加阴影效果
- 字符有深度感和倾斜 - 字符有深度感和倾斜
- 字符旋转角度由 `config.py` `GENERATE_CONFIG["3d_text"]["rotation_range"]` 统一配置
- 标签: 纯字符内容 - 标签: 纯字符内容
### 3D旋转生成器 (threed_rotate_gen.py)
- 圆盘上绘制字符 + 方向标记,随机旋转 0-359°
- 标签 = 旋转角度(整数字符串)
- 文件名格式: `{angle}_{index:06d}.png`
### 3D滑块生成器 (threed_slider_gen.py)
- 纹理背景 + 拼图缺口 + 拼图块在左侧
- 标签 = 缺口 x 坐标偏移(整数字符串)
- 文件名格式: `{offset}_{index:06d}.png`
## 训练规范 ## 训练规范
### 通用训练配置 ### 通用训练配置
@@ -204,7 +250,7 @@ TRAIN_CONFIG = {
'batch_size': 128, 'batch_size': 128,
'lr': 1e-3, 'lr': 1e-3,
'scheduler': 'cosine', 'scheduler': 'cosine',
'synthetic_samples': 30000, # 每类 10000 'synthetic_samples': 50000, # 每类 10000 × 5 类
}, },
'normal': { 'normal': {
'epochs': 50, 'epochs': 50,
@@ -222,7 +268,7 @@ TRAIN_CONFIG = {
'synthetic_samples': 60000, 'synthetic_samples': 60000,
'loss': 'CTCLoss', 'loss': 'CTCLoss',
}, },
'threed': { '3d_text': {
'epochs': 80, 'epochs': 80,
'batch_size': 64, 'batch_size': 64,
'lr': 5e-4, 'lr': 5e-4,
@@ -230,18 +276,36 @@ TRAIN_CONFIG = {
'synthetic_samples': 80000, 'synthetic_samples': 80000,
'loss': 'CTCLoss', 'loss': 'CTCLoss',
}, },
'3d_rotate': {
'epochs': 60,
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 60000,
'loss': 'SmoothL1',
},
'3d_slider': {
'epochs': 60,
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 60000,
'loss': 'SmoothL1',
},
} }
``` ```
### 训练脚本要求 ### 训练脚本要求
每个训练脚本必须: 每个训练脚本必须:
1. 检查合成数据是否已生成,没有则自动调用生成器 1. 训练开始前设置全局随机种子 (random/numpy/torch),使用 `config.RANDOM_SEED`,保证可复现
2. 支持混合真实数据 (如果 data/real/{type}/ 有文件) 2. 检查合成数据是否已生成,没有则自动调用生成器
3. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing 3. 支持混合真实数据 (如果 data/real/{type}/ 有文件)
4. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率 4. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing
5. 保存最佳模型到 checkpoints/ 5. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率 (CTC) 或 MAE, 容差准确率 (回归)
6. 训练结束自动导出 ONNX 到 onnx_models/ 6. 保存最佳模型到 checkpoints/
7. 训练结束自动导出 ONNX 到 onnx_models/
8. DataLoader 统一使用 `num_workers=0` 避免多进程兼容问题
### 数据增强策略 ### 数据增强策略
@@ -281,14 +345,14 @@ class CaptchaPipeline:
pass pass
def classify(self, image: Image.Image) -> str: def classify(self, image: Image.Image) -> str:
"""调度分类,返回类型名: 'normal' / 'math' / '3d'""" """调度分类,返回类型名: 'normal' / 'math' / '3d_text' / '3d_rotate' / '3d_slider'"""
pass pass
def solve(self, image) -> str: def solve(self, image) -> str:
""" """
完整识别流程: 完整识别流程:
1. 分类验证码类型 1. 分类验证码类型
2. 路由到对应专家模型 2. 路由到对应专家模型 (CTC 或回归)
3. 后处理 (算式型需要计算结果) 3. 后处理 (算式型需要计算结果)
4. 返回最终答案字符串 4. 返回最终答案字符串
@@ -311,7 +375,7 @@ def export_model(model, model_name, input_shape, onnx_dir='onnx_models/'):
pass pass
def export_all(): def export_all():
"""依次导出 classifier, normal, math, threed个模型""" """依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型"""
pass pass
``` ```
@@ -325,22 +389,28 @@ uv sync --extra server # 含 HTTP 服务依赖
# 生成训练数据 # 生成训练数据
uv run python cli.py generate --type normal --num 60000 uv run python cli.py generate --type normal --num 60000
uv run python cli.py generate --type math --num 60000 uv run python cli.py generate --type math --num 60000
uv run python cli.py generate --type 3d --num 80000 uv run python cli.py generate --type 3d_text --num 80000
uv run python cli.py generate --type classifier --num 30000 uv run python cli.py generate --type 3d_rotate --num 60000
uv run python cli.py generate --type 3d_slider --num 60000
uv run python cli.py generate --type classifier --num 50000
# 训练模型 # 训练模型
uv run python cli.py train --model classifier uv run python cli.py train --model classifier
uv run python cli.py train --model normal uv run python cli.py train --model normal
uv run python cli.py train --model math uv run python cli.py train --model math
uv run python cli.py train --model 3d uv run python cli.py train --model 3d_text
uv run python cli.py train --model 3d_rotate
uv run python cli.py train --model 3d_slider
uv run python cli.py train --all # 按依赖顺序全部训练 uv run python cli.py train --all # 按依赖顺序全部训练
# 导出 ONNX # 导出 ONNX
uv run python cli.py export --all uv run python cli.py export --all
uv run python cli.py export --model 3d_text # "3d_text" 自动映射为 "threed_text"
# 推理 # 推理
uv run python cli.py predict image.png # 自动分类+识别 uv run python cli.py predict image.png # 自动分类+识别
uv run python cli.py predict image.png --type normal # 跳过分类直接识别 uv run python cli.py predict image.png --type normal # 跳过分类直接识别
uv run python cli.py predict image.png --type 3d_rotate # 指定为旋转类型
uv run python cli.py predict-dir ./test_images/ # 批量识别 uv run python cli.py predict-dir ./test_images/ # 批量识别
# 启动 HTTP 服务 (需先安装 server 可选依赖) # 启动 HTTP 服务 (需先安装 server 可选依赖)
@@ -360,11 +430,12 @@ uv run python cli.py serve --port 8080
1. **所有模型用 float32 训练,导出 ONNX 时不做量化**,先保证精度 1. **所有模型用 float32 训练,导出 ONNX 时不做量化**,先保证精度
2. **CTC 解码统一用贪心解码**,不需要 beam search验证码场景贪心够用 2. **CTC 解码统一用贪心解码**,不需要 beam search验证码场景贪心够用
3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符3d 继续使用去混淆字符集 3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符3d_text 继续使用去混淆字符集
4. **算式识别分两步**: 先 OCR 识别字符串,再用规则计算,不要让模型直接输出数值 4. **算式识别分两步**: 先 OCR 识别字符串,再用规则计算,不要让模型直接输出数值
5. **生成器的随机种子**: 生成数据时设置 seed 保证可复现 5. **随机种子**: 生成数据和训练时均通过 `config.RANDOM_SEED` 设置全局种子 (random/numpy/torch)保证可复现
6. **真实数据文件名格式**: `{label}_{任意}.png`label 部分是标注内容 6. **真实数据文件名格式**: `{label}_{任意}.png`label 部分是标注内容
7. **模型保存格式**: PyTorch checkpoint 包含 model_state_dict, chars, best_acc, epoch 11. **数据集字符过滤**: `CRNNDataset` 加载标签时,若发现字符不在字符集内会发出 warning便于排查标注/字符集不匹配问题
7. **模型保存格式**: CTC checkpoint 包含 model_state_dict, chars, best_acc, epoch; 回归 checkpoint 包含 model_state_dict, label_range, best_mae, best_tol_acc, epoch
8. **不使用 GPU 特有功能**,确保 CPU 也能训练和推理 (只是慢一些) 8. **不使用 GPU 特有功能**,确保 CPU 也能训练和推理 (只是慢一些)
9. **类型扩展**: 新增验证码类型时,只需 (1) 加生成器 (2) 加专家模型 (3) 调度器加一个类别重新训练 9. **类型扩展**: 新增验证码类型时,只需 (1) 加生成器 (2) 加专家模型 (3) 调度器加一个类别重新训练
10. **文档同步**: 对项目结构、配置、架构等做出变更时,必须同步更新 CLAUDE.md 中的对应内容,保持文档与代码一致 10. **文档同步**: 对项目结构、配置、架构等做出变更时,必须同步更新 CLAUDE.md 中的对应内容,保持文档与代码一致
@@ -373,18 +444,20 @@ uv run python cli.py serve --port 8080
| 模型 | 准确率目标 | 推理延迟 | 模型体积 | | 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|------|-----------|---------|---------| |------|-----------|---------|---------|
| 调度分类器 | > 99% | < 5ms | < 500KB | | 调度分类器 (5类) | > 99% | < 5ms | < 500KB |
| 普通字符 | > 95% | < 30ms | < 2MB | | 普通字符 | > 95% | < 30ms | < 2MB |
| 算式识别 | > 93% | < 30ms | < 2MB | | 算式识别 | > 93% | < 30ms | < 2MB |
| 3D立体 | > 85% | < 50ms | < 5MB | | 3D立体文字 | > 85% | < 50ms | < 5MB |
| 全流水线 | - | < 80ms | < 10MB 总计 | | 3D旋转 (±5°) | > 85% | < 30ms | ~1MB |
| 3D滑块 (±3px) | > 90% | < 30ms | ~1MB |
| 全流水线 | - | < 80ms | < 12MB 总计 |
## 开发顺序 ## 开发顺序
1. 先实现 config.py 和 generators/ 1. 先实现 config.py 和 generators/
2. 实现 models/ 中所有模型定义 2. 实现 models/ 中所有模型定义
3. 实现 training/dataset.py 通用数据集类 3. 实现 training/dataset.py 通用数据集类
4. 按顺序训练: normal → math → 3d → classifier 4. 按顺序训练: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier
5. 实现 inference/pipeline.py 和 export_onnx.py 5. 实现 inference/pipeline.py 和 export_onnx.py
6. 实现 cli.py 统一入口 6. 实现 cli.py 统一入口
7. 可选: server.py HTTP 服务 7. 可选: server.py HTTP 服务

83
cli.py
View File

@@ -3,6 +3,9 @@ CaptchaBreaker 命令行入口
用法: 用法:
python cli.py generate --type normal --num 60000 python cli.py generate --type normal --num 60000
python cli.py generate --type 3d_text --num 80000
python cli.py generate --type 3d_rotate --num 60000
python cli.py generate --type 3d_slider --num 60000
python cli.py train --model normal python cli.py train --model normal
python cli.py train --all python cli.py train --all
python cli.py export --all python cli.py export --all
@@ -20,15 +23,21 @@ from pathlib import Path
def cmd_generate(args): def cmd_generate(args):
"""生成训练数据。""" """生成训练数据。"""
from config import ( from config import (
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR, SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR,
CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES, CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
) )
from generators import NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator from generators import (
NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator,
ThreeDRotateGenerator, ThreeDSliderGenerator,
)
gen_map = { gen_map = {
"normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR), "normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR),
"math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR), "math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR),
"3d": (ThreeDCaptchaGenerator, SYNTHETIC_3D_DIR), "3d_text": (ThreeDCaptchaGenerator, SYNTHETIC_3D_TEXT_DIR),
"3d_rotate": (ThreeDRotateGenerator, SYNTHETIC_3D_ROTATE_DIR),
"3d_slider": (ThreeDSliderGenerator, SYNTHETIC_3D_SLIDER_DIR),
} }
captcha_type = args.type captcha_type = args.type
@@ -50,25 +59,31 @@ def cmd_generate(args):
gen = gen_cls() gen = gen_cls()
gen.generate_dataset(num, str(out_dir)) gen.generate_dataset(num, str(out_dir))
else: else:
print(f"未知类型: {captcha_type} 可选: normal, math, 3d, classifier") valid = ", ".join(list(gen_map.keys()) + ["classifier"])
print(f"未知类型: {captcha_type} 可选: {valid}")
sys.exit(1) sys.exit(1)
def cmd_train(args): def cmd_train(args):
"""训练模型。""" """训练模型。"""
if args.all: if args.all:
# 按依赖顺序: normal → math → 3d → classifier print("按顺序训练全部模型: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier\n")
print("按顺序训练全部模型: normal → math → 3d → classifier\n")
from training.train_normal import main as train_normal from training.train_normal import main as train_normal
from training.train_math import main as train_math from training.train_math import main as train_math
from training.train_3d import main as train_3d from training.train_3d_text import main as train_3d_text
from training.train_3d_rotate import main as train_3d_rotate
from training.train_3d_slider import main as train_3d_slider
from training.train_classifier import main as train_classifier from training.train_classifier import main as train_classifier
train_normal() train_normal()
print("\n") print("\n")
train_math() train_math()
print("\n") print("\n")
train_3d() train_3d_text()
print("\n")
train_3d_rotate()
print("\n")
train_3d_slider()
print("\n") print("\n")
train_classifier() train_classifier()
return return
@@ -78,12 +93,16 @@ def cmd_train(args):
from training.train_normal import main as train_fn from training.train_normal import main as train_fn
elif model == "math": elif model == "math":
from training.train_math import main as train_fn from training.train_math import main as train_fn
elif model == "3d": elif model == "3d_text":
from training.train_3d import main as train_fn from training.train_3d_text import main as train_fn
elif model == "3d_rotate":
from training.train_3d_rotate import main as train_fn
elif model == "3d_slider":
from training.train_3d_slider import main as train_fn
elif model == "classifier": elif model == "classifier":
from training.train_classifier import main as train_fn from training.train_classifier import main as train_fn
else: else:
print(f"未知模型: {model} 可选: normal, math, 3d, classifier") print(f"未知模型: {model} 可选: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier")
sys.exit(1) sys.exit(1)
train_fn() train_fn()
@@ -96,7 +115,14 @@ def cmd_export(args):
if args.all: if args.all:
export_all() export_all()
elif args.model: elif args.model:
_load_and_export(args.model) # 别名映射
alias = {
"3d_text": "threed_text",
"3d_rotate": "threed_rotate",
"3d_slider": "threed_slider",
}
name = alias.get(args.model, args.model)
_load_and_export(name)
else: else:
print("请指定 --all 或 --model <name>") print("请指定 --all 或 --model <name>")
sys.exit(1) sys.exit(1)
@@ -137,19 +163,19 @@ def cmd_predict_dir(args):
sys.exit(1) sys.exit(1)
print(f"批量识别: {len(images)} 张图片\n") print(f"批量识别: {len(images)} 张图片\n")
print(f"{'文件名':<30} {'类型':<8} {'结果':<15} {'耗时(ms)':>8}") print(f"{'文件名':<30} {'类型':<10} {'结果':<15} {'耗时(ms)':>8}")
print("-" * 65) print("-" * 67)
total_ms = 0.0 total_ms = 0.0
for img_path in images: for img_path in images:
result = pipeline.solve(str(img_path), captcha_type=args.type) result = pipeline.solve(str(img_path), captcha_type=args.type)
total_ms += result["time_ms"] total_ms += result["time_ms"]
print( print(
f"{img_path.name:<30} {result['type']:<8} " f"{img_path.name:<30} {result['type']:<10} "
f"{result['result']:<15} {result['time_ms']:>8.1f}" f"{result['result']:<15} {result['time_ms']:>8.1f}"
) )
print("-" * 65) print("-" * 67)
print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms") print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms")
@@ -178,28 +204,43 @@ def main():
# ---- generate ---- # ---- generate ----
p_gen = subparsers.add_parser("generate", help="生成训练数据") p_gen = subparsers.add_parser("generate", help="生成训练数据")
p_gen.add_argument("--type", required=True, help="验证码类型: normal, math, 3d, classifier") p_gen.add_argument(
"--type", required=True,
help="验证码类型: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
)
p_gen.add_argument("--num", type=int, required=True, help="生成数量") p_gen.add_argument("--num", type=int, required=True, help="生成数量")
# ---- train ---- # ---- train ----
p_train = subparsers.add_parser("train", help="训练模型") p_train = subparsers.add_parser("train", help="训练模型")
p_train.add_argument("--model", help="模型名: normal, math, 3d, classifier") p_train.add_argument(
"--model",
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
)
p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型") p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型")
# ---- export ---- # ---- export ----
p_export = subparsers.add_parser("export", help="导出 ONNX 模型") p_export = subparsers.add_parser("export", help="导出 ONNX 模型")
p_export.add_argument("--model", help="模型名: normal, math, 3d, classifier, threed") p_export.add_argument(
"--model",
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
)
p_export.add_argument("--all", action="store_true", help="导出全部模型") p_export.add_argument("--all", action="store_true", help="导出全部模型")
# ---- predict ---- # ---- predict ----
p_pred = subparsers.add_parser("predict", help="识别单张验证码") p_pred = subparsers.add_parser("predict", help="识别单张验证码")
p_pred.add_argument("image", help="图片路径") p_pred.add_argument("image", help="图片路径")
p_pred.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d") p_pred.add_argument(
"--type", default=None,
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
)
# ---- predict-dir ---- # ---- predict-dir ----
p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码") p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码")
p_pdir.add_argument("directory", help="图片目录路径") p_pdir.add_argument("directory", help="图片目录路径")
p_pdir.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d") p_pdir.add_argument(
"--type", default=None,
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
)
# ---- serve ---- # ---- serve ----
p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务") p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务")

View File

@@ -23,12 +23,16 @@ CLASSIFIER_DIR = DATA_DIR / "classifier"
# 合成数据子目录 # 合成数据子目录
SYNTHETIC_NORMAL_DIR = SYNTHETIC_DIR / "normal" SYNTHETIC_NORMAL_DIR = SYNTHETIC_DIR / "normal"
SYNTHETIC_MATH_DIR = SYNTHETIC_DIR / "math" SYNTHETIC_MATH_DIR = SYNTHETIC_DIR / "math"
SYNTHETIC_3D_DIR = SYNTHETIC_DIR / "3d" SYNTHETIC_3D_TEXT_DIR = SYNTHETIC_DIR / "3d_text"
SYNTHETIC_3D_ROTATE_DIR = SYNTHETIC_DIR / "3d_rotate"
SYNTHETIC_3D_SLIDER_DIR = SYNTHETIC_DIR / "3d_slider"
# 真实数据子目录 # 真实数据子目录
REAL_NORMAL_DIR = REAL_DIR / "normal" REAL_NORMAL_DIR = REAL_DIR / "normal"
REAL_MATH_DIR = REAL_DIR / "math" REAL_MATH_DIR = REAL_DIR / "math"
REAL_3D_DIR = REAL_DIR / "3d" REAL_3D_TEXT_DIR = REAL_DIR / "3d_text"
REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate"
REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider"
# ============================================================ # ============================================================
# 模型输出目录 # 模型输出目录
@@ -38,8 +42,10 @@ ONNX_DIR = PROJECT_ROOT / "onnx_models"
# 确保关键目录存在 # 确保关键目录存在
for _dir in [ for _dir in [
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR, SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR,
REAL_NORMAL_DIR, REAL_MATH_DIR, REAL_3D_DIR, SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR,
REAL_NORMAL_DIR, REAL_MATH_DIR,
REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR,
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR, CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
]: ]:
_dir.mkdir(parents=True, exist_ok=True) _dir.mkdir(parents=True, exist_ok=True)
@@ -57,7 +63,7 @@ MATH_CHARS = "0123456789+-×÷=?"
THREED_CHARS = "23456789ABCDEFGHJKMNPQRSTUVWXYZ" THREED_CHARS = "23456789ABCDEFGHJKMNPQRSTUVWXYZ"
# 验证码类型列表 (调度分类器输出) # 验证码类型列表 (调度分类器输出)
CAPTCHA_TYPES = ["normal", "math", "3d"] CAPTCHA_TYPES = ["normal", "math", "3d_text", "3d_rotate", "3d_slider"]
NUM_CAPTCHA_TYPES = len(CAPTCHA_TYPES) NUM_CAPTCHA_TYPES = len(CAPTCHA_TYPES)
# ============================================================ # ============================================================
@@ -67,7 +73,9 @@ IMAGE_SIZE = {
"classifier": (64, 128), # 调度分类器输入 "classifier": (64, 128), # 调度分类器输入
"normal": (40, 120), # 普通字符识别 "normal": (40, 120), # 普通字符识别
"math": (40, 160), # 算式识别 (更宽) "math": (40, 160), # 算式识别 (更宽)
"3d": (60, 160), # 3D 立体识别 "3d_text": (60, 160), # 3D 立体文字识别
"3d_rotate": (80, 80), # 3D 旋转角度回归 (正方形)
"3d_slider": (80, 240), # 3D 滑块偏移回归
} }
# ============================================================ # ============================================================
@@ -91,11 +99,25 @@ GENERATE_CONFIG = {
"rotation_range": (-10, 10), "rotation_range": (-10, 10),
"noise_line_range": (2, 4), "noise_line_range": (2, 4),
}, },
"3d": { "3d_text": {
"char_count_range": (4, 5), "char_count_range": (4, 5),
"image_size": (160, 60), # 生成图片尺寸 (W, H) "image_size": (160, 60), # 生成图片尺寸 (W, H)
"shadow_offset": (3, 3), # 阴影偏移 "shadow_offset": (3, 3), # 阴影偏移
"perspective_intensity": 0.3, # 透视变换强度 "perspective_intensity": 0.3, # 透视变换强度
"rotation_range": (-20, 20), # 字符旋转角度
},
"3d_rotate": {
"image_size": (80, 80), # 生成图片尺寸 (W, H)
"disc_radius": 35, # 圆盘半径
"marker_size": 8, # 方向标记大小
"bg_color_range": (200, 240), # 背景色范围
},
"3d_slider": {
"image_size": (240, 80), # 生成图片尺寸 (W, H)
"puzzle_size": (40, 40), # 拼图块大小 (W, H)
"gap_x_range": (50, 200), # 缺口 x 坐标范围
"piece_left_margin": 5, # 拼图块左侧留白
"bg_noise_intensity": 30, # 背景纹理噪声强度
}, },
} }
@@ -108,7 +130,7 @@ TRAIN_CONFIG = {
"batch_size": 128, "batch_size": 128,
"lr": 1e-3, "lr": 1e-3,
"scheduler": "cosine", "scheduler": "cosine",
"synthetic_samples": 30000, # 每类 10000 "synthetic_samples": 50000, # 每类 10000 × 5 类
"val_split": 0.1, # 验证集比例 "val_split": 0.1, # 验证集比例
}, },
"normal": { "normal": {
@@ -129,7 +151,7 @@ TRAIN_CONFIG = {
"loss": "CTCLoss", "loss": "CTCLoss",
"val_split": 0.1, "val_split": 0.1,
}, },
"threed": { "3d_text": {
"epochs": 80, "epochs": 80,
"batch_size": 64, "batch_size": 64,
"lr": 5e-4, "lr": 5e-4,
@@ -138,6 +160,24 @@ TRAIN_CONFIG = {
"loss": "CTCLoss", "loss": "CTCLoss",
"val_split": 0.1, "val_split": 0.1,
}, },
"3d_rotate": {
"epochs": 60,
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 60000,
"loss": "SmoothL1",
"val_split": 0.1,
},
"3d_slider": {
"epochs": 60,
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 60000,
"loss": "SmoothL1",
"val_split": 0.1,
},
} }
# ============================================================ # ============================================================
@@ -163,6 +203,14 @@ ONNX_CONFIG = {
"dynamic_batch": True, # 支持动态 batch size "dynamic_batch": True, # 支持动态 batch size
} }
# ============================================================
# 回归模型标签范围
# ============================================================
REGRESSION_RANGE = {
"3d_rotate": (0, 360), # 旋转角度 0-359°
"3d_slider": (10, 200), # 滑块 x 偏移 (像素)
}
# ============================================================ # ============================================================
# 推理配置 # 推理配置
# ============================================================ # ============================================================

View File

@@ -1,20 +1,26 @@
""" """
数据生成器包 数据生成器包
提供种验证码类型的数据生成器: 提供种验证码类型的数据生成器:
- NormalCaptchaGenerator: 普通字符验证码 - NormalCaptchaGenerator: 普通字符验证码
- MathCaptchaGenerator: 算式验证码 - MathCaptchaGenerator: 算式验证码
- ThreeDCaptchaGenerator: 3D 立体验证码 - ThreeDCaptchaGenerator: 3D 立体文字验证码
- ThreeDRotateGenerator: 3D 旋转验证码
- ThreeDSliderGenerator: 3D 滑块验证码
""" """
from generators.base import BaseCaptchaGenerator from generators.base import BaseCaptchaGenerator
from generators.normal_gen import NormalCaptchaGenerator from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator from generators.threed_gen import ThreeDCaptchaGenerator
from generators.threed_rotate_gen import ThreeDRotateGenerator
from generators.threed_slider_gen import ThreeDSliderGenerator
__all__ = [ __all__ = [
"BaseCaptchaGenerator", "BaseCaptchaGenerator",
"NormalCaptchaGenerator", "NormalCaptchaGenerator",
"MathCaptchaGenerator", "MathCaptchaGenerator",
"ThreeDCaptchaGenerator", "ThreeDCaptchaGenerator",
"ThreeDRotateGenerator",
"ThreeDSliderGenerator",
] ]

View File

@@ -45,7 +45,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
from config import RANDOM_SEED from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED) super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["3d"] self.cfg = GENERATE_CONFIG["3d_text"]
self.chars = THREED_CHARS self.chars = THREED_CHARS
self.width, self.height = self.cfg["image_size"] self.width, self.height = self.cfg["image_size"]
@@ -154,7 +154,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
char_img = self._perspective_transform(char_img, rng) char_img = self._perspective_transform(char_img, rng)
# 随机旋转 # 随机旋转
angle = rng.randint(-20, 20) angle = rng.randint(*self.cfg["rotation_range"])
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True) char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
# 粘贴到画布 # 粘贴到画布

View File

@@ -0,0 +1,122 @@
"""
3D 旋转验证码生成器
生成旋转验证码:圆盘上绘制字符 + 方向标记,随机旋转 0-359°。
用户需将圆盘旋转到正确角度。
标签 = 旋转角度(整数)
文件名格式: {angle}_{index:06d}.png
"""
import random
from PIL import Image, ImageDraw, ImageFilter, ImageFont
from config import GENERATE_CONFIG, THREED_CHARS
from generators.base import BaseCaptchaGenerator
_FONT_PATHS = [
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
]
_DISC_COLORS = [
(180, 200, 220),
(200, 220, 200),
(220, 200, 190),
(200, 200, 220),
(210, 210, 200),
]
class ThreeDRotateGenerator(BaseCaptchaGenerator):
"""3D 旋转验证码生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["3d_rotate"]
self.chars = THREED_CHARS
self.width, self.height = self.cfg["image_size"]
self._fonts: list[str] = []
for p in _FONT_PATHS:
try:
ImageFont.truetype(p, 20)
self._fonts.append(p)
except OSError:
continue
if not self._fonts:
raise RuntimeError("未找到任何可用字体,无法生成验证码")
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
# 随机旋转角度 0-359
angle = rng.randint(0, 359)
if text is None:
text = str(angle)
# 1. 背景
bg_val = rng.randint(*self.cfg["bg_color_range"])
img = Image.new("RGB", (self.width, self.height), (bg_val, bg_val, bg_val))
draw = ImageDraw.Draw(img)
# 2. 绘制圆盘
cx, cy = self.width // 2, self.height // 2
r = self.cfg["disc_radius"]
disc_color = rng.choice(_DISC_COLORS)
draw.ellipse(
[cx - r, cy - r, cx + r, cy + r],
fill=disc_color, outline=(100, 100, 100), width=2,
)
# 3. 在圆盘上绘制字符和方向标记 (未旋转状态)
disc_img = Image.new("RGBA", (r * 2 + 4, r * 2 + 4), (0, 0, 0, 0))
disc_draw = ImageDraw.Draw(disc_img)
dc = r + 2 # disc center
# 字符 (圆盘中心)
font_path = rng.choice(self._fonts)
font_size = int(r * 0.6)
font = ImageFont.truetype(font_path, font_size)
ch = rng.choice(self.chars)
bbox = font.getbbox(ch)
tw = bbox[2] - bbox[0]
th = bbox[3] - bbox[1]
disc_draw.text(
(dc - tw // 2 - bbox[0], dc - th // 2 - bbox[1]),
ch, fill=(50, 50, 50, 255), font=font,
)
# 方向标记 (三角箭头,指向上方)
ms = self.cfg["marker_size"]
marker_y = dc - r + ms + 2
disc_draw.polygon(
[(dc, marker_y - ms), (dc - ms // 2, marker_y), (dc + ms // 2, marker_y)],
fill=(220, 60, 60, 255),
)
# 4. 旋转圆盘内容
disc_img = disc_img.rotate(-angle, resample=Image.BICUBIC, expand=False)
# 5. 粘贴到背景
paste_x = cx - dc
paste_y = cy - dc
img.paste(disc_img, (paste_x, paste_y), disc_img)
# 6. 添加少量噪点
for _ in range(rng.randint(20, 50)):
nx, ny = rng.randint(0, self.width - 1), rng.randint(0, self.height - 1)
nc = tuple(rng.randint(100, 200) for _ in range(3))
draw.point((nx, ny), fill=nc)
# 7. 轻微模糊
img = img.filter(ImageFilter.GaussianBlur(radius=0.5))
return img, text

View File

@@ -0,0 +1,113 @@
"""
3D 滑块验证码生成器
生成滑块拼图验证码:纹理背景 + 拼图缺口 + 拼图块在左侧。
用户需将拼图块滑动到缺口位置。
标签 = 缺口 x 坐标偏移(整数)
文件名格式: {offset}_{index:06d}.png
"""
import random
from PIL import Image, ImageDraw, ImageFilter
from config import GENERATE_CONFIG
from generators.base import BaseCaptchaGenerator
class ThreeDSliderGenerator(BaseCaptchaGenerator):
"""3D 滑块验证码生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["3d_slider"]
self.width, self.height = self.cfg["image_size"]
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
pw, ph = self.cfg["puzzle_size"]
gap_x_lo, gap_x_hi = self.cfg["gap_x_range"]
# 缺口位置
gap_x = rng.randint(gap_x_lo, gap_x_hi)
gap_y = rng.randint(10, self.height - ph - 10)
if text is None:
text = str(gap_x)
# 1. 生成纹理背景
img = self._textured_background(rng)
# 2. 从缺口位置截取拼图块内容
piece_content = img.crop((gap_x, gap_y, gap_x + pw, gap_y + ph)).copy()
# 3. 绘制缺口 (半透明灰色区域)
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle(
[gap_x, gap_y, gap_x + pw, gap_y + ph],
fill=(80, 80, 80, 160),
outline=(60, 60, 60, 200),
width=2,
)
img = img.convert("RGBA")
img = Image.alpha_composite(img, overlay)
img = img.convert("RGB")
# 4. 绘制拼图块在左侧
piece_x = self.cfg["piece_left_margin"]
piece_img = Image.new("RGBA", (pw + 4, ph + 4), (0, 0, 0, 0))
piece_draw = ImageDraw.Draw(piece_img)
# 阴影
piece_draw.rectangle([2, 2, pw + 3, ph + 3], fill=(0, 0, 0, 80))
# 内容
piece_img.paste(piece_content, (0, 0))
# 边框
piece_draw.rectangle([0, 0, pw - 1, ph - 1], outline=(255, 255, 255, 200), width=2)
img_rgba = img.convert("RGBA")
img_rgba.paste(piece_img, (piece_x, gap_y), piece_img)
img = img_rgba.convert("RGB")
# 5. 轻微模糊
img = img.filter(ImageFilter.GaussianBlur(radius=0.3))
return img, text
def _textured_background(self, rng: random.Random) -> Image.Image:
"""生成带纹理的彩色背景。"""
img = Image.new("RGB", (self.width, self.height))
draw = ImageDraw.Draw(img)
# 渐变底色
base_r, base_g, base_b = rng.randint(100, 180), rng.randint(100, 180), rng.randint(100, 180)
for y in range(self.height):
ratio = y / max(self.height - 1, 1)
r = int(base_r + 30 * ratio)
g = int(base_g - 20 * ratio)
b = int(base_b + 10 * ratio)
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
# 添加纹理噪声
noise_intensity = self.cfg["bg_noise_intensity"]
for _ in range(self.width * self.height // 8):
x = rng.randint(0, self.width - 1)
y = rng.randint(0, self.height - 1)
pixel = img.getpixel((x, y))
noise = tuple(
max(0, min(255, c + rng.randint(-noise_intensity, noise_intensity)))
for c in pixel
)
draw.point((x, y), fill=noise)
# 随机色块 (模拟图案)
for _ in range(rng.randint(3, 6)):
x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20)
x2, y2 = x1 + rng.randint(15, 40), y1 + rng.randint(10, 25)
color = tuple(rng.randint(60, 220) for _ in range(3))
draw.rectangle([x1, y1, x2, y2], fill=color)
return img

View File

@@ -17,10 +17,12 @@ from config import (
MATH_CHARS, MATH_CHARS,
THREED_CHARS, THREED_CHARS,
NUM_CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
REGRESSION_RANGE,
) )
from models.classifier import CaptchaClassifier from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
def export_model( def export_model(
@@ -34,7 +36,7 @@ def export_model(
Args: Args:
model: 已加载权重的 PyTorch 模型 model: 已加载权重的 PyTorch 模型
model_name: 模型名 (classifier / normal / math / threed) model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider)
input_shape: 输入形状 (C, H, W) input_shape: 输入形状 (C, H, W)
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR) onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
""" """
@@ -50,7 +52,7 @@ def export_model(
dummy = torch.randn(1, *input_shape) dummy = torch.randn(1, *input_shape)
# 分类器和识别器的 dynamic_axes 不同 # 分类器和识别器的 dynamic_axes 不同
if model_name == "classifier": if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
else: else:
# CTC 模型: output shape = (T, B, C) # CTC 模型: output shape = (T, B, C)
@@ -78,7 +80,8 @@ def _load_and_export(model_name: str):
return return
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}") acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
if model_name == "classifier": if model_name == "classifier":
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES) model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
@@ -94,11 +97,19 @@ def _load_and_export(model_name: str):
h, w = IMAGE_SIZE["math"] h, w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w) model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w) input_shape = (1, h, w)
elif model_name == "threed": elif model_name == "threed_text":
chars = ckpt.get("chars", THREED_CHARS) chars = ckpt.get("chars", THREED_CHARS)
h, w = IMAGE_SIZE["3d"] h, w = IMAGE_SIZE["3d_text"]
model = ThreeDCNN(chars=chars, img_h=h, img_w=w) model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w) input_shape = (1, h, w)
elif model_name == "threed_rotate":
h, w = IMAGE_SIZE["3d_rotate"]
model = RegressionCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed_slider":
h, w = IMAGE_SIZE["3d_slider"]
model = RegressionCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
else: else:
print(f"[错误] 未知模型: {model_name}") print(f"[错误] 未知模型: {model_name}")
return return
@@ -108,11 +119,11 @@ def _load_and_export(model_name: str):
def export_all(): def export_all():
"""依次导出 classifier, normal, math, threed个模型。""" """依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
print("=" * 50) print("=" * 50)
print("导出全部 ONNX 模型") print("导出全部 ONNX 模型")
print("=" * 50) print("=" * 50)
for name in ["classifier", "normal", "math", "threed"]: for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
_load_and_export(name) _load_and_export(name)
print("\n全部导出完成。") print("\n全部导出完成。")

View File

@@ -4,7 +4,7 @@
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。 加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
推理流程: 推理流程:
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出 输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 / 回归缩放 → 后处理 → 输出
对算式类型,解码后还会调用 math_eval 计算结果。 对算式类型,解码后还会调用 math_eval 计算结果。
""" """
@@ -23,6 +23,7 @@ from config import (
NORMAL_CHARS, NORMAL_CHARS,
MATH_CHARS, MATH_CHARS,
THREED_CHARS, THREED_CHARS,
REGRESSION_RANGE,
) )
from inference.math_eval import eval_captcha_math from inference.math_eval import eval_captcha_math
@@ -59,19 +60,24 @@ class CaptchaPipeline:
self.mean = INFERENCE_CONFIG["normalize_mean"] self.mean = INFERENCE_CONFIG["normalize_mean"]
self.std = INFERENCE_CONFIG["normalize_std"] self.std = INFERENCE_CONFIG["normalize_std"]
# 字符集映射 # 字符集映射 (仅 CTC 模型需要)
self._chars = { self._chars = {
"normal": NORMAL_CHARS, "normal": NORMAL_CHARS,
"math": MATH_CHARS, "math": MATH_CHARS,
"3d": THREED_CHARS, "3d_text": THREED_CHARS,
} }
# 回归模型类型
self._regression_types = {"3d_rotate", "3d_slider"}
# 专家模型名 → ONNX 文件名 # 专家模型名 → ONNX 文件名
self._model_files = { self._model_files = {
"classifier": "classifier.onnx", "classifier": "classifier.onnx",
"normal": "normal.onnx", "normal": "normal.onnx",
"math": "math.onnx", "math": "math.onnx",
"3d": "threed.onnx", "3d_text": "threed_text.onnx",
"3d_rotate": "threed_rotate.onnx",
"3d_slider": "threed_slider.onnx",
} }
# 加载所有可用模型 # 加载所有可用模型
@@ -116,7 +122,7 @@ class CaptchaPipeline:
def classify(self, image: Image.Image) -> str: def classify(self, image: Image.Image) -> str:
""" """
调度分类,返回类型名: 'normal' / 'math' / '3d' 调度分类,返回类型名。
Raises: Raises:
RuntimeError: 分类器模型未加载 RuntimeError: 分类器模型未加载
@@ -141,7 +147,7 @@ class CaptchaPipeline:
Args: Args:
image: PIL.Image 或文件路径 (str/Path) 或 bytes image: PIL.Image 或文件路径 (str/Path) 或 bytes
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d') captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d_text'/'3d_rotate'/'3d_slider')
Returns: Returns:
dict: { dict: {
@@ -166,22 +172,32 @@ class CaptchaPipeline:
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型" f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
) )
size_key = captcha_type # "normal"/"math"/"3d" size_key = captcha_type
inp = self.preprocess(img, IMAGE_SIZE[size_key]) inp = self.preprocess(img, IMAGE_SIZE[size_key])
session = self._sessions[captcha_type] session = self._sessions[captcha_type]
input_name = session.get_inputs()[0].name input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
# 4. CTC 贪心解码 # 4. 分支: CTC 解码 vs 回归
if captcha_type in self._regression_types:
# 回归模型: 输出 (batch, 1) sigmoid 值
output = session.run(None, {input_name: inp})[0] # (1, 1)
sigmoid_val = float(output[0, 0])
lo, hi = REGRESSION_RANGE[captcha_type]
real_val = sigmoid_val * (hi - lo) + lo
raw_text = f"{real_val:.1f}"
result = str(int(round(real_val)))
else:
# CTC 模型
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
chars = self._chars[captcha_type] chars = self._chars[captcha_type]
raw_text = self._ctc_greedy_decode(logits, chars) raw_text = self._ctc_greedy_decode(logits, chars)
# 5. 后处理 # 后处理
if captcha_type == "math": if captcha_type == "math":
try: try:
result = eval_captcha_math(raw_text) result = eval_captcha_math(raw_text)
except ValueError: except ValueError:
result = raw_text # 解析失败则返回原始文本 result = raw_text
else: else:
result = raw_text result = raw_text

View File

@@ -1,18 +1,21 @@
""" """
模型定义包 模型定义包
提供种模型: 提供种模型:
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB) - CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB) - LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
- ThreeDCNN: 3D 验证码专用模型 (ResNet-lite + BiLSTM, < 5MB) - ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
""" """
from models.classifier import CaptchaClassifier from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
__all__ = [ __all__ = [
"CaptchaClassifier", "CaptchaClassifier",
"LiteCRNN", "LiteCRNN",
"ThreeDCNN", "ThreeDCNN",
"RegressionCNN",
] ]

86
models/regression_cnn.py Normal file
View File

@@ -0,0 +1,86 @@
"""
回归 CNN 模型
3d_rotate 和 3d_slider 共用的回归模型。
输出 sigmoid 归一化到 [0,1],推理时按 label_range 缩放回原始范围。
架构:
Conv(1→32) + BN + ReLU + Pool
Conv(32→64) + BN + ReLU + Pool
Conv(64→128) + BN + ReLU + Pool
Conv(128→128) + BN + ReLU + Pool
AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid
约 250K 参数,~1MB。
"""
import torch
import torch.nn as nn
class RegressionCNN(nn.Module):
"""
轻量回归 CNN用于 3d_rotate (角度) 和 3d_slider (偏移) 预测。
输出 [0, 1] 范围的 sigmoid 值,需要按 label_range 缩放到实际范围。
"""
def __init__(self, img_h: int = 80, img_w: int = 80):
"""
Args:
img_h: 输入图片高度
img_w: 输入图片宽度
"""
super().__init__()
self.img_h = img_h
self.img_w = img_w
self.features = nn.Sequential(
# block 1: 1 → 32, H/2, W/2
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 2: 32 → 64, H/4, W/4
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 3: 64 → 128, H/8, W/8
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 4: 128 → 128, H/16, W/16
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.pool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, H, W) 灰度图
Returns:
output: (batch, 1) sigmoid 输出 [0, 1]
"""
feat = self.features(x)
feat = self.pool(feat) # (B, 128, 1, 1)
feat = feat.flatten(1) # (B, 128)
out = self.regressor(feat) # (B, 1)
return out

View File

@@ -1,10 +1,13 @@
""" """
训练脚本包 训练脚本包
- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类 - dataset.py: CRNNDataset / CaptchaDataset / RegressionDataset 通用数据集类
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model) - train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
- train_regression_utils.py: 回归训练通用逻辑 (train_regression_model)
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal) - train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
- train_math.py: 训练算式识别 (LiteCRNN - math) - train_math.py: 训练算式识别 (LiteCRNN - math)
- train_3d.py: 训练 3D 立体识别 (ThreeDCNN) - train_3d_text.py: 训练 3D 立体文字识别 (ThreeDCNN)
- train_3d_rotate.py: 训练 3D 旋转回归 (RegressionCNN)
- train_3d_slider.py: 训练 3D 滑块回归 (RegressionCNN)
- train_classifier.py: 训练调度分类器 (CaptchaClassifier) - train_classifier.py: 训练调度分类器 (CaptchaClassifier)
""" """

View File

@@ -1,16 +1,19 @@
""" """
通用 Dataset 类 通用 Dataset 类
提供种数据集: 提供种数据集:
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签) - CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码) - CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1])
文件名格式约定: {label}_{任意}.png 文件名格式约定: {label}_{任意}.png
- 分类器: label 可为任意字符,所在子目录名即为类别 - 分类器: label 可为任意字符,所在子目录名即为类别
- 识别器: label 即标注内容 (如 "A3B8""3+8") - 识别器: label 即标注内容 (如 "A3B8""3+8")
- 回归器: label 为数值 (如 "135""87")
""" """
import os import os
import warnings
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
@@ -98,7 +101,15 @@ class CRNNDataset(Dataset):
img = self.transform(img) img = self.transform(img)
# 编码标签为整数序列 # 编码标签为整数序列
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx] target = []
for c in label:
if c in self.char_to_idx:
target.append(self.char_to_idx[c])
else:
warnings.warn(
f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})",
stacklevel=2,
)
return img, target, label return img, target, label
@staticmethod @staticmethod
@@ -119,7 +130,7 @@ class CaptchaDataset(Dataset):
""" """
分类器训练数据集。 分类器训练数据集。
每个子目录名为类别名 (如 "normal", "math", "3d") 每个子目录名为类别名 (如 "normal", "math", "3d_text")
目录内所有 .png 文件属于该类。 目录内所有 .png 文件属于该类。
""" """
@@ -157,3 +168,59 @@ class CaptchaDataset(Dataset):
if self.transform: if self.transform:
img = self.transform(img) img = self.transform(img)
return img, label return img, label
# ============================================================
# 回归模型用数据集
# ============================================================
class RegressionDataset(Dataset):
"""
回归模型数据集 (3d_rotate / 3d_slider)。
从目录中读取 {value}_{xxx}.png 文件,
将 value 解析为浮点数并归一化到 [0, 1]。
"""
def __init__(
self,
dirs: list[str | Path],
label_range: tuple[float, float],
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表
label_range: (min_val, max_val) 标签原始范围
transform: 图片预处理/增强
"""
self.label_range = label_range
self.lo, self.hi = label_range
self.transform = transform
self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
raw_label = f.stem.rsplit("_", 1)[0]
try:
value = float(raw_label)
except ValueError:
continue
# 归一化到 [0, 1]
norm = (value - self.lo) / max(self.hi - self.lo, 1e-6)
norm = max(0.0, min(1.0, norm))
self.samples.append((str(f), norm))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, torch.tensor([label], dtype=torch.float32)

View File

@@ -0,0 +1,38 @@
"""
训练 3D 旋转验证码回归模型 (RegressionCNN)
用法: python -m training.train_3d_rotate
"""
from config import (
IMAGE_SIZE,
SYNTHETIC_3D_ROTATE_DIR,
REAL_3D_ROTATE_DIR,
)
from generators.threed_rotate_gen import ThreeDRotateGenerator
from models.regression_cnn import RegressionCNN
from training.train_regression_utils import train_regression_model
def main():
img_h, img_w = IMAGE_SIZE["3d_rotate"]
model = RegressionCNN(img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练 3D 旋转验证码回归模型 (RegressionCNN)")
print(f" 输入尺寸: {img_h}×{img_w}")
print(f" 任务: 预测旋转角度 0-359°")
print("=" * 60)
train_regression_model(
model_name="threed_rotate",
model=model,
synthetic_dir=SYNTHETIC_3D_ROTATE_DIR,
real_dir=REAL_3D_ROTATE_DIR,
generator_cls=ThreeDRotateGenerator,
config_key="3d_rotate",
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
"""
训练 3D 滑块验证码回归模型 (RegressionCNN)
用法: python -m training.train_3d_slider
"""
from config import (
IMAGE_SIZE,
SYNTHETIC_3D_SLIDER_DIR,
REAL_3D_SLIDER_DIR,
)
from generators.threed_slider_gen import ThreeDSliderGenerator
from models.regression_cnn import RegressionCNN
from training.train_regression_utils import train_regression_model
def main():
img_h, img_w = IMAGE_SIZE["3d_slider"]
model = RegressionCNN(img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练 3D 滑块验证码回归模型 (RegressionCNN)")
print(f" 输入尺寸: {img_h}×{img_w}")
print(f" 任务: 预测滑块偏移 x 坐标")
print("=" * 60)
train_regression_model(
model_name="threed_slider",
model=model,
synthetic_dir=SYNTHETIC_3D_SLIDER_DIR,
real_dir=REAL_3D_SLIDER_DIR,
generator_cls=ThreeDSliderGenerator,
config_key="3d_slider",
)
if __name__ == "__main__":
main()

View File

@@ -1,14 +1,14 @@
""" """
训练 3D 立体验证码识别模型 (ThreeDCNN) 训练 3D 立体文字验证码识别模型 (ThreeDCNN)
用法: python -m training.train_3d 用法: python -m training.train_3d_text
""" """
from config import ( from config import (
THREED_CHARS, THREED_CHARS,
IMAGE_SIZE, IMAGE_SIZE,
SYNTHETIC_3D_DIR, SYNTHETIC_3D_TEXT_DIR,
REAL_3D_DIR, REAL_3D_TEXT_DIR,
) )
from generators.threed_gen import ThreeDCaptchaGenerator from generators.threed_gen import ThreeDCaptchaGenerator
from models.threed_cnn import ThreeDCNN from models.threed_cnn import ThreeDCNN
@@ -16,23 +16,23 @@ from training.train_utils import train_ctc_model
def main(): def main():
img_h, img_w = IMAGE_SIZE["3d"] img_h, img_w = IMAGE_SIZE["3d_text"]
model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w) model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60) print("=" * 60)
print("训练 3D 立体验证码识别模型 (ThreeDCNN)") print("训练 3D 立体文字验证码识别模型 (ThreeDCNN)")
print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)") print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}") print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60) print("=" * 60)
train_ctc_model( train_ctc_model(
model_name="threed", model_name="threed_text",
model=model, model=model,
chars=THREED_CHARS, chars=THREED_CHARS,
synthetic_dir=SYNTHETIC_3D_DIR, synthetic_dir=SYNTHETIC_3D_TEXT_DIR,
real_dir=REAL_3D_DIR, real_dir=REAL_3D_TEXT_DIR,
generator_cls=ThreeDCaptchaGenerator, generator_cls=ThreeDCaptchaGenerator,
config_key="threed", config_key="3d_text",
) )

View File

@@ -1,16 +1,18 @@
""" """
训练调度分类器 (CaptchaClassifier) 训练调度分类器 (CaptchaClassifier)
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。 从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider
数据来源: data/classifier/ 目录 (按类型子目录组织) 数据来源: data/classifier/ 目录 (按类型子目录组织)
用法: python -m training.train_classifier 用法: python -m training.train_classifier
""" """
import os import os
import random
import shutil import shutil
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
@@ -24,15 +26,20 @@ from config import (
CLASSIFIER_DIR, CLASSIFIER_DIR,
SYNTHETIC_NORMAL_DIR, SYNTHETIC_NORMAL_DIR,
SYNTHETIC_MATH_DIR, SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_DIR, SYNTHETIC_3D_TEXT_DIR,
SYNTHETIC_3D_ROTATE_DIR,
SYNTHETIC_3D_SLIDER_DIR,
CHECKPOINTS_DIR, CHECKPOINTS_DIR,
ONNX_DIR, ONNX_DIR,
ONNX_CONFIG, ONNX_CONFIG,
RANDOM_SEED,
get_device, get_device,
) )
from generators.normal_gen import NormalCaptchaGenerator from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator from generators.threed_gen import ThreeDCaptchaGenerator
from generators.threed_rotate_gen import ThreeDRotateGenerator
from generators.threed_slider_gen import ThreeDSliderGenerator
from models.classifier import CaptchaClassifier from models.classifier import CaptchaClassifier
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
@@ -52,7 +59,9 @@ def _prepare_classifier_data():
type_info = [ type_info = [
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator), ("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator), ("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator), ("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator),
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
] ]
for cls_name, syn_dir, gen_cls in type_info: for cls_name, syn_dir, gen_cls in type_info:
@@ -95,6 +104,13 @@ def main():
img_h, img_w = IMAGE_SIZE["classifier"] img_h, img_w = IMAGE_SIZE["classifier"]
device = get_device() device = get_device()
# 设置随机种子
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
print("=" * 60) print("=" * 60)
print("训练调度分类器 (CaptchaClassifier)") print("训练调度分类器 (CaptchaClassifier)")
print(f" 类别: {CAPTCHA_TYPES}") print(f" 类别: {CAPTCHA_TYPES}")
@@ -128,11 +144,11 @@ def main():
train_loader = DataLoader( train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True, train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, pin_memory=True, num_workers=0, pin_memory=True,
) )
val_loader = DataLoader( val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, pin_memory=True, num_workers=0, pin_memory=True,
) )
print(f"[数据] 训练: {train_size} 验证: {val_size}") print(f"[数据] 训练: {train_size} 验证: {val_size}")

View File

@@ -0,0 +1,264 @@
"""
回归训练通用逻辑
提供 train_regression_model() 函数,被 train_3d_rotate / train_3d_slider 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 RegressionDataset / DataLoader含真实数据混合
3. 回归训练循环 + cosine scheduler
4. 输出日志: epoch, loss, MAE, tolerance 准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束导出 ONNX
"""
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
REGRESSION_RANGE,
RANDOM_SEED,
get_device,
)
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _circular_smooth_l1(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
循环距离 SmoothL1 loss用于角度回归 (处理 0°/360° 边界)。
pred 和 target 都在 [0, 1] 范围。
"""
diff = torch.abs(pred - target)
# 循环距离: min(|d|, 1-|d|)
diff = torch.min(diff, 1.0 - diff)
# SmoothL1
return torch.where(
diff < 1.0 / 360.0, # beta ≈ 1° 归一化
0.5 * diff * diff / (1.0 / 360.0),
diff - 0.5 * (1.0 / 360.0),
).mean()
def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
"""循环 MAE (归一化空间)。"""
diff = np.abs(pred - target)
diff = np.minimum(diff, 1.0 - diff)
return float(np.mean(diff))
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
def train_regression_model(
model_name: str,
model: nn.Module,
synthetic_dir: str | Path,
real_dir: str | Path,
generator_cls,
config_key: str,
):
"""
通用回归训练流程。
Args:
model_name: 模型名称 (用于保存文件: threed_rotate / threed_slider)
model: PyTorch 模型实例 (RegressionCNN)
synthetic_dir: 合成数据目录
real_dir: 真实数据目录
generator_cls: 生成器类 (用于自动生成数据)
config_key: TRAIN_CONFIG / REGRESSION_RANGE 中的键名 (3d_rotate / 3d_slider)
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key]
label_range = REGRESSION_RANGE[config_key]
lo, hi = label_range
is_circular = config_key == "3d_rotate"
device = get_device()
# 容差配置
if config_key == "3d_rotate":
tolerance = 5.0 # ±5°
else:
tolerance = 3.0 # ±3px
_set_seed()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_path = Path(real_dir)
if real_path.exists() and list(real_path.glob("*.png")):
data_dirs.append(str(real_path))
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))}")
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = RegressionDataset(
dirs=data_dirs, label_range=label_range, transform=train_transform,
)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集使用无增强 transform
val_ds_clean = RegressionDataset(
dirs=data_dirs, label_range=label_range, transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 优化器 / 调度器 / 损失 ----
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
if is_circular:
loss_fn = _circular_smooth_l1
else:
loss_fn = nn.SmoothL1Loss()
best_mae = float("inf")
best_tol_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets in pbar:
images = images.to(device)
targets = targets.to(device)
preds = model(images)
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for images, targets in val_loader:
images = images.to(device)
preds = model(images)
all_preds.append(preds.cpu().numpy())
all_targets.append(targets.numpy())
all_preds = np.concatenate(all_preds, axis=0).flatten()
all_targets = np.concatenate(all_targets, axis=0).flatten()
# 缩放回原始范围计算 MAE
preds_real = all_preds * (hi - lo) + lo
targets_real = all_targets * (hi - lo) + lo
if is_circular:
# 循环 MAE
diff = np.abs(preds_real - targets_real)
diff = np.minimum(diff, (hi - lo) - diff)
mae = float(np.mean(diff))
within_tol = diff <= tolerance
else:
mae = float(np.mean(np.abs(preds_real - targets_real)))
within_tol = np.abs(preds_real - targets_real) <= tolerance
tol_acc = float(np.mean(within_tol))
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"MAE={mae:.2f} "
f"tol_acc(±{tolerance:.0f})={tol_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 (以容差准确率为准) ----
if tol_acc >= best_tol_acc:
best_tol_acc = tol_acc
best_mae = mae
torch.save({
"model_state_dict": model.state_dict(),
"label_range": label_range,
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
return best_tol_acc

View File

@@ -1,7 +1,7 @@
""" """
CTC 训练通用逻辑 CTC 训练通用逻辑
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。 提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
职责: 职责:
1. 检查合成数据,不存在则自动调用生成器 1. 检查合成数据,不存在则自动调用生成器
2. 构建 Dataset / DataLoader含真实数据混合 2. 构建 Dataset / DataLoader含真实数据混合
@@ -12,8 +12,10 @@ CTC 训练通用逻辑
""" """
import os import os
import random
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
@@ -25,11 +27,21 @@ from config import (
ONNX_CONFIG, ONNX_CONFIG,
TRAIN_CONFIG, TRAIN_CONFIG,
IMAGE_SIZE, IMAGE_SIZE,
RANDOM_SEED,
get_device, get_device,
) )
from training.dataset import CRNNDataset, build_train_transform, build_val_transform from training.dataset import CRNNDataset, build_train_transform, build_val_transform
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ============================================================ # ============================================================
# 准确率计算 # 准确率计算
# ============================================================ # ============================================================
@@ -104,9 +116,12 @@ def train_ctc_model(
config_key: TRAIN_CONFIG 中的键名 config_key: TRAIN_CONFIG 中的键名
""" """
cfg = TRAIN_CONFIG[config_key] cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"] img_h, img_w = IMAGE_SIZE[config_key]
device = get_device() device = get_device()
# 设置随机种子
_set_seed()
# ---- 1. 检查 / 生成合成数据 ---- # ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir) syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png")) existing = list(syn_path.glob("*.png"))
@@ -139,11 +154,11 @@ def train_ctc_model(
train_loader = DataLoader( train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True, train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True, num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
) )
val_loader = DataLoader( val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True, num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
) )
print(f"[数据] 训练: {train_size} 验证: {val_size}") print(f"[数据] 训练: {train_size} 验证: {val_size}")
@@ -166,12 +181,11 @@ def train_ctc_model(
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar: for images, targets, target_lengths, _ in pbar:
images = images.to(device) images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C) logits = model(images) # (T, B, C)
T, B, C = logits.shape T, B, C = logits.shape
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device) # cuDNN CTC requires targets/lengths on CPU
input_lengths = torch.full((B,), T, dtype=torch.int32)
log_probs = logits.log_softmax(2) log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)