Initialize repository
This commit is contained in:
14
.gitignore
vendored
Normal file
14
.gitignore
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
.idea/
|
||||
.claude/
|
||||
|
||||
data/synthetic/
|
||||
data/classifier/
|
||||
|
||||
checkpoints/
|
||||
onnx_models/
|
||||
|
||||
.DS_Store
|
||||
28
AGENTS.md
Normal file
28
AGENTS.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Repository Guidelines
|
||||
|
||||
## Project Structure & Module Organization
|
||||
Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas, `models/` contains the classifier and expert OCR models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`.
|
||||
|
||||
## Build, Test, and Development Commands
|
||||
Use `uv` for environment and dependency management.
|
||||
|
||||
- `uv sync` installs the base runtime dependencies from `pyproject.toml`.
|
||||
- `uv sync --extra server` installs HTTP service dependencies.
|
||||
- `uv run captcha generate --type normal --num 1000` generates synthetic training data.
|
||||
- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d -> classifier`.
|
||||
- `uv run captcha export --all` exports all trained models to ONNX.
|
||||
- `uv run captcha predict image.png` runs auto-routing inference; add `--type normal` to skip classification.
|
||||
- `uv run captcha predict-dir ./test_images` runs batch inference on a directory.
|
||||
- `uv run captcha serve --port 8080` starts the optional HTTP API when `server.py` is implemented.
|
||||
|
||||
## Coding Style & Naming Conventions
|
||||
Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d`, and `classifier`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding unless the pipeline is intentionally redesigned. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`.
|
||||
|
||||
## Data & Testing Guidelines
|
||||
Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing.
|
||||
|
||||
## Commit & Pull Request Guidelines
|
||||
Git history is not available in this workspace snapshot, so use short imperative commit subjects such as `Add classifier export smoke test`. Keep pull requests focused, describe affected modules, list the commands you ran, and attach sample outputs when prediction behavior changes.
|
||||
|
||||
## Documentation Sync
|
||||
Do not commit large generated datasets unless explicitly required. When a change affects project structure, commands, config, architecture, artifact paths, supported captcha types, or workflow rules, update `AGENTS.md` and `CLAUDE.md` in the same patch.
|
||||
391
CLAUDE.md
Normal file
391
CLAUDE.md
Normal file
@@ -0,0 +1,391 @@
|
||||
# CLAUDE.md - 验证码识别多模型系统 (CaptchaBreaker)
|
||||
|
||||
## 项目概述
|
||||
|
||||
构建一个本地验证码识别系统,采用 **调度模型 + 多专家模型** 的两级架构。调度模型负责分类验证码类型,专家模型负责具体识别。所有模型轻量化设计,最终导出 ONNX 用于部署。
|
||||
|
||||
## 技术栈
|
||||
|
||||
- Python 3.10+
|
||||
- uv (包管理,依赖定义在 pyproject.toml)
|
||||
- PyTorch 2.x (训练)
|
||||
- ONNX + ONNXRuntime (推理部署)
|
||||
- Pillow (图像处理)
|
||||
- FastAPI (可选,提供 HTTP 识别服务)
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
captcha-breaker/
|
||||
├── CLAUDE.md
|
||||
├── pyproject.toml # 项目配置与依赖 (uv 管理)
|
||||
├── config.py # 全局配置 (字符集、图片尺寸、路径等)
|
||||
├── data/
|
||||
│ ├── synthetic/ # 合成训练数据 (自动生成,不入 git)
|
||||
│ │ ├── normal/ # 普通字符型
|
||||
│ │ ├── math/ # 算式型
|
||||
│ │ └── 3d/ # 3D立体型
|
||||
│ ├── real/ # 真实验证码样本 (手动标注)
|
||||
│ │ ├── normal/
|
||||
│ │ ├── math/
|
||||
│ │ └── 3d/
|
||||
│ └── classifier/ # 调度分类器训练数据 (混合各类型)
|
||||
├── generators/
|
||||
│ ├── __init__.py
|
||||
│ ├── base.py # 生成器基类
|
||||
│ ├── normal_gen.py # 普通字符验证码生成器
|
||||
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
|
||||
│ └── threed_gen.py # 3D立体验证码生成器
|
||||
├── models/
|
||||
│ ├── __init__.py
|
||||
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
|
||||
│ ├── classifier.py # 调度分类模型
|
||||
│ └── threed_cnn.py # 3D验证码专用模型 (更深的CNN)
|
||||
├── training/
|
||||
│ ├── __init__.py
|
||||
│ ├── train_classifier.py # 训练调度模型
|
||||
│ ├── train_normal.py # 训练普通字符识别
|
||||
│ ├── train_math.py # 训练算式识别
|
||||
│ ├── train_3d.py # 训练3D识别
|
||||
│ └── dataset.py # 通用 Dataset 类
|
||||
├── inference/
|
||||
│ ├── __init__.py
|
||||
│ ├── pipeline.py # 核心推理流水线 (调度+识别)
|
||||
│ ├── export_onnx.py # PyTorch → ONNX 导出脚本
|
||||
│ └── math_eval.py # 算式计算模块
|
||||
├── checkpoints/ # 训练产出的模型文件
|
||||
│ ├── classifier.pth
|
||||
│ ├── normal.pth
|
||||
│ ├── math.pth
|
||||
│ └── threed.pth
|
||||
├── onnx_models/ # 导出的 ONNX 模型
|
||||
│ ├── classifier.onnx
|
||||
│ ├── normal.onnx
|
||||
│ ├── math.onnx
|
||||
│ └── threed.onnx
|
||||
├── server.py # FastAPI 推理服务 (可选)
|
||||
├── cli.py # 命令行入口
|
||||
└── tests/
|
||||
├── test_generators.py
|
||||
├── test_models.py
|
||||
└── test_pipeline.py
|
||||
```
|
||||
|
||||
## 核心架构设计
|
||||
|
||||
### 推理流水线
|
||||
|
||||
```
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → 后处理 → 输出结果
|
||||
│
|
||||
┌────────┼────────┐
|
||||
▼ ▼ ▼
|
||||
normal math 3d
|
||||
(CRNN) (CRNN) (CNN)
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
"A3B8" "3+8=?"→11 "X9K2"
|
||||
```
|
||||
|
||||
### 调度分类器 (classifier.py)
|
||||
|
||||
- 任务: 图像分类,判断验证码属于哪个类型
|
||||
- 架构: 轻量 CNN,3-4 层卷积 + 全局平均池化 + 全连接
|
||||
- 输入: 灰度图 1x64x128
|
||||
- 输出: softmax 概率分布,类别数 = 验证码类型数
|
||||
- 要求: 准确率 99%+,推理 < 5ms
|
||||
- 模型体积目标: < 500KB
|
||||
|
||||
```python
|
||||
class CaptchaClassifier(nn.Module):
|
||||
"""
|
||||
轻量分类器,几层卷积即可区分不同类型验证码。
|
||||
不同类型验证码视觉差异大(有无运算符、3D效果等),分类很容易。
|
||||
"""
|
||||
def __init__(self, num_types=3):
|
||||
# 4层卷积 + GAP + FC
|
||||
# Conv2d(1,16) -> Conv2d(16,32) -> Conv2d(32,64) -> Conv2d(64,64)
|
||||
# AdaptiveAvgPool2d(1) -> Linear(64, num_types)
|
||||
pass
|
||||
```
|
||||
|
||||
### 普通字符识别专家 (lite_crnn.py - normal 模式)
|
||||
|
||||
- 任务: 识别彩色字符验证码 (数字+字母混合)
|
||||
- 架构: CRNN + CTC
|
||||
- 字符集: `0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ` (36个,包含易混淆字符,按本地配置训练)
|
||||
- 输入: 灰度图 1x40x120
|
||||
- 输出: 字符序列,通过 CTC 贪心解码
|
||||
- 验证码特征: 浅色背景、彩色字符、轻微干扰线、字符有倾斜
|
||||
- 模型体积目标: < 2MB
|
||||
|
||||
### 算式识别专家 (lite_crnn.py - math 模式)
|
||||
|
||||
- 任务: 识别算式验证码并计算结果
|
||||
- 架构: 复用 CRNN + CTC,字符集不同
|
||||
- 字符集: `0123456789+-×÷=?` (数字+运算符)
|
||||
- 输入: 灰度图 1x40x160 (算式通常更宽)
|
||||
- 输出: 识别出算式字符串,然后交给 math_eval.py 计算
|
||||
- 分两步: (1) OCR 识别 → "3+8=?" (2) 正则解析并计算 → 11
|
||||
- 模型体积目标: < 2MB
|
||||
|
||||
```python
|
||||
# math_eval.py 核心逻辑
|
||||
def eval_captcha_math(expr: str) -> str:
|
||||
"""
|
||||
解析并计算验证码算式。
|
||||
支持: 加减乘除,个位到两位数运算。
|
||||
输入: "3+8=?" 或 "12×3=?" 或 "15-7=?"
|
||||
输出: "11" 或 "36" 或 "8"
|
||||
用正则提取数字和运算符,不要用 eval()。
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
### 3D立体识别专家 (threed_cnn.py)
|
||||
|
||||
- 任务: 识别带 3D 透视/阴影效果的验证码
|
||||
- 架构: 更深的 CNN + CRNN,或 ResNet-lite backbone
|
||||
- 输入: 灰度图 1x60x160
|
||||
- 需要更强的特征提取能力来处理透视变形和阴影
|
||||
- 模型体积目标: < 5MB
|
||||
|
||||
## 数据生成器规范
|
||||
|
||||
### 基类 (base.py)
|
||||
|
||||
```python
|
||||
class BaseCaptchaGenerator:
|
||||
def generate(self, text=None) -> tuple[Image.Image, str]:
|
||||
"""生成一张验证码,返回 (图片, 标签文本)"""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_dataset(self, num_samples: int, output_dir: str):
|
||||
"""批量生成,文件名格式: {label}_{index:06d}.png"""
|
||||
pass
|
||||
```
|
||||
|
||||
### 普通字符生成器 (normal_gen.py)
|
||||
|
||||
模拟目标风格:
|
||||
- 浅色随机背景 (RGB 各通道 230-255)
|
||||
- 每个字符随机颜色 (深色: 蓝/红/绿/紫/棕等)
|
||||
- 字符数量: 4-5 个
|
||||
- 字符有 ±15° 随机旋转
|
||||
- 2-5 条浅色干扰线
|
||||
- 少量噪点
|
||||
- 可选轻微高斯模糊
|
||||
|
||||
### 算式生成器 (math_gen.py)
|
||||
|
||||
- 生成形如 `A op B = ?` 的算式图片
|
||||
- A, B 范围: 1-30 的整数
|
||||
- op: +, -, × (除法只生成能整除的)
|
||||
- 确保结果为非负整数
|
||||
- 标签格式: `3+8` (存储算式本身,不存结果)
|
||||
- 视觉风格: 与目标算式验证码一致
|
||||
|
||||
### 3D生成器 (threed_gen.py)
|
||||
|
||||
- 使用 Pillow 的仿射变换模拟 3D 透视
|
||||
- 添加阴影效果
|
||||
- 字符有深度感和倾斜
|
||||
- 标签: 纯字符内容
|
||||
|
||||
## 训练规范
|
||||
|
||||
### 通用训练配置
|
||||
|
||||
```python
|
||||
# config.py 中定义
|
||||
TRAIN_CONFIG = {
|
||||
'classifier': {
|
||||
'epochs': 30,
|
||||
'batch_size': 128,
|
||||
'lr': 1e-3,
|
||||
'scheduler': 'cosine',
|
||||
'synthetic_samples': 30000, # 每类 10000
|
||||
},
|
||||
'normal': {
|
||||
'epochs': 50,
|
||||
'batch_size': 128,
|
||||
'lr': 1e-3,
|
||||
'scheduler': 'cosine',
|
||||
'synthetic_samples': 60000,
|
||||
'loss': 'CTCLoss',
|
||||
},
|
||||
'math': {
|
||||
'epochs': 50,
|
||||
'batch_size': 128,
|
||||
'lr': 1e-3,
|
||||
'scheduler': 'cosine',
|
||||
'synthetic_samples': 60000,
|
||||
'loss': 'CTCLoss',
|
||||
},
|
||||
'threed': {
|
||||
'epochs': 80,
|
||||
'batch_size': 64,
|
||||
'lr': 5e-4,
|
||||
'scheduler': 'cosine',
|
||||
'synthetic_samples': 80000,
|
||||
'loss': 'CTCLoss',
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### 训练脚本要求
|
||||
|
||||
每个训练脚本必须:
|
||||
1. 检查合成数据是否已生成,没有则自动调用生成器
|
||||
2. 支持混合真实数据 (如果 data/real/{type}/ 有文件)
|
||||
3. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing
|
||||
4. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率
|
||||
5. 保存最佳模型到 checkpoints/
|
||||
6. 训练结束自动导出 ONNX 到 onnx_models/
|
||||
|
||||
### 数据增强策略
|
||||
|
||||
```python
|
||||
# 训练时增强
|
||||
train_augment = transforms.Compose([
|
||||
transforms.Grayscale(),
|
||||
transforms.Resize((H, W)),
|
||||
transforms.RandomAffine(degrees=8, translate=(0.05, 0.05), scale=(0.95, 1.05)),
|
||||
transforms.ColorJitter(brightness=0.3, contrast=0.3),
|
||||
transforms.GaussianBlur(3, sigma=(0.1, 0.5)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
transforms.RandomErasing(p=0.15, scale=(0.01, 0.05)),
|
||||
])
|
||||
```
|
||||
|
||||
## 推理流水线 (pipeline.py)
|
||||
|
||||
```python
|
||||
class CaptchaPipeline:
|
||||
"""
|
||||
核心推理流水线。
|
||||
加载调度模型和所有专家模型 (ONNX 格式)。
|
||||
提供统一的 solve(image) 接口。
|
||||
"""
|
||||
|
||||
def __init__(self, models_dir='onnx_models/'):
|
||||
"""
|
||||
初始化加载所有 ONNX 模型。
|
||||
使用 onnxruntime.InferenceSession。
|
||||
"""
|
||||
pass
|
||||
|
||||
def preprocess(self, image: Image.Image, target_size: tuple) -> np.ndarray:
|
||||
"""图片预处理: resize, grayscale, normalize, 转 numpy"""
|
||||
pass
|
||||
|
||||
def classify(self, image: Image.Image) -> str:
|
||||
"""调度分类,返回类型名: 'normal' / 'math' / '3d'"""
|
||||
pass
|
||||
|
||||
def solve(self, image) -> str:
|
||||
"""
|
||||
完整识别流程:
|
||||
1. 分类验证码类型
|
||||
2. 路由到对应专家模型
|
||||
3. 后处理 (算式型需要计算结果)
|
||||
4. 返回最终答案字符串
|
||||
|
||||
image: PIL.Image 或文件路径或 bytes
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
## ONNX 导出 (export_onnx.py)
|
||||
|
||||
```python
|
||||
def export_model(model, model_name, input_shape, onnx_dir='onnx_models/'):
|
||||
"""
|
||||
导出单个模型为 ONNX。
|
||||
- 使用 opset_version=18
|
||||
- 开启 dynamic_axes 支持动态 batch
|
||||
- 导出后用 onnxruntime 验证推理一致性
|
||||
- 可选: onnx 模型简化 (onnxsim)
|
||||
"""
|
||||
pass
|
||||
|
||||
def export_all():
|
||||
"""依次导出 classifier, normal, math, threed 四个模型"""
|
||||
pass
|
||||
```
|
||||
|
||||
## CLI 入口 (cli.py)
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
uv sync # 核心依赖
|
||||
uv sync --extra server # 含 HTTP 服务依赖
|
||||
|
||||
# 生成训练数据
|
||||
uv run python cli.py generate --type normal --num 60000
|
||||
uv run python cli.py generate --type math --num 60000
|
||||
uv run python cli.py generate --type 3d --num 80000
|
||||
uv run python cli.py generate --type classifier --num 30000
|
||||
|
||||
# 训练模型
|
||||
uv run python cli.py train --model classifier
|
||||
uv run python cli.py train --model normal
|
||||
uv run python cli.py train --model math
|
||||
uv run python cli.py train --model 3d
|
||||
uv run python cli.py train --all # 按依赖顺序全部训练
|
||||
|
||||
# 导出 ONNX
|
||||
uv run python cli.py export --all
|
||||
|
||||
# 推理
|
||||
uv run python cli.py predict image.png # 自动分类+识别
|
||||
uv run python cli.py predict image.png --type normal # 跳过分类直接识别
|
||||
uv run python cli.py predict-dir ./test_images/ # 批量识别
|
||||
|
||||
# 启动 HTTP 服务 (需先安装 server 可选依赖)
|
||||
uv run python cli.py serve --port 8080
|
||||
```
|
||||
|
||||
## HTTP 服务 (server.py,可选)
|
||||
|
||||
```python
|
||||
# FastAPI 服务,提供 REST API
|
||||
# POST /solve - 上传图片,返回识别结果
|
||||
# 请求: multipart/form-data,字段名 image
|
||||
# 响应: {"type": "normal", "result": "A3B8", "confidence": 0.95, "time_ms": 45}
|
||||
```
|
||||
|
||||
## 关键约束和注意事项
|
||||
|
||||
1. **所有模型用 float32 训练,导出 ONNX 时不做量化**,先保证精度
|
||||
2. **CTC 解码统一用贪心解码**,不需要 beam search,验证码场景贪心够用
|
||||
3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符,3d 继续使用去混淆字符集
|
||||
4. **算式识别分两步**: 先 OCR 识别字符串,再用规则计算,不要让模型直接输出数值
|
||||
5. **生成器的随机种子**: 生成数据时设置 seed 保证可复现
|
||||
6. **真实数据文件名格式**: `{label}_{任意}.png`,label 部分是标注内容
|
||||
7. **模型保存格式**: PyTorch checkpoint 包含 model_state_dict, chars, best_acc, epoch
|
||||
8. **不使用 GPU 特有功能**,确保 CPU 也能训练和推理 (只是慢一些)
|
||||
9. **类型扩展**: 新增验证码类型时,只需 (1) 加生成器 (2) 加专家模型 (3) 调度器加一个类别重新训练
|
||||
10. **文档同步**: 对项目结构、配置、架构等做出变更时,必须同步更新 CLAUDE.md 中的对应内容,保持文档与代码一致
|
||||
|
||||
## 目标指标
|
||||
|
||||
| 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|
||||
|------|-----------|---------|---------|
|
||||
| 调度分类器 | > 99% | < 5ms | < 500KB |
|
||||
| 普通字符 | > 95% | < 30ms | < 2MB |
|
||||
| 算式识别 | > 93% | < 30ms | < 2MB |
|
||||
| 3D立体 | > 85% | < 50ms | < 5MB |
|
||||
| 全流水线 | - | < 80ms | < 10MB 总计 |
|
||||
|
||||
## 开发顺序
|
||||
|
||||
1. 先实现 config.py 和 generators/
|
||||
2. 实现 models/ 中所有模型定义
|
||||
3. 实现 training/dataset.py 通用数据集类
|
||||
4. 按顺序训练: normal → math → 3d → classifier
|
||||
5. 实现 inference/pipeline.py 和 export_onnx.py
|
||||
6. 实现 cli.py 统一入口
|
||||
7. 可选: server.py HTTP 服务
|
||||
8. 编写 tests/
|
||||
228
cli.py
Normal file
228
cli.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
CaptchaBreaker 命令行入口
|
||||
|
||||
用法:
|
||||
python cli.py generate --type normal --num 60000
|
||||
python cli.py train --model normal
|
||||
python cli.py train --all
|
||||
python cli.py export --all
|
||||
python cli.py predict image.png
|
||||
python cli.py predict image.png --type normal
|
||||
python cli.py predict-dir ./test_images/
|
||||
python cli.py serve --port 8080
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def cmd_generate(args):
|
||||
"""生成训练数据。"""
|
||||
from config import (
|
||||
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
|
||||
CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
|
||||
)
|
||||
from generators import NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator
|
||||
|
||||
gen_map = {
|
||||
"normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR),
|
||||
"math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR),
|
||||
"3d": (ThreeDCaptchaGenerator, SYNTHETIC_3D_DIR),
|
||||
}
|
||||
|
||||
captcha_type = args.type
|
||||
num = args.num
|
||||
|
||||
if captcha_type == "classifier":
|
||||
# 分类器数据: 各类型各生成 num // num_types
|
||||
per_class = num // NUM_CAPTCHA_TYPES
|
||||
print(f"生成分类器训练数据: 每类 {per_class} 张")
|
||||
for cls_name in CAPTCHA_TYPES:
|
||||
gen_cls, out_dir = gen_map[cls_name]
|
||||
cls_dir = CLASSIFIER_DIR / cls_name
|
||||
cls_dir.mkdir(parents=True, exist_ok=True)
|
||||
gen = gen_cls()
|
||||
gen.generate_dataset(per_class, str(cls_dir))
|
||||
elif captcha_type in gen_map:
|
||||
gen_cls, out_dir = gen_map[captcha_type]
|
||||
print(f"生成 {captcha_type} 数据: {num} 张 → {out_dir}")
|
||||
gen = gen_cls()
|
||||
gen.generate_dataset(num, str(out_dir))
|
||||
else:
|
||||
print(f"未知类型: {captcha_type} 可选: normal, math, 3d, classifier")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_train(args):
|
||||
"""训练模型。"""
|
||||
if args.all:
|
||||
# 按依赖顺序: normal → math → 3d → classifier
|
||||
print("按顺序训练全部模型: normal → math → 3d → classifier\n")
|
||||
from training.train_normal import main as train_normal
|
||||
from training.train_math import main as train_math
|
||||
from training.train_3d import main as train_3d
|
||||
from training.train_classifier import main as train_classifier
|
||||
|
||||
train_normal()
|
||||
print("\n")
|
||||
train_math()
|
||||
print("\n")
|
||||
train_3d()
|
||||
print("\n")
|
||||
train_classifier()
|
||||
return
|
||||
|
||||
model = args.model
|
||||
if model == "normal":
|
||||
from training.train_normal import main as train_fn
|
||||
elif model == "math":
|
||||
from training.train_math import main as train_fn
|
||||
elif model == "3d":
|
||||
from training.train_3d import main as train_fn
|
||||
elif model == "classifier":
|
||||
from training.train_classifier import main as train_fn
|
||||
else:
|
||||
print(f"未知模型: {model} 可选: normal, math, 3d, classifier")
|
||||
sys.exit(1)
|
||||
|
||||
train_fn()
|
||||
|
||||
|
||||
def cmd_export(args):
|
||||
"""导出 ONNX 模型。"""
|
||||
from inference.export_onnx import export_all, _load_and_export
|
||||
|
||||
if args.all:
|
||||
export_all()
|
||||
elif args.model:
|
||||
_load_and_export(args.model)
|
||||
else:
|
||||
print("请指定 --all 或 --model <name>")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_predict(args):
|
||||
"""单张图片推理。"""
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
|
||||
image_path = args.image
|
||||
if not Path(image_path).exists():
|
||||
print(f"文件不存在: {image_path}")
|
||||
sys.exit(1)
|
||||
|
||||
pipeline = CaptchaPipeline()
|
||||
result = pipeline.solve(image_path, captcha_type=args.type)
|
||||
|
||||
print(f"文件: {image_path}")
|
||||
print(f"类型: {result['type']}")
|
||||
print(f"识别: {result['raw']}")
|
||||
print(f"结果: {result['result']}")
|
||||
print(f"耗时: {result['time_ms']:.1f} ms")
|
||||
|
||||
|
||||
def cmd_predict_dir(args):
|
||||
"""批量目录推理。"""
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
|
||||
dir_path = Path(args.directory)
|
||||
if not dir_path.is_dir():
|
||||
print(f"目录不存在: {dir_path}")
|
||||
sys.exit(1)
|
||||
|
||||
pipeline = CaptchaPipeline()
|
||||
images = sorted(dir_path.glob("*.png")) + sorted(dir_path.glob("*.jpg"))
|
||||
if not images:
|
||||
print(f"目录中未找到图片: {dir_path}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"批量识别: {len(images)} 张图片\n")
|
||||
print(f"{'文件名':<30} {'类型':<8} {'结果':<15} {'耗时(ms)':>8}")
|
||||
print("-" * 65)
|
||||
|
||||
total_ms = 0.0
|
||||
for img_path in images:
|
||||
result = pipeline.solve(str(img_path), captcha_type=args.type)
|
||||
total_ms += result["time_ms"]
|
||||
print(
|
||||
f"{img_path.name:<30} {result['type']:<8} "
|
||||
f"{result['result']:<15} {result['time_ms']:>8.1f}"
|
||||
)
|
||||
|
||||
print("-" * 65)
|
||||
print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms")
|
||||
|
||||
|
||||
def cmd_serve(args):
|
||||
"""启动 HTTP 服务。"""
|
||||
try:
|
||||
from server import create_app
|
||||
except ImportError:
|
||||
# server.py 尚未实现或缺少依赖
|
||||
print("HTTP 服务需要 FastAPI 和 uvicorn。")
|
||||
print("安装: uv sync --extra server")
|
||||
print("并确保 server.py 已实现。")
|
||||
sys.exit(1)
|
||||
|
||||
import uvicorn
|
||||
app = create_app()
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="captcha-breaker",
|
||||
description="验证码识别多模型系统 - 调度模型 + 多专家模型",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||||
|
||||
# ---- generate ----
|
||||
p_gen = subparsers.add_parser("generate", help="生成训练数据")
|
||||
p_gen.add_argument("--type", required=True, help="验证码类型: normal, math, 3d, classifier")
|
||||
p_gen.add_argument("--num", type=int, required=True, help="生成数量")
|
||||
|
||||
# ---- train ----
|
||||
p_train = subparsers.add_parser("train", help="训练模型")
|
||||
p_train.add_argument("--model", help="模型名: normal, math, 3d, classifier")
|
||||
p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型")
|
||||
|
||||
# ---- export ----
|
||||
p_export = subparsers.add_parser("export", help="导出 ONNX 模型")
|
||||
p_export.add_argument("--model", help="模型名: normal, math, 3d, classifier, threed")
|
||||
p_export.add_argument("--all", action="store_true", help="导出全部模型")
|
||||
|
||||
# ---- predict ----
|
||||
p_pred = subparsers.add_parser("predict", help="识别单张验证码")
|
||||
p_pred.add_argument("image", help="图片路径")
|
||||
p_pred.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
|
||||
|
||||
# ---- predict-dir ----
|
||||
p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码")
|
||||
p_pdir.add_argument("directory", help="图片目录路径")
|
||||
p_pdir.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
|
||||
|
||||
# ---- serve ----
|
||||
p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务")
|
||||
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
|
||||
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
cmd_map = {
|
||||
"generate": cmd_generate,
|
||||
"train": cmd_train,
|
||||
"export": cmd_export,
|
||||
"predict": cmd_predict,
|
||||
"predict-dir": cmd_predict_dir,
|
||||
"serve": cmd_serve,
|
||||
}
|
||||
|
||||
cmd_map[args.command](args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
195
config.py
Normal file
195
config.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
全局配置 - 验证码识别多模型系统 (CaptchaBreaker)
|
||||
|
||||
定义字符集、图片尺寸、路径、训练超参等所有全局常量。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# ============================================================
|
||||
# 项目根目录
|
||||
# ============================================================
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent
|
||||
|
||||
# ============================================================
|
||||
# 数据目录
|
||||
# ============================================================
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
SYNTHETIC_DIR = DATA_DIR / "synthetic"
|
||||
REAL_DIR = DATA_DIR / "real"
|
||||
CLASSIFIER_DIR = DATA_DIR / "classifier"
|
||||
|
||||
# 合成数据子目录
|
||||
SYNTHETIC_NORMAL_DIR = SYNTHETIC_DIR / "normal"
|
||||
SYNTHETIC_MATH_DIR = SYNTHETIC_DIR / "math"
|
||||
SYNTHETIC_3D_DIR = SYNTHETIC_DIR / "3d"
|
||||
|
||||
# 真实数据子目录
|
||||
REAL_NORMAL_DIR = REAL_DIR / "normal"
|
||||
REAL_MATH_DIR = REAL_DIR / "math"
|
||||
REAL_3D_DIR = REAL_DIR / "3d"
|
||||
|
||||
# ============================================================
|
||||
# 模型输出目录
|
||||
# ============================================================
|
||||
CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
|
||||
ONNX_DIR = PROJECT_ROOT / "onnx_models"
|
||||
|
||||
# 确保关键目录存在
|
||||
for _dir in [
|
||||
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
|
||||
REAL_NORMAL_DIR, REAL_MATH_DIR, REAL_3D_DIR,
|
||||
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
|
||||
]:
|
||||
_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ============================================================
|
||||
# 字符集定义
|
||||
# ============================================================
|
||||
# 普通字符验证码: 按当前本地配置保留易混淆字符,覆盖完整数字 + 大写字母
|
||||
NORMAL_CHARS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
# 算式验证码: 数字 + 运算符
|
||||
MATH_CHARS = "0123456789+-×÷=?"
|
||||
|
||||
# 3D 验证码: 继续使用去掉易混淆字符的精简字符集
|
||||
THREED_CHARS = "23456789ABCDEFGHJKMNPQRSTUVWXYZ"
|
||||
|
||||
# 验证码类型列表 (调度分类器输出)
|
||||
CAPTCHA_TYPES = ["normal", "math", "3d"]
|
||||
NUM_CAPTCHA_TYPES = len(CAPTCHA_TYPES)
|
||||
|
||||
# ============================================================
|
||||
# 图片尺寸配置 (H, W)
|
||||
# ============================================================
|
||||
IMAGE_SIZE = {
|
||||
"classifier": (64, 128), # 调度分类器输入
|
||||
"normal": (40, 120), # 普通字符识别
|
||||
"math": (40, 160), # 算式识别 (更宽)
|
||||
"3d": (60, 160), # 3D 立体识别
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 验证码生成参数
|
||||
# ============================================================
|
||||
GENERATE_CONFIG = {
|
||||
"normal": {
|
||||
"char_count_range": (4, 5), # 字符数量: 4-5 个
|
||||
"bg_color_range": (230, 255), # 浅色背景 RGB 各通道
|
||||
"rotation_range": (-15, 15), # 字符旋转角度
|
||||
"noise_line_range": (2, 5), # 干扰线数量
|
||||
"noise_point_num": 100, # 噪点数量
|
||||
"blur_radius": 0.8, # 高斯模糊半径
|
||||
"image_size": (120, 40), # 生成图片尺寸 (W, H)
|
||||
},
|
||||
"math": {
|
||||
"operand_range": (1, 30), # 操作数范围
|
||||
"operators": ["+", "-", "×"], # 支持的运算符 (除法只生成能整除的)
|
||||
"image_size": (160, 40), # 生成图片尺寸 (W, H)
|
||||
"bg_color_range": (230, 255),
|
||||
"rotation_range": (-10, 10),
|
||||
"noise_line_range": (2, 4),
|
||||
},
|
||||
"3d": {
|
||||
"char_count_range": (4, 5),
|
||||
"image_size": (160, 60), # 生成图片尺寸 (W, H)
|
||||
"shadow_offset": (3, 3), # 阴影偏移
|
||||
"perspective_intensity": 0.3, # 透视变换强度
|
||||
},
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 训练配置
|
||||
# ============================================================
|
||||
TRAIN_CONFIG = {
|
||||
"classifier": {
|
||||
"epochs": 30,
|
||||
"batch_size": 128,
|
||||
"lr": 1e-3,
|
||||
"scheduler": "cosine",
|
||||
"synthetic_samples": 30000, # 每类 10000
|
||||
"val_split": 0.1, # 验证集比例
|
||||
},
|
||||
"normal": {
|
||||
"epochs": 50,
|
||||
"batch_size": 128,
|
||||
"lr": 1e-3,
|
||||
"scheduler": "cosine",
|
||||
"synthetic_samples": 60000,
|
||||
"loss": "CTCLoss",
|
||||
"val_split": 0.1,
|
||||
},
|
||||
"math": {
|
||||
"epochs": 50,
|
||||
"batch_size": 128,
|
||||
"lr": 1e-3,
|
||||
"scheduler": "cosine",
|
||||
"synthetic_samples": 60000,
|
||||
"loss": "CTCLoss",
|
||||
"val_split": 0.1,
|
||||
},
|
||||
"threed": {
|
||||
"epochs": 80,
|
||||
"batch_size": 64,
|
||||
"lr": 5e-4,
|
||||
"scheduler": "cosine",
|
||||
"synthetic_samples": 80000,
|
||||
"loss": "CTCLoss",
|
||||
"val_split": 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 数据增强参数 (训练时使用)
|
||||
# ============================================================
|
||||
AUGMENT_CONFIG = {
|
||||
"degrees": 8, # RandomAffine 旋转范围
|
||||
"translate": (0.05, 0.05), # 平移范围
|
||||
"scale": (0.95, 1.05), # 缩放范围
|
||||
"brightness": 0.3, # ColorJitter 亮度
|
||||
"contrast": 0.3, # ColorJitter 对比度
|
||||
"blur_kernel": 3, # GaussianBlur 核大小
|
||||
"blur_sigma": (0.1, 0.5), # GaussianBlur sigma
|
||||
"erasing_prob": 0.15, # RandomErasing 概率
|
||||
"erasing_scale": (0.01, 0.05), # RandomErasing 面积比
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# ONNX 导出配置
|
||||
# ============================================================
|
||||
ONNX_CONFIG = {
|
||||
"opset_version": 18,
|
||||
"dynamic_batch": True, # 支持动态 batch size
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 推理配置
|
||||
# ============================================================
|
||||
INFERENCE_CONFIG = {
|
||||
"default_models_dir": str(ONNX_DIR),
|
||||
"normalize_mean": 0.5,
|
||||
"normalize_std": 0.5,
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 随机种子 (保证数据生成可复现)
|
||||
# ============================================================
|
||||
RANDOM_SEED = 42
|
||||
|
||||
# ============================================================
|
||||
# 设备配置 (优先 GPU,回退 CPU)
|
||||
# 延迟导入 torch,避免仅使用生成器时必须安装 torch
|
||||
# ============================================================
|
||||
def get_device():
|
||||
"""返回可用的 torch 设备,优先 GPU。"""
|
||||
import torch
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# ============================================================
|
||||
# 服务配置 (可选 HTTP 服务)
|
||||
# ============================================================
|
||||
SERVER_CONFIG = {
|
||||
"host": "0.0.0.0",
|
||||
"port": 8080,
|
||||
}
|
||||
0
data/real/3d/.gitkeep
Normal file
0
data/real/3d/.gitkeep
Normal file
0
data/real/math/.gitkeep
Normal file
0
data/real/math/.gitkeep
Normal file
0
data/real/normal/.gitkeep
Normal file
0
data/real/normal/.gitkeep
Normal file
20
generators/__init__.py
Normal file
20
generators/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
数据生成器包
|
||||
|
||||
提供三种验证码类型的数据生成器:
|
||||
- NormalCaptchaGenerator: 普通字符验证码
|
||||
- MathCaptchaGenerator: 算式验证码
|
||||
- ThreeDCaptchaGenerator: 3D 立体验证码
|
||||
"""
|
||||
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
from generators.normal_gen import NormalCaptchaGenerator
|
||||
from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
|
||||
__all__ = [
|
||||
"BaseCaptchaGenerator",
|
||||
"NormalCaptchaGenerator",
|
||||
"MathCaptchaGenerator",
|
||||
"ThreeDCaptchaGenerator",
|
||||
]
|
||||
61
generators/base.py
Normal file
61
generators/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
验证码生成器基类
|
||||
|
||||
所有验证码生成器继承此基类,实现 generate() 方法。
|
||||
基类提供通用的 generate_dataset() 批量生成能力。
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import RANDOM_SEED
|
||||
|
||||
|
||||
class BaseCaptchaGenerator:
|
||||
"""验证码生成器基类。"""
|
||||
|
||||
def __init__(self, seed: int = RANDOM_SEED):
|
||||
"""
|
||||
初始化生成器。
|
||||
|
||||
Args:
|
||||
seed: 随机种子,保证数据生成可复现。
|
||||
"""
|
||||
self.seed = seed
|
||||
self.rng = random.Random(seed)
|
||||
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
"""
|
||||
生成一张验证码图片。
|
||||
|
||||
Args:
|
||||
text: 指定标签文本。为 None 时随机生成。
|
||||
|
||||
Returns:
|
||||
(图片, 标签文本)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_dataset(self, num_samples: int, output_dir: str) -> None:
|
||||
"""
|
||||
批量生成验证码数据集。
|
||||
|
||||
文件名格式: {label}_{index:06d}.png
|
||||
|
||||
Args:
|
||||
num_samples: 生成数量。
|
||||
output_dir: 输出目录路径。
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 重置随机种子,保证每次批量生成结果一致
|
||||
self.rng = random.Random(self.seed)
|
||||
|
||||
for i in tqdm(range(num_samples), desc=f"Generating → {output_path.name}"):
|
||||
img, label = self.generate()
|
||||
filename = f"{label}_{i:06d}.png"
|
||||
img.save(output_path / filename)
|
||||
186
generators/math_gen.py
Normal file
186
generators/math_gen.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
算式验证码生成器
|
||||
|
||||
生成形如 A op B = ? 的算式图片:
|
||||
- A, B 范围: 1-30 的整数
|
||||
- op: +, -, × (除法只生成能整除的)
|
||||
- 确保结果为非负整数
|
||||
- 标签格式: "3+8" (存储算式本身,不存结果)
|
||||
- 视觉风格: 浅色背景、深色字符、干扰线
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont
|
||||
|
||||
from config import GENERATE_CONFIG
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
|
||||
# 字体
|
||||
_FONT_PATHS = [
|
||||
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
|
||||
"/usr/share/fonts/TTF/DejaVuSansMono-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationMono-Bold.ttf",
|
||||
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
|
||||
]
|
||||
|
||||
# 深色调色板
|
||||
_DARK_COLORS = [
|
||||
(0, 0, 180),
|
||||
(180, 0, 0),
|
||||
(0, 130, 0),
|
||||
(130, 0, 130),
|
||||
(120, 60, 0),
|
||||
(0, 0, 0),
|
||||
(50, 50, 150),
|
||||
]
|
||||
|
||||
# 运算符显示映射(用于渲染)
|
||||
_OP_DISPLAY = {
|
||||
"+": "+",
|
||||
"-": "-",
|
||||
"×": "×",
|
||||
"÷": "÷",
|
||||
}
|
||||
|
||||
|
||||
class MathCaptchaGenerator(BaseCaptchaGenerator):
|
||||
"""算式验证码生成器。"""
|
||||
|
||||
def __init__(self, seed: int | None = None):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = GENERATE_CONFIG["math"]
|
||||
self.width, self.height = self.cfg["image_size"]
|
||||
self.operators = self.cfg["operators"]
|
||||
self.op_lo, self.op_hi = self.cfg["operand_range"]
|
||||
|
||||
# 预加载可用字体
|
||||
self._fonts: list[str] = []
|
||||
for p in _FONT_PATHS:
|
||||
try:
|
||||
ImageFont.truetype(p, 20)
|
||||
self._fonts.append(p)
|
||||
except OSError:
|
||||
continue
|
||||
if not self._fonts:
|
||||
raise RuntimeError("未找到任何可用字体,无法生成验证码")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 公共接口
|
||||
# ----------------------------------------------------------
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
rng = self.rng
|
||||
|
||||
# 1. 生成算式
|
||||
if text is None:
|
||||
a, op, b = self._random_expression(rng)
|
||||
text = f"{a}{op}{b}"
|
||||
else:
|
||||
a, op, b = self._parse_expression(text)
|
||||
|
||||
# 显示文本: "3+8=?"
|
||||
display = f"{a}{_OP_DISPLAY.get(op, op)}{b}=?"
|
||||
|
||||
# 2. 浅色背景
|
||||
bg_lo, bg_hi = self.cfg["bg_color_range"]
|
||||
bg = tuple(rng.randint(bg_lo, bg_hi) for _ in range(3))
|
||||
img = Image.new("RGB", (self.width, self.height), bg)
|
||||
|
||||
# 3. 绘制算式文本
|
||||
self._draw_expression(img, display, rng)
|
||||
|
||||
# 4. 干扰线
|
||||
self._draw_noise_lines(img, rng)
|
||||
|
||||
# 5. 轻微模糊
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=0.6))
|
||||
|
||||
return img, text
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 私有方法
|
||||
# ----------------------------------------------------------
|
||||
def _random_expression(self, rng: random.Random) -> tuple[int, str, int]:
|
||||
"""随机生成一个合法算式 (a, op, b),确保结果为非负整数。"""
|
||||
while True:
|
||||
op = rng.choice(self.operators)
|
||||
a = rng.randint(self.op_lo, self.op_hi)
|
||||
b = rng.randint(self.op_lo, self.op_hi)
|
||||
|
||||
if op == "+":
|
||||
return a, op, b
|
||||
elif op == "-":
|
||||
# 确保 a >= b,结果非负
|
||||
if a < b:
|
||||
a, b = b, a
|
||||
return a, op, b
|
||||
elif op == "×":
|
||||
# 限制乘积不过大,保持合理
|
||||
if a * b <= 900:
|
||||
return a, op, b
|
||||
elif op == "÷":
|
||||
# 只生成能整除的
|
||||
if b != 0 and a % b == 0:
|
||||
return a, op, b
|
||||
|
||||
@staticmethod
|
||||
def _parse_expression(text: str) -> tuple[int, str, int]:
|
||||
"""解析标签文本,如 '3+8' -> (3, '+', 8)。"""
|
||||
for op in ["×", "÷", "+", "-"]:
|
||||
if op in text:
|
||||
parts = text.split(op, 1)
|
||||
return int(parts[0]), op, int(parts[1])
|
||||
raise ValueError(f"无法解析算式: {text}")
|
||||
|
||||
def _draw_expression(self, img: Image.Image, display: str, rng: random.Random) -> None:
|
||||
"""将算式文本绘制到图片上,每个字符单独渲染并带轻微旋转。"""
|
||||
n = len(display)
|
||||
slot_w = self.width // n
|
||||
font_size = int(min(slot_w * 0.85, self.height * 0.65))
|
||||
font_size = max(font_size, 14)
|
||||
|
||||
for i, ch in enumerate(display):
|
||||
font_path = rng.choice(self._fonts)
|
||||
|
||||
# 对于 × 等特殊符号,某些字体可能不支持,回退到 DejaVu
|
||||
try:
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
bbox = font.getbbox(ch)
|
||||
if bbox[2] - bbox[0] <= 0:
|
||||
raise ValueError
|
||||
except (OSError, ValueError):
|
||||
font = ImageFont.truetype(self._fonts[0], font_size)
|
||||
bbox = font.getbbox(ch)
|
||||
|
||||
color = rng.choice(_DARK_COLORS)
|
||||
|
||||
cw = bbox[2] - bbox[0] + 4
|
||||
ch_h = bbox[3] - bbox[1] + 4
|
||||
char_img = Image.new("RGBA", (cw, ch_h), (0, 0, 0, 0))
|
||||
ImageDraw.Draw(char_img).text((-bbox[0] + 2, -bbox[1] + 2), ch, fill=color, font=font)
|
||||
|
||||
# 轻微旋转
|
||||
angle = rng.randint(*self.cfg["rotation_range"])
|
||||
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
|
||||
|
||||
x = slot_w * i + (slot_w - char_img.width) // 2
|
||||
y = (self.height - char_img.height) // 2 + rng.randint(-2, 2)
|
||||
x = max(0, min(x, self.width - char_img.width))
|
||||
y = max(0, min(y, self.height - char_img.height))
|
||||
|
||||
img.paste(char_img, (x, y), char_img)
|
||||
|
||||
def _draw_noise_lines(self, img: Image.Image, rng: random.Random) -> None:
|
||||
"""绘制浅色干扰线。"""
|
||||
draw = ImageDraw.Draw(img)
|
||||
lo, hi = self.cfg["noise_line_range"]
|
||||
num = rng.randint(lo, hi)
|
||||
for _ in range(num):
|
||||
x1, y1 = rng.randint(0, self.width), rng.randint(0, self.height)
|
||||
x2, y2 = rng.randint(0, self.width), rng.randint(0, self.height)
|
||||
color = tuple(rng.randint(150, 220) for _ in range(3))
|
||||
draw.line([(x1, y1), (x2, y2)], fill=color, width=rng.randint(1, 2))
|
||||
154
generators/normal_gen.py
Normal file
154
generators/normal_gen.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
普通字符验证码生成器
|
||||
|
||||
生成风格:
|
||||
- 浅色随机背景 (RGB 各通道 230-255)
|
||||
- 每个字符随机深色 (蓝/红/绿/紫/棕等)
|
||||
- 字符数量 4-5 个
|
||||
- 字符有 ±15° 随机旋转
|
||||
- 2-5 条浅色干扰线
|
||||
- 少量噪点
|
||||
- 可选轻微高斯模糊
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont
|
||||
|
||||
from config import GENERATE_CONFIG, NORMAL_CHARS
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
|
||||
# 系统可用字体列表(粗体/常规混合,增加多样性)
|
||||
_FONT_PATHS = [
|
||||
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
|
||||
"/usr/share/fonts/TTF/DejaVuSansMono-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationMono-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
|
||||
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
|
||||
"/usr/share/fonts/gnu-free/FreeMonoBold.otf",
|
||||
]
|
||||
|
||||
# 深色调色板 (R, G, B)
|
||||
_DARK_COLORS = [
|
||||
(0, 0, 180), # 蓝
|
||||
(180, 0, 0), # 红
|
||||
(0, 130, 0), # 绿
|
||||
(130, 0, 130), # 紫
|
||||
(120, 60, 0), # 棕
|
||||
(0, 100, 100), # 青
|
||||
(80, 80, 0), # 橄榄
|
||||
(0, 0, 0), # 黑
|
||||
(100, 0, 50), # 暗玫红
|
||||
(50, 50, 150), # 钢蓝
|
||||
]
|
||||
|
||||
|
||||
class NormalCaptchaGenerator(BaseCaptchaGenerator):
|
||||
"""普通字符验证码生成器。"""
|
||||
|
||||
def __init__(self, seed: int | None = None):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = GENERATE_CONFIG["normal"]
|
||||
self.chars = NORMAL_CHARS
|
||||
self.width, self.height = self.cfg["image_size"]
|
||||
|
||||
# 预加载可用字体
|
||||
self._fonts: list[str] = []
|
||||
for p in _FONT_PATHS:
|
||||
try:
|
||||
ImageFont.truetype(p, 20)
|
||||
self._fonts.append(p)
|
||||
except OSError:
|
||||
continue
|
||||
if not self._fonts:
|
||||
raise RuntimeError("未找到任何可用字体,无法生成验证码")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 公共接口
|
||||
# ----------------------------------------------------------
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
rng = self.rng
|
||||
|
||||
# 1. 随机文本
|
||||
if text is None:
|
||||
length = rng.randint(*self.cfg["char_count_range"])
|
||||
text = "".join(rng.choices(self.chars, k=length))
|
||||
|
||||
# 2. 浅色背景
|
||||
bg_lo, bg_hi = self.cfg["bg_color_range"]
|
||||
bg = tuple(rng.randint(bg_lo, bg_hi) for _ in range(3))
|
||||
img = Image.new("RGB", (self.width, self.height), bg)
|
||||
|
||||
# 3. 逐字符绘制(旋转后粘贴)
|
||||
self._draw_text(img, text, rng)
|
||||
|
||||
# 4. 干扰线
|
||||
self._draw_noise_lines(img, rng)
|
||||
|
||||
# 5. 噪点
|
||||
self._draw_noise_points(img, rng)
|
||||
|
||||
# 6. 轻微高斯模糊
|
||||
if self.cfg["blur_radius"] > 0:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=self.cfg["blur_radius"]))
|
||||
|
||||
return img, text
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 私有方法
|
||||
# ----------------------------------------------------------
|
||||
def _draw_text(self, img: Image.Image, text: str, rng: random.Random) -> None:
|
||||
"""逐字符旋转并粘贴到画布上。"""
|
||||
n = len(text)
|
||||
# 每个字符的水平可用宽度
|
||||
slot_w = self.width // n
|
||||
font_size = int(min(slot_w * 0.9, self.height * 0.7))
|
||||
font_size = max(font_size, 12)
|
||||
|
||||
for i, ch in enumerate(text):
|
||||
font_path = rng.choice(self._fonts)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
color = rng.choice(_DARK_COLORS)
|
||||
|
||||
# 绘制单字符到临时透明图层
|
||||
bbox = font.getbbox(ch)
|
||||
cw = bbox[2] - bbox[0] + 4
|
||||
ch_h = bbox[3] - bbox[1] + 4
|
||||
char_img = Image.new("RGBA", (cw, ch_h), (0, 0, 0, 0))
|
||||
ImageDraw.Draw(char_img).text((-bbox[0] + 2, -bbox[1] + 2), ch, fill=color, font=font)
|
||||
|
||||
# 随机旋转
|
||||
angle = rng.randint(*self.cfg["rotation_range"])
|
||||
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
|
||||
|
||||
# 粘贴位置
|
||||
x = slot_w * i + (slot_w - char_img.width) // 2
|
||||
y = (self.height - char_img.height) // 2 + rng.randint(-3, 3)
|
||||
x = max(0, min(x, self.width - char_img.width))
|
||||
y = max(0, min(y, self.height - char_img.height))
|
||||
|
||||
img.paste(char_img, (x, y), char_img)
|
||||
|
||||
def _draw_noise_lines(self, img: Image.Image, rng: random.Random) -> None:
|
||||
"""绘制浅色干扰线。"""
|
||||
draw = ImageDraw.Draw(img)
|
||||
lo, hi = self.cfg["noise_line_range"]
|
||||
num = rng.randint(lo, hi)
|
||||
for _ in range(num):
|
||||
x1, y1 = rng.randint(0, self.width), rng.randint(0, self.height)
|
||||
x2, y2 = rng.randint(0, self.width), rng.randint(0, self.height)
|
||||
color = tuple(rng.randint(150, 220) for _ in range(3))
|
||||
draw.line([(x1, y1), (x2, y2)], fill=color, width=rng.randint(1, 2))
|
||||
|
||||
def _draw_noise_points(self, img: Image.Image, rng: random.Random) -> None:
|
||||
"""绘制噪点。"""
|
||||
draw = ImageDraw.Draw(img)
|
||||
for _ in range(self.cfg["noise_point_num"]):
|
||||
x = rng.randint(0, self.width - 1)
|
||||
y = rng.randint(0, self.height - 1)
|
||||
color = tuple(rng.randint(0, 200) for _ in range(3))
|
||||
draw.point((x, y), fill=color)
|
||||
211
generators/threed_gen.py
Normal file
211
generators/threed_gen.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
3D 立体验证码生成器
|
||||
|
||||
生成具有 3D 透视/阴影效果的验证码:
|
||||
- 使用仿射变换模拟 3D 透视
|
||||
- 添加阴影效果 (偏移的深色副本)
|
||||
- 字符有深度感和倾斜
|
||||
- 渐变背景增强立体感
|
||||
- 标签: 纯字符内容
|
||||
"""
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFilter, ImageFont
|
||||
|
||||
from config import GENERATE_CONFIG, THREED_CHARS
|
||||
from generators.base import BaseCaptchaGenerator
|
||||
|
||||
# 字体 (粗体效果更好渲染 3D)
|
||||
_FONT_PATHS = [
|
||||
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
|
||||
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
|
||||
]
|
||||
|
||||
# 前景色 — 鲜艳、对比度高
|
||||
_FRONT_COLORS = [
|
||||
(220, 50, 50), # 红
|
||||
(50, 100, 220), # 蓝
|
||||
(30, 160, 30), # 绿
|
||||
(200, 150, 0), # 金
|
||||
(180, 50, 180), # 紫
|
||||
(0, 160, 160), # 青
|
||||
(220, 100, 0), # 橙
|
||||
]
|
||||
|
||||
|
||||
class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
|
||||
"""3D 立体验证码生成器。"""
|
||||
|
||||
def __init__(self, seed: int | None = None):
|
||||
from config import RANDOM_SEED
|
||||
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
|
||||
|
||||
self.cfg = GENERATE_CONFIG["3d"]
|
||||
self.chars = THREED_CHARS
|
||||
self.width, self.height = self.cfg["image_size"]
|
||||
|
||||
# 预加载可用字体
|
||||
self._fonts: list[str] = []
|
||||
for p in _FONT_PATHS:
|
||||
try:
|
||||
ImageFont.truetype(p, 20)
|
||||
self._fonts.append(p)
|
||||
except OSError:
|
||||
continue
|
||||
if not self._fonts:
|
||||
raise RuntimeError("未找到任何可用字体,无法生成验证码")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 公共接口
|
||||
# ----------------------------------------------------------
|
||||
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
|
||||
rng = self.rng
|
||||
|
||||
# 1. 随机文本
|
||||
if text is None:
|
||||
length = rng.randint(*self.cfg["char_count_range"])
|
||||
text = "".join(rng.choices(self.chars, k=length))
|
||||
|
||||
# 2. 渐变背景 (增强立体感)
|
||||
img = self._gradient_background(rng)
|
||||
|
||||
# 3. 逐字符绘制 (阴影 + 透视 + 前景)
|
||||
self._draw_3d_text(img, text, rng)
|
||||
|
||||
# 4. 干扰线 (较粗、有深度感)
|
||||
self._draw_depth_lines(img, rng)
|
||||
|
||||
# 5. 轻微高斯模糊
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=0.7))
|
||||
|
||||
return img, text
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 私有方法
|
||||
# ----------------------------------------------------------
|
||||
def _gradient_background(self, rng: random.Random) -> Image.Image:
|
||||
"""生成从上到下的浅色渐变背景。"""
|
||||
img = Image.new("RGB", (self.width, self.height))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 随机两个浅色
|
||||
c1 = tuple(rng.randint(200, 240) for _ in range(3))
|
||||
c2 = tuple(rng.randint(180, 220) for _ in range(3))
|
||||
|
||||
for y in range(self.height):
|
||||
ratio = y / max(self.height - 1, 1)
|
||||
r = int(c1[0] + (c2[0] - c1[0]) * ratio)
|
||||
g = int(c1[1] + (c2[1] - c1[1]) * ratio)
|
||||
b = int(c1[2] + (c2[2] - c1[2]) * ratio)
|
||||
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
|
||||
|
||||
return img
|
||||
|
||||
def _draw_3d_text(self, img: Image.Image, text: str, rng: random.Random) -> None:
|
||||
"""逐字符绘制 3D 效果: 阴影层 + 透视变换 + 前景层。"""
|
||||
n = len(text)
|
||||
slot_w = self.width // n
|
||||
font_size = int(min(slot_w * 0.8, self.height * 0.65))
|
||||
font_size = max(font_size, 16)
|
||||
|
||||
shadow_dx, shadow_dy = self.cfg["shadow_offset"]
|
||||
|
||||
for i, ch in enumerate(text):
|
||||
font_path = rng.choice(self._fonts)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
front_color = rng.choice(_FRONT_COLORS)
|
||||
# 阴影色: 对应前景色的暗化版本
|
||||
shadow_color = tuple(max(0, c - 80) for c in front_color)
|
||||
|
||||
# 渲染单字符
|
||||
bbox = font.getbbox(ch)
|
||||
cw = bbox[2] - bbox[0] + 8
|
||||
ch_h = bbox[3] - bbox[1] + 8
|
||||
pad = max(shadow_dx, shadow_dy) + 4 # 额外空间给阴影
|
||||
|
||||
canvas_w = cw + pad * 2
|
||||
canvas_h = ch_h + pad * 2
|
||||
|
||||
# --- 阴影层 ---
|
||||
shadow_img = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0))
|
||||
ImageDraw.Draw(shadow_img).text(
|
||||
(-bbox[0] + pad + shadow_dx, -bbox[1] + pad + shadow_dy),
|
||||
ch, fill=shadow_color + (180,), font=font
|
||||
)
|
||||
|
||||
# --- 前景层 ---
|
||||
front_img = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0))
|
||||
ImageDraw.Draw(front_img).text(
|
||||
(-bbox[0] + pad, -bbox[1] + pad),
|
||||
ch, fill=front_color + (255,), font=font
|
||||
)
|
||||
|
||||
# 合并: 先阴影后前景
|
||||
char_img = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0))
|
||||
char_img = Image.alpha_composite(char_img, shadow_img)
|
||||
char_img = Image.alpha_composite(char_img, front_img)
|
||||
|
||||
# 透视变换 (仿射)
|
||||
char_img = self._perspective_transform(char_img, rng)
|
||||
|
||||
# 随机旋转
|
||||
angle = rng.randint(-20, 20)
|
||||
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
|
||||
|
||||
# 粘贴到画布
|
||||
x = slot_w * i + (slot_w - char_img.width) // 2
|
||||
y = (self.height - char_img.height) // 2 + rng.randint(-4, 4)
|
||||
x = max(0, min(x, self.width - char_img.width))
|
||||
y = max(0, min(y, self.height - char_img.height))
|
||||
|
||||
img.paste(char_img, (x, y), char_img)
|
||||
|
||||
def _perspective_transform(self, img: Image.Image, rng: random.Random) -> Image.Image:
|
||||
"""对单个字符图片施加仿射变换模拟 3D 透视。"""
|
||||
w, h = img.size
|
||||
intensity = self.cfg["perspective_intensity"]
|
||||
|
||||
# 随机 shear / scale 参数
|
||||
shear_x = rng.uniform(-intensity, intensity)
|
||||
shear_y = rng.uniform(-intensity * 0.5, intensity * 0.5)
|
||||
scale_x = rng.uniform(1.0 - intensity * 0.3, 1.0 + intensity * 0.3)
|
||||
scale_y = rng.uniform(1.0 - intensity * 0.3, 1.0 + intensity * 0.3)
|
||||
|
||||
# 仿射变换矩阵 (a, b, c, d, e, f) -> (x', y') = (a*x+b*y+c, d*x+e*y+f)
|
||||
# Pillow transform 需要逆变换系数
|
||||
a = scale_x
|
||||
b = shear_x
|
||||
d = shear_y
|
||||
e = scale_y
|
||||
# 计算偏移让中心不变
|
||||
c = (1 - a) * w / 2 - b * h / 2
|
||||
f = -d * w / 2 + (1 - e) * h / 2
|
||||
|
||||
return img.transform(
|
||||
(w, h), Image.AFFINE,
|
||||
(a, b, c, d, e, f),
|
||||
resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
def _draw_depth_lines(self, img: Image.Image, rng: random.Random) -> None:
|
||||
"""绘制有深度感的干扰线 (较粗、带阴影)。"""
|
||||
draw = ImageDraw.Draw(img)
|
||||
num = rng.randint(2, 4)
|
||||
for _ in range(num):
|
||||
x1, y1 = rng.randint(0, self.width), rng.randint(0, self.height)
|
||||
x2, y2 = rng.randint(0, self.width), rng.randint(0, self.height)
|
||||
|
||||
# 阴影线
|
||||
shadow_color = tuple(rng.randint(80, 130) for _ in range(3))
|
||||
dx, dy = self.cfg["shadow_offset"]
|
||||
draw.line([(x1 + dx, y1 + dy), (x2 + dx, y2 + dy)],
|
||||
fill=shadow_color, width=rng.randint(2, 3))
|
||||
|
||||
# 前景线
|
||||
color = tuple(rng.randint(120, 200) for _ in range(3))
|
||||
draw.line([(x1, y1), (x2, y2)], fill=color, width=rng.randint(1, 2))
|
||||
18
inference/__init__.py
Normal file
18
inference/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
推理包
|
||||
|
||||
- pipeline.py: CaptchaPipeline 核心推理流水线
|
||||
- export_onnx.py: PyTorch → ONNX 导出
|
||||
- math_eval.py: 算式计算模块
|
||||
"""
|
||||
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
from inference.math_eval import eval_captcha_math
|
||||
from inference.export_onnx import export_model, export_all
|
||||
|
||||
__all__ = [
|
||||
"CaptchaPipeline",
|
||||
"eval_captcha_math",
|
||||
"export_model",
|
||||
"export_all",
|
||||
]
|
||||
121
inference/export_onnx.py
Normal file
121
inference/export_onnx.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
ONNX 导出脚本
|
||||
|
||||
从 checkpoints/ 加载训练好的 PyTorch 模型,导出为 ONNX 格式到 onnx_models/。
|
||||
支持逐个导出或一次导出全部。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from config import (
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
NUM_CAPTCHA_TYPES,
|
||||
)
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
|
||||
|
||||
def export_model(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
input_shape: tuple,
|
||||
onnx_dir: str | None = None,
|
||||
):
|
||||
"""
|
||||
导出单个模型为 ONNX。
|
||||
|
||||
Args:
|
||||
model: 已加载权重的 PyTorch 模型
|
||||
model_name: 模型名 (classifier / normal / math / threed)
|
||||
input_shape: 输入形状 (C, H, W)
|
||||
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
out_dir = Path(onnx_dir) if onnx_dir else ONNX_DIR
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
onnx_path = out_dir / f"{model_name}.onnx"
|
||||
|
||||
model.eval()
|
||||
model.cpu()
|
||||
|
||||
dummy = torch.randn(1, *input_shape)
|
||||
|
||||
# 分类器和识别器的 dynamic_axes 不同
|
||||
if model_name == "classifier":
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
else:
|
||||
# CTC 模型: output shape = (T, B, C)
|
||||
dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}}
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes=dynamic_axes if ONNX_CONFIG["dynamic_batch"] else None,
|
||||
)
|
||||
|
||||
size_kb = onnx_path.stat().st_size / 1024
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({size_kb:.1f} KB)")
|
||||
|
||||
|
||||
def _load_and_export(model_name: str):
|
||||
"""从 checkpoint 加载模型并导出 ONNX。"""
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
if not ckpt_path.exists():
|
||||
print(f"[跳过] {model_name}: checkpoint 不存在 ({ckpt_path})")
|
||||
return
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}")
|
||||
|
||||
if model_name == "classifier":
|
||||
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
|
||||
h, w = IMAGE_SIZE["classifier"]
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "normal":
|
||||
chars = ckpt.get("chars", NORMAL_CHARS)
|
||||
h, w = IMAGE_SIZE["normal"]
|
||||
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "math":
|
||||
chars = ckpt.get("chars", MATH_CHARS)
|
||||
h, w = IMAGE_SIZE["math"]
|
||||
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
elif model_name == "threed":
|
||||
chars = ckpt.get("chars", THREED_CHARS)
|
||||
h, w = IMAGE_SIZE["3d"]
|
||||
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
|
||||
input_shape = (1, h, w)
|
||||
else:
|
||||
print(f"[错误] 未知模型: {model_name}")
|
||||
return
|
||||
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
export_model(model, model_name, input_shape)
|
||||
|
||||
|
||||
def export_all():
|
||||
"""依次导出 classifier, normal, math, threed 四个模型。"""
|
||||
print("=" * 50)
|
||||
print("导出全部 ONNX 模型")
|
||||
print("=" * 50)
|
||||
for name in ["classifier", "normal", "math", "threed"]:
|
||||
_load_and_export(name)
|
||||
print("\n全部导出完成。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
export_all()
|
||||
66
inference/math_eval.py
Normal file
66
inference/math_eval.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
算式计算模块
|
||||
|
||||
解析并计算验证码中的算式表达式。
|
||||
用正则提取数字和运算符,不使用 eval()。
|
||||
|
||||
支持: 加减乘除,个位到两位数运算。
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
# 匹配: 数字 运算符 数字 (后面可能跟 =? 等)
|
||||
_EXPR_PATTERN = re.compile(
|
||||
r"(\d+)\s*([+\-×÷xX*])\s*(\d+)"
|
||||
)
|
||||
|
||||
# 运算符归一化映射
|
||||
_OP_MAP = {
|
||||
"+": "+",
|
||||
"-": "-",
|
||||
"×": "×",
|
||||
"÷": "÷",
|
||||
"x": "×",
|
||||
"X": "×",
|
||||
"*": "×",
|
||||
}
|
||||
|
||||
|
||||
def eval_captcha_math(expr: str) -> str:
|
||||
"""
|
||||
解析并计算验证码算式。
|
||||
|
||||
支持: 加减乘除,个位到两位数运算。
|
||||
输入: "3+8=?" 或 "12×3=?" 或 "15-7=?" 或 "3+8"
|
||||
输出: "11" 或 "36" 或 "8"
|
||||
|
||||
用正则提取数字和运算符,不使用 eval()。
|
||||
|
||||
Raises:
|
||||
ValueError: 无法解析表达式
|
||||
"""
|
||||
match = _EXPR_PATTERN.search(expr)
|
||||
if not match:
|
||||
raise ValueError(f"无法解析算式: {expr!r}")
|
||||
|
||||
a = int(match.group(1))
|
||||
op_raw = match.group(2)
|
||||
b = int(match.group(3))
|
||||
|
||||
op = _OP_MAP.get(op_raw, op_raw)
|
||||
|
||||
if op == "+":
|
||||
result = a + b
|
||||
elif op == "-":
|
||||
result = a - b
|
||||
elif op == "×":
|
||||
result = a * b
|
||||
elif op == "÷":
|
||||
if b == 0:
|
||||
raise ValueError(f"除数为零: {expr!r}")
|
||||
result = a // b
|
||||
else:
|
||||
raise ValueError(f"不支持的运算符: {op!r} 原式: {expr!r}")
|
||||
|
||||
return str(result)
|
||||
231
inference/pipeline.py
Normal file
231
inference/pipeline.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
核心推理流水线
|
||||
|
||||
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
|
||||
|
||||
推理流程:
|
||||
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
|
||||
|
||||
对算式类型,解码后还会调用 math_eval 计算结果。
|
||||
"""
|
||||
|
||||
import io
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import (
|
||||
CAPTCHA_TYPES,
|
||||
IMAGE_SIZE,
|
||||
INFERENCE_CONFIG,
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
)
|
||||
from inference.math_eval import eval_captcha_math
|
||||
|
||||
|
||||
def _try_import_ort():
|
||||
"""延迟导入 onnxruntime,给出友好错误提示。"""
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
return ort
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"推理需要 onnxruntime,请安装: uv pip install onnxruntime"
|
||||
)
|
||||
|
||||
|
||||
class CaptchaPipeline:
|
||||
"""
|
||||
核心推理流水线。
|
||||
|
||||
加载调度模型和所有专家模型 (ONNX 格式)。
|
||||
提供统一的 solve(image) 接口。
|
||||
"""
|
||||
|
||||
def __init__(self, models_dir: str | None = None):
|
||||
"""
|
||||
初始化加载所有 ONNX 模型。
|
||||
|
||||
Args:
|
||||
models_dir: ONNX 模型目录,默认使用 config 中的路径
|
||||
"""
|
||||
ort = _try_import_ort()
|
||||
|
||||
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
|
||||
self.mean = INFERENCE_CONFIG["normalize_mean"]
|
||||
self.std = INFERENCE_CONFIG["normalize_std"]
|
||||
|
||||
# 字符集映射
|
||||
self._chars = {
|
||||
"normal": NORMAL_CHARS,
|
||||
"math": MATH_CHARS,
|
||||
"3d": THREED_CHARS,
|
||||
}
|
||||
|
||||
# 专家模型名 → ONNX 文件名
|
||||
self._model_files = {
|
||||
"classifier": "classifier.onnx",
|
||||
"normal": "normal.onnx",
|
||||
"math": "math.onnx",
|
||||
"3d": "threed.onnx",
|
||||
}
|
||||
|
||||
# 加载所有可用模型
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 2
|
||||
|
||||
self._sessions: dict[str, "ort.InferenceSession"] = {}
|
||||
for name, fname in self._model_files.items():
|
||||
path = self.models_dir / fname
|
||||
if path.exists():
|
||||
self._sessions[name] = ort.InferenceSession(
|
||||
str(path), sess_options=opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
loaded = list(self._sessions.keys())
|
||||
if not loaded:
|
||||
raise FileNotFoundError(
|
||||
f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}"
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 公共接口
|
||||
# ----------------------------------------------------------
|
||||
def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
|
||||
"""
|
||||
图片预处理: resize, grayscale, normalize, 转 numpy。
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
target_size: (H, W)
|
||||
|
||||
Returns:
|
||||
(1, 1, H, W) float32 ndarray
|
||||
"""
|
||||
h, w = target_size
|
||||
img = image.convert("L").resize((w, h), Image.BILINEAR)
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
arr = (arr - self.mean) / self.std
|
||||
return arr.reshape(1, 1, h, w)
|
||||
|
||||
def classify(self, image: Image.Image) -> str:
|
||||
"""
|
||||
调度分类,返回类型名: 'normal' / 'math' / '3d'。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 分类器模型未加载
|
||||
"""
|
||||
if "classifier" not in self._sessions:
|
||||
raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx")
|
||||
|
||||
inp = self.preprocess(image, IMAGE_SIZE["classifier"])
|
||||
session = self._sessions["classifier"]
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
|
||||
idx = int(np.argmax(logits, axis=1)[0])
|
||||
return CAPTCHA_TYPES[idx]
|
||||
|
||||
def solve(
|
||||
self,
|
||||
image,
|
||||
captcha_type: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
完整识别流程。
|
||||
|
||||
Args:
|
||||
image: PIL.Image 或文件路径 (str/Path) 或 bytes
|
||||
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"type": str, # 验证码类型
|
||||
"raw": str, # OCR 原始识别结果
|
||||
"result": str, # 最终答案 (算式型为计算结果)
|
||||
"time_ms": float, # 推理耗时 (毫秒)
|
||||
}
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# 1. 解析输入
|
||||
img = self._load_image(image)
|
||||
|
||||
# 2. 分类
|
||||
if captcha_type is None:
|
||||
captcha_type = self.classify(img)
|
||||
|
||||
# 3. 路由到专家模型
|
||||
if captcha_type not in self._sessions:
|
||||
raise RuntimeError(
|
||||
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
|
||||
)
|
||||
|
||||
size_key = captcha_type # "normal"/"math"/"3d"
|
||||
inp = self.preprocess(img, IMAGE_SIZE[size_key])
|
||||
session = self._sessions[captcha_type]
|
||||
input_name = session.get_inputs()[0].name
|
||||
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
|
||||
|
||||
# 4. CTC 贪心解码
|
||||
chars = self._chars[captcha_type]
|
||||
raw_text = self._ctc_greedy_decode(logits, chars)
|
||||
|
||||
# 5. 后处理
|
||||
if captcha_type == "math":
|
||||
try:
|
||||
result = eval_captcha_math(raw_text)
|
||||
except ValueError:
|
||||
result = raw_text # 解析失败则返回原始文本
|
||||
else:
|
||||
result = raw_text
|
||||
|
||||
elapsed = (time.perf_counter() - t0) * 1000
|
||||
|
||||
return {
|
||||
"type": captcha_type,
|
||||
"raw": raw_text,
|
||||
"result": result,
|
||||
"time_ms": round(elapsed, 2),
|
||||
}
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 私有方法
|
||||
# ----------------------------------------------------------
|
||||
@staticmethod
|
||||
def _load_image(image) -> Image.Image:
|
||||
"""将多种输入类型统一转为 PIL Image。"""
|
||||
if isinstance(image, Image.Image):
|
||||
return image
|
||||
if isinstance(image, (str, Path)):
|
||||
return Image.open(image).convert("RGB")
|
||||
if isinstance(image, bytes):
|
||||
return Image.open(io.BytesIO(image)).convert("RGB")
|
||||
raise TypeError(f"不支持的图片输入类型: {type(image)}")
|
||||
|
||||
@staticmethod
|
||||
def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str:
|
||||
"""
|
||||
CTC 贪心解码 (numpy 版本)。
|
||||
|
||||
Args:
|
||||
logits: (T, B, C) ONNX 输出
|
||||
chars: 字符集 (不含 blank, blank=index 0)
|
||||
|
||||
Returns:
|
||||
解码后的字符串
|
||||
"""
|
||||
# 取 batch=0
|
||||
preds = np.argmax(logits[:, 0, :], axis=1) # (T,)
|
||||
decoded = []
|
||||
prev = -1
|
||||
for idx in preds:
|
||||
if idx != 0 and idx != prev:
|
||||
decoded.append(chars[idx - 1])
|
||||
prev = idx
|
||||
return "".join(decoded)
|
||||
16
main.py
Normal file
16
main.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# 这是一个示例 Python 脚本。
|
||||
|
||||
# 按 Shift+F10 执行或将其替换为您的代码。
|
||||
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
|
||||
|
||||
|
||||
def print_hi(name):
|
||||
# 在下面的代码行中使用断点来调试脚本。
|
||||
print(f'Hi, {name}') # 按 Ctrl+8 切换断点。
|
||||
|
||||
|
||||
# 按装订区域中的绿色按钮以运行脚本。
|
||||
if __name__ == '__main__':
|
||||
print_hi('PyCharm')
|
||||
|
||||
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助
|
||||
18
models/__init__.py
Normal file
18
models/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
模型定义包
|
||||
|
||||
提供三种模型:
|
||||
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
|
||||
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
|
||||
- ThreeDCNN: 3D 验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
|
||||
"""
|
||||
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
|
||||
__all__ = [
|
||||
"CaptchaClassifier",
|
||||
"LiteCRNN",
|
||||
"ThreeDCNN",
|
||||
]
|
||||
72
models/classifier.py
Normal file
72
models/classifier.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
调度分类器模型
|
||||
|
||||
轻量 CNN 分类器,用于判断验证码类型 (normal / math / 3d)。
|
||||
不同类型验证码视觉差异大,分类任务简单。
|
||||
|
||||
架构: 4 层卷积 + GAP + FC
|
||||
输入: 灰度图 1×64×128
|
||||
输出: softmax 概率分布 (num_types 个类别)
|
||||
体积目标: < 500KB
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CaptchaClassifier(nn.Module):
|
||||
"""
|
||||
轻量分类器。
|
||||
|
||||
4 层卷积 (每层 Conv + BN + ReLU + MaxPool)
|
||||
→ 全局平均池化 → 全连接 → 输出类别数。
|
||||
"""
|
||||
|
||||
def __init__(self, num_types: int = 3):
|
||||
super().__init__()
|
||||
self.num_types = num_types
|
||||
|
||||
self.features = nn.Sequential(
|
||||
# block 1: 1 -> 16, 64x128 -> 32x64
|
||||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(16),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 2: 16 -> 32, 32x64 -> 16x32
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 3: 32 -> 64, 16x32 -> 8x16
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 4: 64 -> 64, 8x16 -> 4x8
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
)
|
||||
|
||||
# 全局平均池化 → 输出 (batch, 64, 1, 1)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.classifier = nn.Linear(64, num_types)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, 64, 128) 灰度图
|
||||
|
||||
Returns:
|
||||
logits: (batch, num_types) 未经 softmax 的原始输出
|
||||
"""
|
||||
x = self.features(x)
|
||||
x = self.gap(x) # (B, 64, 1, 1)
|
||||
x = x.view(x.size(0), -1) # (B, 64)
|
||||
x = self.classifier(x) # (B, num_types)
|
||||
return x
|
||||
141
models/lite_crnn.py
Normal file
141
models/lite_crnn.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
轻量 CRNN 模型 (Convolutional Recurrent Neural Network)
|
||||
|
||||
用于普通字符验证码和算式验证码的 OCR 识别。
|
||||
两种模式通过不同的字符集和输入尺寸区分,共享同一网络架构。
|
||||
|
||||
架构: CNN 特征提取 → 序列映射 → BiLSTM → 全连接 → CTC 解码
|
||||
CTC 输出长度 = 特征图宽度 (经过若干次宽度方向 pool 后)
|
||||
CTC blank 位于 index 0,字符从 index 1 开始映射。
|
||||
|
||||
- normal 模式: 输入 1×40×120, 字符集 30 字符, 体积 < 2MB
|
||||
- math 模式: 输入 1×40×160, 字符集 16 字符, 体积 < 2MB
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LiteCRNN(nn.Module):
|
||||
"""
|
||||
轻量 CRNN + CTC。
|
||||
|
||||
CNN 部分对高度做 4 次 pool (40→20→10→5→1 via AdaptivePool),
|
||||
宽度做 2 次 pool (保留足够序列长度给 CTC)。
|
||||
RNN 部分使用单层 BiLSTM。
|
||||
"""
|
||||
|
||||
def __init__(self, chars: str, img_h: int = 40, img_w: int = 120):
|
||||
"""
|
||||
Args:
|
||||
chars: 字符集字符串 (不含 CTC blank)
|
||||
img_h: 输入图片高度
|
||||
img_w: 输入图片宽度
|
||||
"""
|
||||
super().__init__()
|
||||
self.chars = chars
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
# CTC 类别数 = 字符数 + 1 (blank at index 0)
|
||||
self.num_classes = len(chars) + 1
|
||||
|
||||
# ---- CNN 特征提取 ----
|
||||
self.cnn = nn.Sequential(
|
||||
# block 1: 1 -> 32, H/2, W不变
|
||||
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
|
||||
|
||||
# block 2: 32 -> 64, H/2, W/2
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2), # H/2, W/2
|
||||
|
||||
# block 3: 64 -> 128, H/2, W不变
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
|
||||
|
||||
# block 4: 128 -> 128, H/2, W/2
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2), # H/2, W/2
|
||||
)
|
||||
|
||||
# 经过 4 次高度 pool: img_h / 16 (如 40 → 2, 不够整除时用自适应)
|
||||
# 用 AdaptiveAvgPool 把高度压到 1
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None)) # (B, 128, 1, W')
|
||||
|
||||
# ---- RNN 序列建模 ----
|
||||
self.rnn_input_size = 128
|
||||
self.rnn_hidden = 96
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=self.rnn_input_size,
|
||||
hidden_size=self.rnn_hidden,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
# ---- 输出层 ----
|
||||
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, H, W) 灰度图
|
||||
|
||||
Returns:
|
||||
logits: (seq_len, batch, num_classes)
|
||||
即 CTC 所需的 (T, B, C) 格式
|
||||
"""
|
||||
# CNN
|
||||
conv = self.cnn(x) # (B, 128, H', W')
|
||||
conv = self.adaptive_pool(conv) # (B, 128, 1, W')
|
||||
conv = conv.squeeze(2) # (B, 128, W')
|
||||
conv = conv.permute(0, 2, 1) # (B, W', 128) — batch_first 序列
|
||||
|
||||
# RNN
|
||||
rnn_out, _ = self.rnn(conv) # (B, W', 256)
|
||||
|
||||
# FC
|
||||
logits = self.fc(rnn_out) # (B, W', num_classes)
|
||||
logits = logits.permute(1, 0, 2) # (T, B, C) — CTC 格式
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
def seq_len(self) -> int:
|
||||
"""根据输入宽度计算 CTC 序列长度 (特征图宽度)。"""
|
||||
# 宽度经过 2 次 /2 的 pool
|
||||
return self.img_w // 4
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# CTC 贪心解码
|
||||
# ----------------------------------------------------------
|
||||
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
|
||||
"""
|
||||
CTC 贪心解码。
|
||||
|
||||
Args:
|
||||
logits: (T, B, C) 模型原始输出
|
||||
|
||||
Returns:
|
||||
解码后的字符串列表,长度 = batch size
|
||||
"""
|
||||
# (T, B, C) -> (B, T)
|
||||
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
|
||||
results = []
|
||||
for pred in preds:
|
||||
chars = []
|
||||
prev = -1
|
||||
for idx in pred.tolist():
|
||||
if idx != 0 and idx != prev: # 0 = blank
|
||||
chars.append(self.chars[idx - 1]) # 字符从 index 1 开始
|
||||
prev = idx
|
||||
results.append("".join(chars))
|
||||
return results
|
||||
155
models/threed_cnn.py
Normal file
155
models/threed_cnn.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
3D 立体验证码专用模型
|
||||
|
||||
采用更深的 CNN backbone(类 ResNet 残差块)+ CRNN 序列建模,
|
||||
以更强的特征提取能力处理透视变形和阴影效果。
|
||||
|
||||
架构: ResNet-lite backbone → 自适应池化 → BiLSTM → FC → CTC
|
||||
输入: 灰度图 1×60×160
|
||||
体积目标: < 5MB
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""简化残差块: Conv-BN-ReLU-Conv-BN + shortcut。"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = F.relu(out + residual, inplace=True)
|
||||
return out
|
||||
|
||||
|
||||
class ThreeDCNN(nn.Module):
|
||||
"""
|
||||
3D 验证码识别专用模型。
|
||||
|
||||
backbone 使用 5 层卷积(含 2 个残差块),通道数逐步增长:
|
||||
1 → 32 → 64 → 64(res) → 128 → 128(res)
|
||||
高度通过 pool 压缩后再用自适应池化归一,宽度保留序列长度。
|
||||
之后接 BiLSTM + FC 做 CTC 序列输出。
|
||||
"""
|
||||
|
||||
def __init__(self, chars: str, img_h: int = 60, img_w: int = 160):
|
||||
"""
|
||||
Args:
|
||||
chars: 字符集字符串 (不含 CTC blank)
|
||||
img_h: 输入图片高度
|
||||
img_w: 输入图片宽度
|
||||
"""
|
||||
super().__init__()
|
||||
self.chars = chars
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
self.num_classes = len(chars) + 1 # +1 for CTC blank
|
||||
|
||||
# ---- ResNet-lite backbone ----
|
||||
self.backbone = nn.Sequential(
|
||||
# stage 1: 1 -> 32, H/2, W不变
|
||||
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)),
|
||||
|
||||
# stage 2: 32 -> 64, H/2, W/2
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# stage 3: 残差块 64 -> 64
|
||||
ResidualBlock(64),
|
||||
|
||||
# stage 4: 64 -> 128, H/2, W/2
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# stage 5: 残差块 128 -> 128
|
||||
ResidualBlock(128),
|
||||
|
||||
# stage 6: 128 -> 128, H/2, W不变
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)),
|
||||
)
|
||||
|
||||
# 高度方向自适应压到 1,宽度保持
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None))
|
||||
|
||||
# ---- RNN 序列建模 ----
|
||||
self.rnn_input_size = 128
|
||||
self.rnn_hidden = 128
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=self.rnn_input_size,
|
||||
hidden_size=self.rnn_hidden,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
dropout=0.2,
|
||||
)
|
||||
|
||||
# ---- 输出层 ----
|
||||
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, H, W) 灰度图
|
||||
|
||||
Returns:
|
||||
logits: (seq_len, batch, num_classes) CTC 格式 (T, B, C)
|
||||
"""
|
||||
conv = self.backbone(x) # (B, 128, H', W')
|
||||
conv = self.adaptive_pool(conv) # (B, 128, 1, W')
|
||||
conv = conv.squeeze(2) # (B, 128, W')
|
||||
conv = conv.permute(0, 2, 1) # (B, W', 128)
|
||||
|
||||
rnn_out, _ = self.rnn(conv) # (B, W', 256)
|
||||
logits = self.fc(rnn_out) # (B, W', num_classes)
|
||||
logits = logits.permute(1, 0, 2) # (T, B, C)
|
||||
return logits
|
||||
|
||||
@property
|
||||
def seq_len(self) -> int:
|
||||
"""CTC 序列长度 = 输入宽度经过 2 次 W/2 pool 后的宽度。"""
|
||||
return self.img_w // 4
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# CTC 贪心解码
|
||||
# ----------------------------------------------------------
|
||||
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
|
||||
"""
|
||||
CTC 贪心解码。
|
||||
|
||||
Args:
|
||||
logits: (T, B, C) 模型原始输出
|
||||
|
||||
Returns:
|
||||
解码后的字符串列表
|
||||
"""
|
||||
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
|
||||
results = []
|
||||
for pred in preds:
|
||||
chars = []
|
||||
prev = -1
|
||||
for idx in pred.tolist():
|
||||
if idx != 0 and idx != prev:
|
||||
chars.append(self.chars[idx - 1])
|
||||
prev = idx
|
||||
results.append("".join(chars))
|
||||
return results
|
||||
25
pyproject.toml
Normal file
25
pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[project]
|
||||
name = "captchbreaker"
|
||||
version = "0.1.0"
|
||||
description = "验证码识别多模型系统 - 调度模型 + 多专家模型两级架构"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.0.0",
|
||||
"torchvision>=0.15.0",
|
||||
"onnx>=1.14.0",
|
||||
"onnxscript>=0.6.0",
|
||||
"onnxruntime>=1.15.0",
|
||||
"pillow>=10.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"tqdm>=4.65.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
server = [
|
||||
"fastapi>=0.100.0",
|
||||
"uvicorn>=0.23.0",
|
||||
"python-multipart>=0.0.6",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
captcha = "cli:main"
|
||||
3
tests/__init__.py
Normal file
3
tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
测试包
|
||||
"""
|
||||
10
training/__init__.py
Normal file
10
training/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
训练脚本包
|
||||
|
||||
- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类
|
||||
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
|
||||
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
|
||||
- train_math.py: 训练算式识别 (LiteCRNN - math)
|
||||
- train_3d.py: 训练 3D 立体识别 (ThreeDCNN)
|
||||
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
|
||||
"""
|
||||
159
training/dataset.py
Normal file
159
training/dataset.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
通用 Dataset 类
|
||||
|
||||
提供两种数据集:
|
||||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||||
|
||||
文件名格式约定: {label}_{任意}.png
|
||||
- 分类器: label 可为任意字符,所在子目录名即为类别
|
||||
- 识别器: label 即标注内容 (如 "A3B8" 或 "3+8")
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
from config import AUGMENT_CONFIG
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 增强 / 推理 transform 工厂函数
|
||||
# ============================================================
|
||||
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""训练时数据增强 transform。"""
|
||||
aug = AUGMENT_CONFIG
|
||||
return transforms.Compose([
|
||||
transforms.Grayscale(),
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.RandomAffine(
|
||||
degrees=aug["degrees"],
|
||||
translate=aug["translate"],
|
||||
scale=aug["scale"],
|
||||
),
|
||||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
|
||||
])
|
||||
|
||||
|
||||
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""验证 / 推理时 transform (无增强)。"""
|
||||
return transforms.Compose([
|
||||
transforms.Grayscale(),
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CRNN / CTC 识别用数据集
|
||||
# ============================================================
|
||||
class CRNNDataset(Dataset):
|
||||
"""
|
||||
CTC 识别数据集。
|
||||
|
||||
从目录中读取 {label}_{xxx}.png 文件,
|
||||
将 label 编码为整数序列 (CTC target)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
chars: str,
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
|
||||
chars: 字符集字符串 (不含 CTC blank)
|
||||
transform: 图片预处理/增强
|
||||
"""
|
||||
self.chars = chars
|
||||
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
|
||||
self.transform = transform
|
||||
|
||||
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
for f in sorted(d.glob("*.png")):
|
||||
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
|
||||
self.samples.append((str(f), label))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
path, label = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
# 编码标签为整数序列
|
||||
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
|
||||
return img, target, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""自定义 collate: 图片堆叠为 tensor,标签拼接为 1D tensor。"""
|
||||
import torch
|
||||
images, targets, labels = zip(*batch)
|
||||
images = torch.stack(images, 0)
|
||||
target_lengths = torch.IntTensor([len(t) for t in targets])
|
||||
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
|
||||
return images, targets_flat, target_lengths, list(labels)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 分类器用数据集
|
||||
# ============================================================
|
||||
class CaptchaDataset(Dataset):
|
||||
"""
|
||||
分类器训练数据集。
|
||||
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d"),
|
||||
目录内所有 .png 文件属于该类。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str | Path,
|
||||
class_names: list[str],
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
root_dir: 根目录,包含以类别名命名的子文件夹
|
||||
class_names: 类别名列表 (顺序即标签索引)
|
||||
transform: 图片预处理/增强
|
||||
"""
|
||||
self.class_names = class_names
|
||||
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
|
||||
self.transform = transform
|
||||
|
||||
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
|
||||
root = Path(root_dir)
|
||||
for cls_name in class_names:
|
||||
cls_dir = root / cls_name
|
||||
if not cls_dir.exists():
|
||||
continue
|
||||
for f in sorted(cls_dir.glob("*.png")):
|
||||
self.samples.append((str(f), self.class_to_idx[cls_name]))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
path, label = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
40
training/train_3d.py
Normal file
40
training/train_3d.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
训练 3D 立体验证码识别模型 (ThreeDCNN)
|
||||
|
||||
用法: python -m training.train_3d
|
||||
"""
|
||||
|
||||
from config import (
|
||||
THREED_CHARS,
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_3D_DIR,
|
||||
REAL_3D_DIR,
|
||||
)
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from training.train_utils import train_ctc_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["3d"]
|
||||
model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练 3D 立体验证码识别模型 (ThreeDCNN)")
|
||||
print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_ctc_model(
|
||||
model_name="threed",
|
||||
model=model,
|
||||
chars=THREED_CHARS,
|
||||
synthetic_dir=SYNTHETIC_3D_DIR,
|
||||
real_dir=REAL_3D_DIR,
|
||||
generator_cls=ThreeDCaptchaGenerator,
|
||||
config_key="threed",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
232
training/train_classifier.py
Normal file
232
training/train_classifier.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
训练调度分类器 (CaptchaClassifier)
|
||||
|
||||
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。
|
||||
数据来源: data/classifier/ 目录 (按类型子目录组织)
|
||||
|
||||
用法: python -m training.train_classifier
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
CAPTCHA_TYPES,
|
||||
NUM_CAPTCHA_TYPES,
|
||||
IMAGE_SIZE,
|
||||
TRAIN_CONFIG,
|
||||
CLASSIFIER_DIR,
|
||||
SYNTHETIC_NORMAL_DIR,
|
||||
SYNTHETIC_MATH_DIR,
|
||||
SYNTHETIC_3D_DIR,
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
get_device,
|
||||
)
|
||||
from generators.normal_gen import NormalCaptchaGenerator
|
||||
from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from models.classifier import CaptchaClassifier
|
||||
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
def _prepare_classifier_data():
|
||||
"""
|
||||
准备分类器训练数据。
|
||||
|
||||
策略:从各类型的合成数据目录中软链接 / 复制到 data/classifier/{type}/ 下,
|
||||
每类取相同数量,保证类别平衡。
|
||||
如果各类型合成数据不存在,先自动生成。
|
||||
"""
|
||||
cfg = TRAIN_CONFIG["classifier"]
|
||||
per_class = cfg["synthetic_samples"] // NUM_CAPTCHA_TYPES
|
||||
|
||||
# 各类型: (类名, 合成目录, 生成器类)
|
||||
type_info = [
|
||||
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
|
||||
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
|
||||
("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator),
|
||||
]
|
||||
|
||||
for cls_name, syn_dir, gen_cls in type_info:
|
||||
syn_dir = Path(syn_dir)
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
|
||||
# 如果合成数据不够,生成一些
|
||||
if len(existing) < per_class:
|
||||
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
|
||||
gen = gen_cls()
|
||||
gen.generate_dataset(per_class, str(syn_dir))
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
|
||||
# 复制到 classifier 目录
|
||||
cls_dir = CLASSIFIER_DIR / cls_name
|
||||
cls_dir.mkdir(parents=True, exist_ok=True)
|
||||
already = len(list(cls_dir.glob("*.png")))
|
||||
if already >= per_class:
|
||||
print(f"[数据] {cls_name} 分类器数据已就绪: {already} 张")
|
||||
continue
|
||||
|
||||
# 清空后重新链接
|
||||
for f in cls_dir.glob("*.png"):
|
||||
f.unlink()
|
||||
|
||||
selected = existing[:per_class]
|
||||
for f in tqdm(selected, desc=f"准备 {cls_name}", leave=False):
|
||||
dst = cls_dir / f.name
|
||||
# 使用符号链接节省空间,失败则复制
|
||||
try:
|
||||
dst.symlink_to(f.resolve())
|
||||
except OSError:
|
||||
shutil.copy2(f, dst)
|
||||
|
||||
print(f"[数据] {cls_name} 分类器数据就绪: {len(selected)} 张")
|
||||
|
||||
|
||||
def main():
|
||||
cfg = TRAIN_CONFIG["classifier"]
|
||||
img_h, img_w = IMAGE_SIZE["classifier"]
|
||||
device = get_device()
|
||||
|
||||
print("=" * 60)
|
||||
print("训练调度分类器 (CaptchaClassifier)")
|
||||
print(f" 类别: {CAPTCHA_TYPES}")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
# ---- 1. 准备数据 ----
|
||||
_prepare_classifier_data()
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
train_transform = build_train_transform(img_h, img_w)
|
||||
val_transform = build_val_transform(img_h, img_w)
|
||||
|
||||
full_dataset = CaptchaDataset(
|
||||
root_dir=CLASSIFIER_DIR,
|
||||
class_names=CAPTCHA_TYPES,
|
||||
transform=train_transform,
|
||||
)
|
||||
total = len(full_dataset)
|
||||
val_size = int(total * cfg["val_split"])
|
||||
train_size = total - val_size
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
|
||||
# 验证集无增强
|
||||
val_ds_clean = CaptchaDataset(
|
||||
root_dir=CLASSIFIER_DIR,
|
||||
class_names=CAPTCHA_TYPES,
|
||||
transform=val_transform,
|
||||
)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=2, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=2, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
# ---- 3. 模型 / 优化器 / 调度器 ----
|
||||
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
best_acc = 0.0
|
||||
ckpt_path = CHECKPOINTS_DIR / "classifier.pth"
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, labels in pbar:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
logits = model(images)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
# ---- 5. 验证 ----
|
||||
model.eval()
|
||||
correct = 0
|
||||
total_val = 0
|
||||
with torch.no_grad():
|
||||
for images, labels in val_loader:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
logits = model(images)
|
||||
preds = logits.argmax(dim=1)
|
||||
correct += (preds == labels).sum().item()
|
||||
total_val += labels.size(0)
|
||||
|
||||
val_acc = correct / max(total_val, 1)
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"acc={val_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
# ---- 6. 保存最佳模型 ----
|
||||
if val_acc > best_acc:
|
||||
best_acc = val_acc
|
||||
torch.save({
|
||||
"model_state_dict": model.state_dict(),
|
||||
"class_names": CAPTCHA_TYPES,
|
||||
"best_acc": best_acc,
|
||||
"epoch": epoch,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||||
|
||||
# ---- 7. 导出 ONNX ----
|
||||
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / "classifier.onnx"
|
||||
dummy = torch.randn(1, 1, img_h, img_w)
|
||||
torch.onnx.export(
|
||||
model.cpu(),
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
return best_acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
40
training/train_math.py
Normal file
40
training/train_math.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
训练算式识别模型 (LiteCRNN - math 模式)
|
||||
|
||||
用法: python -m training.train_math
|
||||
"""
|
||||
|
||||
from config import (
|
||||
MATH_CHARS,
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_MATH_DIR,
|
||||
REAL_MATH_DIR,
|
||||
)
|
||||
from generators.math_gen import MathCaptchaGenerator
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from training.train_utils import train_ctc_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["math"]
|
||||
model = LiteCRNN(chars=MATH_CHARS, img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练算式识别模型 (LiteCRNN - math)")
|
||||
print(f" 字符集: {MATH_CHARS} ({len(MATH_CHARS)} 字符)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_ctc_model(
|
||||
model_name="math",
|
||||
model=model,
|
||||
chars=MATH_CHARS,
|
||||
synthetic_dir=SYNTHETIC_MATH_DIR,
|
||||
real_dir=REAL_MATH_DIR,
|
||||
generator_cls=MathCaptchaGenerator,
|
||||
config_key="math",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
40
training/train_normal.py
Normal file
40
training/train_normal.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
训练普通字符识别模型 (LiteCRNN - normal 模式)
|
||||
|
||||
用法: python -m training.train_normal
|
||||
"""
|
||||
|
||||
from config import (
|
||||
NORMAL_CHARS,
|
||||
IMAGE_SIZE,
|
||||
SYNTHETIC_NORMAL_DIR,
|
||||
REAL_NORMAL_DIR,
|
||||
)
|
||||
from generators.normal_gen import NormalCaptchaGenerator
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from training.train_utils import train_ctc_model
|
||||
|
||||
|
||||
def main():
|
||||
img_h, img_w = IMAGE_SIZE["normal"]
|
||||
model = LiteCRNN(chars=NORMAL_CHARS, img_h=img_h, img_w=img_w)
|
||||
|
||||
print("=" * 60)
|
||||
print("训练普通字符识别模型 (LiteCRNN - normal)")
|
||||
print(f" 字符集: {NORMAL_CHARS} ({len(NORMAL_CHARS)} 字符)")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_ctc_model(
|
||||
model_name="normal",
|
||||
model=model,
|
||||
chars=NORMAL_CHARS,
|
||||
synthetic_dir=SYNTHETIC_NORMAL_DIR,
|
||||
real_dir=REAL_NORMAL_DIR,
|
||||
generator_cls=NormalCaptchaGenerator,
|
||||
config_key="normal",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
232
training/train_utils.py
Normal file
232
training/train_utils.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
CTC 训练通用逻辑
|
||||
|
||||
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
|
||||
职责:
|
||||
1. 检查合成数据,不存在则自动调用生成器
|
||||
2. 构建 Dataset / DataLoader(含真实数据混合)
|
||||
3. CTC 训练循环 + cosine scheduler
|
||||
4. 输出日志: epoch, loss, 整体准确率, 字符级准确率
|
||||
5. 保存最佳模型到 checkpoints/
|
||||
6. 训练结束导出 ONNX
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
CHECKPOINTS_DIR,
|
||||
ONNX_DIR,
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
get_device,
|
||||
)
|
||||
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 准确率计算
|
||||
# ============================================================
|
||||
def _calc_accuracy(preds: list[str], labels: list[str]):
|
||||
"""返回 (整体准确率, 字符级准确率)。"""
|
||||
total_samples = len(preds)
|
||||
correct_samples = 0
|
||||
total_chars = 0
|
||||
correct_chars = 0
|
||||
|
||||
for pred, label in zip(preds, labels):
|
||||
if pred == label:
|
||||
correct_samples += 1
|
||||
# 字符级: 逐位比较 (取较短长度)
|
||||
max_len = max(len(pred), len(label))
|
||||
if max_len == 0:
|
||||
continue
|
||||
for i in range(max_len):
|
||||
total_chars += 1
|
||||
if i < len(pred) and i < len(label) and pred[i] == label[i]:
|
||||
correct_chars += 1
|
||||
|
||||
sample_acc = correct_samples / max(total_samples, 1)
|
||||
char_acc = correct_chars / max(total_chars, 1)
|
||||
return sample_acc, char_acc
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ONNX 导出
|
||||
# ============================================================
|
||||
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
"""导出模型为 ONNX 格式。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||||
dummy = torch.randn(1, 1, img_h, img_w)
|
||||
torch.onnx.export(
|
||||
model.cpu(),
|
||||
dummy,
|
||||
str(onnx_path),
|
||||
opset_version=ONNX_CONFIG["opset_version"],
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={"input": {0: "batch"}, "output": {1: "batch"}}
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 核心训练函数
|
||||
# ============================================================
|
||||
def train_ctc_model(
|
||||
model_name: str,
|
||||
model: nn.Module,
|
||||
chars: str,
|
||||
synthetic_dir: str | Path,
|
||||
real_dir: str | Path,
|
||||
generator_cls,
|
||||
config_key: str,
|
||||
):
|
||||
"""
|
||||
通用 CTC 训练流程。
|
||||
|
||||
Args:
|
||||
model_name: 模型名称 (用于保存文件: normal / math / threed)
|
||||
model: PyTorch 模型实例 (LiteCRNN 或 ThreeDCNN)
|
||||
chars: 字符集字符串
|
||||
synthetic_dir: 合成数据目录
|
||||
real_dir: 真实数据目录
|
||||
generator_cls: 生成器类 (用于自动生成数据)
|
||||
config_key: TRAIN_CONFIG 中的键名
|
||||
"""
|
||||
cfg = TRAIN_CONFIG[config_key]
|
||||
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
|
||||
device = get_device()
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
real_path = Path(real_dir)
|
||||
if real_path.exists() and list(real_path.glob("*.png")):
|
||||
data_dirs.append(str(real_path))
|
||||
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))} 张")
|
||||
|
||||
train_transform = build_train_transform(img_h, img_w)
|
||||
val_transform = build_val_transform(img_h, img_w)
|
||||
|
||||
full_dataset = CRNNDataset(dirs=data_dirs, chars=chars, transform=train_transform)
|
||||
total = len(full_dataset)
|
||||
val_size = int(total * cfg["val_split"])
|
||||
train_size = total - val_size
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
|
||||
# 验证集使用无增强 transform
|
||||
val_ds_clean = CRNNDataset(dirs=data_dirs, chars=chars, transform=val_transform)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds, batch_size=cfg["batch_size"], shuffle=True,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
|
||||
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
# ---- 3. 优化器 / 调度器 / 损失 ----
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
|
||||
|
||||
best_acc = 0.0
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(1, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for images, targets, target_lengths, _ in pbar:
|
||||
images = images.to(device)
|
||||
targets = targets.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
logits = model(images) # (T, B, C)
|
||||
T, B, C = logits.shape
|
||||
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
|
||||
|
||||
log_probs = logits.log_softmax(2)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
# ---- 5. 验证 ----
|
||||
model.eval()
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
with torch.no_grad():
|
||||
for images, _, _, labels in val_loader:
|
||||
images = images.to(device)
|
||||
logits = model(images)
|
||||
preds = model.greedy_decode(logits)
|
||||
all_preds.extend(preds)
|
||||
all_labels.extend(labels)
|
||||
|
||||
sample_acc, char_acc = _calc_accuracy(all_preds, all_labels)
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"acc={sample_acc:.4f} "
|
||||
f"char_acc={char_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
# ---- 6. 保存最佳模型 ----
|
||||
if sample_acc >= best_acc:
|
||||
best_acc = sample_acc
|
||||
torch.save({
|
||||
"model_state_dict": model.state_dict(),
|
||||
"chars": chars,
|
||||
"best_acc": best_acc,
|
||||
"epoch": epoch,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||||
|
||||
# ---- 7. 导出 ONNX ----
|
||||
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
|
||||
# 加载最佳权重再导出
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, model_name, img_h, img_w)
|
||||
|
||||
return best_acc
|
||||
Reference in New Issue
Block a user