Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider

Split the single "3d" captcha type into three independent expert models:
- 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN)
- 3d_rotate: rotation angle regression (new RegressionCNN, circular loss)
- 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss)

CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated
to 50000 (10000 per class). New generators, model, dataset, training
utilities, and full pipeline/export/CLI support for all subtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hua
2026-03-11 13:55:53 +08:00
parent 760b80ee5e
commit f5be7671bc
20 changed files with 1109 additions and 142 deletions

View File

@@ -1,25 +1,33 @@
# Repository Guidelines
## Project Structure & Module Organization
Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas, `models/` contains the classifier and expert OCR models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`.
Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas (5 types: normal, math, 3d_text, 3d_rotate, 3d_slider), `models/` contains the classifier, CTC expert models, and regression models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`.
## Build, Test, and Development Commands
Use `uv` for environment and dependency management.
- `uv sync` installs the base runtime dependencies from `pyproject.toml`.
- `uv sync --extra server` installs HTTP service dependencies.
- `uv run captcha generate --type normal --num 1000` generates synthetic training data.
- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d -> classifier`.
- `uv run captcha generate --type normal --num 1000` generates synthetic training data. Types: `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, `classifier`.
- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d_text -> 3d_rotate -> 3d_slider -> classifier`.
- `uv run captcha export --all` exports all trained models to ONNX.
- `uv run captcha export --model 3d_text` exports a single model; `3d_text` is automatically mapped to `threed_text`.
- `uv run captcha predict image.png` runs auto-routing inference; add `--type normal` to skip classification.
- `uv run captcha predict-dir ./test_images` runs batch inference on a directory.
- `uv run captcha serve --port 8080` starts the optional HTTP API when `server.py` is implemented.
## Coding Style & Naming Conventions
Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d`, and `classifier`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding unless the pipeline is intentionally redesigned. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`.
Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Checkpoint/ONNX file names use `threed_text`, `threed_rotate`, `threed_slider` (underscored, no hyphens). Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding for OCR models. Regression models (3d_rotate, 3d_slider) output sigmoid [0,1] scaled by `REGRESSION_RANGE`. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`.
## Training & Data Rules
- All training scripts must set the global random seed (`random`, `numpy`, `torch`) via `config.RANDOM_SEED` before training begins.
- All DataLoaders use `num_workers=0` for cross-platform consistency.
- Generator parameters (rotation, noise, shadow, etc.) must come from `config.GENERATE_CONFIG`, not hardcoded values.
- `CRNNDataset` emits a `warnings.warn` when a label contains characters outside the configured charset, rather than silently dropping them.
- `RegressionDataset` parses numeric labels from filenames and normalizes to [0,1] via `label_range`.
## Data & Testing Guidelines
Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing.
Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. For regression types, label is the numeric value (angle or offset). Sample targets are defined in `config.py`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing.
## Commit & Pull Request Guidelines
Git history is not available in this workspace snapshot, so use short imperative commit subjects such as `Add classifier export smoke test`. Keep pull requests focused, describe affected modules, list the commands you ran, and attach sample outputs when prediction behavior changes.

151
CLAUDE.md
View File

@@ -24,29 +24,40 @@ captcha-breaker/
│ ├── synthetic/ # 合成训练数据 (自动生成,不入 git)
│ │ ├── normal/ # 普通字符型
│ │ ├── math/ # 算式型
│ │ ── 3d/ # 3D立体型
│ │ ── 3d_text/ # 3D立体文字
│ │ ├── 3d_rotate/ # 3D旋转型
│ │ └── 3d_slider/ # 3D滑块型
│ ├── real/ # 真实验证码样本 (手动标注)
│ │ ├── normal/
│ │ ├── math/
│ │ ── 3d/
│ │ ── 3d_text/
│ │ ├── 3d_rotate/
│ │ └── 3d_slider/
│ └── classifier/ # 调度分类器训练数据 (混合各类型)
├── generators/
│ ├── __init__.py
│ ├── base.py # 生成器基类
│ ├── normal_gen.py # 普通字符验证码生成器
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
── threed_gen.py # 3D立体验证码生成器
── threed_gen.py # 3D立体文字验证码生成器
│ ├── threed_rotate_gen.py # 3D旋转验证码生成器
│ └── threed_slider_gen.py # 3D滑块验证码生成器
├── models/
│ ├── __init__.py
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
│ ├── classifier.py # 调度分类模型
── threed_cnn.py # 3D验证码专用模型 (更深的CNN)
── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN)
│ └── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB)
├── training/
│ ├── __init__.py
│ ├── train_classifier.py # 训练调度模型
│ ├── train_normal.py # 训练普通字符识别
│ ├── train_math.py # 训练算式识别
│ ├── train_3d.py # 训练3D识别
│ ├── train_3d_text.py # 训练3D文字识别
│ ├── train_3d_rotate.py # 训练3D旋转回归
│ ├── train_3d_slider.py # 训练3D滑块回归
│ ├── train_utils.py # CTC 训练通用逻辑
│ ├── train_regression_utils.py # 回归训练通用逻辑
│ └── dataset.py # 通用 Dataset 类
├── inference/
│ ├── __init__.py
@@ -57,12 +68,16 @@ captcha-breaker/
│ ├── classifier.pth
│ ├── normal.pth
│ ├── math.pth
── threed.pth
── threed_text.pth
│ ├── threed_rotate.pth
│ └── threed_slider.pth
├── onnx_models/ # 导出的 ONNX 模型
│ ├── classifier.onnx
│ ├── normal.onnx
│ ├── math.onnx
── threed.onnx
── threed_text.onnx
│ ├── threed_rotate.onnx
│ └── threed_slider.onnx
├── server.py # FastAPI 推理服务 (可选)
├── cli.py # 命令行入口
└── tests/
@@ -78,13 +93,13 @@ captcha-breaker/
```
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → 后处理 → 输出结果
┌────────────────┐
▼ ▼ ▼
normal math 3d
(CRNN) (CRNN) (CNN)
│ │ │
▼ ▼ ▼
"A3B8" "3+8=?"→11 "X9K2"
┌────────┬───┼───────┬──────────┐
normal math 3d_text 3d_rotate 3d_slider
(CRNN) (CRNN) (CNN) (RegCNN) (RegCNN)
"A3B8" "3+8=?"→11 "X9K2" "135" "87"
```
### 调度分类器 (classifier.py)
@@ -102,7 +117,7 @@ class CaptchaClassifier(nn.Module):
轻量分类器,几层卷积即可区分不同类型验证码。
不同类型验证码视觉差异大有无运算符、3D效果等分类很容易。
"""
def __init__(self, num_types=3):
def __init__(self, num_types=5):
# 4层卷积 + GAP + FC
# Conv2d(1,16) -> Conv2d(16,32) -> Conv2d(32,64) -> Conv2d(64,64)
# AdaptiveAvgPool2d(1) -> Linear(64, num_types)
@@ -142,14 +157,32 @@ def eval_captcha_math(expr: str) -> str:
pass
```
### 3D立体识别专家 (threed_cnn.py)
### 3D立体文字识别专家 (threed_cnn.py)
- 任务: 识别带 3D 透视/阴影效果的验证码
- 任务: 识别带 3D 透视/阴影效果的文字验证码
- 架构: 更深的 CNN + CRNN或 ResNet-lite backbone
- 输入: 灰度图 1x60x160
- 需要更强的特征提取能力来处理透视变形和阴影
- 模型体积目标: < 5MB
### 3D旋转识别专家 (regression_cnn.py - 3d_rotate 模式)
- 任务: 预测旋转验证码的正确旋转角度
- 架构: 轻量回归 CNN (4层卷积 + GAP + FC + Sigmoid)
- 输入: 灰度图 1x80x80
- 输出: [0,1] sigmoid 值,按 (0, 360) 缩放回角度
- 标签范围: 0-359°
- 模型体积目标: ~1MB
### 3D滑块识别专家 (regression_cnn.py - 3d_slider 模式)
- 任务: 预测滑块拼图缺口的 x 偏移
- 架构: 同上回归 CNN不同输入尺寸
- 输入: 灰度图 1x80x240
- 输出: [0,1] sigmoid 值,按 (10, 200) 缩放回像素偏移
- 标签范围: 10-200px
- 模型体积目标: ~1MB
## 数据生成器规范
### 基类 (base.py)
@@ -185,13 +218,26 @@ class BaseCaptchaGenerator:
- 标签格式: `3+8` (存储算式本身,不存结果)
- 视觉风格: 与目标算式验证码一致
### 3D生成器 (threed_gen.py)
### 3D文字生成器 (threed_gen.py)
- 使用 Pillow 的仿射变换模拟 3D 透视
- 添加阴影效果
- 字符有深度感和倾斜
- 字符旋转角度由 `config.py` `GENERATE_CONFIG["3d_text"]["rotation_range"]` 统一配置
- 标签: 纯字符内容
### 3D旋转生成器 (threed_rotate_gen.py)
- 圆盘上绘制字符 + 方向标记,随机旋转 0-359°
- 标签 = 旋转角度(整数字符串)
- 文件名格式: `{angle}_{index:06d}.png`
### 3D滑块生成器 (threed_slider_gen.py)
- 纹理背景 + 拼图缺口 + 拼图块在左侧
- 标签 = 缺口 x 坐标偏移(整数字符串)
- 文件名格式: `{offset}_{index:06d}.png`
## 训练规范
### 通用训练配置
@@ -204,7 +250,7 @@ TRAIN_CONFIG = {
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 30000, # 每类 10000
'synthetic_samples': 50000, # 每类 10000 × 5 类
},
'normal': {
'epochs': 50,
@@ -222,7 +268,7 @@ TRAIN_CONFIG = {
'synthetic_samples': 60000,
'loss': 'CTCLoss',
},
'threed': {
'3d_text': {
'epochs': 80,
'batch_size': 64,
'lr': 5e-4,
@@ -230,18 +276,36 @@ TRAIN_CONFIG = {
'synthetic_samples': 80000,
'loss': 'CTCLoss',
},
'3d_rotate': {
'epochs': 60,
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 60000,
'loss': 'SmoothL1',
},
'3d_slider': {
'epochs': 60,
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 60000,
'loss': 'SmoothL1',
},
}
```
### 训练脚本要求
每个训练脚本必须:
1. 检查合成数据是否已生成,没有则自动调用生成器
2. 支持混合真实数据 (如果 data/real/{type}/ 有文件)
3. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing
4. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束自动导出 ONNX 到 onnx_models/
1. 训练开始前设置全局随机种子 (random/numpy/torch),使用 `config.RANDOM_SEED`,保证可复现
2. 检查合成数据是否已生成,没有则自动调用生成器
3. 支持混合真实数据 (如果 data/real/{type}/ 有文件)
4. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing
5. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率 (CTC) 或 MAE, 容差准确率 (回归)
6. 保存最佳模型到 checkpoints/
7. 训练结束自动导出 ONNX 到 onnx_models/
8. DataLoader 统一使用 `num_workers=0` 避免多进程兼容问题
### 数据增强策略
@@ -281,14 +345,14 @@ class CaptchaPipeline:
pass
def classify(self, image: Image.Image) -> str:
"""调度分类,返回类型名: 'normal' / 'math' / '3d'"""
"""调度分类,返回类型名: 'normal' / 'math' / '3d_text' / '3d_rotate' / '3d_slider'"""
pass
def solve(self, image) -> str:
"""
完整识别流程:
1. 分类验证码类型
2. 路由到对应专家模型
2. 路由到对应专家模型 (CTC 或回归)
3. 后处理 (算式型需要计算结果)
4. 返回最终答案字符串
@@ -311,7 +375,7 @@ def export_model(model, model_name, input_shape, onnx_dir='onnx_models/'):
pass
def export_all():
"""依次导出 classifier, normal, math, threed个模型"""
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型"""
pass
```
@@ -325,22 +389,28 @@ uv sync --extra server # 含 HTTP 服务依赖
# 生成训练数据
uv run python cli.py generate --type normal --num 60000
uv run python cli.py generate --type math --num 60000
uv run python cli.py generate --type 3d --num 80000
uv run python cli.py generate --type classifier --num 30000
uv run python cli.py generate --type 3d_text --num 80000
uv run python cli.py generate --type 3d_rotate --num 60000
uv run python cli.py generate --type 3d_slider --num 60000
uv run python cli.py generate --type classifier --num 50000
# 训练模型
uv run python cli.py train --model classifier
uv run python cli.py train --model normal
uv run python cli.py train --model math
uv run python cli.py train --model 3d
uv run python cli.py train --model 3d_text
uv run python cli.py train --model 3d_rotate
uv run python cli.py train --model 3d_slider
uv run python cli.py train --all # 按依赖顺序全部训练
# 导出 ONNX
uv run python cli.py export --all
uv run python cli.py export --model 3d_text # "3d_text" 自动映射为 "threed_text"
# 推理
uv run python cli.py predict image.png # 自动分类+识别
uv run python cli.py predict image.png --type normal # 跳过分类直接识别
uv run python cli.py predict image.png --type 3d_rotate # 指定为旋转类型
uv run python cli.py predict-dir ./test_images/ # 批量识别
# 启动 HTTP 服务 (需先安装 server 可选依赖)
@@ -360,11 +430,12 @@ uv run python cli.py serve --port 8080
1. **所有模型用 float32 训练,导出 ONNX 时不做量化**,先保证精度
2. **CTC 解码统一用贪心解码**,不需要 beam search验证码场景贪心够用
3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符3d 继续使用去混淆字符集
3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符3d_text 继续使用去混淆字符集
4. **算式识别分两步**: 先 OCR 识别字符串,再用规则计算,不要让模型直接输出数值
5. **生成器的随机种子**: 生成数据时设置 seed 保证可复现
5. **随机种子**: 生成数据和训练时均通过 `config.RANDOM_SEED` 设置全局种子 (random/numpy/torch)保证可复现
6. **真实数据文件名格式**: `{label}_{任意}.png`label 部分是标注内容
7. **模型保存格式**: PyTorch checkpoint 包含 model_state_dict, chars, best_acc, epoch
11. **数据集字符过滤**: `CRNNDataset` 加载标签时,若发现字符不在字符集内会发出 warning便于排查标注/字符集不匹配问题
7. **模型保存格式**: CTC checkpoint 包含 model_state_dict, chars, best_acc, epoch; 回归 checkpoint 包含 model_state_dict, label_range, best_mae, best_tol_acc, epoch
8. **不使用 GPU 特有功能**,确保 CPU 也能训练和推理 (只是慢一些)
9. **类型扩展**: 新增验证码类型时,只需 (1) 加生成器 (2) 加专家模型 (3) 调度器加一个类别重新训练
10. **文档同步**: 对项目结构、配置、架构等做出变更时,必须同步更新 CLAUDE.md 中的对应内容,保持文档与代码一致
@@ -373,18 +444,20 @@ uv run python cli.py serve --port 8080
| 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|------|-----------|---------|---------|
| 调度分类器 | > 99% | < 5ms | < 500KB |
| 调度分类器 (5类) | > 99% | < 5ms | < 500KB |
| 普通字符 | > 95% | < 30ms | < 2MB |
| 算式识别 | > 93% | < 30ms | < 2MB |
| 3D立体 | > 85% | < 50ms | < 5MB |
| 全流水线 | - | < 80ms | < 10MB 总计 |
| 3D立体文字 | > 85% | < 50ms | < 5MB |
| 3D旋转 (±5°) | > 85% | < 30ms | ~1MB |
| 3D滑块 (±3px) | > 90% | < 30ms | ~1MB |
| 全流水线 | - | < 80ms | < 12MB 总计 |
## 开发顺序
1. 先实现 config.py 和 generators/
2. 实现 models/ 中所有模型定义
3. 实现 training/dataset.py 通用数据集类
4. 按顺序训练: normal → math → 3d → classifier
4. 按顺序训练: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier
5. 实现 inference/pipeline.py 和 export_onnx.py
6. 实现 cli.py 统一入口
7. 可选: server.py HTTP 服务

83
cli.py
View File

@@ -3,6 +3,9 @@ CaptchaBreaker 命令行入口
用法:
python cli.py generate --type normal --num 60000
python cli.py generate --type 3d_text --num 80000
python cli.py generate --type 3d_rotate --num 60000
python cli.py generate --type 3d_slider --num 60000
python cli.py train --model normal
python cli.py train --all
python cli.py export --all
@@ -20,15 +23,21 @@ from pathlib import Path
def cmd_generate(args):
"""生成训练数据。"""
from config import (
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR,
CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
)
from generators import NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator
from generators import (
NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator,
ThreeDRotateGenerator, ThreeDSliderGenerator,
)
gen_map = {
"normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR),
"math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR),
"3d": (ThreeDCaptchaGenerator, SYNTHETIC_3D_DIR),
"3d_text": (ThreeDCaptchaGenerator, SYNTHETIC_3D_TEXT_DIR),
"3d_rotate": (ThreeDRotateGenerator, SYNTHETIC_3D_ROTATE_DIR),
"3d_slider": (ThreeDSliderGenerator, SYNTHETIC_3D_SLIDER_DIR),
}
captcha_type = args.type
@@ -50,25 +59,31 @@ def cmd_generate(args):
gen = gen_cls()
gen.generate_dataset(num, str(out_dir))
else:
print(f"未知类型: {captcha_type} 可选: normal, math, 3d, classifier")
valid = ", ".join(list(gen_map.keys()) + ["classifier"])
print(f"未知类型: {captcha_type} 可选: {valid}")
sys.exit(1)
def cmd_train(args):
"""训练模型。"""
if args.all:
# 按依赖顺序: normal → math → 3d → classifier
print("按顺序训练全部模型: normal → math → 3d → classifier\n")
print("按顺序训练全部模型: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier\n")
from training.train_normal import main as train_normal
from training.train_math import main as train_math
from training.train_3d import main as train_3d
from training.train_3d_text import main as train_3d_text
from training.train_3d_rotate import main as train_3d_rotate
from training.train_3d_slider import main as train_3d_slider
from training.train_classifier import main as train_classifier
train_normal()
print("\n")
train_math()
print("\n")
train_3d()
train_3d_text()
print("\n")
train_3d_rotate()
print("\n")
train_3d_slider()
print("\n")
train_classifier()
return
@@ -78,12 +93,16 @@ def cmd_train(args):
from training.train_normal import main as train_fn
elif model == "math":
from training.train_math import main as train_fn
elif model == "3d":
from training.train_3d import main as train_fn
elif model == "3d_text":
from training.train_3d_text import main as train_fn
elif model == "3d_rotate":
from training.train_3d_rotate import main as train_fn
elif model == "3d_slider":
from training.train_3d_slider import main as train_fn
elif model == "classifier":
from training.train_classifier import main as train_fn
else:
print(f"未知模型: {model} 可选: normal, math, 3d, classifier")
print(f"未知模型: {model} 可选: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier")
sys.exit(1)
train_fn()
@@ -96,7 +115,14 @@ def cmd_export(args):
if args.all:
export_all()
elif args.model:
_load_and_export(args.model)
# 别名映射
alias = {
"3d_text": "threed_text",
"3d_rotate": "threed_rotate",
"3d_slider": "threed_slider",
}
name = alias.get(args.model, args.model)
_load_and_export(name)
else:
print("请指定 --all 或 --model <name>")
sys.exit(1)
@@ -137,19 +163,19 @@ def cmd_predict_dir(args):
sys.exit(1)
print(f"批量识别: {len(images)} 张图片\n")
print(f"{'文件名':<30} {'类型':<8} {'结果':<15} {'耗时(ms)':>8}")
print("-" * 65)
print(f"{'文件名':<30} {'类型':<10} {'结果':<15} {'耗时(ms)':>8}")
print("-" * 67)
total_ms = 0.0
for img_path in images:
result = pipeline.solve(str(img_path), captcha_type=args.type)
total_ms += result["time_ms"]
print(
f"{img_path.name:<30} {result['type']:<8} "
f"{img_path.name:<30} {result['type']:<10} "
f"{result['result']:<15} {result['time_ms']:>8.1f}"
)
print("-" * 65)
print("-" * 67)
print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms")
@@ -178,28 +204,43 @@ def main():
# ---- generate ----
p_gen = subparsers.add_parser("generate", help="生成训练数据")
p_gen.add_argument("--type", required=True, help="验证码类型: normal, math, 3d, classifier")
p_gen.add_argument(
"--type", required=True,
help="验证码类型: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
)
p_gen.add_argument("--num", type=int, required=True, help="生成数量")
# ---- train ----
p_train = subparsers.add_parser("train", help="训练模型")
p_train.add_argument("--model", help="模型名: normal, math, 3d, classifier")
p_train.add_argument(
"--model",
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
)
p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型")
# ---- export ----
p_export = subparsers.add_parser("export", help="导出 ONNX 模型")
p_export.add_argument("--model", help="模型名: normal, math, 3d, classifier, threed")
p_export.add_argument(
"--model",
help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier",
)
p_export.add_argument("--all", action="store_true", help="导出全部模型")
# ---- predict ----
p_pred = subparsers.add_parser("predict", help="识别单张验证码")
p_pred.add_argument("image", help="图片路径")
p_pred.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
p_pred.add_argument(
"--type", default=None,
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
)
# ---- predict-dir ----
p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码")
p_pdir.add_argument("directory", help="图片目录路径")
p_pdir.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
p_pdir.add_argument(
"--type", default=None,
help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider",
)
# ---- serve ----
p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务")

View File

@@ -23,12 +23,16 @@ CLASSIFIER_DIR = DATA_DIR / "classifier"
# 合成数据子目录
SYNTHETIC_NORMAL_DIR = SYNTHETIC_DIR / "normal"
SYNTHETIC_MATH_DIR = SYNTHETIC_DIR / "math"
SYNTHETIC_3D_DIR = SYNTHETIC_DIR / "3d"
SYNTHETIC_3D_TEXT_DIR = SYNTHETIC_DIR / "3d_text"
SYNTHETIC_3D_ROTATE_DIR = SYNTHETIC_DIR / "3d_rotate"
SYNTHETIC_3D_SLIDER_DIR = SYNTHETIC_DIR / "3d_slider"
# 真实数据子目录
REAL_NORMAL_DIR = REAL_DIR / "normal"
REAL_MATH_DIR = REAL_DIR / "math"
REAL_3D_DIR = REAL_DIR / "3d"
REAL_3D_TEXT_DIR = REAL_DIR / "3d_text"
REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate"
REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider"
# ============================================================
# 模型输出目录
@@ -38,8 +42,10 @@ ONNX_DIR = PROJECT_ROOT / "onnx_models"
# 确保关键目录存在
for _dir in [
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
REAL_NORMAL_DIR, REAL_MATH_DIR, REAL_3D_DIR,
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR,
REAL_NORMAL_DIR, REAL_MATH_DIR,
REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR,
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
]:
_dir.mkdir(parents=True, exist_ok=True)
@@ -57,7 +63,7 @@ MATH_CHARS = "0123456789+-×÷=?"
THREED_CHARS = "23456789ABCDEFGHJKMNPQRSTUVWXYZ"
# 验证码类型列表 (调度分类器输出)
CAPTCHA_TYPES = ["normal", "math", "3d"]
CAPTCHA_TYPES = ["normal", "math", "3d_text", "3d_rotate", "3d_slider"]
NUM_CAPTCHA_TYPES = len(CAPTCHA_TYPES)
# ============================================================
@@ -67,7 +73,9 @@ IMAGE_SIZE = {
"classifier": (64, 128), # 调度分类器输入
"normal": (40, 120), # 普通字符识别
"math": (40, 160), # 算式识别 (更宽)
"3d": (60, 160), # 3D 立体识别
"3d_text": (60, 160), # 3D 立体文字识别
"3d_rotate": (80, 80), # 3D 旋转角度回归 (正方形)
"3d_slider": (80, 240), # 3D 滑块偏移回归
}
# ============================================================
@@ -91,11 +99,25 @@ GENERATE_CONFIG = {
"rotation_range": (-10, 10),
"noise_line_range": (2, 4),
},
"3d": {
"3d_text": {
"char_count_range": (4, 5),
"image_size": (160, 60), # 生成图片尺寸 (W, H)
"shadow_offset": (3, 3), # 阴影偏移
"perspective_intensity": 0.3, # 透视变换强度
"rotation_range": (-20, 20), # 字符旋转角度
},
"3d_rotate": {
"image_size": (80, 80), # 生成图片尺寸 (W, H)
"disc_radius": 35, # 圆盘半径
"marker_size": 8, # 方向标记大小
"bg_color_range": (200, 240), # 背景色范围
},
"3d_slider": {
"image_size": (240, 80), # 生成图片尺寸 (W, H)
"puzzle_size": (40, 40), # 拼图块大小 (W, H)
"gap_x_range": (50, 200), # 缺口 x 坐标范围
"piece_left_margin": 5, # 拼图块左侧留白
"bg_noise_intensity": 30, # 背景纹理噪声强度
},
}
@@ -108,7 +130,7 @@ TRAIN_CONFIG = {
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 30000, # 每类 10000
"synthetic_samples": 50000, # 每类 10000 × 5 类
"val_split": 0.1, # 验证集比例
},
"normal": {
@@ -129,7 +151,7 @@ TRAIN_CONFIG = {
"loss": "CTCLoss",
"val_split": 0.1,
},
"threed": {
"3d_text": {
"epochs": 80,
"batch_size": 64,
"lr": 5e-4,
@@ -138,6 +160,24 @@ TRAIN_CONFIG = {
"loss": "CTCLoss",
"val_split": 0.1,
},
"3d_rotate": {
"epochs": 60,
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 60000,
"loss": "SmoothL1",
"val_split": 0.1,
},
"3d_slider": {
"epochs": 60,
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 60000,
"loss": "SmoothL1",
"val_split": 0.1,
},
}
# ============================================================
@@ -163,6 +203,14 @@ ONNX_CONFIG = {
"dynamic_batch": True, # 支持动态 batch size
}
# ============================================================
# 回归模型标签范围
# ============================================================
REGRESSION_RANGE = {
"3d_rotate": (0, 360), # 旋转角度 0-359°
"3d_slider": (10, 200), # 滑块 x 偏移 (像素)
}
# ============================================================
# 推理配置
# ============================================================

View File

@@ -1,20 +1,26 @@
"""
数据生成器包
提供种验证码类型的数据生成器:
提供种验证码类型的数据生成器:
- NormalCaptchaGenerator: 普通字符验证码
- MathCaptchaGenerator: 算式验证码
- ThreeDCaptchaGenerator: 3D 立体验证码
- ThreeDCaptchaGenerator: 3D 立体文字验证码
- ThreeDRotateGenerator: 3D 旋转验证码
- ThreeDSliderGenerator: 3D 滑块验证码
"""
from generators.base import BaseCaptchaGenerator
from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
from generators.threed_rotate_gen import ThreeDRotateGenerator
from generators.threed_slider_gen import ThreeDSliderGenerator
__all__ = [
"BaseCaptchaGenerator",
"NormalCaptchaGenerator",
"MathCaptchaGenerator",
"ThreeDCaptchaGenerator",
"ThreeDRotateGenerator",
"ThreeDSliderGenerator",
]

View File

@@ -45,7 +45,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["3d"]
self.cfg = GENERATE_CONFIG["3d_text"]
self.chars = THREED_CHARS
self.width, self.height = self.cfg["image_size"]
@@ -154,7 +154,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
char_img = self._perspective_transform(char_img, rng)
# 随机旋转
angle = rng.randint(-20, 20)
angle = rng.randint(*self.cfg["rotation_range"])
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
# 粘贴到画布

View File

@@ -0,0 +1,122 @@
"""
3D 旋转验证码生成器
生成旋转验证码:圆盘上绘制字符 + 方向标记,随机旋转 0-359°。
用户需将圆盘旋转到正确角度。
标签 = 旋转角度(整数)
文件名格式: {angle}_{index:06d}.png
"""
import random
from PIL import Image, ImageDraw, ImageFilter, ImageFont
from config import GENERATE_CONFIG, THREED_CHARS
from generators.base import BaseCaptchaGenerator
_FONT_PATHS = [
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
]
_DISC_COLORS = [
(180, 200, 220),
(200, 220, 200),
(220, 200, 190),
(200, 200, 220),
(210, 210, 200),
]
class ThreeDRotateGenerator(BaseCaptchaGenerator):
"""3D 旋转验证码生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["3d_rotate"]
self.chars = THREED_CHARS
self.width, self.height = self.cfg["image_size"]
self._fonts: list[str] = []
for p in _FONT_PATHS:
try:
ImageFont.truetype(p, 20)
self._fonts.append(p)
except OSError:
continue
if not self._fonts:
raise RuntimeError("未找到任何可用字体,无法生成验证码")
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
# 随机旋转角度 0-359
angle = rng.randint(0, 359)
if text is None:
text = str(angle)
# 1. 背景
bg_val = rng.randint(*self.cfg["bg_color_range"])
img = Image.new("RGB", (self.width, self.height), (bg_val, bg_val, bg_val))
draw = ImageDraw.Draw(img)
# 2. 绘制圆盘
cx, cy = self.width // 2, self.height // 2
r = self.cfg["disc_radius"]
disc_color = rng.choice(_DISC_COLORS)
draw.ellipse(
[cx - r, cy - r, cx + r, cy + r],
fill=disc_color, outline=(100, 100, 100), width=2,
)
# 3. 在圆盘上绘制字符和方向标记 (未旋转状态)
disc_img = Image.new("RGBA", (r * 2 + 4, r * 2 + 4), (0, 0, 0, 0))
disc_draw = ImageDraw.Draw(disc_img)
dc = r + 2 # disc center
# 字符 (圆盘中心)
font_path = rng.choice(self._fonts)
font_size = int(r * 0.6)
font = ImageFont.truetype(font_path, font_size)
ch = rng.choice(self.chars)
bbox = font.getbbox(ch)
tw = bbox[2] - bbox[0]
th = bbox[3] - bbox[1]
disc_draw.text(
(dc - tw // 2 - bbox[0], dc - th // 2 - bbox[1]),
ch, fill=(50, 50, 50, 255), font=font,
)
# 方向标记 (三角箭头,指向上方)
ms = self.cfg["marker_size"]
marker_y = dc - r + ms + 2
disc_draw.polygon(
[(dc, marker_y - ms), (dc - ms // 2, marker_y), (dc + ms // 2, marker_y)],
fill=(220, 60, 60, 255),
)
# 4. 旋转圆盘内容
disc_img = disc_img.rotate(-angle, resample=Image.BICUBIC, expand=False)
# 5. 粘贴到背景
paste_x = cx - dc
paste_y = cy - dc
img.paste(disc_img, (paste_x, paste_y), disc_img)
# 6. 添加少量噪点
for _ in range(rng.randint(20, 50)):
nx, ny = rng.randint(0, self.width - 1), rng.randint(0, self.height - 1)
nc = tuple(rng.randint(100, 200) for _ in range(3))
draw.point((nx, ny), fill=nc)
# 7. 轻微模糊
img = img.filter(ImageFilter.GaussianBlur(radius=0.5))
return img, text

View File

@@ -0,0 +1,113 @@
"""
3D 滑块验证码生成器
生成滑块拼图验证码:纹理背景 + 拼图缺口 + 拼图块在左侧。
用户需将拼图块滑动到缺口位置。
标签 = 缺口 x 坐标偏移(整数)
文件名格式: {offset}_{index:06d}.png
"""
import random
from PIL import Image, ImageDraw, ImageFilter
from config import GENERATE_CONFIG
from generators.base import BaseCaptchaGenerator
class ThreeDSliderGenerator(BaseCaptchaGenerator):
"""3D 滑块验证码生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["3d_slider"]
self.width, self.height = self.cfg["image_size"]
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
pw, ph = self.cfg["puzzle_size"]
gap_x_lo, gap_x_hi = self.cfg["gap_x_range"]
# 缺口位置
gap_x = rng.randint(gap_x_lo, gap_x_hi)
gap_y = rng.randint(10, self.height - ph - 10)
if text is None:
text = str(gap_x)
# 1. 生成纹理背景
img = self._textured_background(rng)
# 2. 从缺口位置截取拼图块内容
piece_content = img.crop((gap_x, gap_y, gap_x + pw, gap_y + ph)).copy()
# 3. 绘制缺口 (半透明灰色区域)
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle(
[gap_x, gap_y, gap_x + pw, gap_y + ph],
fill=(80, 80, 80, 160),
outline=(60, 60, 60, 200),
width=2,
)
img = img.convert("RGBA")
img = Image.alpha_composite(img, overlay)
img = img.convert("RGB")
# 4. 绘制拼图块在左侧
piece_x = self.cfg["piece_left_margin"]
piece_img = Image.new("RGBA", (pw + 4, ph + 4), (0, 0, 0, 0))
piece_draw = ImageDraw.Draw(piece_img)
# 阴影
piece_draw.rectangle([2, 2, pw + 3, ph + 3], fill=(0, 0, 0, 80))
# 内容
piece_img.paste(piece_content, (0, 0))
# 边框
piece_draw.rectangle([0, 0, pw - 1, ph - 1], outline=(255, 255, 255, 200), width=2)
img_rgba = img.convert("RGBA")
img_rgba.paste(piece_img, (piece_x, gap_y), piece_img)
img = img_rgba.convert("RGB")
# 5. 轻微模糊
img = img.filter(ImageFilter.GaussianBlur(radius=0.3))
return img, text
def _textured_background(self, rng: random.Random) -> Image.Image:
"""生成带纹理的彩色背景。"""
img = Image.new("RGB", (self.width, self.height))
draw = ImageDraw.Draw(img)
# 渐变底色
base_r, base_g, base_b = rng.randint(100, 180), rng.randint(100, 180), rng.randint(100, 180)
for y in range(self.height):
ratio = y / max(self.height - 1, 1)
r = int(base_r + 30 * ratio)
g = int(base_g - 20 * ratio)
b = int(base_b + 10 * ratio)
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
# 添加纹理噪声
noise_intensity = self.cfg["bg_noise_intensity"]
for _ in range(self.width * self.height // 8):
x = rng.randint(0, self.width - 1)
y = rng.randint(0, self.height - 1)
pixel = img.getpixel((x, y))
noise = tuple(
max(0, min(255, c + rng.randint(-noise_intensity, noise_intensity)))
for c in pixel
)
draw.point((x, y), fill=noise)
# 随机色块 (模拟图案)
for _ in range(rng.randint(3, 6)):
x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20)
x2, y2 = x1 + rng.randint(15, 40), y1 + rng.randint(10, 25)
color = tuple(rng.randint(60, 220) for _ in range(3))
draw.rectangle([x1, y1, x2, y2], fill=color)
return img

View File

@@ -17,10 +17,12 @@ from config import (
MATH_CHARS,
THREED_CHARS,
NUM_CAPTCHA_TYPES,
REGRESSION_RANGE,
)
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
def export_model(
@@ -34,7 +36,7 @@ def export_model(
Args:
model: 已加载权重的 PyTorch 模型
model_name: 模型名 (classifier / normal / math / threed)
model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider)
input_shape: 输入形状 (C, H, W)
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
"""
@@ -50,7 +52,7 @@ def export_model(
dummy = torch.randn(1, *input_shape)
# 分类器和识别器的 dynamic_axes 不同
if model_name == "classifier":
if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"):
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
else:
# CTC 模型: output shape = (T, B, C)
@@ -78,7 +80,8 @@ def _load_and_export(model_name: str):
return
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}")
acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?')
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}")
if model_name == "classifier":
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
@@ -94,11 +97,19 @@ def _load_and_export(model_name: str):
h, w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed":
elif model_name == "threed_text":
chars = ckpt.get("chars", THREED_CHARS)
h, w = IMAGE_SIZE["3d"]
h, w = IMAGE_SIZE["3d_text"]
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed_rotate":
h, w = IMAGE_SIZE["3d_rotate"]
model = RegressionCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed_slider":
h, w = IMAGE_SIZE["3d_slider"]
model = RegressionCNN(img_h=h, img_w=w)
input_shape = (1, h, w)
else:
print(f"[错误] 未知模型: {model_name}")
return
@@ -108,11 +119,11 @@ def _load_and_export(model_name: str):
def export_all():
"""依次导出 classifier, normal, math, threed个模型。"""
"""依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。"""
print("=" * 50)
print("导出全部 ONNX 模型")
print("=" * 50)
for name in ["classifier", "normal", "math", "threed"]:
for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]:
_load_and_export(name)
print("\n全部导出完成。")

View File

@@ -4,7 +4,7 @@
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
推理流程:
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 / 回归缩放 → 后处理 → 输出
对算式类型,解码后还会调用 math_eval 计算结果。
"""
@@ -23,6 +23,7 @@ from config import (
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
REGRESSION_RANGE,
)
from inference.math_eval import eval_captcha_math
@@ -59,19 +60,24 @@ class CaptchaPipeline:
self.mean = INFERENCE_CONFIG["normalize_mean"]
self.std = INFERENCE_CONFIG["normalize_std"]
# 字符集映射
# 字符集映射 (仅 CTC 模型需要)
self._chars = {
"normal": NORMAL_CHARS,
"math": MATH_CHARS,
"3d": THREED_CHARS,
"3d_text": THREED_CHARS,
}
# 回归模型类型
self._regression_types = {"3d_rotate", "3d_slider"}
# 专家模型名 → ONNX 文件名
self._model_files = {
"classifier": "classifier.onnx",
"normal": "normal.onnx",
"math": "math.onnx",
"3d": "threed.onnx",
"3d_text": "threed_text.onnx",
"3d_rotate": "threed_rotate.onnx",
"3d_slider": "threed_slider.onnx",
}
# 加载所有可用模型
@@ -116,7 +122,7 @@ class CaptchaPipeline:
def classify(self, image: Image.Image) -> str:
"""
调度分类,返回类型名: 'normal' / 'math' / '3d'
调度分类,返回类型名。
Raises:
RuntimeError: 分类器模型未加载
@@ -141,7 +147,7 @@ class CaptchaPipeline:
Args:
image: PIL.Image 或文件路径 (str/Path) 或 bytes
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d_text'/'3d_rotate'/'3d_slider')
Returns:
dict: {
@@ -166,22 +172,32 @@ class CaptchaPipeline:
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
)
size_key = captcha_type # "normal"/"math"/"3d"
size_key = captcha_type
inp = self.preprocess(img, IMAGE_SIZE[size_key])
session = self._sessions[captcha_type]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
# 4. CTC 贪心解码
# 4. 分支: CTC 解码 vs 回归
if captcha_type in self._regression_types:
# 回归模型: 输出 (batch, 1) sigmoid 值
output = session.run(None, {input_name: inp})[0] # (1, 1)
sigmoid_val = float(output[0, 0])
lo, hi = REGRESSION_RANGE[captcha_type]
real_val = sigmoid_val * (hi - lo) + lo
raw_text = f"{real_val:.1f}"
result = str(int(round(real_val)))
else:
# CTC 模型
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
chars = self._chars[captcha_type]
raw_text = self._ctc_greedy_decode(logits, chars)
# 5. 后处理
# 后处理
if captcha_type == "math":
try:
result = eval_captcha_math(raw_text)
except ValueError:
result = raw_text # 解析失败则返回原始文本
result = raw_text
else:
result = raw_text

View File

@@ -1,18 +1,21 @@
"""
模型定义包
提供种模型:
提供种模型:
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
- ThreeDCNN: 3D 验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
- ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
"""
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
__all__ = [
"CaptchaClassifier",
"LiteCRNN",
"ThreeDCNN",
"RegressionCNN",
]

86
models/regression_cnn.py Normal file
View File

@@ -0,0 +1,86 @@
"""
回归 CNN 模型
3d_rotate 和 3d_slider 共用的回归模型。
输出 sigmoid 归一化到 [0,1],推理时按 label_range 缩放回原始范围。
架构:
Conv(1→32) + BN + ReLU + Pool
Conv(32→64) + BN + ReLU + Pool
Conv(64→128) + BN + ReLU + Pool
Conv(128→128) + BN + ReLU + Pool
AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid
约 250K 参数,~1MB。
"""
import torch
import torch.nn as nn
class RegressionCNN(nn.Module):
"""
轻量回归 CNN用于 3d_rotate (角度) 和 3d_slider (偏移) 预测。
输出 [0, 1] 范围的 sigmoid 值,需要按 label_range 缩放到实际范围。
"""
def __init__(self, img_h: int = 80, img_w: int = 80):
"""
Args:
img_h: 输入图片高度
img_w: 输入图片宽度
"""
super().__init__()
self.img_h = img_h
self.img_w = img_w
self.features = nn.Sequential(
# block 1: 1 → 32, H/2, W/2
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 2: 32 → 64, H/4, W/4
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 3: 64 → 128, H/8, W/8
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 4: 128 → 128, H/16, W/16
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.pool = nn.AdaptiveAvgPool2d(1)
self.regressor = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, H, W) 灰度图
Returns:
output: (batch, 1) sigmoid 输出 [0, 1]
"""
feat = self.features(x)
feat = self.pool(feat) # (B, 128, 1, 1)
feat = feat.flatten(1) # (B, 128)
out = self.regressor(feat) # (B, 1)
return out

View File

@@ -1,10 +1,13 @@
"""
训练脚本包
- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类
- dataset.py: CRNNDataset / CaptchaDataset / RegressionDataset 通用数据集类
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
- train_regression_utils.py: 回归训练通用逻辑 (train_regression_model)
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
- train_math.py: 训练算式识别 (LiteCRNN - math)
- train_3d.py: 训练 3D 立体识别 (ThreeDCNN)
- train_3d_text.py: 训练 3D 立体文字识别 (ThreeDCNN)
- train_3d_rotate.py: 训练 3D 旋转回归 (RegressionCNN)
- train_3d_slider.py: 训练 3D 滑块回归 (RegressionCNN)
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
"""

View File

@@ -1,16 +1,19 @@
"""
通用 Dataset 类
提供种数据集:
提供种数据集:
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1])
文件名格式约定: {label}_{任意}.png
- 分类器: label 可为任意字符,所在子目录名即为类别
- 识别器: label 即标注内容 (如 "A3B8""3+8")
- 回归器: label 为数值 (如 "135""87")
"""
import os
import warnings
from pathlib import Path
from PIL import Image
@@ -98,7 +101,15 @@ class CRNNDataset(Dataset):
img = self.transform(img)
# 编码标签为整数序列
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
target = []
for c in label:
if c in self.char_to_idx:
target.append(self.char_to_idx[c])
else:
warnings.warn(
f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})",
stacklevel=2,
)
return img, target, label
@staticmethod
@@ -119,7 +130,7 @@ class CaptchaDataset(Dataset):
"""
分类器训练数据集。
每个子目录名为类别名 (如 "normal", "math", "3d")
每个子目录名为类别名 (如 "normal", "math", "3d_text")
目录内所有 .png 文件属于该类。
"""
@@ -157,3 +168,59 @@ class CaptchaDataset(Dataset):
if self.transform:
img = self.transform(img)
return img, label
# ============================================================
# 回归模型用数据集
# ============================================================
class RegressionDataset(Dataset):
"""
回归模型数据集 (3d_rotate / 3d_slider)。
从目录中读取 {value}_{xxx}.png 文件,
将 value 解析为浮点数并归一化到 [0, 1]。
"""
def __init__(
self,
dirs: list[str | Path],
label_range: tuple[float, float],
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表
label_range: (min_val, max_val) 标签原始范围
transform: 图片预处理/增强
"""
self.label_range = label_range
self.lo, self.hi = label_range
self.transform = transform
self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
raw_label = f.stem.rsplit("_", 1)[0]
try:
value = float(raw_label)
except ValueError:
continue
# 归一化到 [0, 1]
norm = (value - self.lo) / max(self.hi - self.lo, 1e-6)
norm = max(0.0, min(1.0, norm))
self.samples.append((str(f), norm))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, torch.tensor([label], dtype=torch.float32)

View File

@@ -0,0 +1,38 @@
"""
训练 3D 旋转验证码回归模型 (RegressionCNN)
用法: python -m training.train_3d_rotate
"""
from config import (
IMAGE_SIZE,
SYNTHETIC_3D_ROTATE_DIR,
REAL_3D_ROTATE_DIR,
)
from generators.threed_rotate_gen import ThreeDRotateGenerator
from models.regression_cnn import RegressionCNN
from training.train_regression_utils import train_regression_model
def main():
img_h, img_w = IMAGE_SIZE["3d_rotate"]
model = RegressionCNN(img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练 3D 旋转验证码回归模型 (RegressionCNN)")
print(f" 输入尺寸: {img_h}×{img_w}")
print(f" 任务: 预测旋转角度 0-359°")
print("=" * 60)
train_regression_model(
model_name="threed_rotate",
model=model,
synthetic_dir=SYNTHETIC_3D_ROTATE_DIR,
real_dir=REAL_3D_ROTATE_DIR,
generator_cls=ThreeDRotateGenerator,
config_key="3d_rotate",
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
"""
训练 3D 滑块验证码回归模型 (RegressionCNN)
用法: python -m training.train_3d_slider
"""
from config import (
IMAGE_SIZE,
SYNTHETIC_3D_SLIDER_DIR,
REAL_3D_SLIDER_DIR,
)
from generators.threed_slider_gen import ThreeDSliderGenerator
from models.regression_cnn import RegressionCNN
from training.train_regression_utils import train_regression_model
def main():
img_h, img_w = IMAGE_SIZE["3d_slider"]
model = RegressionCNN(img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练 3D 滑块验证码回归模型 (RegressionCNN)")
print(f" 输入尺寸: {img_h}×{img_w}")
print(f" 任务: 预测滑块偏移 x 坐标")
print("=" * 60)
train_regression_model(
model_name="threed_slider",
model=model,
synthetic_dir=SYNTHETIC_3D_SLIDER_DIR,
real_dir=REAL_3D_SLIDER_DIR,
generator_cls=ThreeDSliderGenerator,
config_key="3d_slider",
)
if __name__ == "__main__":
main()

View File

@@ -1,14 +1,14 @@
"""
训练 3D 立体验证码识别模型 (ThreeDCNN)
训练 3D 立体文字验证码识别模型 (ThreeDCNN)
用法: python -m training.train_3d
用法: python -m training.train_3d_text
"""
from config import (
THREED_CHARS,
IMAGE_SIZE,
SYNTHETIC_3D_DIR,
REAL_3D_DIR,
SYNTHETIC_3D_TEXT_DIR,
REAL_3D_TEXT_DIR,
)
from generators.threed_gen import ThreeDCaptchaGenerator
from models.threed_cnn import ThreeDCNN
@@ -16,23 +16,23 @@ from training.train_utils import train_ctc_model
def main():
img_h, img_w = IMAGE_SIZE["3d"]
img_h, img_w = IMAGE_SIZE["3d_text"]
model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练 3D 立体验证码识别模型 (ThreeDCNN)")
print("训练 3D 立体文字验证码识别模型 (ThreeDCNN)")
print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_ctc_model(
model_name="threed",
model_name="threed_text",
model=model,
chars=THREED_CHARS,
synthetic_dir=SYNTHETIC_3D_DIR,
real_dir=REAL_3D_DIR,
synthetic_dir=SYNTHETIC_3D_TEXT_DIR,
real_dir=REAL_3D_TEXT_DIR,
generator_cls=ThreeDCaptchaGenerator,
config_key="threed",
config_key="3d_text",
)

View File

@@ -1,16 +1,18 @@
"""
训练调度分类器 (CaptchaClassifier)
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider
数据来源: data/classifier/ 目录 (按类型子目录组织)
用法: python -m training.train_classifier
"""
import os
import random
import shutil
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
@@ -24,15 +26,20 @@ from config import (
CLASSIFIER_DIR,
SYNTHETIC_NORMAL_DIR,
SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_DIR,
SYNTHETIC_3D_TEXT_DIR,
SYNTHETIC_3D_ROTATE_DIR,
SYNTHETIC_3D_SLIDER_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
RANDOM_SEED,
get_device,
)
from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
from generators.threed_rotate_gen import ThreeDRotateGenerator
from generators.threed_slider_gen import ThreeDSliderGenerator
from models.classifier import CaptchaClassifier
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
@@ -52,7 +59,9 @@ def _prepare_classifier_data():
type_info = [
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator),
("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator),
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
]
for cls_name, syn_dir, gen_cls in type_info:
@@ -95,6 +104,13 @@ def main():
img_h, img_w = IMAGE_SIZE["classifier"]
device = get_device()
# 设置随机种子
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
print("=" * 60)
print("训练调度分类器 (CaptchaClassifier)")
print(f" 类别: {CAPTCHA_TYPES}")
@@ -128,11 +144,11 @@ def main():
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, pin_memory=True,
num_workers=0, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, pin_memory=True,
num_workers=0, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")

View File

@@ -0,0 +1,264 @@
"""
回归训练通用逻辑
提供 train_regression_model() 函数,被 train_3d_rotate / train_3d_slider 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 RegressionDataset / DataLoader含真实数据混合
3. 回归训练循环 + cosine scheduler
4. 输出日志: epoch, loss, MAE, tolerance 准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束导出 ONNX
"""
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
REGRESSION_RANGE,
RANDOM_SEED,
get_device,
)
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _circular_smooth_l1(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
循环距离 SmoothL1 loss用于角度回归 (处理 0°/360° 边界)。
pred 和 target 都在 [0, 1] 范围。
"""
diff = torch.abs(pred - target)
# 循环距离: min(|d|, 1-|d|)
diff = torch.min(diff, 1.0 - diff)
# SmoothL1
return torch.where(
diff < 1.0 / 360.0, # beta ≈ 1° 归一化
0.5 * diff * diff / (1.0 / 360.0),
diff - 0.5 * (1.0 / 360.0),
).mean()
def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
"""循环 MAE (归一化空间)。"""
diff = np.abs(pred - target)
diff = np.minimum(diff, 1.0 - diff)
return float(np.mean(diff))
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
def train_regression_model(
model_name: str,
model: nn.Module,
synthetic_dir: str | Path,
real_dir: str | Path,
generator_cls,
config_key: str,
):
"""
通用回归训练流程。
Args:
model_name: 模型名称 (用于保存文件: threed_rotate / threed_slider)
model: PyTorch 模型实例 (RegressionCNN)
synthetic_dir: 合成数据目录
real_dir: 真实数据目录
generator_cls: 生成器类 (用于自动生成数据)
config_key: TRAIN_CONFIG / REGRESSION_RANGE 中的键名 (3d_rotate / 3d_slider)
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key]
label_range = REGRESSION_RANGE[config_key]
lo, hi = label_range
is_circular = config_key == "3d_rotate"
device = get_device()
# 容差配置
if config_key == "3d_rotate":
tolerance = 5.0 # ±5°
else:
tolerance = 3.0 # ±3px
_set_seed()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_path = Path(real_dir)
if real_path.exists() and list(real_path.glob("*.png")):
data_dirs.append(str(real_path))
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))}")
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = RegressionDataset(
dirs=data_dirs, label_range=label_range, transform=train_transform,
)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集使用无增强 transform
val_ds_clean = RegressionDataset(
dirs=data_dirs, label_range=label_range, transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=0, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=0, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 优化器 / 调度器 / 损失 ----
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
if is_circular:
loss_fn = _circular_smooth_l1
else:
loss_fn = nn.SmoothL1Loss()
best_mae = float("inf")
best_tol_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets in pbar:
images = images.to(device)
targets = targets.to(device)
preds = model(images)
loss = loss_fn(preds, targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for images, targets in val_loader:
images = images.to(device)
preds = model(images)
all_preds.append(preds.cpu().numpy())
all_targets.append(targets.numpy())
all_preds = np.concatenate(all_preds, axis=0).flatten()
all_targets = np.concatenate(all_targets, axis=0).flatten()
# 缩放回原始范围计算 MAE
preds_real = all_preds * (hi - lo) + lo
targets_real = all_targets * (hi - lo) + lo
if is_circular:
# 循环 MAE
diff = np.abs(preds_real - targets_real)
diff = np.minimum(diff, (hi - lo) - diff)
mae = float(np.mean(diff))
within_tol = diff <= tolerance
else:
mae = float(np.mean(np.abs(preds_real - targets_real)))
within_tol = np.abs(preds_real - targets_real) <= tolerance
tol_acc = float(np.mean(within_tol))
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"MAE={mae:.2f} "
f"tol_acc(±{tolerance:.0f})={tol_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 (以容差准确率为准) ----
if tol_acc >= best_tol_acc:
best_tol_acc = tol_acc
best_mae = mae
torch.save({
"model_state_dict": model.state_dict(),
"label_range": label_range,
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
return best_tol_acc

View File

@@ -1,7 +1,7 @@
"""
CTC 训练通用逻辑
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 Dataset / DataLoader含真实数据混合
@@ -12,8 +12,10 @@ CTC 训练通用逻辑
"""
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
@@ -25,11 +27,21 @@ from config import (
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
RANDOM_SEED,
get_device,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
def _set_seed(seed: int = RANDOM_SEED):
"""设置全局随机种子,保证训练可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ============================================================
# 准确率计算
# ============================================================
@@ -104,9 +116,12 @@ def train_ctc_model(
config_key: TRAIN_CONFIG 中的键名
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
img_h, img_w = IMAGE_SIZE[config_key]
device = get_device()
# 设置随机种子
_set_seed()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
@@ -139,11 +154,11 @@ def train_ctc_model(
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
@@ -166,12 +181,11 @@ def train_ctc_model(
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar:
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C)
T, B, C = logits.shape
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
# cuDNN CTC requires targets/lengths on CPU
input_lengths = torch.full((B,), T, dtype=torch.int32)
log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)