Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider
Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
18
AGENTS.md
18
AGENTS.md
@@ -1,25 +1,33 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas, `models/` contains the classifier and expert OCR models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`.
|
||||
Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas (5 types: normal, math, 3d_text, 3d_rotate, 3d_slider), `models/` contains the classifier, CTC expert models, and regression models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
Use `uv` for environment and dependency management.
|
||||
|
||||
- `uv sync` installs the base runtime dependencies from `pyproject.toml`.
|
||||
- `uv sync --extra server` installs HTTP service dependencies.
|
||||
- `uv run captcha generate --type normal --num 1000` generates synthetic training data.
|
||||
- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d -> classifier`.
|
||||
- `uv run captcha generate --type normal --num 1000` generates synthetic training data. Types: `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, `classifier`.
|
||||
- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d_text -> 3d_rotate -> 3d_slider -> classifier`.
|
||||
- `uv run captcha export --all` exports all trained models to ONNX.
|
||||
- `uv run captcha export --model 3d_text` exports a single model; `3d_text` is automatically mapped to `threed_text`.
|
||||
- `uv run captcha predict image.png` runs auto-routing inference; add `--type normal` to skip classification.
|
||||
- `uv run captcha predict-dir ./test_images` runs batch inference on a directory.
|
||||
- `uv run captcha serve --port 8080` starts the optional HTTP API when `server.py` is implemented.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d`, and `classifier`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding unless the pipeline is intentionally redesigned. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`.
|
||||
Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Checkpoint/ONNX file names use `threed_text`, `threed_rotate`, `threed_slider` (underscored, no hyphens). Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding for OCR models. Regression models (3d_rotate, 3d_slider) output sigmoid [0,1] scaled by `REGRESSION_RANGE`. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`.
|
||||
|
||||
## Training & Data Rules
|
||||
- All training scripts must set the global random seed (`random`, `numpy`, `torch`) via `config.RANDOM_SEED` before training begins.
|
||||
- All DataLoaders use `num_workers=0` for cross-platform consistency.
|
||||
- Generator parameters (rotation, noise, shadow, etc.) must come from `config.GENERATE_CONFIG`, not hardcoded values.
|
||||
- `CRNNDataset` emits a `warnings.warn` when a label contains characters outside the configured charset, rather than silently dropping them.
|
||||
- `RegressionDataset` parses numeric labels from filenames and normalizes to [0,1] via `label_range`.
|
||||
|
||||
## Data & Testing Guidelines
|
||||
Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing.
|
||||
Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. For regression types, label is the numeric value (angle or offset). Sample targets are defined in `config.py`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Git history is not available in this workspace snapshot, so use short imperative commit subjects such as `Add classifier export smoke test`. Keep pull requests focused, describe affected modules, list the commands you ran, and attach sample outputs when prediction behavior changes.
|
||||
|
||||
151
CLAUDE.md
151
CLAUDE.md
@@ -24,29 +24,40 @@ captcha-breaker/
|
||||
│ ├── synthetic/ # 合成训练数据 (自动生成,不入 git)
|
||||
│ │ ├── normal/ # 普通字符型
|
||||
│ │ ├── math/ # 算式型
|
||||
│ │ └── 3d/ # 3D立体型
|
||||
│ │ ├── 3d_text/ # 3D立体文字型
|
||||
│ │ ├── 3d_rotate/ # 3D旋转型
|
||||
│ │ └── 3d_slider/ # 3D滑块型
|
||||
│ ├── real/ # 真实验证码样本 (手动标注)
|
||||
│ │ ├── normal/
|
||||
│ │ ├── math/
|
||||
│ │ └── 3d/
|
||||
│ │ ├── 3d_text/
|
||||
│ │ ├── 3d_rotate/
|
||||
│ │ └── 3d_slider/
|
||||
│ └── classifier/ # 调度分类器训练数据 (混合各类型)
|
||||
├── generators/
|
||||
│ ├── __init__.py
|
||||
│ ├── base.py # 生成器基类
|
||||
│ ├── normal_gen.py # 普通字符验证码生成器
|
||||
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
|
||||
│ └── threed_gen.py # 3D立体验证码生成器
|
||||
│ ├── threed_gen.py # 3D立体文字验证码生成器
|
||||
│ ├── threed_rotate_gen.py # 3D旋转验证码生成器
|
||||
│ └── threed_slider_gen.py # 3D滑块验证码生成器
|
||||
├── models/
|
||||
│ ├── __init__.py
|
||||
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
|
||||
│ ├── classifier.py # 调度分类模型
|
||||
│ └── threed_cnn.py # 3D验证码专用模型 (更深的CNN)
|
||||
│ ├── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN)
|
||||
│ └── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
|
||||
├── training/
|
||||
│ ├── __init__.py
|
||||
│ ├── train_classifier.py # 训练调度模型
|
||||
│ ├── train_normal.py # 训练普通字符识别
|
||||
│ ├── train_math.py # 训练算式识别
|
||||
│ ├── train_3d.py # 训练3D识别
|
||||
│ ├── train_3d_text.py # 训练3D文字识别
|
||||
│ ├── train_3d_rotate.py # 训练3D旋转回归
|
||||
│ ├── train_3d_slider.py # 训练3D滑块回归
|
||||
│ ├── train_utils.py # CTC 训练通用逻辑
|
||||
│ ├── train_regression_utils.py # 回归训练通用逻辑
|
||||
│ └── dataset.py # 通用 Dataset 类
|
||||
├── inference/
|
||||
│ ├── __init__.py
|
||||
@@ -57,12 +68,16 @@ captcha-breaker/
|
||||
│ ├── classifier.pth
|
||||
│ ├── normal.pth
|
||||
│ ├── math.pth
|
||||
│ └── threed.pth
|
||||
│ ├── threed_text.pth
|
||||
│ ├── threed_rotate.pth
|
||||
│ └── threed_slider.pth
|
||||
├── onnx_models/ # 导出的 ONNX 模型
|
||||
│ ├── classifier.onnx
|
||||
│ ├── normal.onnx
|
||||
│ ├── math.onnx
|
||||
│ └── threed.onnx
|
||||
│ ├── threed_text.onnx
|
||||
│ ├── threed_rotate.onnx
|
||||
│ └── threed_slider.onnx
|
||||
├── server.py # FastAPI 推理服务 (可选)
|
||||
├── cli.py # 命令行入口
|
||||
└── tests/
|
||||
@@ -78,13 +93,13 @@ captcha-breaker/
|
||||
```
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → 后处理 → 输出结果
|
||||
│
|
||||
┌────────┼────────┐
|
||||
▼ ▼ ▼
|
||||
normal math 3d
|
||||
(CRNN) (CRNN) (CNN)
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
"A3B8" "3+8=?"→11 "X9K2"
|
||||
┌────────┬───┼───────┬──────────┐
|
||||
▼ ▼ ▼ ▼ ▼
|
||||
normal math 3d_text 3d_rotate 3d_slider
|
||||
(CRNN) (CRNN) (CNN) (RegCNN) (RegCNN)
|
||||
│ │ │ │ │
|
||||
▼ ▼ ▼ ▼ ▼
|
||||
"A3B8" "3+8=?"→11 "X9K2" "135" "87"
|
||||
```
|
||||
|
||||
### 调度分类器 (classifier.py)
|
||||
@@ -102,7 +117,7 @@ class CaptchaClassifier(nn.Module):
|
||||
轻量分类器,几层卷积即可区分不同类型验证码。
|
||||
不同类型验证码视觉差异大(有无运算符、3D效果等),分类很容易。
|
||||
"""
|
||||
def __init__(self, num_types=3):
|
||||
def __init__(self, num_types=5):
|
||||
# 4层卷积 + GAP + FC
|
||||
# Conv2d(1,16) -> Conv2d(16,32) -> Conv2d(32,64) -> Conv2d(64,64)
|
||||
# AdaptiveAvgPool2d(1) -> Linear(64, num_types)
|
||||
@@ -142,14 +157,32 @@ def eval_captcha_math(expr: str) -> str:
|
||||
pass
|
||||
```
|
||||
|
||||
### 3D立体识别专家 (threed_cnn.py)
|
||||
### 3D立体文字识别专家 (threed_cnn.py)
|
||||
|
||||
- 任务: 识别带 3D 透视/阴影效果的验证码
|
||||
- 任务: 识别带 3D 透视/阴影效果的文字验证码
|
||||
- 架构: 更深的 CNN + CRNN,或 ResNet-lite backbone
|
||||
- 输入: 灰度图 1x60x160
|
||||
- 需要更强的特征提取能力来处理透视变形和阴影
|
||||
- 模型体积目标: < 5MB
|
||||
|
||||
### 3D旋转识别专家 (regression_cnn.py - 3d_rotate 模式)
|
||||
|
||||
- 任务: 预测旋转验证码的正确旋转角度
|
||||
- 架构: 轻量回归 CNN (4层卷积 + GAP + FC + Sigmoid)
|
||||
- 输入: 灰度图 1x80x80
|
||||
- 输出: [0,1] sigmoid 值,按 (0, 360) 缩放回角度
|
||||
- 标签范围: 0-359°
|
||||
- 模型体积目标: ~1MB
|
||||
|
||||
### 3D滑块识别专家 (regression_cnn.py - 3d_slider 模式)
|
||||
|
||||
- 任务: 预测滑块拼图缺口的 x 偏移
|
||||
- 架构: 同上回归 CNN,不同输入尺寸
|
||||
- 输入: 灰度图 1x80x240
|
||||
- 输出: [0,1] sigmoid 值,按 (10, 200) 缩放回像素偏移
|
||||
- 标签范围: 10-200px
|
||||
- 模型体积目标: ~1MB
|
||||
|
||||
## 数据生成器规范
|
||||
|
||||
### 基类 (base.py)
|
||||
@@ -185,13 +218,26 @@ class BaseCaptchaGenerator:
|
||||
- 标签格式: `3+8` (存储算式本身,不存结果)
|
||||
- 视觉风格: 与目标算式验证码一致
|
||||
|
||||
### 3D生成器 (threed_gen.py)
|
||||
### 3D文字生成器 (threed_gen.py)
|
||||
|
||||
- 使用 Pillow 的仿射变换模拟 3D 透视
|
||||
- 添加阴影效果
|
||||
- 字符有深度感和倾斜
|
||||
- 字符旋转角度由 `config.py` `GENERATE_CONFIG["3d_text"]["rotation_range"]` 统一配置
|
||||
- 标签: 纯字符内容
|
||||
|
||||
### 3D旋转生成器 (threed_rotate_gen.py)
|
||||
|
||||
- 圆盘上绘制字符 + 方向标记,随机旋转 0-359°
|
||||
- 标签 = 旋转角度(整数字符串)
|
||||
- 文件名格式: `{angle}_{index:06d}.png`
|
||||
|
||||
### 3D滑块生成器 (threed_slider_gen.py)
|
||||
|
||||
- 纹理背景 + 拼图缺口 + 拼图块在左侧
|
||||
- 标签 = 缺口 x 坐标偏移(整数字符串)
|
||||
- 文件名格式: `{offset}_{index:06d}.png`
|
||||
|
||||
## 训练规范
|
||||
|
||||
### 通用训练配置
|
||||
@@ -204,7 +250,7 @@ TRAIN_CONFIG = {
|
||||
'batch_size': 128,
|
||||
'lr': 1e-3,
|
||||
'scheduler': 'cosine',
|
||||
'synthetic_samples': 30000, # 每类 10000
|
||||
'synthetic_samples': 50000, # 每类 10000 × 5 类
|
||||
},
|
||||
'normal': {
|
||||
'epochs': 50,
|
||||
@@ -222,7 +268,7 @@ TRAIN_CONFIG = {
|
||||
'synthetic_samples': 60000,
|
||||
'loss': 'CTCLoss',
|
||||
},
|
||||
'threed': {
|
||||
'3d_text': {
|
||||
'epochs': 80,
|
||||
'batch_size': 64,
|
||||
'lr': 5e-4,
|
||||
@@ -230,18 +276,36 @@ TRAIN_CONFIG = {
|
||||
'synthetic_samples': 80000,
|
||||
'loss': 'CTCLoss',
|
||||
},
|
||||
'3d_rotate': {
|
||||
'epochs': 60,
|
||||
'batch_size': 128,
|
||||
'lr': 1e-3,
|
||||
'scheduler': 'cosine',
|
||||
'synthetic_samples': 60000,
|
||||
'loss': 'SmoothL1',
|
||||
},
|
||||
'3d_slider': {
|
||||
'epochs': 60,
|
||||
'batch_size': 128,
|
||||
'lr': 1e-3,
|
||||
'scheduler': 'cosine',
|
||||
'synthetic_samples': 60000,
|
||||
'loss': 'SmoothL1',
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### 训练脚本要求
|
||||
|
||||
每个训练脚本必须:
|
||||
1. 检查合成数据是否已生成,没有则自动调用生成器
|
||||
2. 支持混合真实数据 (如果 data/real/{type}/ 有文件)
|
||||
3. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing
|
||||
4. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率
|
||||
5. 保存最佳模型到 checkpoints/
|
||||
6. 训练结束自动导出 ONNX 到 onnx_models/
|
||||
1. 训练开始前设置全局随机种子 (random/numpy/torch),使用 `config.RANDOM_SEED`,保证可复现
|
||||
2. 检查合成数据是否已生成,没有则自动调用生成器
|
||||
3. 支持混合真实数据 (如果 data/real/{type}/ 有文件)
|
||||
4. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing
|
||||
5. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率 (CTC) 或 MAE, 容差准确率 (回归)
|
||||
6. 保存最佳模型到 checkpoints/
|
||||
7. 训练结束自动导出 ONNX 到 onnx_models/
|
||||
8. DataLoader 统一使用 `num_workers=0` 避免多进程兼容问题
|
||||
|
||||
### 数据增强策略
|
||||
|
||||
@@ -281,14 +345,14 @@ class CaptchaPipeline:
|
||||
pass
|
||||
|
||||
def classify(self, image: Image.Image) -> str:
|
||||
"""调度分类,返回类型名: 'normal' / 'math' / '3d'"""
|
||||
"""调度分类,返回类型名: 'normal' / 'math' / '3d_text' / '3d_rotate' / '3d_slider'"""
|
||||
pass
|
||||
|
||||
def solve(self, image) -> str:
|
||||
"""
|
||||
完整识别流程:
|
||||
1. 分类验证码类型
|
||||
2. 路由到对应专家模型
|
||||
2. 路由到对应专家模型 (CTC 或回归)
|
||||
3. 后处理 (算式型需要计算结果)
|
||||
4. 返回最终答案字符串
|
||||
|
||||
@@ -311,7 +375,7 @@ def export_model(model, model_name, input_shape, onnx_dir='onnx_models/'):
|
||||
pass
|
||||
|
||||
def export_all():
|
||||
"""依次导出 classifier, normal, math, threed 四个模型"""
|
||||
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型"""
|
||||
pass
|
||||
```
|
||||
|
||||
@@ -325,22 +389,28 @@ uv sync --extra server # 含 HTTP 服务依赖
|
||||
# 生成训练数据
|
||||
uv run python cli.py generate --type normal --num 60000
|
||||
uv run python cli.py generate --type math --num 60000
|
||||
uv run python cli.py generate --type 3d --num 80000
|
||||
uv run python cli.py generate --type classifier --num 30000
|
||||
uv run python cli.py generate --type 3d_text --num 80000
|
||||
uv run python cli.py generate --type 3d_rotate --num 60000
|
||||
uv run python cli.py generate --type 3d_slider --num 60000
|
||||
uv run python cli.py generate --type classifier --num 50000
|
||||
|
||||
# 训练模型
|
||||
uv run python cli.py train --model classifier
|
||||
uv run python cli.py train --model normal
|
||||
uv run python cli.py train --model math
|
||||
uv run python cli.py train --model 3d
|
||||
uv run python cli.py train --model 3d_text
|
||||
uv run python cli.py train --model 3d_rotate
|
||||
uv run python cli.py train --model 3d_slider
|
||||
uv run python cli.py train --all # 按依赖顺序全部训练
|
||||
|
||||
# 导出 ONNX
|
||||
uv run python cli.py export --all
|
||||
uv run python cli.py export --model 3d_text # "3d_text" 自动映射为 "threed_text"
|
||||
|
||||
# 推理
|
||||
uv run python cli.py predict image.png # 自动分类+识别
|
||||
uv run python cli.py predict image.png --type normal # 跳过分类直接识别
|
||||
uv run python cli.py predict image.png --type 3d_rotate # 指定为旋转类型
|
||||
uv run python cli.py predict-dir ./test_images/ # 批量识别
|
||||
|
||||
# 启动 HTTP 服务 (需先安装 server 可选依赖)
|
||||
@@ -360,11 +430,12 @@ uv run python cli.py serve --port 8080
|
||||
|
||||
1. **所有模型用 float32 训练,导出 ONNX 时不做量化**,先保证精度
|
||||
2. **CTC 解码统一用贪心解码**,不需要 beam search,验证码场景贪心够用
|
||||
3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符,3d 继续使用去混淆字符集
|
||||
3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符,3d_text 继续使用去混淆字符集
|
||||
4. **算式识别分两步**: 先 OCR 识别字符串,再用规则计算,不要让模型直接输出数值
|
||||
5. **生成器的随机种子**: 生成数据时设置 seed 保证可复现
|
||||
5. **随机种子**: 生成数据和训练时均通过 `config.RANDOM_SEED` 设置全局种子 (random/numpy/torch),保证可复现
|
||||
6. **真实数据文件名格式**: `{label}_{任意}.png`,label 部分是标注内容
|
||||
7. **模型保存格式**: PyTorch checkpoint 包含 model_state_dict, chars, best_acc, epoch
|
||||
11. **数据集字符过滤**: `CRNNDataset` 加载标签时,若发现字符不在字符集内会发出 warning,便于排查标注/字符集不匹配问题
|
||||
7. **模型保存格式**: CTC checkpoint 包含 model_state_dict, chars, best_acc, epoch; 回归 checkpoint 包含 model_state_dict, label_range, best_mae, best_tol_acc, epoch
|
||||
8. **不使用 GPU 特有功能**,确保 CPU 也能训练和推理 (只是慢一些)
|
||||
9. **类型扩展**: 新增验证码类型时,只需 (1) 加生成器 (2) 加专家模型 (3) 调度器加一个类别重新训练
|
||||
10. **文档同步**: 对项目结构、配置、架构等做出变更时,必须同步更新 CLAUDE.md 中的对应内容,保持文档与代码一致
|
||||
@@ -373,18 +444,20 @@ uv run python cli.py serve --port 8080
|
||||
|
||||
| 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|
||||
|------|-----------|---------|---------|
|
||||
| 调度分类器 | > 99% | < 5ms | < 500KB |
|
||||
| 调度分类器 (5类) | > 99% | < 5ms | < 500KB |
|
||||
| 普通字符 | > 95% | < 30ms | < 2MB |
|
||||
| 算式识别 | > 93% | < 30ms | < 2MB |
|
||||
| 3D立体 | > 85% | < 50ms | < 5MB |
|
||||
| 全流水线 | - | < 80ms | < 10MB 总计 |
|
||||
| 3D立体文字 | > 85% | < 50ms | < 5MB |
|
||||
| 3D旋转 (±5°) | > 85% | < 30ms | ~1MB |
|
||||
| 3D滑块 (±3px) | > 90% | < 30ms | ~1MB |
|
||||
| 全流水线 | - | < 80ms | < 12MB 总计 |
|
||||
|
||||
## 开发顺序
|
||||
|
||||
1. 先实现 config.py 和 generators/
|
||||
2. 实现 models/ 中所有模型定义
|
||||
3. 实现 training/dataset.py 通用数据集类
|
||||
4. 按顺序训练: normal → math → 3d → classifier
|
||||
4. 按顺序训练: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier
|
||||
5. 实现 inference/pipeline.py 和 export_onnx.py
|
||||
6. 实现 cli.py 统一入口
|
||||
7. 可选: server.py HTTP 服务
|
||||
|
||||
83
cli.py
83
cli.py
@@ -3,6 +3,9 @@ CaptchaBreaker 命令行入口
|
||||
|
||||
用法:
|
||||
python cli.py generate --type normal --num 60000
|
||||
python cli.py generate --type 3d_text --num 80000
|
||||
python cli.py generate --type 3d_rotate --num 60000
|
||||
python cli.py generate --type 3d_slider --num 60000
|
||||
python cli.py train --model normal
|
||||
python cli.py train --all
|
||||
python cli.py export --all
|
||||
@@ -20,15 +23,21 @@ from pathlib import Path
|
||||
def cmd_generate(args):
|
||||
"""生成训练数据。"""
|
||||
from config import (
|
||||
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
|
||||
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR,
|
||||
SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR,
|
||||
CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
|
||||
)
|
||||
from generators import NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator
|
||||
from generators import (
|
||||
NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator,
|
||||
ThreeDRotateGenerator, ThreeDSliderGenerator,
|
||||
)
|
||||
|
||||
gen_map = {
|
||||
"normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR),
|
||||
"math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR),
|
||||
"3d": (ThreeDCaptchaGenerator, SYNTHETIC_3D_DIR),
|
||||
"3d_text": (ThreeDCaptchaGenerator, SYNTHETIC_3D_TEXT_DIR),
|
||||
"3d_rotate": (ThreeDRotateGenerator, SYNTHETIC_3D_ROTATE_DIR),
|
||||
"3d_slider": (ThreeDSliderGenerator, SYNTHETIC_3D_SLIDER_DIR),
|
||||
}
|
||||
|
||||
captcha_type = args.type
|
||||
@@ -50,25 +59,31 @@ def cmd_generate(args):
|
||||
gen = gen_cls()
|
||||
gen.generate_dataset(num, str(out_dir))
|
||||
else:
|
||||
print(f"未知类型: {captcha_type} 可选: normal, math, 3d, classifier")
|
||||
valid = ", ".join(list(gen_map.keys()) + ["classifier"])
|
||||
print(f"未知类型: {captcha_type} 可选: {valid}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_train(args):
|
||||
"""训练模型。"""
|
||||
if args.all:
|
||||
# 按依赖顺序: normal → math → 3d → classifier
|
||||
print("按顺序训练全部模型: normal → math → 3d → classifier\n")
|
||||
print("按顺序训练全部模型: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier\n")
|
||||
from training.train_normal import main as train_normal
|
||||
from training.train_math import main as train_math
|
||||
from training.train_3d import main as train_3d
|
||||
from training.train_3d_text import main as train_3d_text
|
||||
from training.train_3d_rotate import main as train_3d_rotate
|
||||
from training.train_3d_slider import main as train_3d_slider
|
||||
from training.train_classifier import main as train_classifier
|
||||
|
||||
train_normal()
|
||||
print("\n")
|
||||
train_math()
|
||||
print("\n")
|
||||
train_3d()
|
||||
train_3d_text()
|
||||
print("\n")
|
||||
train_3d_rotate()
|
||||
print("\n")
|
||||
train_3d_slider()
|
||||
print("\n")
|
||||
train_classifier()
|
||||
return
|
||||
@@ -78,12 +93,16 @@ def cmd_train(args):
|
||||
from training.train_normal import main as train_fn
|
||||
elif model == "math":
|
||||
from training.train_math import main as train_fn
|
||||
elif model == "3d":
|
||||
from training.train_3d import main as train_fn
|
||||
elif model == "3d_text":
|
||||
from training.train_3d_text import main as train_fn
|
||||
elif model == "3d_rotate":
|
||||
from training.train_3d_rotate import main as train_fn
|
||||
elif model == "3d_slider":
|
||||
from training.train_3d_slider import main as train_fn
|
||||
elif model == "classifier":
|
||||
from training.train_classifier import main as train_fn
|
||||
else:
|
||||
print(f"未知模型: {model} 可选: normal, math, 3d, classifier")
|
||||
print(f"未知模型: {model} 可选: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier")
|
||||
sys.exit(1)
|
||||
|
||||
train_fn()
|
||||
@@ -96,7 +115,14 @@ def cmd_export(args):
|
||||
if args.all:
|
||||
export_all()
|
||||
elif args.model:
|
||||
_load_and_export(args.model)
|
||||
# 别名映射
|
||||
alias = {
|
||||
"3d_text": "threed_text",
|
||||
"3d_rotate": "threed_rotate",
|
||||
"3d_slider": "threed_slider",
|
||||
}
|
||||
name = alias.get(args.model, args.model)
|
||||
_load_and_export(name)
|
||||
else:
|
||||
print("请指定 --all 或 --model <name>")
|
||||
sys.exit(1)
|
||||
@@ -137,19 +163,19 @@ def cmd_predict_dir(args):
|
||||
sys.exit(1)
|
||||
|
||||
print(f"批量识别: {len(images)} 张图片\n")
|
||||
print(f"{'文件名':<30} {'类型':<8} {'结果':<15} {'耗时(ms)':>8}")
|
||||
print("-" * 65)
|
||||
print(f"{'文件名':<30} {'类型':<10} {'结果':<15} {'耗时(ms)':>8}")
|
||||
print("-" * 67)
|
||||
|
||||
total_ms = 0.0
|
||||
for img_path in images:
|
||||
result = pipeline.solve(str(img_path), captcha_type=args.type)
|
||||
total_ms += result["time_ms"]
|
||||
print(
|
||||
f"{img_path.name:<30} {result['type']:<8} "
|
||||
f"{img_path.name:<30} {result['type']:<10} "
|
||||
f"{result['result']:<15} {result['time_ms']:>8.1f}"
|
||||
)
|
||||
|
||||
print("-" * 65)
|
||||
print("-" * 67)
|
||||
print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms")
|
||||
|
||||
|
||||
@@ -178,28 +204,43 @@ def main():
|
||||
|
||||
# ---- generate ----
|
||||
p_gen = subparsers.add_parser("generate", help="生成训练数据")
|
||||
p_gen.add_argument("--type", required=True, help="验证码类型: normal, math, 3d, classifier")
|
||||
p_gen.add_argument(
|
||||
"--type", required=True,
|
||||
help="验证码类型: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
|
||||
)
|
||||
p_gen.add_argument("--num", type=int, required=True, help="生成数量")
|
||||
|
||||
# ---- train ----
|
||||
p_train = subparsers.add_parser("train", help="训练模型")
|
||||
p_train.add_argument("--model", help="模型名: normal, math, 3d, classifier")
|
||||
p_train.add_argument(
|
||||
"--model",
|
||||
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
|
||||
)
|
||||
p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型")
|
||||
|
||||
# ---- export ----
|
||||
p_export = subparsers.add_parser("export", help="导出 ONNX 模型")
|
||||
p_export.add_argument("--model", help="模型名: normal, math, 3d, classifier, threed")
|
||||
p_export.add_argument(
|
||||
"--model",
|
||||
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
|
||||
)
|
||||
p_export.add_argument("--all", action="store_true", help="导出全部模型")
|
||||
|
||||
# ---- predict ----
|
||||
p_pred = subparsers.add_parser("predict", help="识别单张验证码")
|
||||
p_pred.add_argument("image", help="图片路径")
|
||||
p_pred.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
|
||||
p_pred.add_argument(
|
||||
"--type", default=None,
|
||||
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
|
||||
)
|
||||
|
||||
# ---- predict-dir ----
|
||||
p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码")
|
||||
p_pdir.add_argument("directory", help="图片目录路径")
|
||||
p_pdir.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
|
||||
p_pdir.add_argument(
|
||||
"--type", default=None,
|
||||
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
|
||||
)
|
||||
|
||||
# ---- serve ----
|
||||
p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务")
|
||||
|
||||
66
config.py
66
config.py
@@ -23,12 +23,16 @@ CLASSIFIER_DIR = DATA_DIR / "classifier"
|
||||
# 合成数据子目录
|
||||
SYNTHETIC_NORMAL_DIR = SYNTHETIC_DIR / "normal"
|
||||
SYNTHETIC_MATH_DIR = SYNTHETIC_DIR / "math"
|
||||
SYNTHETIC_3D_DIR = SYNTHETIC_DIR / "3d"
|
||||
SYNTHETIC_3D_TEXT_DIR = SYNTHETIC_DIR / "3d_text"
|
||||
SYNTHETIC_3D_ROTATE_DIR = SYNTHETIC_DIR / "3d_rotate"
|
||||
SYNTHETIC_3D_SLIDER_DIR = SYNTHETIC_DIR / "3d_slider"
|
||||
|
||||
# 真实数据子目录
|
||||
REAL_NORMAL_DIR = REAL_DIR / "normal"
|
||||
REAL_MATH_DIR = REAL_DIR / "math"
|
||||
REAL_3D_DIR = REAL_DIR / "3d"
|
||||
REAL_3D_TEXT_DIR = REAL_DIR / "3d_text"
|
||||
REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate"
|
||||
REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider"
|
||||
|
||||
# ============================================================
|
||||
# 模型输出目录
|
||||
@@ -38,8 +42,10 @@ ONNX_DIR = PROJECT_ROOT / "onnx_models"
|
||||
|
||||
# 确保关键目录存在
|
||||
for _dir in [
|
||||
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
|
||||
REAL_NORMAL_DIR, REAL_MATH_DIR, REAL_3D_DIR,
|
||||
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR,
|
||||
SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR,
|
||||
REAL_NORMAL_DIR, REAL_MATH_DIR,
|
||||
REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR,
|
||||
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
|
||||
]:
|
||||
_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -57,7 +63,7 @@ MATH_CHARS = "0123456789+-×÷=?"
|
||||
THREED_CHARS = "23456789ABCDEFGHJKMNPQRSTUVWXYZ"
|
||||
|
||||
# 验证码类型列表 (调度分类器输出)
|
||||
CAPTCHA_TYPES = ["normal", "math", "3d"]
|
||||
CAPTCHA_TYPES = ["normal", "math", "3d_text", "3d_rotate", "3d_slider"]
|
||||
NUM_CAPTCHA_TYPES = len(CAPTCHA_TYPES)
|
||||
|
||||
# ============================================================
|
||||
@@ -67,7 +73,9 @@ IMAGE_SIZE = {
|
||||
"classifier": (64, 128), # 调度分类器输入
|
||||
"normal": (40, 120), # 普通字符识别
|
||||
"math": (40, 160), # 算式识别 (更宽)
|
||||
"3d": (60, 160), # 3D 立体识别
|
||||
"3d_text": (60, 160), # 3D 立体文字识别
|
||||
"3d_rotate": (80, 80), # 3D 旋转角度回归 (正方形)
|
||||
"3d_slider": (80, 240), # 3D 滑块偏移回归
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
@@ -91,11 +99,25 @@ GENERATE_CONFIG = {
|
||||
"rotation_range": (-10, 10),
|
||||
"noise_line_range": (2, 4),
|
||||
},
|
||||
"3d": {
|
||||
"3d_text": {
|
||||
"char_count_range": (4, 5),
|
||||
"image_size": (160, 60), # 生成图片尺寸 (W, H)
|
||||
"shadow_offset": (3, 3), # 阴影偏移
|
||||
"perspective_intensity": 0.3, # 透视变换强度
|
||||
"rotation_range": (-20, 20), # 字符旋转角度
|
||||
},
|
||||
"3d_rotate": {
|
||||
"image_size": (80, 80), # 生成图片尺寸 (W, H)
|
||||
"disc_radius": 35, # 圆盘半径
|
||||
"marker_size": 8, # 方向标记大小
|
||||
"bg_color_range": (200, 240), # 背景色范围
|
||||
},
|
||||
"3d_slider": {
|
||||
"image_size": (240, 80), # 生成图片尺寸 (W, H)
|
||||
"puzzle_size": (40, 40), # 拼图块大小 (W, H)
|
||||
"gap_x_range": (50, 200), # 缺口 x 坐标范围
|
||||
"piece_left_margin": 5, # 拼图块左侧留白
|
||||
"bg_noise_intensity": 30, # 背景纹理噪声强度
|
||||
},
|
||||
}
|
||||
|
||||
@@ -108,7 +130,7 @@ TRAIN_CONFIG = {
|
||||
"batch_size": 128,
|
||||
"lr": 1e-3,
|
||||
"scheduler": "cosine",
|
||||
"synthetic_samples": 30000, # 每类 10000
|
||||
"synthetic_samples": 50000, # 每类 10000 × 5 类
|
||||
"val_split": 0.1, # 验证集比例
|
||||
},
|
||||
"normal": {
|
||||
@@ -129,7 +151,7 @@ TRAIN_CONFIG = {
|
||||
"loss": "CTCLoss",
|
||||
"val_split": 0.1,
|
||||
},
|
||||
"threed": {
|
||||
"3d_text": {
|
||||
"epochs": 80,
|
||||
"batch_size": 64,
|
||||
"lr": 5e-4,
|
||||
@@ -138,6 +160,24 @@ TRAIN_CONFIG = {
|
||||
"loss": "CTCLoss",
|
||||
"val_split": 0.1,
|
||||
},
|
||||
"3d_rotate": {
|
||||
"epochs": 60,
|
||||
"batch_size": 128,
|
||||
"lr": 1e-3,
|
||||
"scheduler": "cosine",
|
||||
"synthetic_samples": 60000,
|
||||
"loss": "SmoothL1",
|
||||
"val_split": 0.1,
|
||||
},
|
||||
"3d_slider": {
|
||||
"epochs": 60,
|
||||
"batch_size": 128,
|
||||
"lr": 1e-3,
|
||||
"scheduler": "cosine",
|
||||
"synthetic_samples": 60000,
|
||||
"loss": "SmoothL1",
|
||||
"val_split": 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
@@ -163,6 +203,14 @@ ONNX_CONFIG = {
|
||||
"dynamic_batch": True, # 支持动态 batch size
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 回归模型标签范围
|
||||
# ============================================================
|
||||
REGRESSION_RANGE = {
|
||||
"3d_rotate": (0, 360), # 旋转角度 0-359°
|
||||
"3d_slider": (10, 200), # 滑块 x 偏移 (像素)
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 推理配置
|
||||
# ============================================================
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
"""
|
||||
数据生成器包
|
||||
|
||||
提供三种验证码类型的数据生成器:
|
||||
提供五种验证码类型的数据生成器:
|
||||
- NormalCaptchaGenerator: 普通字符验证码
|
||||
- MathCaptchaGenerator: 算式验证码
|
||||
- ThreeDCaptchaGenerator: 3D 立体验证码
|
||||
- ThreeDCaptchaGenerator: 3D 立体文字验证码
|
||||
- ThreeDRotateGenerator: 3D 旋转验证码
|
||||
- ThreeDSliderGenerator: 3D 滑块验证码
|
||||
"""
|
||||
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
from generators.normal_gen import NormalCaptchaGenerator
|
||||
from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
|
||||
__all__ = [
|
||||
"BaseCaptchaGenerator",
|
||||
"NormalCaptchaGenerator",
|
||||
"MathCaptchaGenerator",
|
||||
"ThreeDCaptchaGenerator",
|
||||
"ThreeDRotateGenerator",
|
||||
"ThreeDSliderGenerator",
|
||||
]
|
||||
|
||||
@@ -45,7 +45,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = GENERATE_CONFIG["3d"]
|
||||
self.cfg = GENERATE_CONFIG["3d_text"]
|
||||
self.chars = THREED_CHARS
|
||||
self.width, self.height = self.cfg["image_size"]
|
||||
|
||||
@@ -154,7 +154,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
|
||||
char_img = self._perspective_transform(char_img, rng)
|
||||
|
||||
# 随机旋转
|
||||
angle = rng.randint(-20, 20)
|
||||
angle = rng.randint(*self.cfg["rotation_range"])
|
||||
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
|
||||
|
||||
# 粘贴到画布
|
||||
|
||||
122
generators/threed_rotate_gen.py
Normal file
122
generators/threed_rotate_gen.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
3D 旋转验证码生成器
|
||||
|
||||
生成旋转验证码:圆盘上绘制字符 + 方向标记,随机旋转 0-359°。
|
||||
用户需将圆盘旋转到正确角度。
|
||||
|
||||
标签 = 旋转角度(整数)
|
||||
文件名格式: {angle}_{index:06d}.png
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont
|
||||
|
||||
from config import GENERATE_CONFIG, THREED_CHARS
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
|
||||
_FONT_PATHS = [
|
||||
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
|
||||
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
|
||||
]
|
||||
|
||||
_DISC_COLORS = [
|
||||
(180, 200, 220),
|
||||
(200, 220, 200),
|
||||
(220, 200, 190),
|
||||
(200, 200, 220),
|
||||
(210, 210, 200),
|
||||
]
|
||||
|
||||
|
||||
class ThreeDRotateGenerator(BaseCaptchaGenerator):
|
||||
"""3D 旋转验证码生成器。"""
|
||||
|
||||
def __init__(self, seed: int | None = None):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = GENERATE_CONFIG["3d_rotate"]
|
||||
self.chars = THREED_CHARS
|
||||
self.width, self.height = self.cfg["image_size"]
|
||||
|
||||
self._fonts: list[str] = []
|
||||
for p in _FONT_PATHS:
|
||||
try:
|
||||
ImageFont.truetype(p, 20)
|
||||
self._fonts.append(p)
|
||||
except OSError:
|
||||
continue
|
||||
if not self._fonts:
|
||||
raise RuntimeError("未找到任何可用字体,无法生成验证码")
|
||||
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
rng = self.rng
|
||||
|
||||
# 随机旋转角度 0-359
|
||||
angle = rng.randint(0, 359)
|
||||
|
||||
if text is None:
|
||||
text = str(angle)
|
||||
|
||||
# 1. 背景
|
||||
bg_val = rng.randint(*self.cfg["bg_color_range"])
|
||||
img = Image.new("RGB", (self.width, self.height), (bg_val, bg_val, bg_val))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 2. 绘制圆盘
|
||||
cx, cy = self.width // 2, self.height // 2
|
||||
r = self.cfg["disc_radius"]
|
||||
disc_color = rng.choice(_DISC_COLORS)
|
||||
draw.ellipse(
|
||||
[cx - r, cy - r, cx + r, cy + r],
|
||||
fill=disc_color, outline=(100, 100, 100), width=2,
|
||||
)
|
||||
|
||||
# 3. 在圆盘上绘制字符和方向标记 (未旋转状态)
|
||||
disc_img = Image.new("RGBA", (r * 2 + 4, r * 2 + 4), (0, 0, 0, 0))
|
||||
disc_draw = ImageDraw.Draw(disc_img)
|
||||
dc = r + 2 # disc center
|
||||
|
||||
# 字符 (圆盘中心)
|
||||
font_path = rng.choice(self._fonts)
|
||||
font_size = int(r * 0.6)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
ch = rng.choice(self.chars)
|
||||
bbox = font.getbbox(ch)
|
||||
tw = bbox[2] - bbox[0]
|
||||
th = bbox[3] - bbox[1]
|
||||
disc_draw.text(
|
||||
(dc - tw // 2 - bbox[0], dc - th // 2 - bbox[1]),
|
||||
ch, fill=(50, 50, 50, 255), font=font,
|
||||
)
|
||||
|
||||
# 方向标记 (三角箭头,指向上方)
|
||||
ms = self.cfg["marker_size"]
|
||||
marker_y = dc - r + ms + 2
|
||||
disc_draw.polygon(
|
||||
[(dc, marker_y - ms), (dc - ms // 2, marker_y), (dc + ms // 2, marker_y)],
|
||||
fill=(220, 60, 60, 255),
|
||||
)
|
||||
|
||||
# 4. 旋转圆盘内容
|
||||
disc_img = disc_img.rotate(-angle, resample=Image.BICUBIC, expand=False)
|
||||
|
||||
# 5. 粘贴到背景
|
||||
paste_x = cx - dc
|
||||
paste_y = cy - dc
|
||||
img.paste(disc_img, (paste_x, paste_y), disc_img)
|
||||
|
||||
# 6. 添加少量噪点
|
||||
for _ in range(rng.randint(20, 50)):
|
||||
nx, ny = rng.randint(0, self.width - 1), rng.randint(0, self.height - 1)
|
||||
nc = tuple(rng.randint(100, 200) for _ in range(3))
|
||||
draw.point((nx, ny), fill=nc)
|
||||
|
||||
# 7. 轻微模糊
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=0.5))
|
||||
|
||||
return img, text
|
||||
113
generators/threed_slider_gen.py
Normal file
113
generators/threed_slider_gen.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
3D 滑块验证码生成器
|
||||
|
||||
生成滑块拼图验证码:纹理背景 + 拼图缺口 + 拼图块在左侧。
|
||||
用户需将拼图块滑动到缺口位置。
|
||||
|
||||
标签 = 缺口 x 坐标偏移(整数)
|
||||
文件名格式: {offset}_{index:06d}.png
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
|
||||
from config import GENERATE_CONFIG
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
|
||||
|
||||
class ThreeDSliderGenerator(BaseCaptchaGenerator):
|
||||
"""3D 滑块验证码生成器。"""
|
||||
|
||||
def __init__(self, seed: int | None = None):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = GENERATE_CONFIG["3d_slider"]
|
||||
self.width, self.height = self.cfg["image_size"]
|
||||
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
rng = self.rng
|
||||
pw, ph = self.cfg["puzzle_size"]
|
||||
gap_x_lo, gap_x_hi = self.cfg["gap_x_range"]
|
||||
|
||||
# 缺口位置
|
||||
gap_x = rng.randint(gap_x_lo, gap_x_hi)
|
||||
gap_y = rng.randint(10, self.height - ph - 10)
|
||||
|
||||
if text is None:
|
||||
text = str(gap_x)
|
||||
|
||||
# 1. 生成纹理背景
|
||||
img = self._textured_background(rng)
|
||||
|
||||
# 2. 从缺口位置截取拼图块内容
|
||||
piece_content = img.crop((gap_x, gap_y, gap_x + pw, gap_y + ph)).copy()
|
||||
|
||||
# 3. 绘制缺口 (半透明灰色区域)
|
||||
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
||||
overlay_draw = ImageDraw.Draw(overlay)
|
||||
overlay_draw.rectangle(
|
||||
[gap_x, gap_y, gap_x + pw, gap_y + ph],
|
||||
fill=(80, 80, 80, 160),
|
||||
outline=(60, 60, 60, 200),
|
||||
width=2,
|
||||
)
|
||||
img = img.convert("RGBA")
|
||||
img = Image.alpha_composite(img, overlay)
|
||||
img = img.convert("RGB")
|
||||
|
||||
# 4. 绘制拼图块在左侧
|
||||
piece_x = self.cfg["piece_left_margin"]
|
||||
piece_img = Image.new("RGBA", (pw + 4, ph + 4), (0, 0, 0, 0))
|
||||
piece_draw = ImageDraw.Draw(piece_img)
|
||||
# 阴影
|
||||
piece_draw.rectangle([2, 2, pw + 3, ph + 3], fill=(0, 0, 0, 80))
|
||||
# 内容
|
||||
piece_img.paste(piece_content, (0, 0))
|
||||
# 边框
|
||||
piece_draw.rectangle([0, 0, pw - 1, ph - 1], outline=(255, 255, 255, 200), width=2)
|
||||
|
||||
img_rgba = img.convert("RGBA")
|
||||
img_rgba.paste(piece_img, (piece_x, gap_y), piece_img)
|
||||
img = img_rgba.convert("RGB")
|
||||
|
||||
# 5. 轻微模糊
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=0.3))
|
||||
|
||||
return img, text
|
||||
|
||||
def _textured_background(self, rng: random.Random) -> Image.Image:
|
||||
"""生成带纹理的彩色背景。"""
|
||||
img = Image.new("RGB", (self.width, self.height))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 渐变底色
|
||||
base_r, base_g, base_b = rng.randint(100, 180), rng.randint(100, 180), rng.randint(100, 180)
|
||||
for y in range(self.height):
|
||||
ratio = y / max(self.height - 1, 1)
|
||||
r = int(base_r + 30 * ratio)
|
||||
g = int(base_g - 20 * ratio)
|
||||
b = int(base_b + 10 * ratio)
|
||||
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
|
||||
|
||||
# 添加纹理噪声
|
||||
noise_intensity = self.cfg["bg_noise_intensity"]
|
||||
for _ in range(self.width * self.height // 8):
|
||||
x = rng.randint(0, self.width - 1)
|
||||
y = rng.randint(0, self.height - 1)
|
||||
pixel = img.getpixel((x, y))
|
||||
noise = tuple(
|
||||
max(0, min(255, c + rng.randint(-noise_intensity, noise_intensity)))
|
||||
for c in pixel
|
||||
)
|
||||
draw.point((x, y), fill=noise)
|
||||
|
||||
# 随机色块 (模拟图案)
|
||||
for _ in range(rng.randint(3, 6)):
|
||||
x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20)
|
||||
x2, y2 = x1 + rng.randint(15, 40), y1 + rng.randint(10, 25)
|
||||
color = tuple(rng.randint(60, 220) for _ in range(3))
|
||||
draw.rectangle([x1, y1, x2, y2], fill=color)
|
||||
|
||||
return img
|
||||
@@ -17,10 +17,12 @@ from config import (
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
NUM_CAPTCHA_TYPES,
|
||||
REGRESSION_RANGE,
|
||||
)
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
|
||||
|
||||
def export_model(
|
||||
@@ -34,7 +36,7 @@ def export_model(
|
||||
|
||||
Args:
|
||||
model: 已加载权重的 PyTorch 模型
|
||||
model_name: 模型名 (classifier / normal / math / threed)
|
||||
model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider)
|
||||
input_shape: 输入形状 (C, H, W)
|
||||
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
|
||||
"""
|
||||
@@ -50,7 +52,7 @@ def export_model(
|
||||
dummy = torch.randn(1, *input_shape)
|
||||
|
||||
# 分类器和识别器的 dynamic_axes 不同
|
||||
if model_name == "classifier":
|
||||
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
else:
|
||||
# CTC 模型: output shape = (T, B, C)
|
||||
@@ -78,7 +80,8 @@ def _load_and_export(model_name: str):
|
||||
return
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}")
|
||||
acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
|
||||
|
||||
if model_name == "classifier":
|
||||
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
|
||||
@@ -94,11 +97,19 @@ def _load_and_export(model_name: str):
|
||||
h, w = IMAGE_SIZE["math"]
|
||||
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed":
|
||||
elif model_name == "threed_text":
|
||||
chars = ckpt.get("chars", THREED_CHARS)
|
||||
h, w = IMAGE_SIZE["3d"]
|
||||
h, w = IMAGE_SIZE["3d_text"]
|
||||
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed_rotate":
|
||||
h, w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed_slider":
|
||||
h, w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
else:
|
||||
print(f"[错误] 未知模型: {model_name}")
|
||||
return
|
||||
@@ -108,11 +119,11 @@ def _load_and_export(model_name: str):
|
||||
|
||||
|
||||
def export_all():
|
||||
"""依次导出 classifier, normal, math, threed 四个模型。"""
|
||||
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
|
||||
print("=" * 50)
|
||||
print("导出全部 ONNX 模型")
|
||||
print("=" * 50)
|
||||
for name in ["classifier", "normal", "math", "threed"]:
|
||||
for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
|
||||
_load_and_export(name)
|
||||
print("\n全部导出完成。")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
|
||||
|
||||
推理流程:
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 / 回归缩放 → 后处理 → 输出
|
||||
|
||||
对算式类型,解码后还会调用 math_eval 计算结果。
|
||||
"""
|
||||
@@ -23,6 +23,7 @@ from config import (
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
REGRESSION_RANGE,
|
||||
)
|
||||
from inference.math_eval import eval_captcha_math
|
||||
|
||||
@@ -59,19 +60,24 @@ class CaptchaPipeline:
|
||||
self.mean = INFERENCE_CONFIG["normalize_mean"]
|
||||
self.std = INFERENCE_CONFIG["normalize_std"]
|
||||
|
||||
# 字符集映射
|
||||
# 字符集映射 (仅 CTC 模型需要)
|
||||
self._chars = {
|
||||
"normal": NORMAL_CHARS,
|
||||
"math": MATH_CHARS,
|
||||
"3d": THREED_CHARS,
|
||||
"3d_text": THREED_CHARS,
|
||||
}
|
||||
|
||||
# 回归模型类型
|
||||
self._regression_types = {"3d_rotate", "3d_slider"}
|
||||
|
||||
# 专家模型名 → ONNX 文件名
|
||||
self._model_files = {
|
||||
"classifier": "classifier.onnx",
|
||||
"normal": "normal.onnx",
|
||||
"math": "math.onnx",
|
||||
"3d": "threed.onnx",
|
||||
"3d_text": "threed_text.onnx",
|
||||
"3d_rotate": "threed_rotate.onnx",
|
||||
"3d_slider": "threed_slider.onnx",
|
||||
}
|
||||
|
||||
# 加载所有可用模型
|
||||
@@ -116,7 +122,7 @@ class CaptchaPipeline:
|
||||
|
||||
def classify(self, image: Image.Image) -> str:
|
||||
"""
|
||||
调度分类,返回类型名: 'normal' / 'math' / '3d'。
|
||||
调度分类,返回类型名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 分类器模型未加载
|
||||
@@ -141,7 +147,7 @@ class CaptchaPipeline:
|
||||
|
||||
Args:
|
||||
image: PIL.Image 或文件路径 (str/Path) 或 bytes
|
||||
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
|
||||
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d_text'/'3d_rotate'/'3d_slider')
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -166,22 +172,32 @@ class CaptchaPipeline:
|
||||
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
|
||||
)
|
||||
|
||||
size_key = captcha_type # "normal"/"math"/"3d"
|
||||
size_key = captcha_type
|
||||
inp = self.preprocess(img, IMAGE_SIZE[size_key])
|
||||
session = self._sessions[captcha_type]
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
|
||||
# 4. CTC 贪心解码
|
||||
# 4. 分支: CTC 解码 vs 回归
|
||||
if captcha_type in self._regression_types:
|
||||
# 回归模型: 输出 (batch, 1) sigmoid 值
|
||||
output = session.run(None, {input_name: inp})[0] # (1, 1)
|
||||
sigmoid_val = float(output[0, 0])
|
||||
lo, hi = REGRESSION_RANGE[captcha_type]
|
||||
real_val = sigmoid_val * (hi - lo) + lo
|
||||
raw_text = f"{real_val:.1f}"
|
||||
result = str(int(round(real_val)))
|
||||
else:
|
||||
# CTC 模型
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
chars = self._chars[captcha_type]
|
||||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||||
|
||||
# 5. 后处理
|
||||
# 后处理
|
||||
if captcha_type == "math":
|
||||
try:
|
||||
result = eval_captcha_math(raw_text)
|
||||
except ValueError:
|
||||
result = raw_text # 解析失败则返回原始文本
|
||||
result = raw_text
|
||||
else:
|
||||
result = raw_text
|
||||
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
"""
|
||||
模型定义包
|
||||
|
||||
提供三种模型:
|
||||
提供四种模型:
|
||||
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
|
||||
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
|
||||
- ThreeDCNN: 3D 验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
|
||||
- ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
|
||||
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
|
||||
"""
|
||||
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
|
||||
__all__ = [
|
||||
"CaptchaClassifier",
|
||||
"LiteCRNN",
|
||||
"ThreeDCNN",
|
||||
"RegressionCNN",
|
||||
]
|
||||
|
||||
86
models/regression_cnn.py
Normal file
86
models/regression_cnn.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
回归 CNN 模型
|
||||
|
||||
3d_rotate 和 3d_slider 共用的回归模型。
|
||||
输出 sigmoid 归一化到 [0,1],推理时按 label_range 缩放回原始范围。
|
||||
|
||||
架构:
|
||||
Conv(1→32) + BN + ReLU + Pool
|
||||
Conv(32→64) + BN + ReLU + Pool
|
||||
Conv(64→128) + BN + ReLU + Pool
|
||||
Conv(128→128) + BN + ReLU + Pool
|
||||
AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid
|
||||
|
||||
约 250K 参数,~1MB。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RegressionCNN(nn.Module):
|
||||
"""
|
||||
轻量回归 CNN,用于 3d_rotate (角度) 和 3d_slider (偏移) 预测。
|
||||
|
||||
输出 [0, 1] 范围的 sigmoid 值,需要按 label_range 缩放到实际范围。
|
||||
"""
|
||||
|
||||
def __init__(self, img_h: int = 80, img_w: int = 80):
|
||||
"""
|
||||
Args:
|
||||
img_h: 输入图片高度
|
||||
img_w: 输入图片宽度
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
|
||||
self.features = nn.Sequential(
|
||||
# block 1: 1 → 32, H/2, W/2
|
||||
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 2: 32 → 64, H/4, W/4
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 3: 64 → 128, H/8, W/8
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 4: 128 → 128, H/16, W/16
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
)
|
||||
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.regressor = nn.Sequential(
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(64, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, H, W) 灰度图
|
||||
|
||||
Returns:
|
||||
output: (batch, 1) sigmoid 输出 [0, 1]
|
||||
"""
|
||||
feat = self.features(x)
|
||||
feat = self.pool(feat) # (B, 128, 1, 1)
|
||||
feat = feat.flatten(1) # (B, 128)
|
||||
out = self.regressor(feat) # (B, 1)
|
||||
return out
|
||||
@@ -1,10 +1,13 @@
|
||||
"""
|
||||
训练脚本包
|
||||
|
||||
- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类
|
||||
- dataset.py: CRNNDataset / CaptchaDataset / RegressionDataset 通用数据集类
|
||||
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
|
||||
- train_regression_utils.py: 回归训练通用逻辑 (train_regression_model)
|
||||
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
|
||||
- train_math.py: 训练算式识别 (LiteCRNN - math)
|
||||
- train_3d.py: 训练 3D 立体识别 (ThreeDCNN)
|
||||
- train_3d_text.py: 训练 3D 立体文字识别 (ThreeDCNN)
|
||||
- train_3d_rotate.py: 训练 3D 旋转回归 (RegressionCNN)
|
||||
- train_3d_slider.py: 训练 3D 滑块回归 (RegressionCNN)
|
||||
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
|
||||
"""
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
通用 Dataset 类
|
||||
|
||||
提供两种数据集:
|
||||
提供三种数据集:
|
||||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||||
- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1])
|
||||
|
||||
文件名格式约定: {label}_{任意}.png
|
||||
- 分类器: label 可为任意字符,所在子目录名即为类别
|
||||
- 识别器: label 即标注内容 (如 "A3B8" 或 "3+8")
|
||||
- 回归器: label 为数值 (如 "135" 或 "87")
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
@@ -98,7 +101,15 @@ class CRNNDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
|
||||
# 编码标签为整数序列
|
||||
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
|
||||
target = []
|
||||
for c in label:
|
||||
if c in self.char_to_idx:
|
||||
target.append(self.char_to_idx[c])
|
||||
else:
|
||||
warnings.warn(
|
||||
f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})",
|
||||
stacklevel=2,
|
||||
)
|
||||
return img, target, label
|
||||
|
||||
@staticmethod
|
||||
@@ -119,7 +130,7 @@ class CaptchaDataset(Dataset):
|
||||
"""
|
||||
分类器训练数据集。
|
||||
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d"),
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d_text"),
|
||||
目录内所有 .png 文件属于该类。
|
||||
"""
|
||||
|
||||
@@ -157,3 +168,59 @@ class CaptchaDataset(Dataset):
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 回归模型用数据集
|
||||
# ============================================================
|
||||
class RegressionDataset(Dataset):
|
||||
"""
|
||||
回归模型数据集 (3d_rotate / 3d_slider)。
|
||||
|
||||
从目录中读取 {value}_{xxx}.png 文件,
|
||||
将 value 解析为浮点数并归一化到 [0, 1]。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
label_range: tuple[float, float],
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dirs: 数据目录列表
|
||||
label_range: (min_val, max_val) 标签原始范围
|
||||
transform: 图片预处理/增强
|
||||
"""
|
||||
self.label_range = label_range
|
||||
self.lo, self.hi = label_range
|
||||
self.transform = transform
|
||||
|
||||
self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签)
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
for f in sorted(d.glob("*.png")):
|
||||
raw_label = f.stem.rsplit("_", 1)[0]
|
||||
try:
|
||||
value = float(raw_label)
|
||||
except ValueError:
|
||||
continue
|
||||
# 归一化到 [0, 1]
|
||||
norm = (value - self.lo) / max(self.hi - self.lo, 1e-6)
|
||||
norm = max(0.0, min(1.0, norm))
|
||||
self.samples.append((str(f), norm))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
import torch
|
||||
path, label = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([label], dtype=torch.float32)
|
||||
|
||||
|
||||
38
training/train_3d_rotate.py
Normal file
38
training/train_3d_rotate.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
训练 3D 旋转验证码回归模型 (RegressionCNN)
|
||||
|
||||
用法: python -m training.train_3d_rotate
|
||||
"""
|
||||
|
||||
from config import (
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_3D_ROTATE_DIR,
|
||||
REAL_3D_ROTATE_DIR,
|
||||
)
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from training.train_regression_utils import train_regression_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练 3D 旋转验证码回归模型 (RegressionCNN)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print(f" 任务: 预测旋转角度 0-359°")
|
||||
print("=" * 60)
|
||||
|
||||
train_regression_model(
|
||||
model_name="threed_rotate",
|
||||
model=model,
|
||||
synthetic_dir=SYNTHETIC_3D_ROTATE_DIR,
|
||||
real_dir=REAL_3D_ROTATE_DIR,
|
||||
generator_cls=ThreeDRotateGenerator,
|
||||
config_key="3d_rotate",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
training/train_3d_slider.py
Normal file
38
training/train_3d_slider.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
训练 3D 滑块验证码回归模型 (RegressionCNN)
|
||||
|
||||
用法: python -m training.train_3d_slider
|
||||
"""
|
||||
|
||||
from config import (
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_3D_SLIDER_DIR,
|
||||
REAL_3D_SLIDER_DIR,
|
||||
)
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from training.train_regression_utils import train_regression_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练 3D 滑块验证码回归模型 (RegressionCNN)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print(f" 任务: 预测滑块偏移 x 坐标")
|
||||
print("=" * 60)
|
||||
|
||||
train_regression_model(
|
||||
model_name="threed_slider",
|
||||
model=model,
|
||||
synthetic_dir=SYNTHETIC_3D_SLIDER_DIR,
|
||||
real_dir=REAL_3D_SLIDER_DIR,
|
||||
generator_cls=ThreeDSliderGenerator,
|
||||
config_key="3d_slider",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,14 +1,14 @@
|
||||
"""
|
||||
训练 3D 立体验证码识别模型 (ThreeDCNN)
|
||||
训练 3D 立体文字验证码识别模型 (ThreeDCNN)
|
||||
|
||||
用法: python -m training.train_3d
|
||||
用法: python -m training.train_3d_text
|
||||
"""
|
||||
|
||||
from config import (
|
||||
THREED_CHARS,
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_3D_DIR,
|
||||
REAL_3D_DIR,
|
||||
SYNTHETIC_3D_TEXT_DIR,
|
||||
REAL_3D_TEXT_DIR,
|
||||
)
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
@@ -16,23 +16,23 @@ from training.train_utils import train_ctc_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["3d"]
|
||||
img_h, img_w = IMAGE_SIZE["3d_text"]
|
||||
model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练 3D 立体验证码识别模型 (ThreeDCNN)")
|
||||
print("训练 3D 立体文字验证码识别模型 (ThreeDCNN)")
|
||||
print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_ctc_model(
|
||||
model_name="threed",
|
||||
model_name="threed_text",
|
||||
model=model,
|
||||
chars=THREED_CHARS,
|
||||
synthetic_dir=SYNTHETIC_3D_DIR,
|
||||
real_dir=REAL_3D_DIR,
|
||||
synthetic_dir=SYNTHETIC_3D_TEXT_DIR,
|
||||
real_dir=REAL_3D_TEXT_DIR,
|
||||
generator_cls=ThreeDCaptchaGenerator,
|
||||
config_key="threed",
|
||||
config_key="3d_text",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"""
|
||||
训练调度分类器 (CaptchaClassifier)
|
||||
|
||||
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。
|
||||
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider。
|
||||
数据来源: data/classifier/ 目录 (按类型子目录组织)
|
||||
|
||||
用法: python -m training.train_classifier
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@@ -24,15 +26,20 @@ from config import (
|
||||
CLASSIFIER_DIR,
|
||||
SYNTHETIC_NORMAL_DIR,
|
||||
SYNTHETIC_MATH_DIR,
|
||||
SYNTHETIC_3D_DIR,
|
||||
SYNTHETIC_3D_TEXT_DIR,
|
||||
SYNTHETIC_3D_ROTATE_DIR,
|
||||
SYNTHETIC_3D_SLIDER_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from generators.normal_gen import NormalCaptchaGenerator
|
||||
from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
from models.classifier import CaptchaClassifier
|
||||
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
|
||||
|
||||
@@ -52,7 +59,9 @@ def _prepare_classifier_data():
|
||||
type_info = [
|
||||
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
|
||||
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
|
||||
("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator),
|
||||
("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator),
|
||||
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
|
||||
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
|
||||
]
|
||||
|
||||
for cls_name, syn_dir, gen_cls in type_info:
|
||||
@@ -95,6 +104,13 @@ def main():
|
||||
img_h, img_w = IMAGE_SIZE["classifier"]
|
||||
device = get_device()
|
||||
|
||||
# 设置随机种子
|
||||
random.seed(RANDOM_SEED)
|
||||
np.random.seed(RANDOM_SEED)
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(RANDOM_SEED)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练调度分类器 (CaptchaClassifier)")
|
||||
print(f" 类别: {CAPTCHA_TYPES}")
|
||||
@@ -128,11 +144,11 @@ def main():
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=2, pin_memory=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=2, pin_memory=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
264
training/train_regression_utils.py
Normal file
264
training/train_regression_utils.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
回归训练通用逻辑
|
||||
|
||||
提供 train_regression_model() 函数,被 train_3d_rotate / train_3d_slider 共用。
|
||||
职责:
|
||||
1. 检查合成数据,不存在则自动调用生成器
|
||||
2. 构建 RegressionDataset / DataLoader(含真实数据混合)
|
||||
3. 回归训练循环 + cosine scheduler
|
||||
4. 输出日志: epoch, loss, MAE, tolerance 准确率
|
||||
5. 保存最佳模型到 checkpoints/
|
||||
6. 训练结束导出 ONNX
|
||||
"""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
REGRESSION_RANGE,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
"""设置全局随机种子,保证训练可复现。"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def _circular_smooth_l1(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
循环距离 SmoothL1 loss,用于角度回归 (处理 0°/360° 边界)。
|
||||
pred 和 target 都在 [0, 1] 范围。
|
||||
"""
|
||||
diff = torch.abs(pred - target)
|
||||
# 循环距离: min(|d|, 1-|d|)
|
||||
diff = torch.min(diff, 1.0 - diff)
|
||||
# SmoothL1
|
||||
return torch.where(
|
||||
diff < 1.0 / 360.0, # beta ≈ 1° 归一化
|
||||
0.5 * diff * diff / (1.0 / 360.0),
|
||||
diff - 0.5 * (1.0 / 360.0),
|
||||
).mean()
|
||||
|
||||
|
||||
def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
|
||||
"""循环 MAE (归一化空间)。"""
|
||||
diff = np.abs(pred - target)
|
||||
diff = np.minimum(diff, 1.0 - diff)
|
||||
return float(np.mean(diff))
|
||||
|
||||
|
||||
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
"""导出模型为 ONNX 格式。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||||
dummy = torch.randn(1, 1, img_h, img_w)
|
||||
torch.onnx.export(
|
||||
model.cpu(),
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
def train_regression_model(
|
||||
model_name: str,
|
||||
model: nn.Module,
|
||||
synthetic_dir: str | Path,
|
||||
real_dir: str | Path,
|
||||
generator_cls,
|
||||
config_key: str,
|
||||
):
|
||||
"""
|
||||
通用回归训练流程。
|
||||
|
||||
Args:
|
||||
model_name: 模型名称 (用于保存文件: threed_rotate / threed_slider)
|
||||
model: PyTorch 模型实例 (RegressionCNN)
|
||||
synthetic_dir: 合成数据目录
|
||||
real_dir: 真实数据目录
|
||||
generator_cls: 生成器类 (用于自动生成数据)
|
||||
config_key: TRAIN_CONFIG / REGRESSION_RANGE 中的键名 (3d_rotate / 3d_slider)
|
||||
"""
|
||||
cfg = TRAIN_CONFIG[config_key]
|
||||
img_h, img_w = IMAGE_SIZE[config_key]
|
||||
label_range = REGRESSION_RANGE[config_key]
|
||||
lo, hi = label_range
|
||||
is_circular = config_key == "3d_rotate"
|
||||
device = get_device()
|
||||
|
||||
# 容差配置
|
||||
if config_key == "3d_rotate":
|
||||
tolerance = 5.0 # ±5°
|
||||
else:
|
||||
tolerance = 3.0 # ±3px
|
||||
|
||||
_set_seed()
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
real_path = Path(real_dir)
|
||||
if real_path.exists() and list(real_path.glob("*.png")):
|
||||
data_dirs.append(str(real_path))
|
||||
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))} 张")
|
||||
|
||||
train_transform = build_train_transform(img_h, img_w)
|
||||
val_transform = build_val_transform(img_h, img_w)
|
||||
|
||||
full_dataset = RegressionDataset(
|
||||
dirs=data_dirs, label_range=label_range, transform=train_transform,
|
||||
)
|
||||
total = len(full_dataset)
|
||||
val_size = int(total * cfg["val_split"])
|
||||
train_size = total - val_size
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
|
||||
# 验证集使用无增强 transform
|
||||
val_ds_clean = RegressionDataset(
|
||||
dirs=data_dirs, label_range=label_range, transform=val_transform,
|
||||
)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=0, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
# ---- 3. 优化器 / 调度器 / 损失 ----
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
|
||||
if is_circular:
|
||||
loss_fn = _circular_smooth_l1
|
||||
else:
|
||||
loss_fn = nn.SmoothL1Loss()
|
||||
|
||||
best_mae = float("inf")
|
||||
best_tol_acc = 0.0
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
preds = model(images)
|
||||
loss = loss_fn(preds, targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
# ---- 5. 验证 ----
|
||||
model.eval()
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
with torch.no_grad():
|
||||
for images, targets in val_loader:
|
||||
images = images.to(device)
|
||||
preds = model(images)
|
||||
all_preds.append(preds.cpu().numpy())
|
||||
all_targets.append(targets.numpy())
|
||||
|
||||
all_preds = np.concatenate(all_preds, axis=0).flatten()
|
||||
all_targets = np.concatenate(all_targets, axis=0).flatten()
|
||||
|
||||
# 缩放回原始范围计算 MAE
|
||||
preds_real = all_preds * (hi - lo) + lo
|
||||
targets_real = all_targets * (hi - lo) + lo
|
||||
|
||||
if is_circular:
|
||||
# 循环 MAE
|
||||
diff = np.abs(preds_real - targets_real)
|
||||
diff = np.minimum(diff, (hi - lo) - diff)
|
||||
mae = float(np.mean(diff))
|
||||
within_tol = diff <= tolerance
|
||||
else:
|
||||
mae = float(np.mean(np.abs(preds_real - targets_real)))
|
||||
within_tol = np.abs(preds_real - targets_real) <= tolerance
|
||||
|
||||
tol_acc = float(np.mean(within_tol))
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"MAE={mae:.2f} "
|
||||
f"tol_acc(±{tolerance:.0f})={tol_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
# ---- 6. 保存最佳模型 (以容差准确率为准) ----
|
||||
if tol_acc >= best_tol_acc:
|
||||
best_tol_acc = tol_acc
|
||||
best_mae = mae
|
||||
torch.save({
|
||||
"model_state_dict": model.state_dict(),
|
||||
"label_range": label_range,
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
|
||||
|
||||
# ---- 7. 导出 ONNX ----
|
||||
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, model_name, img_h, img_w)
|
||||
|
||||
return best_tol_acc
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
CTC 训练通用逻辑
|
||||
|
||||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
|
||||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
|
||||
职责:
|
||||
1. 检查合成数据,不存在则自动调用生成器
|
||||
2. 构建 Dataset / DataLoader(含真实数据混合)
|
||||
@@ -12,8 +12,10 @@ CTC 训练通用逻辑
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
@@ -25,11 +27,21 @@ from config import (
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
"""设置全局随机种子,保证训练可复现。"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 准确率计算
|
||||
# ============================================================
|
||||
@@ -104,9 +116,12 @@ def train_ctc_model(
|
||||
config_key: TRAIN_CONFIG 中的键名
|
||||
"""
|
||||
cfg = TRAIN_CONFIG[config_key]
|
||||
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
|
||||
img_h, img_w = IMAGE_SIZE[config_key]
|
||||
device = get_device()
|
||||
|
||||
# 设置随机种子
|
||||
_set_seed()
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
@@ -139,11 +154,11 @@ def train_ctc_model(
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
@@ -166,12 +181,11 @@ def train_ctc_model(
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets, target_lengths, _ in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
logits = model(images) # (T, B, C)
|
||||
T, B, C = logits.shape
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
|
||||
# cuDNN CTC requires targets/lengths on CPU
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32)
|
||||
|
||||
log_probs = logits.log_softmax(2)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
|
||||
Reference in New Issue
Block a user