From f5be7671bceaa4dfe3e94e91c51a0f59e1db1a80 Mon Sep 17 00:00:00 2001 From: Hua Date: Wed, 11 Mar 2026 13:55:53 +0800 Subject: [PATCH] Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 --- AGENTS.md | 18 +- CLAUDE.md | 157 ++++++++---- cli.py | 83 +++++-- config.py | 66 +++++- generators/__init__.py | 10 +- generators/threed_gen.py | 4 +- generators/threed_rotate_gen.py | 122 ++++++++++ generators/threed_slider_gen.py | 113 +++++++++ inference/export_onnx.py | 25 +- inference/pipeline.py | 54 +++-- models/__init__.py | 7 +- models/regression_cnn.py | 86 +++++++ training/__init__.py | 15 +- training/dataset.py | 77 +++++- training/train_3d_rotate.py | 38 +++ training/train_3d_slider.py | 38 +++ training/{train_3d.py => train_3d_text.py} | 20 +- training/train_classifier.py | 26 +- training/train_regression_utils.py | 264 +++++++++++++++++++++ training/train_utils.py | 28 ++- 20 files changed, 1109 insertions(+), 142 deletions(-) create mode 100644 generators/threed_rotate_gen.py create mode 100644 generators/threed_slider_gen.py create mode 100644 models/regression_cnn.py create mode 100644 training/train_3d_rotate.py create mode 100644 training/train_3d_slider.py rename training/{train_3d.py => train_3d_text.py} (60%) create mode 100644 training/train_regression_utils.py diff --git a/AGENTS.md b/AGENTS.md index 14e4833..90f6d7e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,25 +1,33 @@ # Repository Guidelines ## Project Structure & Module Organization -Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas, `models/` contains the classifier and expert OCR models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`. +Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas (5 types: normal, math, 3d_text, 3d_rotate, 3d_slider), `models/` contains the classifier, CTC expert models, and regression models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`. ## Build, Test, and Development Commands Use `uv` for environment and dependency management. - `uv sync` installs the base runtime dependencies from `pyproject.toml`. - `uv sync --extra server` installs HTTP service dependencies. -- `uv run captcha generate --type normal --num 1000` generates synthetic training data. -- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d -> classifier`. +- `uv run captcha generate --type normal --num 1000` generates synthetic training data. Types: `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, `classifier`. +- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d_text -> 3d_rotate -> 3d_slider -> classifier`. - `uv run captcha export --all` exports all trained models to ONNX. +- `uv run captcha export --model 3d_text` exports a single model; `3d_text` is automatically mapped to `threed_text`. - `uv run captcha predict image.png` runs auto-routing inference; add `--type normal` to skip classification. - `uv run captcha predict-dir ./test_images` runs batch inference on a directory. - `uv run captcha serve --port 8080` starts the optional HTTP API when `server.py` is implemented. ## Coding Style & Naming Conventions -Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d`, and `classifier`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding unless the pipeline is intentionally redesigned. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`. +Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d_text`, `3d_rotate`, `3d_slider`, and `classifier`. Checkpoint/ONNX file names use `threed_text`, `threed_rotate`, `threed_slider` (underscored, no hyphens). Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding for OCR models. Regression models (3d_rotate, 3d_slider) output sigmoid [0,1] scaled by `REGRESSION_RANGE`. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`. + +## Training & Data Rules +- All training scripts must set the global random seed (`random`, `numpy`, `torch`) via `config.RANDOM_SEED` before training begins. +- All DataLoaders use `num_workers=0` for cross-platform consistency. +- Generator parameters (rotation, noise, shadow, etc.) must come from `config.GENERATE_CONFIG`, not hardcoded values. +- `CRNNDataset` emits a `warnings.warn` when a label contains characters outside the configured charset, rather than silently dropping them. +- `RegressionDataset` parses numeric labels from filenames and normalizes to [0,1] via `label_range`. ## Data & Testing Guidelines -Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing. +Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. For regression types, label is the numeric value (angle or offset). Sample targets are defined in `config.py`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing. ## Commit & Pull Request Guidelines Git history is not available in this workspace snapshot, so use short imperative commit subjects such as `Add classifier export smoke test`. Keep pull requests focused, describe affected modules, list the commands you ran, and attach sample outputs when prediction behavior changes. diff --git a/CLAUDE.md b/CLAUDE.md index 9ffbaa0..bcf06bd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -24,29 +24,40 @@ captcha-breaker/ │ ├── synthetic/ # 合成训练数据 (自动生成,不入 git) │ │ ├── normal/ # 普通字符型 │ │ ├── math/ # 算式型 -│ │ └── 3d/ # 3D立体型 +│ │ ├── 3d_text/ # 3D立体文字型 +│ │ ├── 3d_rotate/ # 3D旋转型 +│ │ └── 3d_slider/ # 3D滑块型 │ ├── real/ # 真实验证码样本 (手动标注) │ │ ├── normal/ │ │ ├── math/ -│ │ └── 3d/ +│ │ ├── 3d_text/ +│ │ ├── 3d_rotate/ +│ │ └── 3d_slider/ │ └── classifier/ # 调度分类器训练数据 (混合各类型) ├── generators/ │ ├── __init__.py │ ├── base.py # 生成器基类 │ ├── normal_gen.py # 普通字符验证码生成器 │ ├── math_gen.py # 算式验证码生成器 (如 3+8=?) -│ └── threed_gen.py # 3D立体验证码生成器 +│ ├── threed_gen.py # 3D立体文字验证码生成器 +│ ├── threed_rotate_gen.py # 3D旋转验证码生成器 +│ └── threed_slider_gen.py # 3D滑块验证码生成器 ├── models/ │ ├── __init__.py │ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式) │ ├── classifier.py # 调度分类模型 -│ └── threed_cnn.py # 3D验证码专用模型 (更深的CNN) +│ ├── threed_cnn.py # 3D文字验证码专用模型 (更深的CNN) +│ └── regression_cnn.py # 回归CNN (3D旋转+滑块, ~1MB) ├── training/ │ ├── __init__.py │ ├── train_classifier.py # 训练调度模型 │ ├── train_normal.py # 训练普通字符识别 │ ├── train_math.py # 训练算式识别 -│ ├── train_3d.py # 训练3D识别 +│ ├── train_3d_text.py # 训练3D文字识别 +│ ├── train_3d_rotate.py # 训练3D旋转回归 +│ ├── train_3d_slider.py # 训练3D滑块回归 +│ ├── train_utils.py # CTC 训练通用逻辑 +│ ├── train_regression_utils.py # 回归训练通用逻辑 │ └── dataset.py # 通用 Dataset 类 ├── inference/ │ ├── __init__.py @@ -57,12 +68,16 @@ captcha-breaker/ │ ├── classifier.pth │ ├── normal.pth │ ├── math.pth -│ └── threed.pth +│ ├── threed_text.pth +│ ├── threed_rotate.pth +│ └── threed_slider.pth ├── onnx_models/ # 导出的 ONNX 模型 │ ├── classifier.onnx │ ├── normal.onnx │ ├── math.onnx -│ └── threed.onnx +│ ├── threed_text.onnx +│ ├── threed_rotate.onnx +│ └── threed_slider.onnx ├── server.py # FastAPI 推理服务 (可选) ├── cli.py # 命令行入口 └── tests/ @@ -78,13 +93,13 @@ captcha-breaker/ ``` 输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → 后处理 → 输出结果 │ - ┌────────┼────────┐ - ▼ ▼ ▼ - normal math 3d - (CRNN) (CRNN) (CNN) - │ │ │ - ▼ ▼ ▼ - "A3B8" "3+8=?"→11 "X9K2" + ┌────────┬───┼───────┬──────────┐ + ▼ ▼ ▼ ▼ ▼ + normal math 3d_text 3d_rotate 3d_slider + (CRNN) (CRNN) (CNN) (RegCNN) (RegCNN) + │ │ │ │ │ + ▼ ▼ ▼ ▼ ▼ + "A3B8" "3+8=?"→11 "X9K2" "135" "87" ``` ### 调度分类器 (classifier.py) @@ -102,7 +117,7 @@ class CaptchaClassifier(nn.Module): 轻量分类器,几层卷积即可区分不同类型验证码。 不同类型验证码视觉差异大(有无运算符、3D效果等),分类很容易。 """ - def __init__(self, num_types=3): + def __init__(self, num_types=5): # 4层卷积 + GAP + FC # Conv2d(1,16) -> Conv2d(16,32) -> Conv2d(32,64) -> Conv2d(64,64) # AdaptiveAvgPool2d(1) -> Linear(64, num_types) @@ -142,14 +157,32 @@ def eval_captcha_math(expr: str) -> str: pass ``` -### 3D立体识别专家 (threed_cnn.py) +### 3D立体文字识别专家 (threed_cnn.py) -- 任务: 识别带 3D 透视/阴影效果的验证码 +- 任务: 识别带 3D 透视/阴影效果的文字验证码 - 架构: 更深的 CNN + CRNN,或 ResNet-lite backbone - 输入: 灰度图 1x60x160 - 需要更强的特征提取能力来处理透视变形和阴影 - 模型体积目标: < 5MB +### 3D旋转识别专家 (regression_cnn.py - 3d_rotate 模式) + +- 任务: 预测旋转验证码的正确旋转角度 +- 架构: 轻量回归 CNN (4层卷积 + GAP + FC + Sigmoid) +- 输入: 灰度图 1x80x80 +- 输出: [0,1] sigmoid 值,按 (0, 360) 缩放回角度 +- 标签范围: 0-359° +- 模型体积目标: ~1MB + +### 3D滑块识别专家 (regression_cnn.py - 3d_slider 模式) + +- 任务: 预测滑块拼图缺口的 x 偏移 +- 架构: 同上回归 CNN,不同输入尺寸 +- 输入: 灰度图 1x80x240 +- 输出: [0,1] sigmoid 值,按 (10, 200) 缩放回像素偏移 +- 标签范围: 10-200px +- 模型体积目标: ~1MB + ## 数据生成器规范 ### 基类 (base.py) @@ -185,13 +218,26 @@ class BaseCaptchaGenerator: - 标签格式: `3+8` (存储算式本身,不存结果) - 视觉风格: 与目标算式验证码一致 -### 3D生成器 (threed_gen.py) +### 3D文字生成器 (threed_gen.py) - 使用 Pillow 的仿射变换模拟 3D 透视 - 添加阴影效果 - 字符有深度感和倾斜 +- 字符旋转角度由 `config.py` `GENERATE_CONFIG["3d_text"]["rotation_range"]` 统一配置 - 标签: 纯字符内容 +### 3D旋转生成器 (threed_rotate_gen.py) + +- 圆盘上绘制字符 + 方向标记,随机旋转 0-359° +- 标签 = 旋转角度(整数字符串) +- 文件名格式: `{angle}_{index:06d}.png` + +### 3D滑块生成器 (threed_slider_gen.py) + +- 纹理背景 + 拼图缺口 + 拼图块在左侧 +- 标签 = 缺口 x 坐标偏移(整数字符串) +- 文件名格式: `{offset}_{index:06d}.png` + ## 训练规范 ### 通用训练配置 @@ -204,7 +250,7 @@ TRAIN_CONFIG = { 'batch_size': 128, 'lr': 1e-3, 'scheduler': 'cosine', - 'synthetic_samples': 30000, # 每类 10000 + 'synthetic_samples': 50000, # 每类 10000 × 5 类 }, 'normal': { 'epochs': 50, @@ -222,7 +268,7 @@ TRAIN_CONFIG = { 'synthetic_samples': 60000, 'loss': 'CTCLoss', }, - 'threed': { + '3d_text': { 'epochs': 80, 'batch_size': 64, 'lr': 5e-4, @@ -230,18 +276,36 @@ TRAIN_CONFIG = { 'synthetic_samples': 80000, 'loss': 'CTCLoss', }, + '3d_rotate': { + 'epochs': 60, + 'batch_size': 128, + 'lr': 1e-3, + 'scheduler': 'cosine', + 'synthetic_samples': 60000, + 'loss': 'SmoothL1', + }, + '3d_slider': { + 'epochs': 60, + 'batch_size': 128, + 'lr': 1e-3, + 'scheduler': 'cosine', + 'synthetic_samples': 60000, + 'loss': 'SmoothL1', + }, } ``` ### 训练脚本要求 每个训练脚本必须: -1. 检查合成数据是否已生成,没有则自动调用生成器 -2. 支持混合真实数据 (如果 data/real/{type}/ 有文件) -3. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing -4. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率 -5. 保存最佳模型到 checkpoints/ -6. 训练结束自动导出 ONNX 到 onnx_models/ +1. 训练开始前设置全局随机种子 (random/numpy/torch),使用 `config.RANDOM_SEED`,保证可复现 +2. 检查合成数据是否已生成,没有则自动调用生成器 +3. 支持混合真实数据 (如果 data/real/{type}/ 有文件) +4. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing +5. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率 (CTC) 或 MAE, 容差准确率 (回归) +6. 保存最佳模型到 checkpoints/ +7. 训练结束自动导出 ONNX 到 onnx_models/ +8. DataLoader 统一使用 `num_workers=0` 避免多进程兼容问题 ### 数据增强策略 @@ -281,14 +345,14 @@ class CaptchaPipeline: pass def classify(self, image: Image.Image) -> str: - """调度分类,返回类型名: 'normal' / 'math' / '3d'""" + """调度分类,返回类型名: 'normal' / 'math' / '3d_text' / '3d_rotate' / '3d_slider'""" pass def solve(self, image) -> str: """ 完整识别流程: 1. 分类验证码类型 - 2. 路由到对应专家模型 + 2. 路由到对应专家模型 (CTC 或回归) 3. 后处理 (算式型需要计算结果) 4. 返回最终答案字符串 @@ -311,7 +375,7 @@ def export_model(model, model_name, input_shape, onnx_dir='onnx_models/'): pass def export_all(): - """依次导出 classifier, normal, math, threed 四个模型""" + """依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型""" pass ``` @@ -325,23 +389,29 @@ uv sync --extra server # 含 HTTP 服务依赖 # 生成训练数据 uv run python cli.py generate --type normal --num 60000 uv run python cli.py generate --type math --num 60000 -uv run python cli.py generate --type 3d --num 80000 -uv run python cli.py generate --type classifier --num 30000 +uv run python cli.py generate --type 3d_text --num 80000 +uv run python cli.py generate --type 3d_rotate --num 60000 +uv run python cli.py generate --type 3d_slider --num 60000 +uv run python cli.py generate --type classifier --num 50000 # 训练模型 uv run python cli.py train --model classifier uv run python cli.py train --model normal uv run python cli.py train --model math -uv run python cli.py train --model 3d +uv run python cli.py train --model 3d_text +uv run python cli.py train --model 3d_rotate +uv run python cli.py train --model 3d_slider uv run python cli.py train --all # 按依赖顺序全部训练 # 导出 ONNX uv run python cli.py export --all +uv run python cli.py export --model 3d_text # "3d_text" 自动映射为 "threed_text" # 推理 -uv run python cli.py predict image.png # 自动分类+识别 -uv run python cli.py predict image.png --type normal # 跳过分类直接识别 -uv run python cli.py predict-dir ./test_images/ # 批量识别 +uv run python cli.py predict image.png # 自动分类+识别 +uv run python cli.py predict image.png --type normal # 跳过分类直接识别 +uv run python cli.py predict image.png --type 3d_rotate # 指定为旋转类型 +uv run python cli.py predict-dir ./test_images/ # 批量识别 # 启动 HTTP 服务 (需先安装 server 可选依赖) uv run python cli.py serve --port 8080 @@ -360,11 +430,12 @@ uv run python cli.py serve --port 8080 1. **所有模型用 float32 训练,导出 ONNX 时不做量化**,先保证精度 2. **CTC 解码统一用贪心解码**,不需要 beam search,验证码场景贪心够用 -3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符,3d 继续使用去混淆字符集 +3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符,3d_text 继续使用去混淆字符集 4. **算式识别分两步**: 先 OCR 识别字符串,再用规则计算,不要让模型直接输出数值 -5. **生成器的随机种子**: 生成数据时设置 seed 保证可复现 +5. **随机种子**: 生成数据和训练时均通过 `config.RANDOM_SEED` 设置全局种子 (random/numpy/torch),保证可复现 6. **真实数据文件名格式**: `{label}_{任意}.png`,label 部分是标注内容 -7. **模型保存格式**: PyTorch checkpoint 包含 model_state_dict, chars, best_acc, epoch +11. **数据集字符过滤**: `CRNNDataset` 加载标签时,若发现字符不在字符集内会发出 warning,便于排查标注/字符集不匹配问题 +7. **模型保存格式**: CTC checkpoint 包含 model_state_dict, chars, best_acc, epoch; 回归 checkpoint 包含 model_state_dict, label_range, best_mae, best_tol_acc, epoch 8. **不使用 GPU 特有功能**,确保 CPU 也能训练和推理 (只是慢一些) 9. **类型扩展**: 新增验证码类型时,只需 (1) 加生成器 (2) 加专家模型 (3) 调度器加一个类别重新训练 10. **文档同步**: 对项目结构、配置、架构等做出变更时,必须同步更新 CLAUDE.md 中的对应内容,保持文档与代码一致 @@ -373,18 +444,20 @@ uv run python cli.py serve --port 8080 | 模型 | 准确率目标 | 推理延迟 | 模型体积 | |------|-----------|---------|---------| -| 调度分类器 | > 99% | < 5ms | < 500KB | +| 调度分类器 (5类) | > 99% | < 5ms | < 500KB | | 普通字符 | > 95% | < 30ms | < 2MB | | 算式识别 | > 93% | < 30ms | < 2MB | -| 3D立体 | > 85% | < 50ms | < 5MB | -| 全流水线 | - | < 80ms | < 10MB 总计 | +| 3D立体文字 | > 85% | < 50ms | < 5MB | +| 3D旋转 (±5°) | > 85% | < 30ms | ~1MB | +| 3D滑块 (±3px) | > 90% | < 30ms | ~1MB | +| 全流水线 | - | < 80ms | < 12MB 总计 | ## 开发顺序 1. 先实现 config.py 和 generators/ 2. 实现 models/ 中所有模型定义 3. 实现 training/dataset.py 通用数据集类 -4. 按顺序训练: normal → math → 3d → classifier +4. 按顺序训练: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier 5. 实现 inference/pipeline.py 和 export_onnx.py 6. 实现 cli.py 统一入口 7. 可选: server.py HTTP 服务 diff --git a/cli.py b/cli.py index ff5aaff..faa2fe5 100644 --- a/cli.py +++ b/cli.py @@ -3,6 +3,9 @@ CaptchaBreaker 命令行入口 用法: python cli.py generate --type normal --num 60000 + python cli.py generate --type 3d_text --num 80000 + python cli.py generate --type 3d_rotate --num 60000 + python cli.py generate --type 3d_slider --num 60000 python cli.py train --model normal python cli.py train --all python cli.py export --all @@ -20,15 +23,21 @@ from pathlib import Path def cmd_generate(args): """生成训练数据。""" from config import ( - SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR, + SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, + SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR, CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES, ) - from generators import NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator + from generators import ( + NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator, + ThreeDRotateGenerator, ThreeDSliderGenerator, + ) gen_map = { "normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR), "math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR), - "3d": (ThreeDCaptchaGenerator, SYNTHETIC_3D_DIR), + "3d_text": (ThreeDCaptchaGenerator, SYNTHETIC_3D_TEXT_DIR), + "3d_rotate": (ThreeDRotateGenerator, SYNTHETIC_3D_ROTATE_DIR), + "3d_slider": (ThreeDSliderGenerator, SYNTHETIC_3D_SLIDER_DIR), } captcha_type = args.type @@ -50,25 +59,31 @@ def cmd_generate(args): gen = gen_cls() gen.generate_dataset(num, str(out_dir)) else: - print(f"未知类型: {captcha_type} 可选: normal, math, 3d, classifier") + valid = ", ".join(list(gen_map.keys()) + ["classifier"]) + print(f"未知类型: {captcha_type} 可选: {valid}") sys.exit(1) def cmd_train(args): """训练模型。""" if args.all: - # 按依赖顺序: normal → math → 3d → classifier - print("按顺序训练全部模型: normal → math → 3d → classifier\n") + print("按顺序训练全部模型: normal → math → 3d_text → 3d_rotate → 3d_slider → classifier\n") from training.train_normal import main as train_normal from training.train_math import main as train_math - from training.train_3d import main as train_3d + from training.train_3d_text import main as train_3d_text + from training.train_3d_rotate import main as train_3d_rotate + from training.train_3d_slider import main as train_3d_slider from training.train_classifier import main as train_classifier train_normal() print("\n") train_math() print("\n") - train_3d() + train_3d_text() + print("\n") + train_3d_rotate() + print("\n") + train_3d_slider() print("\n") train_classifier() return @@ -78,12 +93,16 @@ def cmd_train(args): from training.train_normal import main as train_fn elif model == "math": from training.train_math import main as train_fn - elif model == "3d": - from training.train_3d import main as train_fn + elif model == "3d_text": + from training.train_3d_text import main as train_fn + elif model == "3d_rotate": + from training.train_3d_rotate import main as train_fn + elif model == "3d_slider": + from training.train_3d_slider import main as train_fn elif model == "classifier": from training.train_classifier import main as train_fn else: - print(f"未知模型: {model} 可选: normal, math, 3d, classifier") + print(f"未知模型: {model} 可选: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier") sys.exit(1) train_fn() @@ -96,7 +115,14 @@ def cmd_export(args): if args.all: export_all() elif args.model: - _load_and_export(args.model) + # 别名映射 + alias = { + "3d_text": "threed_text", + "3d_rotate": "threed_rotate", + "3d_slider": "threed_slider", + } + name = alias.get(args.model, args.model) + _load_and_export(name) else: print("请指定 --all 或 --model ") sys.exit(1) @@ -137,19 +163,19 @@ def cmd_predict_dir(args): sys.exit(1) print(f"批量识别: {len(images)} 张图片\n") - print(f"{'文件名':<30} {'类型':<8} {'结果':<15} {'耗时(ms)':>8}") - print("-" * 65) + print(f"{'文件名':<30} {'类型':<10} {'结果':<15} {'耗时(ms)':>8}") + print("-" * 67) total_ms = 0.0 for img_path in images: result = pipeline.solve(str(img_path), captcha_type=args.type) total_ms += result["time_ms"] print( - f"{img_path.name:<30} {result['type']:<8} " + f"{img_path.name:<30} {result['type']:<10} " f"{result['result']:<15} {result['time_ms']:>8.1f}" ) - print("-" * 65) + print("-" * 67) print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms") @@ -178,28 +204,43 @@ def main(): # ---- generate ---- p_gen = subparsers.add_parser("generate", help="生成训练数据") - p_gen.add_argument("--type", required=True, help="验证码类型: normal, math, 3d, classifier") + p_gen.add_argument( + "--type", required=True, + help="验证码类型: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier", + ) p_gen.add_argument("--num", type=int, required=True, help="生成数量") # ---- train ---- p_train = subparsers.add_parser("train", help="训练模型") - p_train.add_argument("--model", help="模型名: normal, math, 3d, classifier") + p_train.add_argument( + "--model", + help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier", + ) p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型") # ---- export ---- p_export = subparsers.add_parser("export", help="导出 ONNX 模型") - p_export.add_argument("--model", help="模型名: normal, math, 3d, classifier, threed") + p_export.add_argument( + "--model", + help="模型名: normal, math, 3d_text, 3d_rotate, 3d_slider, classifier", + ) p_export.add_argument("--all", action="store_true", help="导出全部模型") # ---- predict ---- p_pred = subparsers.add_parser("predict", help="识别单张验证码") p_pred.add_argument("image", help="图片路径") - p_pred.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d") + p_pred.add_argument( + "--type", default=None, + help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider", + ) # ---- predict-dir ---- p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码") p_pdir.add_argument("directory", help="图片目录路径") - p_pdir.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d") + p_pdir.add_argument( + "--type", default=None, + help="指定类型跳过分类: normal, math, 3d_text, 3d_rotate, 3d_slider", + ) # ---- serve ---- p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务") diff --git a/config.py b/config.py index 7f7d280..4a18afa 100644 --- a/config.py +++ b/config.py @@ -23,12 +23,16 @@ CLASSIFIER_DIR = DATA_DIR / "classifier" # 合成数据子目录 SYNTHETIC_NORMAL_DIR = SYNTHETIC_DIR / "normal" SYNTHETIC_MATH_DIR = SYNTHETIC_DIR / "math" -SYNTHETIC_3D_DIR = SYNTHETIC_DIR / "3d" +SYNTHETIC_3D_TEXT_DIR = SYNTHETIC_DIR / "3d_text" +SYNTHETIC_3D_ROTATE_DIR = SYNTHETIC_DIR / "3d_rotate" +SYNTHETIC_3D_SLIDER_DIR = SYNTHETIC_DIR / "3d_slider" # 真实数据子目录 REAL_NORMAL_DIR = REAL_DIR / "normal" REAL_MATH_DIR = REAL_DIR / "math" -REAL_3D_DIR = REAL_DIR / "3d" +REAL_3D_TEXT_DIR = REAL_DIR / "3d_text" +REAL_3D_ROTATE_DIR = REAL_DIR / "3d_rotate" +REAL_3D_SLIDER_DIR = REAL_DIR / "3d_slider" # ============================================================ # 模型输出目录 @@ -38,8 +42,10 @@ ONNX_DIR = PROJECT_ROOT / "onnx_models" # 确保关键目录存在 for _dir in [ - SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR, - REAL_NORMAL_DIR, REAL_MATH_DIR, REAL_3D_DIR, + SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, + SYNTHETIC_3D_TEXT_DIR, SYNTHETIC_3D_ROTATE_DIR, SYNTHETIC_3D_SLIDER_DIR, + REAL_NORMAL_DIR, REAL_MATH_DIR, + REAL_3D_TEXT_DIR, REAL_3D_ROTATE_DIR, REAL_3D_SLIDER_DIR, CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR, ]: _dir.mkdir(parents=True, exist_ok=True) @@ -57,7 +63,7 @@ MATH_CHARS = "0123456789+-×÷=?" THREED_CHARS = "23456789ABCDEFGHJKMNPQRSTUVWXYZ" # 验证码类型列表 (调度分类器输出) -CAPTCHA_TYPES = ["normal", "math", "3d"] +CAPTCHA_TYPES = ["normal", "math", "3d_text", "3d_rotate", "3d_slider"] NUM_CAPTCHA_TYPES = len(CAPTCHA_TYPES) # ============================================================ @@ -67,7 +73,9 @@ IMAGE_SIZE = { "classifier": (64, 128), # 调度分类器输入 "normal": (40, 120), # 普通字符识别 "math": (40, 160), # 算式识别 (更宽) - "3d": (60, 160), # 3D 立体识别 + "3d_text": (60, 160), # 3D 立体文字识别 + "3d_rotate": (80, 80), # 3D 旋转角度回归 (正方形) + "3d_slider": (80, 240), # 3D 滑块偏移回归 } # ============================================================ @@ -91,11 +99,25 @@ GENERATE_CONFIG = { "rotation_range": (-10, 10), "noise_line_range": (2, 4), }, - "3d": { + "3d_text": { "char_count_range": (4, 5), "image_size": (160, 60), # 生成图片尺寸 (W, H) "shadow_offset": (3, 3), # 阴影偏移 "perspective_intensity": 0.3, # 透视变换强度 + "rotation_range": (-20, 20), # 字符旋转角度 + }, + "3d_rotate": { + "image_size": (80, 80), # 生成图片尺寸 (W, H) + "disc_radius": 35, # 圆盘半径 + "marker_size": 8, # 方向标记大小 + "bg_color_range": (200, 240), # 背景色范围 + }, + "3d_slider": { + "image_size": (240, 80), # 生成图片尺寸 (W, H) + "puzzle_size": (40, 40), # 拼图块大小 (W, H) + "gap_x_range": (50, 200), # 缺口 x 坐标范围 + "piece_left_margin": 5, # 拼图块左侧留白 + "bg_noise_intensity": 30, # 背景纹理噪声强度 }, } @@ -108,7 +130,7 @@ TRAIN_CONFIG = { "batch_size": 128, "lr": 1e-3, "scheduler": "cosine", - "synthetic_samples": 30000, # 每类 10000 + "synthetic_samples": 50000, # 每类 10000 × 5 类 "val_split": 0.1, # 验证集比例 }, "normal": { @@ -129,7 +151,7 @@ TRAIN_CONFIG = { "loss": "CTCLoss", "val_split": 0.1, }, - "threed": { + "3d_text": { "epochs": 80, "batch_size": 64, "lr": 5e-4, @@ -138,6 +160,24 @@ TRAIN_CONFIG = { "loss": "CTCLoss", "val_split": 0.1, }, + "3d_rotate": { + "epochs": 60, + "batch_size": 128, + "lr": 1e-3, + "scheduler": "cosine", + "synthetic_samples": 60000, + "loss": "SmoothL1", + "val_split": 0.1, + }, + "3d_slider": { + "epochs": 60, + "batch_size": 128, + "lr": 1e-3, + "scheduler": "cosine", + "synthetic_samples": 60000, + "loss": "SmoothL1", + "val_split": 0.1, + }, } # ============================================================ @@ -163,6 +203,14 @@ ONNX_CONFIG = { "dynamic_batch": True, # 支持动态 batch size } +# ============================================================ +# 回归模型标签范围 +# ============================================================ +REGRESSION_RANGE = { + "3d_rotate": (0, 360), # 旋转角度 0-359° + "3d_slider": (10, 200), # 滑块 x 偏移 (像素) +} + # ============================================================ # 推理配置 # ============================================================ diff --git a/generators/__init__.py b/generators/__init__.py index d59cf50..dab6a9c 100644 --- a/generators/__init__.py +++ b/generators/__init__.py @@ -1,20 +1,26 @@ """ 数据生成器包 -提供三种验证码类型的数据生成器: +提供五种验证码类型的数据生成器: - NormalCaptchaGenerator: 普通字符验证码 - MathCaptchaGenerator: 算式验证码 -- ThreeDCaptchaGenerator: 3D 立体验证码 +- ThreeDCaptchaGenerator: 3D 立体文字验证码 +- ThreeDRotateGenerator: 3D 旋转验证码 +- ThreeDSliderGenerator: 3D 滑块验证码 """ from generators.base import BaseCaptchaGenerator from generators.normal_gen import NormalCaptchaGenerator from generators.math_gen import MathCaptchaGenerator from generators.threed_gen import ThreeDCaptchaGenerator +from generators.threed_rotate_gen import ThreeDRotateGenerator +from generators.threed_slider_gen import ThreeDSliderGenerator __all__ = [ "BaseCaptchaGenerator", "NormalCaptchaGenerator", "MathCaptchaGenerator", "ThreeDCaptchaGenerator", + "ThreeDRotateGenerator", + "ThreeDSliderGenerator", ] diff --git a/generators/threed_gen.py b/generators/threed_gen.py index 18f2924..615cb4d 100644 --- a/generators/threed_gen.py +++ b/generators/threed_gen.py @@ -45,7 +45,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator): from config import RANDOM_SEED super().__init__(seed=seed if seed is not None else RANDOM_SEED) - self.cfg = GENERATE_CONFIG["3d"] + self.cfg = GENERATE_CONFIG["3d_text"] self.chars = THREED_CHARS self.width, self.height = self.cfg["image_size"] @@ -154,7 +154,7 @@ class ThreeDCaptchaGenerator(BaseCaptchaGenerator): char_img = self._perspective_transform(char_img, rng) # 随机旋转 - angle = rng.randint(-20, 20) + angle = rng.randint(*self.cfg["rotation_range"]) char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True) # 粘贴到画布 diff --git a/generators/threed_rotate_gen.py b/generators/threed_rotate_gen.py new file mode 100644 index 0000000..9599753 --- /dev/null +++ b/generators/threed_rotate_gen.py @@ -0,0 +1,122 @@ +""" +3D 旋转验证码生成器 + +生成旋转验证码:圆盘上绘制字符 + 方向标记,随机旋转 0-359°。 +用户需将圆盘旋转到正确角度。 + +标签 = 旋转角度(整数) +文件名格式: {angle}_{index:06d}.png +""" + +import random + +from PIL import Image, ImageDraw, ImageFilter, ImageFont + +from config import GENERATE_CONFIG, THREED_CHARS +from generators.base import BaseCaptchaGenerator + +_FONT_PATHS = [ + "/usr/share/fonts/TTF/DejaVuSans-Bold.ttf", + "/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf", + "/usr/share/fonts/liberation/LiberationSans-Bold.ttf", + "/usr/share/fonts/liberation/LiberationSerif-Bold.ttf", + "/usr/share/fonts/gnu-free/FreeSansBold.otf", +] + +_DISC_COLORS = [ + (180, 200, 220), + (200, 220, 200), + (220, 200, 190), + (200, 200, 220), + (210, 210, 200), +] + + +class ThreeDRotateGenerator(BaseCaptchaGenerator): + """3D 旋转验证码生成器。""" + + def __init__(self, seed: int | None = None): + from config import RANDOM_SEED + super().__init__(seed=seed if seed is not None else RANDOM_SEED) + + self.cfg = GENERATE_CONFIG["3d_rotate"] + self.chars = THREED_CHARS + self.width, self.height = self.cfg["image_size"] + + self._fonts: list[str] = [] + for p in _FONT_PATHS: + try: + ImageFont.truetype(p, 20) + self._fonts.append(p) + except OSError: + continue + if not self._fonts: + raise RuntimeError("未找到任何可用字体,无法生成验证码") + + def generate(self, text: str | None = None) -> tuple[Image.Image, str]: + rng = self.rng + + # 随机旋转角度 0-359 + angle = rng.randint(0, 359) + + if text is None: + text = str(angle) + + # 1. 背景 + bg_val = rng.randint(*self.cfg["bg_color_range"]) + img = Image.new("RGB", (self.width, self.height), (bg_val, bg_val, bg_val)) + draw = ImageDraw.Draw(img) + + # 2. 绘制圆盘 + cx, cy = self.width // 2, self.height // 2 + r = self.cfg["disc_radius"] + disc_color = rng.choice(_DISC_COLORS) + draw.ellipse( + [cx - r, cy - r, cx + r, cy + r], + fill=disc_color, outline=(100, 100, 100), width=2, + ) + + # 3. 在圆盘上绘制字符和方向标记 (未旋转状态) + disc_img = Image.new("RGBA", (r * 2 + 4, r * 2 + 4), (0, 0, 0, 0)) + disc_draw = ImageDraw.Draw(disc_img) + dc = r + 2 # disc center + + # 字符 (圆盘中心) + font_path = rng.choice(self._fonts) + font_size = int(r * 0.6) + font = ImageFont.truetype(font_path, font_size) + ch = rng.choice(self.chars) + bbox = font.getbbox(ch) + tw = bbox[2] - bbox[0] + th = bbox[3] - bbox[1] + disc_draw.text( + (dc - tw // 2 - bbox[0], dc - th // 2 - bbox[1]), + ch, fill=(50, 50, 50, 255), font=font, + ) + + # 方向标记 (三角箭头,指向上方) + ms = self.cfg["marker_size"] + marker_y = dc - r + ms + 2 + disc_draw.polygon( + [(dc, marker_y - ms), (dc - ms // 2, marker_y), (dc + ms // 2, marker_y)], + fill=(220, 60, 60, 255), + ) + + # 4. 旋转圆盘内容 + disc_img = disc_img.rotate(-angle, resample=Image.BICUBIC, expand=False) + + # 5. 粘贴到背景 + paste_x = cx - dc + paste_y = cy - dc + img.paste(disc_img, (paste_x, paste_y), disc_img) + + # 6. 添加少量噪点 + for _ in range(rng.randint(20, 50)): + nx, ny = rng.randint(0, self.width - 1), rng.randint(0, self.height - 1) + nc = tuple(rng.randint(100, 200) for _ in range(3)) + draw.point((nx, ny), fill=nc) + + # 7. 轻微模糊 + img = img.filter(ImageFilter.GaussianBlur(radius=0.5)) + + return img, text diff --git a/generators/threed_slider_gen.py b/generators/threed_slider_gen.py new file mode 100644 index 0000000..364e6d3 --- /dev/null +++ b/generators/threed_slider_gen.py @@ -0,0 +1,113 @@ +""" +3D 滑块验证码生成器 + +生成滑块拼图验证码:纹理背景 + 拼图缺口 + 拼图块在左侧。 +用户需将拼图块滑动到缺口位置。 + +标签 = 缺口 x 坐标偏移(整数) +文件名格式: {offset}_{index:06d}.png +""" + +import random + +from PIL import Image, ImageDraw, ImageFilter + +from config import GENERATE_CONFIG +from generators.base import BaseCaptchaGenerator + + +class ThreeDSliderGenerator(BaseCaptchaGenerator): + """3D 滑块验证码生成器。""" + + def __init__(self, seed: int | None = None): + from config import RANDOM_SEED + super().__init__(seed=seed if seed is not None else RANDOM_SEED) + + self.cfg = GENERATE_CONFIG["3d_slider"] + self.width, self.height = self.cfg["image_size"] + + def generate(self, text: str | None = None) -> tuple[Image.Image, str]: + rng = self.rng + pw, ph = self.cfg["puzzle_size"] + gap_x_lo, gap_x_hi = self.cfg["gap_x_range"] + + # 缺口位置 + gap_x = rng.randint(gap_x_lo, gap_x_hi) + gap_y = rng.randint(10, self.height - ph - 10) + + if text is None: + text = str(gap_x) + + # 1. 生成纹理背景 + img = self._textured_background(rng) + + # 2. 从缺口位置截取拼图块内容 + piece_content = img.crop((gap_x, gap_y, gap_x + pw, gap_y + ph)).copy() + + # 3. 绘制缺口 (半透明灰色区域) + overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) + overlay_draw = ImageDraw.Draw(overlay) + overlay_draw.rectangle( + [gap_x, gap_y, gap_x + pw, gap_y + ph], + fill=(80, 80, 80, 160), + outline=(60, 60, 60, 200), + width=2, + ) + img = img.convert("RGBA") + img = Image.alpha_composite(img, overlay) + img = img.convert("RGB") + + # 4. 绘制拼图块在左侧 + piece_x = self.cfg["piece_left_margin"] + piece_img = Image.new("RGBA", (pw + 4, ph + 4), (0, 0, 0, 0)) + piece_draw = ImageDraw.Draw(piece_img) + # 阴影 + piece_draw.rectangle([2, 2, pw + 3, ph + 3], fill=(0, 0, 0, 80)) + # 内容 + piece_img.paste(piece_content, (0, 0)) + # 边框 + piece_draw.rectangle([0, 0, pw - 1, ph - 1], outline=(255, 255, 255, 200), width=2) + + img_rgba = img.convert("RGBA") + img_rgba.paste(piece_img, (piece_x, gap_y), piece_img) + img = img_rgba.convert("RGB") + + # 5. 轻微模糊 + img = img.filter(ImageFilter.GaussianBlur(radius=0.3)) + + return img, text + + def _textured_background(self, rng: random.Random) -> Image.Image: + """生成带纹理的彩色背景。""" + img = Image.new("RGB", (self.width, self.height)) + draw = ImageDraw.Draw(img) + + # 渐变底色 + base_r, base_g, base_b = rng.randint(100, 180), rng.randint(100, 180), rng.randint(100, 180) + for y in range(self.height): + ratio = y / max(self.height - 1, 1) + r = int(base_r + 30 * ratio) + g = int(base_g - 20 * ratio) + b = int(base_b + 10 * ratio) + draw.line([(0, y), (self.width, y)], fill=(r, g, b)) + + # 添加纹理噪声 + noise_intensity = self.cfg["bg_noise_intensity"] + for _ in range(self.width * self.height // 8): + x = rng.randint(0, self.width - 1) + y = rng.randint(0, self.height - 1) + pixel = img.getpixel((x, y)) + noise = tuple( + max(0, min(255, c + rng.randint(-noise_intensity, noise_intensity))) + for c in pixel + ) + draw.point((x, y), fill=noise) + + # 随机色块 (模拟图案) + for _ in range(rng.randint(3, 6)): + x1, y1 = rng.randint(0, self.width - 30), rng.randint(0, self.height - 20) + x2, y2 = x1 + rng.randint(15, 40), y1 + rng.randint(10, 25) + color = tuple(rng.randint(60, 220) for _ in range(3)) + draw.rectangle([x1, y1, x2, y2], fill=color) + + return img diff --git a/inference/export_onnx.py b/inference/export_onnx.py index 838652e..7701444 100644 --- a/inference/export_onnx.py +++ b/inference/export_onnx.py @@ -17,10 +17,12 @@ from config import ( MATH_CHARS, THREED_CHARS, NUM_CAPTCHA_TYPES, + REGRESSION_RANGE, ) from models.classifier import CaptchaClassifier from models.lite_crnn import LiteCRNN from models.threed_cnn import ThreeDCNN +from models.regression_cnn import RegressionCNN def export_model( @@ -34,7 +36,7 @@ def export_model( Args: model: 已加载权重的 PyTorch 模型 - model_name: 模型名 (classifier / normal / math / threed) + model_name: 模型名 (classifier / normal / math / threed_text / threed_rotate / threed_slider) input_shape: 输入形状 (C, H, W) onnx_dir: 输出目录 (默认使用 config.ONNX_DIR) """ @@ -50,7 +52,7 @@ def export_model( dummy = torch.randn(1, *input_shape) # 分类器和识别器的 dynamic_axes 不同 - if model_name == "classifier": + if model_name == "classifier" or model_name in ("threed_rotate", "threed_slider"): dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} else: # CTC 模型: output shape = (T, B, C) @@ -78,7 +80,8 @@ def _load_and_export(model_name: str): return ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) - print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}") + acc_info = ckpt.get('best_acc') or ckpt.get('best_tol_acc', '?') + print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={acc_info}") if model_name == "classifier": model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES) @@ -94,11 +97,19 @@ def _load_and_export(model_name: str): h, w = IMAGE_SIZE["math"] model = LiteCRNN(chars=chars, img_h=h, img_w=w) input_shape = (1, h, w) - elif model_name == "threed": + elif model_name == "threed_text": chars = ckpt.get("chars", THREED_CHARS) - h, w = IMAGE_SIZE["3d"] + h, w = IMAGE_SIZE["3d_text"] model = ThreeDCNN(chars=chars, img_h=h, img_w=w) input_shape = (1, h, w) + elif model_name == "threed_rotate": + h, w = IMAGE_SIZE["3d_rotate"] + model = RegressionCNN(img_h=h, img_w=w) + input_shape = (1, h, w) + elif model_name == "threed_slider": + h, w = IMAGE_SIZE["3d_slider"] + model = RegressionCNN(img_h=h, img_w=w) + input_shape = (1, h, w) else: print(f"[错误] 未知模型: {model_name}") return @@ -108,11 +119,11 @@ def _load_and_export(model_name: str): def export_all(): - """依次导出 classifier, normal, math, threed 四个模型。""" + """依次导出 classifier, normal, math, threed_text, threed_rotate, threed_slider 六个模型。""" print("=" * 50) print("导出全部 ONNX 模型") print("=" * 50) - for name in ["classifier", "normal", "math", "threed"]: + for name in ["classifier", "normal", "math", "threed_text", "threed_rotate", "threed_slider"]: _load_and_export(name) print("\n全部导出完成。") diff --git a/inference/pipeline.py b/inference/pipeline.py index c537718..aa781b2 100644 --- a/inference/pipeline.py +++ b/inference/pipeline.py @@ -4,7 +4,7 @@ 加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。 推理流程: - 输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出 + 输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 / 回归缩放 → 后处理 → 输出 对算式类型,解码后还会调用 math_eval 计算结果。 """ @@ -23,6 +23,7 @@ from config import ( NORMAL_CHARS, MATH_CHARS, THREED_CHARS, + REGRESSION_RANGE, ) from inference.math_eval import eval_captcha_math @@ -59,19 +60,24 @@ class CaptchaPipeline: self.mean = INFERENCE_CONFIG["normalize_mean"] self.std = INFERENCE_CONFIG["normalize_std"] - # 字符集映射 + # 字符集映射 (仅 CTC 模型需要) self._chars = { "normal": NORMAL_CHARS, "math": MATH_CHARS, - "3d": THREED_CHARS, + "3d_text": THREED_CHARS, } + # 回归模型类型 + self._regression_types = {"3d_rotate", "3d_slider"} + # 专家模型名 → ONNX 文件名 self._model_files = { "classifier": "classifier.onnx", "normal": "normal.onnx", "math": "math.onnx", - "3d": "threed.onnx", + "3d_text": "threed_text.onnx", + "3d_rotate": "threed_rotate.onnx", + "3d_slider": "threed_slider.onnx", } # 加载所有可用模型 @@ -116,7 +122,7 @@ class CaptchaPipeline: def classify(self, image: Image.Image) -> str: """ - 调度分类,返回类型名: 'normal' / 'math' / '3d'。 + 调度分类,返回类型名。 Raises: RuntimeError: 分类器模型未加载 @@ -141,7 +147,7 @@ class CaptchaPipeline: Args: image: PIL.Image 或文件路径 (str/Path) 或 bytes - captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d') + captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d_text'/'3d_rotate'/'3d_slider') Returns: dict: { @@ -166,24 +172,34 @@ class CaptchaPipeline: f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型" ) - size_key = captcha_type # "normal"/"math"/"3d" + size_key = captcha_type inp = self.preprocess(img, IMAGE_SIZE[size_key]) session = self._sessions[captcha_type] input_name = session.get_inputs()[0].name - logits = session.run(None, {input_name: inp})[0] # (T, 1, C) - # 4. CTC 贪心解码 - chars = self._chars[captcha_type] - raw_text = self._ctc_greedy_decode(logits, chars) - - # 5. 后处理 - if captcha_type == "math": - try: - result = eval_captcha_math(raw_text) - except ValueError: - result = raw_text # 解析失败则返回原始文本 + # 4. 分支: CTC 解码 vs 回归 + if captcha_type in self._regression_types: + # 回归模型: 输出 (batch, 1) sigmoid 值 + output = session.run(None, {input_name: inp})[0] # (1, 1) + sigmoid_val = float(output[0, 0]) + lo, hi = REGRESSION_RANGE[captcha_type] + real_val = sigmoid_val * (hi - lo) + lo + raw_text = f"{real_val:.1f}" + result = str(int(round(real_val))) else: - result = raw_text + # CTC 模型 + logits = session.run(None, {input_name: inp})[0] # (T, 1, C) + chars = self._chars[captcha_type] + raw_text = self._ctc_greedy_decode(logits, chars) + + # 后处理 + if captcha_type == "math": + try: + result = eval_captcha_math(raw_text) + except ValueError: + result = raw_text + else: + result = raw_text elapsed = (time.perf_counter() - t0) * 1000 diff --git a/models/__init__.py b/models/__init__.py index cc521af..9a8efc8 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,18 +1,21 @@ """ 模型定义包 -提供三种模型: +提供四种模型: - CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB) - LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB) -- ThreeDCNN: 3D 验证码专用模型 (ResNet-lite + BiLSTM, < 5MB) +- ThreeDCNN: 3D 文字验证码专用模型 (ResNet-lite + BiLSTM, < 5MB) +- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB) """ from models.classifier import CaptchaClassifier from models.lite_crnn import LiteCRNN from models.threed_cnn import ThreeDCNN +from models.regression_cnn import RegressionCNN __all__ = [ "CaptchaClassifier", "LiteCRNN", "ThreeDCNN", + "RegressionCNN", ] diff --git a/models/regression_cnn.py b/models/regression_cnn.py new file mode 100644 index 0000000..dff7034 --- /dev/null +++ b/models/regression_cnn.py @@ -0,0 +1,86 @@ +""" +回归 CNN 模型 + +3d_rotate 和 3d_slider 共用的回归模型。 +输出 sigmoid 归一化到 [0,1],推理时按 label_range 缩放回原始范围。 + +架构: + Conv(1→32) + BN + ReLU + Pool + Conv(32→64) + BN + ReLU + Pool + Conv(64→128) + BN + ReLU + Pool + Conv(128→128) + BN + ReLU + Pool + AdaptiveAvgPool2d(1) → FC(128→64) → ReLU → Dropout(0.2) → FC(64→1) → Sigmoid + +约 250K 参数,~1MB。 +""" + +import torch +import torch.nn as nn + + +class RegressionCNN(nn.Module): + """ + 轻量回归 CNN,用于 3d_rotate (角度) 和 3d_slider (偏移) 预测。 + + 输出 [0, 1] 范围的 sigmoid 值,需要按 label_range 缩放到实际范围。 + """ + + def __init__(self, img_h: int = 80, img_w: int = 80): + """ + Args: + img_h: 输入图片高度 + img_w: 输入图片宽度 + """ + super().__init__() + self.img_h = img_h + self.img_w = img_w + + self.features = nn.Sequential( + # block 1: 1 → 32, H/2, W/2 + nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 2: 32 → 64, H/4, W/4 + nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 3: 64 → 128, H/8, W/8 + nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + # block 4: 128 → 128, H/16, W/16 + nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + ) + + self.pool = nn.AdaptiveAvgPool2d(1) + + self.regressor = nn.Sequential( + nn.Linear(128, 64), + nn.ReLU(inplace=True), + nn.Dropout(0.2), + nn.Linear(64, 1), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (batch, 1, H, W) 灰度图 + + Returns: + output: (batch, 1) sigmoid 输出 [0, 1] + """ + feat = self.features(x) + feat = self.pool(feat) # (B, 128, 1, 1) + feat = feat.flatten(1) # (B, 128) + out = self.regressor(feat) # (B, 1) + return out diff --git a/training/__init__.py b/training/__init__.py index 8d05557..232f3b0 100644 --- a/training/__init__.py +++ b/training/__init__.py @@ -1,10 +1,13 @@ """ 训练脚本包 -- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类 -- train_utils.py: CTC 训练通用逻辑 (train_ctc_model) -- train_normal.py: 训练普通字符识别 (LiteCRNN - normal) -- train_math.py: 训练算式识别 (LiteCRNN - math) -- train_3d.py: 训练 3D 立体识别 (ThreeDCNN) -- train_classifier.py: 训练调度分类器 (CaptchaClassifier) +- dataset.py: CRNNDataset / CaptchaDataset / RegressionDataset 通用数据集类 +- train_utils.py: CTC 训练通用逻辑 (train_ctc_model) +- train_regression_utils.py: 回归训练通用逻辑 (train_regression_model) +- train_normal.py: 训练普通字符识别 (LiteCRNN - normal) +- train_math.py: 训练算式识别 (LiteCRNN - math) +- train_3d_text.py: 训练 3D 立体文字识别 (ThreeDCNN) +- train_3d_rotate.py: 训练 3D 旋转回归 (RegressionCNN) +- train_3d_slider.py: 训练 3D 滑块回归 (RegressionCNN) +- train_classifier.py: 训练调度分类器 (CaptchaClassifier) """ diff --git a/training/dataset.py b/training/dataset.py index ef44790..becf24e 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -1,16 +1,19 @@ """ 通用 Dataset 类 -提供两种数据集: -- CaptchaDataset: 用于分类器训练 (图片 → 类别标签) -- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码) +提供三种数据集: +- CaptchaDataset: 用于分类器训练 (图片 → 类别标签) +- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码) +- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1]) 文件名格式约定: {label}_{任意}.png - 分类器: label 可为任意字符,所在子目录名即为类别 - 识别器: label 即标注内容 (如 "A3B8" 或 "3+8") + - 回归器: label 为数值 (如 "135" 或 "87") """ import os +import warnings from pathlib import Path from PIL import Image @@ -98,7 +101,15 @@ class CRNNDataset(Dataset): img = self.transform(img) # 编码标签为整数序列 - target = [self.char_to_idx[c] for c in label if c in self.char_to_idx] + target = [] + for c in label: + if c in self.char_to_idx: + target.append(self.char_to_idx[c]) + else: + warnings.warn( + f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})", + stacklevel=2, + ) return img, target, label @staticmethod @@ -119,7 +130,7 @@ class CaptchaDataset(Dataset): """ 分类器训练数据集。 - 每个子目录名为类别名 (如 "normal", "math", "3d"), + 每个子目录名为类别名 (如 "normal", "math", "3d_text"), 目录内所有 .png 文件属于该类。 """ @@ -157,3 +168,59 @@ class CaptchaDataset(Dataset): if self.transform: img = self.transform(img) return img, label + + +# ============================================================ +# 回归模型用数据集 +# ============================================================ +class RegressionDataset(Dataset): + """ + 回归模型数据集 (3d_rotate / 3d_slider)。 + + 从目录中读取 {value}_{xxx}.png 文件, + 将 value 解析为浮点数并归一化到 [0, 1]。 + """ + + def __init__( + self, + dirs: list[str | Path], + label_range: tuple[float, float], + transform: transforms.Compose | None = None, + ): + """ + Args: + dirs: 数据目录列表 + label_range: (min_val, max_val) 标签原始范围 + transform: 图片预处理/增强 + """ + self.label_range = label_range + self.lo, self.hi = label_range + self.transform = transform + + self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签) + for d in dirs: + d = Path(d) + if not d.exists(): + continue + for f in sorted(d.glob("*.png")): + raw_label = f.stem.rsplit("_", 1)[0] + try: + value = float(raw_label) + except ValueError: + continue + # 归一化到 [0, 1] + norm = (value - self.lo) / max(self.hi - self.lo, 1e-6) + norm = max(0.0, min(1.0, norm)) + self.samples.append((str(f), norm)) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int): + import torch + path, label = self.samples[idx] + img = Image.open(path).convert("RGB") + if self.transform: + img = self.transform(img) + return img, torch.tensor([label], dtype=torch.float32) + diff --git a/training/train_3d_rotate.py b/training/train_3d_rotate.py new file mode 100644 index 0000000..665959c --- /dev/null +++ b/training/train_3d_rotate.py @@ -0,0 +1,38 @@ +""" +训练 3D 旋转验证码回归模型 (RegressionCNN) + +用法: python -m training.train_3d_rotate +""" + +from config import ( + IMAGE_SIZE, + SYNTHETIC_3D_ROTATE_DIR, + REAL_3D_ROTATE_DIR, +) +from generators.threed_rotate_gen import ThreeDRotateGenerator +from models.regression_cnn import RegressionCNN +from training.train_regression_utils import train_regression_model + + +def main(): + img_h, img_w = IMAGE_SIZE["3d_rotate"] + model = RegressionCNN(img_h=img_h, img_w=img_w) + + print("=" * 60) + print("训练 3D 旋转验证码回归模型 (RegressionCNN)") + print(f" 输入尺寸: {img_h}×{img_w}") + print(f" 任务: 预测旋转角度 0-359°") + print("=" * 60) + + train_regression_model( + model_name="threed_rotate", + model=model, + synthetic_dir=SYNTHETIC_3D_ROTATE_DIR, + real_dir=REAL_3D_ROTATE_DIR, + generator_cls=ThreeDRotateGenerator, + config_key="3d_rotate", + ) + + +if __name__ == "__main__": + main() diff --git a/training/train_3d_slider.py b/training/train_3d_slider.py new file mode 100644 index 0000000..1a83025 --- /dev/null +++ b/training/train_3d_slider.py @@ -0,0 +1,38 @@ +""" +训练 3D 滑块验证码回归模型 (RegressionCNN) + +用法: python -m training.train_3d_slider +""" + +from config import ( + IMAGE_SIZE, + SYNTHETIC_3D_SLIDER_DIR, + REAL_3D_SLIDER_DIR, +) +from generators.threed_slider_gen import ThreeDSliderGenerator +from models.regression_cnn import RegressionCNN +from training.train_regression_utils import train_regression_model + + +def main(): + img_h, img_w = IMAGE_SIZE["3d_slider"] + model = RegressionCNN(img_h=img_h, img_w=img_w) + + print("=" * 60) + print("训练 3D 滑块验证码回归模型 (RegressionCNN)") + print(f" 输入尺寸: {img_h}×{img_w}") + print(f" 任务: 预测滑块偏移 x 坐标") + print("=" * 60) + + train_regression_model( + model_name="threed_slider", + model=model, + synthetic_dir=SYNTHETIC_3D_SLIDER_DIR, + real_dir=REAL_3D_SLIDER_DIR, + generator_cls=ThreeDSliderGenerator, + config_key="3d_slider", + ) + + +if __name__ == "__main__": + main() diff --git a/training/train_3d.py b/training/train_3d_text.py similarity index 60% rename from training/train_3d.py rename to training/train_3d_text.py index 8a20cdd..fd86314 100644 --- a/training/train_3d.py +++ b/training/train_3d_text.py @@ -1,14 +1,14 @@ """ -训练 3D 立体验证码识别模型 (ThreeDCNN) +训练 3D 立体文字验证码识别模型 (ThreeDCNN) -用法: python -m training.train_3d +用法: python -m training.train_3d_text """ from config import ( THREED_CHARS, IMAGE_SIZE, - SYNTHETIC_3D_DIR, - REAL_3D_DIR, + SYNTHETIC_3D_TEXT_DIR, + REAL_3D_TEXT_DIR, ) from generators.threed_gen import ThreeDCaptchaGenerator from models.threed_cnn import ThreeDCNN @@ -16,23 +16,23 @@ from training.train_utils import train_ctc_model def main(): - img_h, img_w = IMAGE_SIZE["3d"] + img_h, img_w = IMAGE_SIZE["3d_text"] model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w) print("=" * 60) - print("训练 3D 立体验证码识别模型 (ThreeDCNN)") + print("训练 3D 立体文字验证码识别模型 (ThreeDCNN)") print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)") print(f" 输入尺寸: {img_h}×{img_w}") print("=" * 60) train_ctc_model( - model_name="threed", + model_name="threed_text", model=model, chars=THREED_CHARS, - synthetic_dir=SYNTHETIC_3D_DIR, - real_dir=REAL_3D_DIR, + synthetic_dir=SYNTHETIC_3D_TEXT_DIR, + real_dir=REAL_3D_TEXT_DIR, generator_cls=ThreeDCaptchaGenerator, - config_key="threed", + config_key="3d_text", ) diff --git a/training/train_classifier.py b/training/train_classifier.py index 49bec3e..ad1c277 100644 --- a/training/train_classifier.py +++ b/training/train_classifier.py @@ -1,16 +1,18 @@ """ 训练调度分类器 (CaptchaClassifier) -从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。 +从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d_text / 3d_rotate / 3d_slider。 数据来源: data/classifier/ 目录 (按类型子目录组织) 用法: python -m training.train_classifier """ import os +import random import shutil from pathlib import Path +import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split @@ -24,15 +26,20 @@ from config import ( CLASSIFIER_DIR, SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, - SYNTHETIC_3D_DIR, + SYNTHETIC_3D_TEXT_DIR, + SYNTHETIC_3D_ROTATE_DIR, + SYNTHETIC_3D_SLIDER_DIR, CHECKPOINTS_DIR, ONNX_DIR, ONNX_CONFIG, + RANDOM_SEED, get_device, ) from generators.normal_gen import NormalCaptchaGenerator from generators.math_gen import MathCaptchaGenerator from generators.threed_gen import ThreeDCaptchaGenerator +from generators.threed_rotate_gen import ThreeDRotateGenerator +from generators.threed_slider_gen import ThreeDSliderGenerator from models.classifier import CaptchaClassifier from training.dataset import CaptchaDataset, build_train_transform, build_val_transform @@ -52,7 +59,9 @@ def _prepare_classifier_data(): type_info = [ ("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator), ("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator), - ("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator), + ("3d_text", SYNTHETIC_3D_TEXT_DIR, ThreeDCaptchaGenerator), + ("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator), + ("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator), ] for cls_name, syn_dir, gen_cls in type_info: @@ -95,6 +104,13 @@ def main(): img_h, img_w = IMAGE_SIZE["classifier"] device = get_device() + # 设置随机种子 + random.seed(RANDOM_SEED) + np.random.seed(RANDOM_SEED) + torch.manual_seed(RANDOM_SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(RANDOM_SEED) + print("=" * 60) print("训练调度分类器 (CaptchaClassifier)") print(f" 类别: {CAPTCHA_TYPES}") @@ -128,11 +144,11 @@ def main(): train_loader = DataLoader( train_ds, batch_size=cfg["batch_size"], shuffle=True, - num_workers=2, pin_memory=True, + num_workers=0, pin_memory=True, ) val_loader = DataLoader( val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, - num_workers=2, pin_memory=True, + num_workers=0, pin_memory=True, ) print(f"[数据] 训练: {train_size} 验证: {val_size}") diff --git a/training/train_regression_utils.py b/training/train_regression_utils.py new file mode 100644 index 0000000..d0a2dd8 --- /dev/null +++ b/training/train_regression_utils.py @@ -0,0 +1,264 @@ +""" +回归训练通用逻辑 + +提供 train_regression_model() 函数,被 train_3d_rotate / train_3d_slider 共用。 +职责: +1. 检查合成数据,不存在则自动调用生成器 +2. 构建 RegressionDataset / DataLoader(含真实数据混合) +3. 回归训练循环 + cosine scheduler +4. 输出日志: epoch, loss, MAE, tolerance 准确率 +5. 保存最佳模型到 checkpoints/ +6. 训练结束导出 ONNX +""" + +import random +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, random_split +from tqdm import tqdm + +from config import ( + CHECKPOINTS_DIR, + ONNX_DIR, + ONNX_CONFIG, + TRAIN_CONFIG, + IMAGE_SIZE, + REGRESSION_RANGE, + RANDOM_SEED, + get_device, +) +from training.dataset import RegressionDataset, build_train_transform, build_val_transform + + +def _set_seed(seed: int = RANDOM_SEED): + """设置全局随机种子,保证训练可复现。""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def _circular_smooth_l1(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + 循环距离 SmoothL1 loss,用于角度回归 (处理 0°/360° 边界)。 + pred 和 target 都在 [0, 1] 范围。 + """ + diff = torch.abs(pred - target) + # 循环距离: min(|d|, 1-|d|) + diff = torch.min(diff, 1.0 - diff) + # SmoothL1 + return torch.where( + diff < 1.0 / 360.0, # beta ≈ 1° 归一化 + 0.5 * diff * diff / (1.0 / 360.0), + diff - 0.5 * (1.0 / 360.0), + ).mean() + + +def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float: + """循环 MAE (归一化空间)。""" + diff = np.abs(pred - target) + diff = np.minimum(diff, 1.0 - diff) + return float(np.mean(diff)) + + +def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int): + """导出模型为 ONNX 格式。""" + model.eval() + onnx_path = ONNX_DIR / f"{model_name}.onnx" + dummy = torch.randn(1, 1, img_h, img_w) + torch.onnx.export( + model.cpu(), + dummy, + str(onnx_path), + opset_version=ONNX_CONFIG["opset_version"], + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}} + if ONNX_CONFIG["dynamic_batch"] + else None, + ) + print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)") + + +def train_regression_model( + model_name: str, + model: nn.Module, + synthetic_dir: str | Path, + real_dir: str | Path, + generator_cls, + config_key: str, +): + """ + 通用回归训练流程。 + + Args: + model_name: 模型名称 (用于保存文件: threed_rotate / threed_slider) + model: PyTorch 模型实例 (RegressionCNN) + synthetic_dir: 合成数据目录 + real_dir: 真实数据目录 + generator_cls: 生成器类 (用于自动生成数据) + config_key: TRAIN_CONFIG / REGRESSION_RANGE 中的键名 (3d_rotate / 3d_slider) + """ + cfg = TRAIN_CONFIG[config_key] + img_h, img_w = IMAGE_SIZE[config_key] + label_range = REGRESSION_RANGE[config_key] + lo, hi = label_range + is_circular = config_key == "3d_rotate" + device = get_device() + + # 容差配置 + if config_key == "3d_rotate": + tolerance = 5.0 # ±5° + else: + tolerance = 3.0 # ±3px + + _set_seed() + + # ---- 1. 检查 / 生成合成数据 ---- + syn_path = Path(synthetic_dir) + existing = list(syn_path.glob("*.png")) + if len(existing) < cfg["synthetic_samples"]: + print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...") + gen = generator_cls() + gen.generate_dataset(cfg["synthetic_samples"], str(syn_path)) + else: + print(f"[数据] 合成数据已就绪: {len(existing)} 张") + + # ---- 2. 构建数据集 ---- + data_dirs = [str(syn_path)] + real_path = Path(real_dir) + if real_path.exists() and list(real_path.glob("*.png")): + data_dirs.append(str(real_path)) + print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))} 张") + + train_transform = build_train_transform(img_h, img_w) + val_transform = build_val_transform(img_h, img_w) + + full_dataset = RegressionDataset( + dirs=data_dirs, label_range=label_range, transform=train_transform, + ) + total = len(full_dataset) + val_size = int(total * cfg["val_split"]) + train_size = total - val_size + train_ds, val_ds = random_split(full_dataset, [train_size, val_size]) + + # 验证集使用无增强 transform + val_ds_clean = RegressionDataset( + dirs=data_dirs, label_range=label_range, transform=val_transform, + ) + val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices] + + train_loader = DataLoader( + train_ds, batch_size=cfg["batch_size"], shuffle=True, + num_workers=0, pin_memory=True, + ) + val_loader = DataLoader( + val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, + num_workers=0, pin_memory=True, + ) + + print(f"[数据] 训练: {train_size} 验证: {val_size}") + + # ---- 3. 优化器 / 调度器 / 损失 ---- + model = model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"]) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"]) + + if is_circular: + loss_fn = _circular_smooth_l1 + else: + loss_fn = nn.SmoothL1Loss() + + best_mae = float("inf") + best_tol_acc = 0.0 + ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth" + + # ---- 4. 训练循环 ---- + for epoch in range(1, cfg["epochs"] + 1): + model.train() + total_loss = 0.0 + num_batches = 0 + + pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) + for images, targets in pbar: + images = images.to(device) + targets = targets.to(device) + + preds = model(images) + loss = loss_fn(preds, targets) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) + optimizer.step() + + total_loss += loss.item() + num_batches += 1 + pbar.set_postfix(loss=f"{loss.item():.4f}") + + scheduler.step() + avg_loss = total_loss / max(num_batches, 1) + + # ---- 5. 验证 ---- + model.eval() + all_preds = [] + all_targets = [] + with torch.no_grad(): + for images, targets in val_loader: + images = images.to(device) + preds = model(images) + all_preds.append(preds.cpu().numpy()) + all_targets.append(targets.numpy()) + + all_preds = np.concatenate(all_preds, axis=0).flatten() + all_targets = np.concatenate(all_targets, axis=0).flatten() + + # 缩放回原始范围计算 MAE + preds_real = all_preds * (hi - lo) + lo + targets_real = all_targets * (hi - lo) + lo + + if is_circular: + # 循环 MAE + diff = np.abs(preds_real - targets_real) + diff = np.minimum(diff, (hi - lo) - diff) + mae = float(np.mean(diff)) + within_tol = diff <= tolerance + else: + mae = float(np.mean(np.abs(preds_real - targets_real))) + within_tol = np.abs(preds_real - targets_real) <= tolerance + + tol_acc = float(np.mean(within_tol)) + lr = scheduler.get_last_lr()[0] + + print( + f"Epoch {epoch:3d}/{cfg['epochs']} " + f"loss={avg_loss:.4f} " + f"MAE={mae:.2f} " + f"tol_acc(±{tolerance:.0f})={tol_acc:.4f} " + f"lr={lr:.6f}" + ) + + # ---- 6. 保存最佳模型 (以容差准确率为准) ---- + if tol_acc >= best_tol_acc: + best_tol_acc = tol_acc + best_mae = mae + torch.save({ + "model_state_dict": model.state_dict(), + "label_range": label_range, + "best_mae": best_mae, + "best_tol_acc": best_tol_acc, + "epoch": epoch, + }, ckpt_path) + print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}") + + # ---- 7. 导出 ONNX ---- + print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}") + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) + model.load_state_dict(ckpt["model_state_dict"]) + _export_onnx(model, model_name, img_h, img_w) + + return best_tol_acc diff --git a/training/train_utils.py b/training/train_utils.py index 0a10fbe..7377acc 100644 --- a/training/train_utils.py +++ b/training/train_utils.py @@ -1,7 +1,7 @@ """ CTC 训练通用逻辑 -提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。 +提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d_text 共用。 职责: 1. 检查合成数据,不存在则自动调用生成器 2. 构建 Dataset / DataLoader(含真实数据混合) @@ -12,8 +12,10 @@ CTC 训练通用逻辑 """ import os +import random from pathlib import Path +import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split @@ -25,11 +27,21 @@ from config import ( ONNX_CONFIG, TRAIN_CONFIG, IMAGE_SIZE, + RANDOM_SEED, get_device, ) from training.dataset import CRNNDataset, build_train_transform, build_val_transform +def _set_seed(seed: int = RANDOM_SEED): + """设置全局随机种子,保证训练可复现。""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # ============================================================ # 准确率计算 # ============================================================ @@ -104,9 +116,12 @@ def train_ctc_model( config_key: TRAIN_CONFIG 中的键名 """ cfg = TRAIN_CONFIG[config_key] - img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"] + img_h, img_w = IMAGE_SIZE[config_key] device = get_device() + # 设置随机种子 + _set_seed() + # ---- 1. 检查 / 生成合成数据 ---- syn_path = Path(synthetic_dir) existing = list(syn_path.glob("*.png")) @@ -139,11 +154,11 @@ def train_ctc_model( train_loader = DataLoader( train_ds, batch_size=cfg["batch_size"], shuffle=True, - num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True, + num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True, ) val_loader = DataLoader( val_ds_clean, batch_size=cfg["batch_size"], shuffle=False, - num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True, + num_workers=0, collate_fn=CRNNDataset.collate_fn, pin_memory=True, ) print(f"[数据] 训练: {train_size} 验证: {val_size}") @@ -166,12 +181,11 @@ def train_ctc_model( pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False) for images, targets, target_lengths, _ in pbar: images = images.to(device) - targets = targets.to(device) - target_lengths = target_lengths.to(device) logits = model(images) # (T, B, C) T, B, C = logits.shape - input_lengths = torch.full((B,), T, dtype=torch.int32, device=device) + # cuDNN CTC requires targets/lengths on CPU + input_lengths = torch.full((B,), T, dtype=torch.int32) log_probs = logits.log_softmax(2) loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)