Initialize repository

This commit is contained in:
Hua
2026-03-10 18:47:29 +08:00
commit 760b80ee5e
32 changed files with 4343 additions and 0 deletions

14
.gitignore vendored Normal file
View File

@@ -0,0 +1,14 @@
.venv/
__pycache__/
*.py[cod]
.idea/
.claude/
data/synthetic/
data/classifier/
checkpoints/
onnx_models/
.DS_Store

28
AGENTS.md Normal file
View File

@@ -0,0 +1,28 @@
# Repository Guidelines
## Project Structure & Module Organization
Use `cli.py` as the main entrypoint and keep shared settings in `config.py`. `generators/` builds synthetic captchas, `models/` contains the classifier and expert OCR models, `training/` owns datasets and training scripts, and `inference/` contains the ONNX pipeline, export code, and math post-processing. Runtime artifacts live in `data/`, `checkpoints/`, and `onnx_models/`.
## Build, Test, and Development Commands
Use `uv` for environment and dependency management.
- `uv sync` installs the base runtime dependencies from `pyproject.toml`.
- `uv sync --extra server` installs HTTP service dependencies.
- `uv run captcha generate --type normal --num 1000` generates synthetic training data.
- `uv run captcha train --model normal` trains one model; `uv run captcha train --all` runs the full order: `normal -> math -> 3d -> classifier`.
- `uv run captcha export --all` exports all trained models to ONNX.
- `uv run captcha predict image.png` runs auto-routing inference; add `--type normal` to skip classification.
- `uv run captcha predict-dir ./test_images` runs batch inference on a directory.
- `uv run captcha serve --port 8080` starts the optional HTTP API when `server.py` is implemented.
## Coding Style & Naming Conventions
Target Python 3.10+ and follow existing style: 4-space indentation, snake_case for functions/modules, PascalCase for classes, and short docstrings on public entrypoints. Keep captcha-type ids exactly `normal`, `math`, `3d`, and `classifier`. Preserve the design rules from `CLAUDE.md`: float32 training/export, CPU-safe ops, and greedy CTC decoding unless the pipeline is intentionally redesigned. `normal` uses the local configured charset and currently includes confusing characters; math captchas must be recognized as strings and then evaluated in `inference/math_eval.py`.
## Data & Testing Guidelines
Synthetic generator output should use `{label}_{index:06d}.png`; real labeled samples should use `{label}_{anything}.png`. Save best checkpoints to `checkpoints/` and export matching ONNX files to `onnx_models/`. Use `pytest`, place tests under `tests/` as `test_<feature>.py`, and run them with `uv run pytest`. For model, data, or routing changes, add a fast smoke test for shapes, decoding, CLI behavior, or pipeline routing.
## Commit & Pull Request Guidelines
Git history is not available in this workspace snapshot, so use short imperative commit subjects such as `Add classifier export smoke test`. Keep pull requests focused, describe affected modules, list the commands you ran, and attach sample outputs when prediction behavior changes.
## Documentation Sync
Do not commit large generated datasets unless explicitly required. When a change affects project structure, commands, config, architecture, artifact paths, supported captcha types, or workflow rules, update `AGENTS.md` and `CLAUDE.md` in the same patch.

391
CLAUDE.md Normal file
View File

@@ -0,0 +1,391 @@
# CLAUDE.md - 验证码识别多模型系统 (CaptchaBreaker)
## 项目概述
构建一个本地验证码识别系统,采用 **调度模型 + 多专家模型** 的两级架构。调度模型负责分类验证码类型,专家模型负责具体识别。所有模型轻量化设计,最终导出 ONNX 用于部署。
## 技术栈
- Python 3.10+
- uv (包管理,依赖定义在 pyproject.toml)
- PyTorch 2.x (训练)
- ONNX + ONNXRuntime (推理部署)
- Pillow (图像处理)
- FastAPI (可选,提供 HTTP 识别服务)
## 项目结构
```
captcha-breaker/
├── CLAUDE.md
├── pyproject.toml # 项目配置与依赖 (uv 管理)
├── config.py # 全局配置 (字符集、图片尺寸、路径等)
├── data/
│ ├── synthetic/ # 合成训练数据 (自动生成,不入 git)
│ │ ├── normal/ # 普通字符型
│ │ ├── math/ # 算式型
│ │ └── 3d/ # 3D立体型
│ ├── real/ # 真实验证码样本 (手动标注)
│ │ ├── normal/
│ │ ├── math/
│ │ └── 3d/
│ └── classifier/ # 调度分类器训练数据 (混合各类型)
├── generators/
│ ├── __init__.py
│ ├── base.py # 生成器基类
│ ├── normal_gen.py # 普通字符验证码生成器
│ ├── math_gen.py # 算式验证码生成器 (如 3+8=?)
│ └── threed_gen.py # 3D立体验证码生成器
├── models/
│ ├── __init__.py
│ ├── lite_crnn.py # 轻量 CRNN (用于普通字符和算式)
│ ├── classifier.py # 调度分类模型
│ └── threed_cnn.py # 3D验证码专用模型 (更深的CNN)
├── training/
│ ├── __init__.py
│ ├── train_classifier.py # 训练调度模型
│ ├── train_normal.py # 训练普通字符识别
│ ├── train_math.py # 训练算式识别
│ ├── train_3d.py # 训练3D识别
│ └── dataset.py # 通用 Dataset 类
├── inference/
│ ├── __init__.py
│ ├── pipeline.py # 核心推理流水线 (调度+识别)
│ ├── export_onnx.py # PyTorch → ONNX 导出脚本
│ └── math_eval.py # 算式计算模块
├── checkpoints/ # 训练产出的模型文件
│ ├── classifier.pth
│ ├── normal.pth
│ ├── math.pth
│ └── threed.pth
├── onnx_models/ # 导出的 ONNX 模型
│ ├── classifier.onnx
│ ├── normal.onnx
│ ├── math.onnx
│ └── threed.onnx
├── server.py # FastAPI 推理服务 (可选)
├── cli.py # 命令行入口
└── tests/
├── test_generators.py
├── test_models.py
└── test_pipeline.py
```
## 核心架构设计
### 推理流水线
```
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → 后处理 → 输出结果
┌────────┼────────┐
▼ ▼ ▼
normal math 3d
(CRNN) (CRNN) (CNN)
│ │ │
▼ ▼ ▼
"A3B8" "3+8=?"→11 "X9K2"
```
### 调度分类器 (classifier.py)
- 任务: 图像分类,判断验证码属于哪个类型
- 架构: 轻量 CNN3-4 层卷积 + 全局平均池化 + 全连接
- 输入: 灰度图 1x64x128
- 输出: softmax 概率分布,类别数 = 验证码类型数
- 要求: 准确率 99%+,推理 < 5ms
- 模型体积目标: < 500KB
```python
class CaptchaClassifier(nn.Module):
"""
轻量分类器,几层卷积即可区分不同类型验证码。
不同类型验证码视觉差异大有无运算符、3D效果等分类很容易。
"""
def __init__(self, num_types=3):
# 4层卷积 + GAP + FC
# Conv2d(1,16) -> Conv2d(16,32) -> Conv2d(32,64) -> Conv2d(64,64)
# AdaptiveAvgPool2d(1) -> Linear(64, num_types)
pass
```
### 普通字符识别专家 (lite_crnn.py - normal 模式)
- 任务: 识别彩色字符验证码 (数字+字母混合)
- 架构: CRNN + CTC
- 字符集: `0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ` (36个包含易混淆字符按本地配置训练)
- 输入: 灰度图 1x40x120
- 输出: 字符序列,通过 CTC 贪心解码
- 验证码特征: 浅色背景、彩色字符、轻微干扰线、字符有倾斜
- 模型体积目标: < 2MB
### 算式识别专家 (lite_crnn.py - math 模式)
- 任务: 识别算式验证码并计算结果
- 架构: 复用 CRNN + CTC字符集不同
- 字符集: `0123456789+-×÷=?` (数字+运算符)
- 输入: 灰度图 1x40x160 (算式通常更宽)
- 输出: 识别出算式字符串,然后交给 math_eval.py 计算
- 分两步: (1) OCR 识别 → "3+8=?" (2) 正则解析并计算 → 11
- 模型体积目标: < 2MB
```python
# math_eval.py 核心逻辑
def eval_captcha_math(expr: str) -> str:
"""
解析并计算验证码算式。
支持: 加减乘除,个位到两位数运算。
输入: "3+8=?""12×3=?""15-7=?"
输出: "11""36""8"
用正则提取数字和运算符,不要用 eval()。
"""
pass
```
### 3D立体识别专家 (threed_cnn.py)
- 任务: 识别带 3D 透视/阴影效果的验证码
- 架构: 更深的 CNN + CRNN或 ResNet-lite backbone
- 输入: 灰度图 1x60x160
- 需要更强的特征提取能力来处理透视变形和阴影
- 模型体积目标: < 5MB
## 数据生成器规范
### 基类 (base.py)
```python
class BaseCaptchaGenerator:
def generate(self, text=None) -> tuple[Image.Image, str]:
"""生成一张验证码,返回 (图片, 标签文本)"""
raise NotImplementedError
def generate_dataset(self, num_samples: int, output_dir: str):
"""批量生成,文件名格式: {label}_{index:06d}.png"""
pass
```
### 普通字符生成器 (normal_gen.py)
模拟目标风格:
- 浅色随机背景 (RGB 各通道 230-255)
- 每个字符随机颜色 (深色: 蓝/红/绿/紫/棕等)
- 字符数量: 4-5 个
- 字符有 ±15° 随机旋转
- 2-5 条浅色干扰线
- 少量噪点
- 可选轻微高斯模糊
### 算式生成器 (math_gen.py)
- 生成形如 `A op B = ?` 的算式图片
- A, B 范围: 1-30 的整数
- op: +, -, × (除法只生成能整除的)
- 确保结果为非负整数
- 标签格式: `3+8` (存储算式本身,不存结果)
- 视觉风格: 与目标算式验证码一致
### 3D生成器 (threed_gen.py)
- 使用 Pillow 的仿射变换模拟 3D 透视
- 添加阴影效果
- 字符有深度感和倾斜
- 标签: 纯字符内容
## 训练规范
### 通用训练配置
```python
# config.py 中定义
TRAIN_CONFIG = {
'classifier': {
'epochs': 30,
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 30000, # 每类 10000
},
'normal': {
'epochs': 50,
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 60000,
'loss': 'CTCLoss',
},
'math': {
'epochs': 50,
'batch_size': 128,
'lr': 1e-3,
'scheduler': 'cosine',
'synthetic_samples': 60000,
'loss': 'CTCLoss',
},
'threed': {
'epochs': 80,
'batch_size': 64,
'lr': 5e-4,
'scheduler': 'cosine',
'synthetic_samples': 80000,
'loss': 'CTCLoss',
},
}
```
### 训练脚本要求
每个训练脚本必须:
1. 检查合成数据是否已生成,没有则自动调用生成器
2. 支持混合真实数据 (如果 data/real/{type}/ 有文件)
3. 使用数据增强: RandomAffine, ColorJitter, GaussianBlur, RandomErasing
4. 输出训练日志: epoch, loss, 整体准确率, 字符级准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束自动导出 ONNX 到 onnx_models/
### 数据增强策略
```python
# 训练时增强
train_augment = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((H, W)),
transforms.RandomAffine(degrees=8, translate=(0.05, 0.05), scale=(0.95, 1.05)),
transforms.ColorJitter(brightness=0.3, contrast=0.3),
transforms.GaussianBlur(3, sigma=(0.1, 0.5)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.RandomErasing(p=0.15, scale=(0.01, 0.05)),
])
```
## 推理流水线 (pipeline.py)
```python
class CaptchaPipeline:
"""
核心推理流水线。
加载调度模型和所有专家模型 (ONNX 格式)。
提供统一的 solve(image) 接口。
"""
def __init__(self, models_dir='onnx_models/'):
"""
初始化加载所有 ONNX 模型。
使用 onnxruntime.InferenceSession。
"""
pass
def preprocess(self, image: Image.Image, target_size: tuple) -> np.ndarray:
"""图片预处理: resize, grayscale, normalize, 转 numpy"""
pass
def classify(self, image: Image.Image) -> str:
"""调度分类,返回类型名: 'normal' / 'math' / '3d'"""
pass
def solve(self, image) -> str:
"""
完整识别流程:
1. 分类验证码类型
2. 路由到对应专家模型
3. 后处理 (算式型需要计算结果)
4. 返回最终答案字符串
image: PIL.Image 或文件路径或 bytes
"""
pass
```
## ONNX 导出 (export_onnx.py)
```python
def export_model(model, model_name, input_shape, onnx_dir='onnx_models/'):
"""
导出单个模型为 ONNX。
- 使用 opset_version=18
- 开启 dynamic_axes 支持动态 batch
- 导出后用 onnxruntime 验证推理一致性
- 可选: onnx 模型简化 (onnxsim)
"""
pass
def export_all():
"""依次导出 classifier, normal, math, threed 四个模型"""
pass
```
## CLI 入口 (cli.py)
```bash
# 安装依赖
uv sync # 核心依赖
uv sync --extra server # 含 HTTP 服务依赖
# 生成训练数据
uv run python cli.py generate --type normal --num 60000
uv run python cli.py generate --type math --num 60000
uv run python cli.py generate --type 3d --num 80000
uv run python cli.py generate --type classifier --num 30000
# 训练模型
uv run python cli.py train --model classifier
uv run python cli.py train --model normal
uv run python cli.py train --model math
uv run python cli.py train --model 3d
uv run python cli.py train --all # 按依赖顺序全部训练
# 导出 ONNX
uv run python cli.py export --all
# 推理
uv run python cli.py predict image.png # 自动分类+识别
uv run python cli.py predict image.png --type normal # 跳过分类直接识别
uv run python cli.py predict-dir ./test_images/ # 批量识别
# 启动 HTTP 服务 (需先安装 server 可选依赖)
uv run python cli.py serve --port 8080
```
## HTTP 服务 (server.py可选)
```python
# FastAPI 服务,提供 REST API
# POST /solve - 上传图片,返回识别结果
# 请求: multipart/form-data字段名 image
# 响应: {"type": "normal", "result": "A3B8", "confidence": 0.95, "time_ms": 45}
```
## 关键约束和注意事项
1. **所有模型用 float32 训练,导出 ONNX 时不做量化**,先保证精度
2. **CTC 解码统一用贪心解码**,不需要 beam search验证码场景贪心够用
3. **字符集由 config.py 统一定义**: 当前 normal 保留易混淆字符3d 继续使用去混淆字符集
4. **算式识别分两步**: 先 OCR 识别字符串,再用规则计算,不要让模型直接输出数值
5. **生成器的随机种子**: 生成数据时设置 seed 保证可复现
6. **真实数据文件名格式**: `{label}_{任意}.png`label 部分是标注内容
7. **模型保存格式**: PyTorch checkpoint 包含 model_state_dict, chars, best_acc, epoch
8. **不使用 GPU 特有功能**,确保 CPU 也能训练和推理 (只是慢一些)
9. **类型扩展**: 新增验证码类型时,只需 (1) 加生成器 (2) 加专家模型 (3) 调度器加一个类别重新训练
10. **文档同步**: 对项目结构、配置、架构等做出变更时,必须同步更新 CLAUDE.md 中的对应内容,保持文档与代码一致
## 目标指标
| 模型 | 准确率目标 | 推理延迟 | 模型体积 |
|------|-----------|---------|---------|
| 调度分类器 | > 99% | < 5ms | < 500KB |
| 普通字符 | > 95% | < 30ms | < 2MB |
| 算式识别 | > 93% | < 30ms | < 2MB |
| 3D立体 | > 85% | < 50ms | < 5MB |
| 全流水线 | - | < 80ms | < 10MB 总计 |
## 开发顺序
1. 先实现 config.py 和 generators/
2. 实现 models/ 中所有模型定义
3. 实现 training/dataset.py 通用数据集类
4. 按顺序训练: normal → math → 3d → classifier
5. 实现 inference/pipeline.py 和 export_onnx.py
6. 实现 cli.py 统一入口
7. 可选: server.py HTTP 服务
8. 编写 tests/

228
cli.py Normal file
View File

@@ -0,0 +1,228 @@
"""
CaptchaBreaker 命令行入口
用法:
python cli.py generate --type normal --num 60000
python cli.py train --model normal
python cli.py train --all
python cli.py export --all
python cli.py predict image.png
python cli.py predict image.png --type normal
python cli.py predict-dir ./test_images/
python cli.py serve --port 8080
"""
import argparse
import sys
from pathlib import Path
def cmd_generate(args):
"""生成训练数据。"""
from config import (
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
CLASSIFIER_DIR, TRAIN_CONFIG, CAPTCHA_TYPES, NUM_CAPTCHA_TYPES,
)
from generators import NormalCaptchaGenerator, MathCaptchaGenerator, ThreeDCaptchaGenerator
gen_map = {
"normal": (NormalCaptchaGenerator, SYNTHETIC_NORMAL_DIR),
"math": (MathCaptchaGenerator, SYNTHETIC_MATH_DIR),
"3d": (ThreeDCaptchaGenerator, SYNTHETIC_3D_DIR),
}
captcha_type = args.type
num = args.num
if captcha_type == "classifier":
# 分类器数据: 各类型各生成 num // num_types
per_class = num // NUM_CAPTCHA_TYPES
print(f"生成分类器训练数据: 每类 {per_class}")
for cls_name in CAPTCHA_TYPES:
gen_cls, out_dir = gen_map[cls_name]
cls_dir = CLASSIFIER_DIR / cls_name
cls_dir.mkdir(parents=True, exist_ok=True)
gen = gen_cls()
gen.generate_dataset(per_class, str(cls_dir))
elif captcha_type in gen_map:
gen_cls, out_dir = gen_map[captcha_type]
print(f"生成 {captcha_type} 数据: {num} 张 → {out_dir}")
gen = gen_cls()
gen.generate_dataset(num, str(out_dir))
else:
print(f"未知类型: {captcha_type} 可选: normal, math, 3d, classifier")
sys.exit(1)
def cmd_train(args):
"""训练模型。"""
if args.all:
# 按依赖顺序: normal → math → 3d → classifier
print("按顺序训练全部模型: normal → math → 3d → classifier\n")
from training.train_normal import main as train_normal
from training.train_math import main as train_math
from training.train_3d import main as train_3d
from training.train_classifier import main as train_classifier
train_normal()
print("\n")
train_math()
print("\n")
train_3d()
print("\n")
train_classifier()
return
model = args.model
if model == "normal":
from training.train_normal import main as train_fn
elif model == "math":
from training.train_math import main as train_fn
elif model == "3d":
from training.train_3d import main as train_fn
elif model == "classifier":
from training.train_classifier import main as train_fn
else:
print(f"未知模型: {model} 可选: normal, math, 3d, classifier")
sys.exit(1)
train_fn()
def cmd_export(args):
"""导出 ONNX 模型。"""
from inference.export_onnx import export_all, _load_and_export
if args.all:
export_all()
elif args.model:
_load_and_export(args.model)
else:
print("请指定 --all 或 --model <name>")
sys.exit(1)
def cmd_predict(args):
"""单张图片推理。"""
from inference.pipeline import CaptchaPipeline
image_path = args.image
if not Path(image_path).exists():
print(f"文件不存在: {image_path}")
sys.exit(1)
pipeline = CaptchaPipeline()
result = pipeline.solve(image_path, captcha_type=args.type)
print(f"文件: {image_path}")
print(f"类型: {result['type']}")
print(f"识别: {result['raw']}")
print(f"结果: {result['result']}")
print(f"耗时: {result['time_ms']:.1f} ms")
def cmd_predict_dir(args):
"""批量目录推理。"""
from inference.pipeline import CaptchaPipeline
dir_path = Path(args.directory)
if not dir_path.is_dir():
print(f"目录不存在: {dir_path}")
sys.exit(1)
pipeline = CaptchaPipeline()
images = sorted(dir_path.glob("*.png")) + sorted(dir_path.glob("*.jpg"))
if not images:
print(f"目录中未找到图片: {dir_path}")
sys.exit(1)
print(f"批量识别: {len(images)} 张图片\n")
print(f"{'文件名':<30} {'类型':<8} {'结果':<15} {'耗时(ms)':>8}")
print("-" * 65)
total_ms = 0.0
for img_path in images:
result = pipeline.solve(str(img_path), captcha_type=args.type)
total_ms += result["time_ms"]
print(
f"{img_path.name:<30} {result['type']:<8} "
f"{result['result']:<15} {result['time_ms']:>8.1f}"
)
print("-" * 65)
print(f"总计: {len(images)} 张 平均: {total_ms / len(images):.1f} ms 总耗时: {total_ms:.1f} ms")
def cmd_serve(args):
"""启动 HTTP 服务。"""
try:
from server import create_app
except ImportError:
# server.py 尚未实现或缺少依赖
print("HTTP 服务需要 FastAPI 和 uvicorn。")
print("安装: uv sync --extra server")
print("并确保 server.py 已实现。")
sys.exit(1)
import uvicorn
app = create_app()
uvicorn.run(app, host=args.host, port=args.port)
def main():
parser = argparse.ArgumentParser(
prog="captcha-breaker",
description="验证码识别多模型系统 - 调度模型 + 多专家模型",
)
subparsers = parser.add_subparsers(dest="command", help="子命令")
# ---- generate ----
p_gen = subparsers.add_parser("generate", help="生成训练数据")
p_gen.add_argument("--type", required=True, help="验证码类型: normal, math, 3d, classifier")
p_gen.add_argument("--num", type=int, required=True, help="生成数量")
# ---- train ----
p_train = subparsers.add_parser("train", help="训练模型")
p_train.add_argument("--model", help="模型名: normal, math, 3d, classifier")
p_train.add_argument("--all", action="store_true", help="按依赖顺序训练全部模型")
# ---- export ----
p_export = subparsers.add_parser("export", help="导出 ONNX 模型")
p_export.add_argument("--model", help="模型名: normal, math, 3d, classifier, threed")
p_export.add_argument("--all", action="store_true", help="导出全部模型")
# ---- predict ----
p_pred = subparsers.add_parser("predict", help="识别单张验证码")
p_pred.add_argument("image", help="图片路径")
p_pred.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
# ---- predict-dir ----
p_pdir = subparsers.add_parser("predict-dir", help="批量识别目录中的验证码")
p_pdir.add_argument("directory", help="图片目录路径")
p_pdir.add_argument("--type", default=None, help="指定类型跳过分类: normal, math, 3d")
# ---- serve ----
p_serve = subparsers.add_parser("serve", help="启动 HTTP 识别服务")
p_serve.add_argument("--host", default="0.0.0.0", help="监听地址 (默认 0.0.0.0)")
p_serve.add_argument("--port", type=int, default=8080, help="监听端口 (默认 8080)")
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(0)
cmd_map = {
"generate": cmd_generate,
"train": cmd_train,
"export": cmd_export,
"predict": cmd_predict,
"predict-dir": cmd_predict_dir,
"serve": cmd_serve,
}
cmd_map[args.command](args)
if __name__ == "__main__":
main()

195
config.py Normal file
View File

@@ -0,0 +1,195 @@
"""
全局配置 - 验证码识别多模型系统 (CaptchaBreaker)
定义字符集、图片尺寸、路径、训练超参等所有全局常量。
"""
import os
from pathlib import Path
# ============================================================
# 项目根目录
# ============================================================
PROJECT_ROOT = Path(__file__).resolve().parent
# ============================================================
# 数据目录
# ============================================================
DATA_DIR = PROJECT_ROOT / "data"
SYNTHETIC_DIR = DATA_DIR / "synthetic"
REAL_DIR = DATA_DIR / "real"
CLASSIFIER_DIR = DATA_DIR / "classifier"
# 合成数据子目录
SYNTHETIC_NORMAL_DIR = SYNTHETIC_DIR / "normal"
SYNTHETIC_MATH_DIR = SYNTHETIC_DIR / "math"
SYNTHETIC_3D_DIR = SYNTHETIC_DIR / "3d"
# 真实数据子目录
REAL_NORMAL_DIR = REAL_DIR / "normal"
REAL_MATH_DIR = REAL_DIR / "math"
REAL_3D_DIR = REAL_DIR / "3d"
# ============================================================
# 模型输出目录
# ============================================================
CHECKPOINTS_DIR = PROJECT_ROOT / "checkpoints"
ONNX_DIR = PROJECT_ROOT / "onnx_models"
# 确保关键目录存在
for _dir in [
SYNTHETIC_NORMAL_DIR, SYNTHETIC_MATH_DIR, SYNTHETIC_3D_DIR,
REAL_NORMAL_DIR, REAL_MATH_DIR, REAL_3D_DIR,
CLASSIFIER_DIR, CHECKPOINTS_DIR, ONNX_DIR,
]:
_dir.mkdir(parents=True, exist_ok=True)
# ============================================================
# 字符集定义
# ============================================================
# 普通字符验证码: 按当前本地配置保留易混淆字符,覆盖完整数字 + 大写字母
NORMAL_CHARS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
# 算式验证码: 数字 + 运算符
MATH_CHARS = "0123456789+-×÷=?"
# 3D 验证码: 继续使用去掉易混淆字符的精简字符集
THREED_CHARS = "23456789ABCDEFGHJKMNPQRSTUVWXYZ"
# 验证码类型列表 (调度分类器输出)
CAPTCHA_TYPES = ["normal", "math", "3d"]
NUM_CAPTCHA_TYPES = len(CAPTCHA_TYPES)
# ============================================================
# 图片尺寸配置 (H, W)
# ============================================================
IMAGE_SIZE = {
"classifier": (64, 128), # 调度分类器输入
"normal": (40, 120), # 普通字符识别
"math": (40, 160), # 算式识别 (更宽)
"3d": (60, 160), # 3D 立体识别
}
# ============================================================
# 验证码生成参数
# ============================================================
GENERATE_CONFIG = {
"normal": {
"char_count_range": (4, 5), # 字符数量: 4-5 个
"bg_color_range": (230, 255), # 浅色背景 RGB 各通道
"rotation_range": (-15, 15), # 字符旋转角度
"noise_line_range": (2, 5), # 干扰线数量
"noise_point_num": 100, # 噪点数量
"blur_radius": 0.8, # 高斯模糊半径
"image_size": (120, 40), # 生成图片尺寸 (W, H)
},
"math": {
"operand_range": (1, 30), # 操作数范围
"operators": ["+", "-", "×"], # 支持的运算符 (除法只生成能整除的)
"image_size": (160, 40), # 生成图片尺寸 (W, H)
"bg_color_range": (230, 255),
"rotation_range": (-10, 10),
"noise_line_range": (2, 4),
},
"3d": {
"char_count_range": (4, 5),
"image_size": (160, 60), # 生成图片尺寸 (W, H)
"shadow_offset": (3, 3), # 阴影偏移
"perspective_intensity": 0.3, # 透视变换强度
},
}
# ============================================================
# 训练配置
# ============================================================
TRAIN_CONFIG = {
"classifier": {
"epochs": 30,
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 30000, # 每类 10000
"val_split": 0.1, # 验证集比例
},
"normal": {
"epochs": 50,
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 60000,
"loss": "CTCLoss",
"val_split": 0.1,
},
"math": {
"epochs": 50,
"batch_size": 128,
"lr": 1e-3,
"scheduler": "cosine",
"synthetic_samples": 60000,
"loss": "CTCLoss",
"val_split": 0.1,
},
"threed": {
"epochs": 80,
"batch_size": 64,
"lr": 5e-4,
"scheduler": "cosine",
"synthetic_samples": 80000,
"loss": "CTCLoss",
"val_split": 0.1,
},
}
# ============================================================
# 数据增强参数 (训练时使用)
# ============================================================
AUGMENT_CONFIG = {
"degrees": 8, # RandomAffine 旋转范围
"translate": (0.05, 0.05), # 平移范围
"scale": (0.95, 1.05), # 缩放范围
"brightness": 0.3, # ColorJitter 亮度
"contrast": 0.3, # ColorJitter 对比度
"blur_kernel": 3, # GaussianBlur 核大小
"blur_sigma": (0.1, 0.5), # GaussianBlur sigma
"erasing_prob": 0.15, # RandomErasing 概率
"erasing_scale": (0.01, 0.05), # RandomErasing 面积比
}
# ============================================================
# ONNX 导出配置
# ============================================================
ONNX_CONFIG = {
"opset_version": 18,
"dynamic_batch": True, # 支持动态 batch size
}
# ============================================================
# 推理配置
# ============================================================
INFERENCE_CONFIG = {
"default_models_dir": str(ONNX_DIR),
"normalize_mean": 0.5,
"normalize_std": 0.5,
}
# ============================================================
# 随机种子 (保证数据生成可复现)
# ============================================================
RANDOM_SEED = 42
# ============================================================
# 设备配置 (优先 GPU回退 CPU)
# 延迟导入 torch避免仅使用生成器时必须安装 torch
# ============================================================
def get_device():
"""返回可用的 torch 设备,优先 GPU。"""
import torch
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============================================================
# 服务配置 (可选 HTTP 服务)
# ============================================================
SERVER_CONFIG = {
"host": "0.0.0.0",
"port": 8080,
}

0
data/real/3d/.gitkeep Normal file
View File

0
data/real/math/.gitkeep Normal file
View File

View File

20
generators/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""
数据生成器包
提供三种验证码类型的数据生成器:
- NormalCaptchaGenerator: 普通字符验证码
- MathCaptchaGenerator: 算式验证码
- ThreeDCaptchaGenerator: 3D 立体验证码
"""
from generators.base import BaseCaptchaGenerator
from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
__all__ = [
"BaseCaptchaGenerator",
"NormalCaptchaGenerator",
"MathCaptchaGenerator",
"ThreeDCaptchaGenerator",
]

61
generators/base.py Normal file
View File

@@ -0,0 +1,61 @@
"""
验证码生成器基类
所有验证码生成器继承此基类,实现 generate() 方法。
基类提供通用的 generate_dataset() 批量生成能力。
"""
import os
import random
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from config import RANDOM_SEED
class BaseCaptchaGenerator:
"""验证码生成器基类。"""
def __init__(self, seed: int = RANDOM_SEED):
"""
初始化生成器。
Args:
seed: 随机种子,保证数据生成可复现。
"""
self.seed = seed
self.rng = random.Random(seed)
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
"""
生成一张验证码图片。
Args:
text: 指定标签文本。为 None 时随机生成。
Returns:
(图片, 标签文本)
"""
raise NotImplementedError
def generate_dataset(self, num_samples: int, output_dir: str) -> None:
"""
批量生成验证码数据集。
文件名格式: {label}_{index:06d}.png
Args:
num_samples: 生成数量。
output_dir: 输出目录路径。
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 重置随机种子,保证每次批量生成结果一致
self.rng = random.Random(self.seed)
for i in tqdm(range(num_samples), desc=f"Generating → {output_path.name}"):
img, label = self.generate()
filename = f"{label}_{i:06d}.png"
img.save(output_path / filename)

186
generators/math_gen.py Normal file
View File

@@ -0,0 +1,186 @@
"""
算式验证码生成器
生成形如 A op B = ? 的算式图片:
- A, B 范围: 1-30 的整数
- op: +, -, × (除法只生成能整除的)
- 确保结果为非负整数
- 标签格式: "3+8" (存储算式本身,不存结果)
- 视觉风格: 浅色背景、深色字符、干扰线
"""
import random
from PIL import Image, ImageDraw, ImageFilter, ImageFont
from config import GENERATE_CONFIG
from generators.base import BaseCaptchaGenerator
# 字体
_FONT_PATHS = [
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSansMono-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/liberation/LiberationMono-Bold.ttf",
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
]
# 深色调色板
_DARK_COLORS = [
(0, 0, 180),
(180, 0, 0),
(0, 130, 0),
(130, 0, 130),
(120, 60, 0),
(0, 0, 0),
(50, 50, 150),
]
# 运算符显示映射(用于渲染)
_OP_DISPLAY = {
"+": "+",
"-": "-",
"×": "×",
"÷": "÷",
}
class MathCaptchaGenerator(BaseCaptchaGenerator):
"""算式验证码生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["math"]
self.width, self.height = self.cfg["image_size"]
self.operators = self.cfg["operators"]
self.op_lo, self.op_hi = self.cfg["operand_range"]
# 预加载可用字体
self._fonts: list[str] = []
for p in _FONT_PATHS:
try:
ImageFont.truetype(p, 20)
self._fonts.append(p)
except OSError:
continue
if not self._fonts:
raise RuntimeError("未找到任何可用字体,无法生成验证码")
# ----------------------------------------------------------
# 公共接口
# ----------------------------------------------------------
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
# 1. 生成算式
if text is None:
a, op, b = self._random_expression(rng)
text = f"{a}{op}{b}"
else:
a, op, b = self._parse_expression(text)
# 显示文本: "3+8=?"
display = f"{a}{_OP_DISPLAY.get(op, op)}{b}=?"
# 2. 浅色背景
bg_lo, bg_hi = self.cfg["bg_color_range"]
bg = tuple(rng.randint(bg_lo, bg_hi) for _ in range(3))
img = Image.new("RGB", (self.width, self.height), bg)
# 3. 绘制算式文本
self._draw_expression(img, display, rng)
# 4. 干扰线
self._draw_noise_lines(img, rng)
# 5. 轻微模糊
img = img.filter(ImageFilter.GaussianBlur(radius=0.6))
return img, text
# ----------------------------------------------------------
# 私有方法
# ----------------------------------------------------------
def _random_expression(self, rng: random.Random) -> tuple[int, str, int]:
"""随机生成一个合法算式 (a, op, b),确保结果为非负整数。"""
while True:
op = rng.choice(self.operators)
a = rng.randint(self.op_lo, self.op_hi)
b = rng.randint(self.op_lo, self.op_hi)
if op == "+":
return a, op, b
elif op == "-":
# 确保 a >= b结果非负
if a < b:
a, b = b, a
return a, op, b
elif op == "×":
# 限制乘积不过大,保持合理
if a * b <= 900:
return a, op, b
elif op == "÷":
# 只生成能整除的
if b != 0 and a % b == 0:
return a, op, b
@staticmethod
def _parse_expression(text: str) -> tuple[int, str, int]:
"""解析标签文本,如 '3+8' -> (3, '+', 8)。"""
for op in ["×", "÷", "+", "-"]:
if op in text:
parts = text.split(op, 1)
return int(parts[0]), op, int(parts[1])
raise ValueError(f"无法解析算式: {text}")
def _draw_expression(self, img: Image.Image, display: str, rng: random.Random) -> None:
"""将算式文本绘制到图片上,每个字符单独渲染并带轻微旋转。"""
n = len(display)
slot_w = self.width // n
font_size = int(min(slot_w * 0.85, self.height * 0.65))
font_size = max(font_size, 14)
for i, ch in enumerate(display):
font_path = rng.choice(self._fonts)
# 对于 × 等特殊符号,某些字体可能不支持,回退到 DejaVu
try:
font = ImageFont.truetype(font_path, font_size)
bbox = font.getbbox(ch)
if bbox[2] - bbox[0] <= 0:
raise ValueError
except (OSError, ValueError):
font = ImageFont.truetype(self._fonts[0], font_size)
bbox = font.getbbox(ch)
color = rng.choice(_DARK_COLORS)
cw = bbox[2] - bbox[0] + 4
ch_h = bbox[3] - bbox[1] + 4
char_img = Image.new("RGBA", (cw, ch_h), (0, 0, 0, 0))
ImageDraw.Draw(char_img).text((-bbox[0] + 2, -bbox[1] + 2), ch, fill=color, font=font)
# 轻微旋转
angle = rng.randint(*self.cfg["rotation_range"])
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
x = slot_w * i + (slot_w - char_img.width) // 2
y = (self.height - char_img.height) // 2 + rng.randint(-2, 2)
x = max(0, min(x, self.width - char_img.width))
y = max(0, min(y, self.height - char_img.height))
img.paste(char_img, (x, y), char_img)
def _draw_noise_lines(self, img: Image.Image, rng: random.Random) -> None:
"""绘制浅色干扰线。"""
draw = ImageDraw.Draw(img)
lo, hi = self.cfg["noise_line_range"]
num = rng.randint(lo, hi)
for _ in range(num):
x1, y1 = rng.randint(0, self.width), rng.randint(0, self.height)
x2, y2 = rng.randint(0, self.width), rng.randint(0, self.height)
color = tuple(rng.randint(150, 220) for _ in range(3))
draw.line([(x1, y1), (x2, y2)], fill=color, width=rng.randint(1, 2))

154
generators/normal_gen.py Normal file
View File

@@ -0,0 +1,154 @@
"""
普通字符验证码生成器
生成风格:
- 浅色随机背景 (RGB 各通道 230-255)
- 每个字符随机深色 (蓝/红/绿/紫/棕等)
- 字符数量 4-5 个
- 字符有 ±15° 随机旋转
- 2-5 条浅色干扰线
- 少量噪点
- 可选轻微高斯模糊
"""
import random
from PIL import Image, ImageDraw, ImageFilter, ImageFont
from config import GENERATE_CONFIG, NORMAL_CHARS
from generators.base import BaseCaptchaGenerator
# 系统可用字体列表(粗体/常规混合,增加多样性)
_FONT_PATHS = [
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSansMono-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/liberation/LiberationMono-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
"/usr/share/fonts/gnu-free/FreeMonoBold.otf",
]
# 深色调色板 (R, G, B)
_DARK_COLORS = [
(0, 0, 180), # 蓝
(180, 0, 0), # 红
(0, 130, 0), # 绿
(130, 0, 130), # 紫
(120, 60, 0), # 棕
(0, 100, 100), # 青
(80, 80, 0), # 橄榄
(0, 0, 0), # 黑
(100, 0, 50), # 暗玫红
(50, 50, 150), # 钢蓝
]
class NormalCaptchaGenerator(BaseCaptchaGenerator):
"""普通字符验证码生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["normal"]
self.chars = NORMAL_CHARS
self.width, self.height = self.cfg["image_size"]
# 预加载可用字体
self._fonts: list[str] = []
for p in _FONT_PATHS:
try:
ImageFont.truetype(p, 20)
self._fonts.append(p)
except OSError:
continue
if not self._fonts:
raise RuntimeError("未找到任何可用字体,无法生成验证码")
# ----------------------------------------------------------
# 公共接口
# ----------------------------------------------------------
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
# 1. 随机文本
if text is None:
length = rng.randint(*self.cfg["char_count_range"])
text = "".join(rng.choices(self.chars, k=length))
# 2. 浅色背景
bg_lo, bg_hi = self.cfg["bg_color_range"]
bg = tuple(rng.randint(bg_lo, bg_hi) for _ in range(3))
img = Image.new("RGB", (self.width, self.height), bg)
# 3. 逐字符绘制(旋转后粘贴)
self._draw_text(img, text, rng)
# 4. 干扰线
self._draw_noise_lines(img, rng)
# 5. 噪点
self._draw_noise_points(img, rng)
# 6. 轻微高斯模糊
if self.cfg["blur_radius"] > 0:
img = img.filter(ImageFilter.GaussianBlur(radius=self.cfg["blur_radius"]))
return img, text
# ----------------------------------------------------------
# 私有方法
# ----------------------------------------------------------
def _draw_text(self, img: Image.Image, text: str, rng: random.Random) -> None:
"""逐字符旋转并粘贴到画布上。"""
n = len(text)
# 每个字符的水平可用宽度
slot_w = self.width // n
font_size = int(min(slot_w * 0.9, self.height * 0.7))
font_size = max(font_size, 12)
for i, ch in enumerate(text):
font_path = rng.choice(self._fonts)
font = ImageFont.truetype(font_path, font_size)
color = rng.choice(_DARK_COLORS)
# 绘制单字符到临时透明图层
bbox = font.getbbox(ch)
cw = bbox[2] - bbox[0] + 4
ch_h = bbox[3] - bbox[1] + 4
char_img = Image.new("RGBA", (cw, ch_h), (0, 0, 0, 0))
ImageDraw.Draw(char_img).text((-bbox[0] + 2, -bbox[1] + 2), ch, fill=color, font=font)
# 随机旋转
angle = rng.randint(*self.cfg["rotation_range"])
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
# 粘贴位置
x = slot_w * i + (slot_w - char_img.width) // 2
y = (self.height - char_img.height) // 2 + rng.randint(-3, 3)
x = max(0, min(x, self.width - char_img.width))
y = max(0, min(y, self.height - char_img.height))
img.paste(char_img, (x, y), char_img)
def _draw_noise_lines(self, img: Image.Image, rng: random.Random) -> None:
"""绘制浅色干扰线。"""
draw = ImageDraw.Draw(img)
lo, hi = self.cfg["noise_line_range"]
num = rng.randint(lo, hi)
for _ in range(num):
x1, y1 = rng.randint(0, self.width), rng.randint(0, self.height)
x2, y2 = rng.randint(0, self.width), rng.randint(0, self.height)
color = tuple(rng.randint(150, 220) for _ in range(3))
draw.line([(x1, y1), (x2, y2)], fill=color, width=rng.randint(1, 2))
def _draw_noise_points(self, img: Image.Image, rng: random.Random) -> None:
"""绘制噪点。"""
draw = ImageDraw.Draw(img)
for _ in range(self.cfg["noise_point_num"]):
x = rng.randint(0, self.width - 1)
y = rng.randint(0, self.height - 1)
color = tuple(rng.randint(0, 200) for _ in range(3))
draw.point((x, y), fill=color)

211
generators/threed_gen.py Normal file
View File

@@ -0,0 +1,211 @@
"""
3D 立体验证码生成器
生成具有 3D 透视/阴影效果的验证码:
- 使用仿射变换模拟 3D 透视
- 添加阴影效果 (偏移的深色副本)
- 字符有深度感和倾斜
- 渐变背景增强立体感
- 标签: 纯字符内容
"""
import math
import random
from PIL import Image, ImageDraw, ImageFilter, ImageFont
from config import GENERATE_CONFIG, THREED_CHARS
from generators.base import BaseCaptchaGenerator
# 字体 (粗体效果更好渲染 3D)
_FONT_PATHS = [
"/usr/share/fonts/TTF/DejaVuSans-Bold.ttf",
"/usr/share/fonts/TTF/DejaVuSerif-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSans-Bold.ttf",
"/usr/share/fonts/liberation/LiberationSerif-Bold.ttf",
"/usr/share/fonts/gnu-free/FreeSansBold.otf",
]
# 前景色 — 鲜艳、对比度高
_FRONT_COLORS = [
(220, 50, 50), # 红
(50, 100, 220), # 蓝
(30, 160, 30), # 绿
(200, 150, 0), # 金
(180, 50, 180), # 紫
(0, 160, 160), # 青
(220, 100, 0), # 橙
]
class ThreeDCaptchaGenerator(BaseCaptchaGenerator):
"""3D 立体验证码生成器。"""
def __init__(self, seed: int | None = None):
from config import RANDOM_SEED
super().__init__(seed=seed if seed is not None else RANDOM_SEED)
self.cfg = GENERATE_CONFIG["3d"]
self.chars = THREED_CHARS
self.width, self.height = self.cfg["image_size"]
# 预加载可用字体
self._fonts: list[str] = []
for p in _FONT_PATHS:
try:
ImageFont.truetype(p, 20)
self._fonts.append(p)
except OSError:
continue
if not self._fonts:
raise RuntimeError("未找到任何可用字体,无法生成验证码")
# ----------------------------------------------------------
# 公共接口
# ----------------------------------------------------------
def generate(self, text: str | None = None) -> tuple[Image.Image, str]:
rng = self.rng
# 1. 随机文本
if text is None:
length = rng.randint(*self.cfg["char_count_range"])
text = "".join(rng.choices(self.chars, k=length))
# 2. 渐变背景 (增强立体感)
img = self._gradient_background(rng)
# 3. 逐字符绘制 (阴影 + 透视 + 前景)
self._draw_3d_text(img, text, rng)
# 4. 干扰线 (较粗、有深度感)
self._draw_depth_lines(img, rng)
# 5. 轻微高斯模糊
img = img.filter(ImageFilter.GaussianBlur(radius=0.7))
return img, text
# ----------------------------------------------------------
# 私有方法
# ----------------------------------------------------------
def _gradient_background(self, rng: random.Random) -> Image.Image:
"""生成从上到下的浅色渐变背景。"""
img = Image.new("RGB", (self.width, self.height))
draw = ImageDraw.Draw(img)
# 随机两个浅色
c1 = tuple(rng.randint(200, 240) for _ in range(3))
c2 = tuple(rng.randint(180, 220) for _ in range(3))
for y in range(self.height):
ratio = y / max(self.height - 1, 1)
r = int(c1[0] + (c2[0] - c1[0]) * ratio)
g = int(c1[1] + (c2[1] - c1[1]) * ratio)
b = int(c1[2] + (c2[2] - c1[2]) * ratio)
draw.line([(0, y), (self.width, y)], fill=(r, g, b))
return img
def _draw_3d_text(self, img: Image.Image, text: str, rng: random.Random) -> None:
"""逐字符绘制 3D 效果: 阴影层 + 透视变换 + 前景层。"""
n = len(text)
slot_w = self.width // n
font_size = int(min(slot_w * 0.8, self.height * 0.65))
font_size = max(font_size, 16)
shadow_dx, shadow_dy = self.cfg["shadow_offset"]
for i, ch in enumerate(text):
font_path = rng.choice(self._fonts)
font = ImageFont.truetype(font_path, font_size)
front_color = rng.choice(_FRONT_COLORS)
# 阴影色: 对应前景色的暗化版本
shadow_color = tuple(max(0, c - 80) for c in front_color)
# 渲染单字符
bbox = font.getbbox(ch)
cw = bbox[2] - bbox[0] + 8
ch_h = bbox[3] - bbox[1] + 8
pad = max(shadow_dx, shadow_dy) + 4 # 额外空间给阴影
canvas_w = cw + pad * 2
canvas_h = ch_h + pad * 2
# --- 阴影层 ---
shadow_img = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0))
ImageDraw.Draw(shadow_img).text(
(-bbox[0] + pad + shadow_dx, -bbox[1] + pad + shadow_dy),
ch, fill=shadow_color + (180,), font=font
)
# --- 前景层 ---
front_img = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0))
ImageDraw.Draw(front_img).text(
(-bbox[0] + pad, -bbox[1] + pad),
ch, fill=front_color + (255,), font=font
)
# 合并: 先阴影后前景
char_img = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0))
char_img = Image.alpha_composite(char_img, shadow_img)
char_img = Image.alpha_composite(char_img, front_img)
# 透视变换 (仿射)
char_img = self._perspective_transform(char_img, rng)
# 随机旋转
angle = rng.randint(-20, 20)
char_img = char_img.rotate(angle, resample=Image.BICUBIC, expand=True)
# 粘贴到画布
x = slot_w * i + (slot_w - char_img.width) // 2
y = (self.height - char_img.height) // 2 + rng.randint(-4, 4)
x = max(0, min(x, self.width - char_img.width))
y = max(0, min(y, self.height - char_img.height))
img.paste(char_img, (x, y), char_img)
def _perspective_transform(self, img: Image.Image, rng: random.Random) -> Image.Image:
"""对单个字符图片施加仿射变换模拟 3D 透视。"""
w, h = img.size
intensity = self.cfg["perspective_intensity"]
# 随机 shear / scale 参数
shear_x = rng.uniform(-intensity, intensity)
shear_y = rng.uniform(-intensity * 0.5, intensity * 0.5)
scale_x = rng.uniform(1.0 - intensity * 0.3, 1.0 + intensity * 0.3)
scale_y = rng.uniform(1.0 - intensity * 0.3, 1.0 + intensity * 0.3)
# 仿射变换矩阵 (a, b, c, d, e, f) -> (x', y') = (a*x+b*y+c, d*x+e*y+f)
# Pillow transform 需要逆变换系数
a = scale_x
b = shear_x
d = shear_y
e = scale_y
# 计算偏移让中心不变
c = (1 - a) * w / 2 - b * h / 2
f = -d * w / 2 + (1 - e) * h / 2
return img.transform(
(w, h), Image.AFFINE,
(a, b, c, d, e, f),
resample=Image.BICUBIC
)
def _draw_depth_lines(self, img: Image.Image, rng: random.Random) -> None:
"""绘制有深度感的干扰线 (较粗、带阴影)。"""
draw = ImageDraw.Draw(img)
num = rng.randint(2, 4)
for _ in range(num):
x1, y1 = rng.randint(0, self.width), rng.randint(0, self.height)
x2, y2 = rng.randint(0, self.width), rng.randint(0, self.height)
# 阴影线
shadow_color = tuple(rng.randint(80, 130) for _ in range(3))
dx, dy = self.cfg["shadow_offset"]
draw.line([(x1 + dx, y1 + dy), (x2 + dx, y2 + dy)],
fill=shadow_color, width=rng.randint(2, 3))
# 前景线
color = tuple(rng.randint(120, 200) for _ in range(3))
draw.line([(x1, y1), (x2, y2)], fill=color, width=rng.randint(1, 2))

18
inference/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
"""
推理包
- pipeline.py: CaptchaPipeline 核心推理流水线
- export_onnx.py: PyTorch → ONNX 导出
- math_eval.py: 算式计算模块
"""
from inference.pipeline import CaptchaPipeline
from inference.math_eval import eval_captcha_math
from inference.export_onnx import export_model, export_all
__all__ = [
"CaptchaPipeline",
"eval_captcha_math",
"export_model",
"export_all",
]

121
inference/export_onnx.py Normal file
View File

@@ -0,0 +1,121 @@
"""
ONNX 导出脚本
从 checkpoints/ 加载训练好的 PyTorch 模型,导出为 ONNX 格式到 onnx_models/。
支持逐个导出或一次导出全部。
"""
import torch
import torch.nn as nn
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
IMAGE_SIZE,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
NUM_CAPTCHA_TYPES,
)
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
def export_model(
model: nn.Module,
model_name: str,
input_shape: tuple,
onnx_dir: str | None = None,
):
"""
导出单个模型为 ONNX。
Args:
model: 已加载权重的 PyTorch 模型
model_name: 模型名 (classifier / normal / math / threed)
input_shape: 输入形状 (C, H, W)
onnx_dir: 输出目录 (默认使用 config.ONNX_DIR)
"""
from pathlib import Path
out_dir = Path(onnx_dir) if onnx_dir else ONNX_DIR
out_dir.mkdir(parents=True, exist_ok=True)
onnx_path = out_dir / f"{model_name}.onnx"
model.eval()
model.cpu()
dummy = torch.randn(1, *input_shape)
# 分类器和识别器的 dynamic_axes 不同
if model_name == "classifier":
dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}}
else:
# CTC 模型: output shape = (T, B, C)
dynamic_axes = {"input": {0: "batch"}, "output": {1: "batch"}}
torch.onnx.export(
model,
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes if ONNX_CONFIG["dynamic_batch"] else None,
)
size_kb = onnx_path.stat().st_size / 1024
print(f"[ONNX] 导出完成: {onnx_path} ({size_kb:.1f} KB)")
def _load_and_export(model_name: str):
"""从 checkpoint 加载模型并导出 ONNX。"""
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
if not ckpt_path.exists():
print(f"[跳过] {model_name}: checkpoint 不存在 ({ckpt_path})")
return
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(f"[加载] {model_name}: epoch={ckpt.get('epoch', '?')} acc={ckpt.get('best_acc', '?')}")
if model_name == "classifier":
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES)
h, w = IMAGE_SIZE["classifier"]
input_shape = (1, h, w)
elif model_name == "normal":
chars = ckpt.get("chars", NORMAL_CHARS)
h, w = IMAGE_SIZE["normal"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "math":
chars = ckpt.get("chars", MATH_CHARS)
h, w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
elif model_name == "threed":
chars = ckpt.get("chars", THREED_CHARS)
h, w = IMAGE_SIZE["3d"]
model = ThreeDCNN(chars=chars, img_h=h, img_w=w)
input_shape = (1, h, w)
else:
print(f"[错误] 未知模型: {model_name}")
return
model.load_state_dict(ckpt["model_state_dict"])
export_model(model, model_name, input_shape)
def export_all():
"""依次导出 classifier, normal, math, threed 四个模型。"""
print("=" * 50)
print("导出全部 ONNX 模型")
print("=" * 50)
for name in ["classifier", "normal", "math", "threed"]:
_load_and_export(name)
print("\n全部导出完成。")
if __name__ == "__main__":
export_all()

66
inference/math_eval.py Normal file
View File

@@ -0,0 +1,66 @@
"""
算式计算模块
解析并计算验证码中的算式表达式。
用正则提取数字和运算符,不使用 eval()。
支持: 加减乘除,个位到两位数运算。
"""
import re
# 匹配: 数字 运算符 数字 (后面可能跟 =? 等)
_EXPR_PATTERN = re.compile(
r"(\d+)\s*([+\-×÷xX*])\s*(\d+)"
)
# 运算符归一化映射
_OP_MAP = {
"+": "+",
"-": "-",
"×": "×",
"÷": "÷",
"x": "×",
"X": "×",
"*": "×",
}
def eval_captcha_math(expr: str) -> str:
"""
解析并计算验证码算式。
支持: 加减乘除,个位到两位数运算。
输入: "3+8=?""12×3=?""15-7=?""3+8"
输出: "11""36""8"
用正则提取数字和运算符,不使用 eval()。
Raises:
ValueError: 无法解析表达式
"""
match = _EXPR_PATTERN.search(expr)
if not match:
raise ValueError(f"无法解析算式: {expr!r}")
a = int(match.group(1))
op_raw = match.group(2)
b = int(match.group(3))
op = _OP_MAP.get(op_raw, op_raw)
if op == "+":
result = a + b
elif op == "-":
result = a - b
elif op == "×":
result = a * b
elif op == "÷":
if b == 0:
raise ValueError(f"除数为零: {expr!r}")
result = a // b
else:
raise ValueError(f"不支持的运算符: {op!r} 原式: {expr!r}")
return str(result)

231
inference/pipeline.py Normal file
View File

@@ -0,0 +1,231 @@
"""
核心推理流水线
加载调度模型和所有专家模型 (ONNX 格式),提供统一的 solve(image) 接口。
推理流程:
输入图片 → 预处理 → 调度分类器 → 路由到专家模型 → CTC 解码 → 后处理 → 输出
对算式类型,解码后还会调用 math_eval 计算结果。
"""
import io
import time
from pathlib import Path
import numpy as np
from PIL import Image
from config import (
CAPTCHA_TYPES,
IMAGE_SIZE,
INFERENCE_CONFIG,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
)
from inference.math_eval import eval_captcha_math
def _try_import_ort():
"""延迟导入 onnxruntime给出友好错误提示。"""
try:
import onnxruntime as ort
return ort
except ImportError:
raise ImportError(
"推理需要 onnxruntime请安装: uv pip install onnxruntime"
)
class CaptchaPipeline:
"""
核心推理流水线。
加载调度模型和所有专家模型 (ONNX 格式)。
提供统一的 solve(image) 接口。
"""
def __init__(self, models_dir: str | None = None):
"""
初始化加载所有 ONNX 模型。
Args:
models_dir: ONNX 模型目录,默认使用 config 中的路径
"""
ort = _try_import_ort()
self.models_dir = Path(models_dir or INFERENCE_CONFIG["default_models_dir"])
self.mean = INFERENCE_CONFIG["normalize_mean"]
self.std = INFERENCE_CONFIG["normalize_std"]
# 字符集映射
self._chars = {
"normal": NORMAL_CHARS,
"math": MATH_CHARS,
"3d": THREED_CHARS,
}
# 专家模型名 → ONNX 文件名
self._model_files = {
"classifier": "classifier.onnx",
"normal": "normal.onnx",
"math": "math.onnx",
"3d": "threed.onnx",
}
# 加载所有可用模型
opts = ort.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 2
self._sessions: dict[str, "ort.InferenceSession"] = {}
for name, fname in self._model_files.items():
path = self.models_dir / fname
if path.exists():
self._sessions[name] = ort.InferenceSession(
str(path), sess_options=opts,
providers=["CPUExecutionProvider"],
)
loaded = list(self._sessions.keys())
if not loaded:
raise FileNotFoundError(
f"未找到任何 ONNX 模型,请先训练并导出模型到 {self.models_dir}"
)
# ----------------------------------------------------------
# 公共接口
# ----------------------------------------------------------
def preprocess(self, image: Image.Image, target_size: tuple[int, int]) -> np.ndarray:
"""
图片预处理: resize, grayscale, normalize, 转 numpy。
Args:
image: PIL Image
target_size: (H, W)
Returns:
(1, 1, H, W) float32 ndarray
"""
h, w = target_size
img = image.convert("L").resize((w, h), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = (arr - self.mean) / self.std
return arr.reshape(1, 1, h, w)
def classify(self, image: Image.Image) -> str:
"""
调度分类,返回类型名: 'normal' / 'math' / '3d'
Raises:
RuntimeError: 分类器模型未加载
"""
if "classifier" not in self._sessions:
raise RuntimeError("分类器模型未加载,请先训练并导出 classifier.onnx")
inp = self.preprocess(image, IMAGE_SIZE["classifier"])
session = self._sessions["classifier"]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (1, num_types)
idx = int(np.argmax(logits, axis=1)[0])
return CAPTCHA_TYPES[idx]
def solve(
self,
image,
captcha_type: str | None = None,
) -> dict:
"""
完整识别流程。
Args:
image: PIL.Image 或文件路径 (str/Path) 或 bytes
captcha_type: 指定类型可跳过分类 ('normal'/'math'/'3d')
Returns:
dict: {
"type": str, # 验证码类型
"raw": str, # OCR 原始识别结果
"result": str, # 最终答案 (算式型为计算结果)
"time_ms": float, # 推理耗时 (毫秒)
}
"""
t0 = time.perf_counter()
# 1. 解析输入
img = self._load_image(image)
# 2. 分类
if captcha_type is None:
captcha_type = self.classify(img)
# 3. 路由到专家模型
if captcha_type not in self._sessions:
raise RuntimeError(
f"专家模型 '{captcha_type}' 未加载,请先训练并导出对应 ONNX 模型"
)
size_key = captcha_type # "normal"/"math"/"3d"
inp = self.preprocess(img, IMAGE_SIZE[size_key])
session = self._sessions[captcha_type]
input_name = session.get_inputs()[0].name
logits = session.run(None, {input_name: inp})[0] # (T, 1, C)
# 4. CTC 贪心解码
chars = self._chars[captcha_type]
raw_text = self._ctc_greedy_decode(logits, chars)
# 5. 后处理
if captcha_type == "math":
try:
result = eval_captcha_math(raw_text)
except ValueError:
result = raw_text # 解析失败则返回原始文本
else:
result = raw_text
elapsed = (time.perf_counter() - t0) * 1000
return {
"type": captcha_type,
"raw": raw_text,
"result": result,
"time_ms": round(elapsed, 2),
}
# ----------------------------------------------------------
# 私有方法
# ----------------------------------------------------------
@staticmethod
def _load_image(image) -> Image.Image:
"""将多种输入类型统一转为 PIL Image。"""
if isinstance(image, Image.Image):
return image
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
if isinstance(image, bytes):
return Image.open(io.BytesIO(image)).convert("RGB")
raise TypeError(f"不支持的图片输入类型: {type(image)}")
@staticmethod
def _ctc_greedy_decode(logits: np.ndarray, chars: str) -> str:
"""
CTC 贪心解码 (numpy 版本)。
Args:
logits: (T, B, C) ONNX 输出
chars: 字符集 (不含 blank, blank=index 0)
Returns:
解码后的字符串
"""
# 取 batch=0
preds = np.argmax(logits[:, 0, :], axis=1) # (T,)
decoded = []
prev = -1
for idx in preds:
if idx != 0 and idx != prev:
decoded.append(chars[idx - 1])
prev = idx
return "".join(decoded)

16
main.py Normal file
View File

@@ -0,0 +1,16 @@
# 这是一个示例 Python 脚本。
# 按 Shift+F10 执行或将其替换为您的代码。
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
def print_hi(name):
# 在下面的代码行中使用断点来调试脚本。
print(f'Hi, {name}') # 按 Ctrl+8 切换断点。
# 按装订区域中的绿色按钮以运行脚本。
if __name__ == '__main__':
print_hi('PyCharm')
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助

18
models/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
"""
模型定义包
提供三种模型:
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
- ThreeDCNN: 3D 验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
"""
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
__all__ = [
"CaptchaClassifier",
"LiteCRNN",
"ThreeDCNN",
]

72
models/classifier.py Normal file
View File

@@ -0,0 +1,72 @@
"""
调度分类器模型
轻量 CNN 分类器,用于判断验证码类型 (normal / math / 3d)。
不同类型验证码视觉差异大,分类任务简单。
架构: 4 层卷积 + GAP + FC
输入: 灰度图 1×64×128
输出: softmax 概率分布 (num_types 个类别)
体积目标: < 500KB
"""
import torch
import torch.nn as nn
class CaptchaClassifier(nn.Module):
"""
轻量分类器。
4 层卷积 (每层 Conv + BN + ReLU + MaxPool)
→ 全局平均池化 → 全连接 → 输出类别数。
"""
def __init__(self, num_types: int = 3):
super().__init__()
self.num_types = num_types
self.features = nn.Sequential(
# block 1: 1 -> 16, 64x128 -> 32x64
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 2: 16 -> 32, 32x64 -> 16x32
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 3: 32 -> 64, 16x32 -> 8x16
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 4: 64 -> 64, 8x16 -> 4x8
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
# 全局平均池化 → 输出 (batch, 64, 1, 1)
self.gap = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(64, num_types)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, 64, 128) 灰度图
Returns:
logits: (batch, num_types) 未经 softmax 的原始输出
"""
x = self.features(x)
x = self.gap(x) # (B, 64, 1, 1)
x = x.view(x.size(0), -1) # (B, 64)
x = self.classifier(x) # (B, num_types)
return x

141
models/lite_crnn.py Normal file
View File

@@ -0,0 +1,141 @@
"""
轻量 CRNN 模型 (Convolutional Recurrent Neural Network)
用于普通字符验证码和算式验证码的 OCR 识别。
两种模式通过不同的字符集和输入尺寸区分,共享同一网络架构。
架构: CNN 特征提取 → 序列映射 → BiLSTM → 全连接 → CTC 解码
CTC 输出长度 = 特征图宽度 (经过若干次宽度方向 pool 后)
CTC blank 位于 index 0字符从 index 1 开始映射。
- normal 模式: 输入 1×40×120, 字符集 30 字符, 体积 < 2MB
- math 模式: 输入 1×40×160, 字符集 16 字符, 体积 < 2MB
"""
import torch
import torch.nn as nn
class LiteCRNN(nn.Module):
"""
轻量 CRNN + CTC。
CNN 部分对高度做 4 次 pool (40→20→10→5→1 via AdaptivePool)
宽度做 2 次 pool (保留足够序列长度给 CTC)。
RNN 部分使用单层 BiLSTM。
"""
def __init__(self, chars: str, img_h: int = 40, img_w: int = 120):
"""
Args:
chars: 字符集字符串 (不含 CTC blank)
img_h: 输入图片高度
img_w: 输入图片宽度
"""
super().__init__()
self.chars = chars
self.img_h = img_h
self.img_w = img_w
# CTC 类别数 = 字符数 + 1 (blank at index 0)
self.num_classes = len(chars) + 1
# ---- CNN 特征提取 ----
self.cnn = nn.Sequential(
# block 1: 1 -> 32, H/2, W不变
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
# block 2: 32 -> 64, H/2, W/2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # H/2, W/2
# block 3: 64 -> 128, H/2, W不变
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
# block 4: 128 -> 128, H/2, W/2
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # H/2, W/2
)
# 经过 4 次高度 pool: img_h / 16 (如 40 → 2, 不够整除时用自适应)
# 用 AdaptiveAvgPool 把高度压到 1
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None)) # (B, 128, 1, W')
# ---- RNN 序列建模 ----
self.rnn_input_size = 128
self.rnn_hidden = 96
self.rnn = nn.LSTM(
input_size=self.rnn_input_size,
hidden_size=self.rnn_hidden,
num_layers=1,
batch_first=True,
bidirectional=True,
)
# ---- 输出层 ----
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, H, W) 灰度图
Returns:
logits: (seq_len, batch, num_classes)
即 CTC 所需的 (T, B, C) 格式
"""
# CNN
conv = self.cnn(x) # (B, 128, H', W')
conv = self.adaptive_pool(conv) # (B, 128, 1, W')
conv = conv.squeeze(2) # (B, 128, W')
conv = conv.permute(0, 2, 1) # (B, W', 128) — batch_first 序列
# RNN
rnn_out, _ = self.rnn(conv) # (B, W', 256)
# FC
logits = self.fc(rnn_out) # (B, W', num_classes)
logits = logits.permute(1, 0, 2) # (T, B, C) — CTC 格式
return logits
@property
def seq_len(self) -> int:
"""根据输入宽度计算 CTC 序列长度 (特征图宽度)。"""
# 宽度经过 2 次 /2 的 pool
return self.img_w // 4
# ----------------------------------------------------------
# CTC 贪心解码
# ----------------------------------------------------------
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
"""
CTC 贪心解码。
Args:
logits: (T, B, C) 模型原始输出
Returns:
解码后的字符串列表,长度 = batch size
"""
# (T, B, C) -> (B, T)
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
results = []
for pred in preds:
chars = []
prev = -1
for idx in pred.tolist():
if idx != 0 and idx != prev: # 0 = blank
chars.append(self.chars[idx - 1]) # 字符从 index 1 开始
prev = idx
results.append("".join(chars))
return results

155
models/threed_cnn.py Normal file
View File

@@ -0,0 +1,155 @@
"""
3D 立体验证码专用模型
采用更深的 CNN backbone类 ResNet 残差块)+ CRNN 序列建模,
以更强的特征提取能力处理透视变形和阴影效果。
架构: ResNet-lite backbone → 自适应池化 → BiLSTM → FC → CTC
输入: 灰度图 1×60×160
体积目标: < 5MB
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
"""简化残差块: Conv-BN-ReLU-Conv-BN + shortcut。"""
def __init__(self, channels: int):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
out = self.bn2(self.conv2(out))
out = F.relu(out + residual, inplace=True)
return out
class ThreeDCNN(nn.Module):
"""
3D 验证码识别专用模型。
backbone 使用 5 层卷积(含 2 个残差块),通道数逐步增长:
1 → 32 → 64 → 64(res) → 128 → 128(res)
高度通过 pool 压缩后再用自适应池化归一,宽度保留序列长度。
之后接 BiLSTM + FC 做 CTC 序列输出。
"""
def __init__(self, chars: str, img_h: int = 60, img_w: int = 160):
"""
Args:
chars: 字符集字符串 (不含 CTC blank)
img_h: 输入图片高度
img_w: 输入图片宽度
"""
super().__init__()
self.chars = chars
self.img_h = img_h
self.img_w = img_w
self.num_classes = len(chars) + 1 # +1 for CTC blank
# ---- ResNet-lite backbone ----
self.backbone = nn.Sequential(
# stage 1: 1 -> 32, H/2, W不变
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)),
# stage 2: 32 -> 64, H/2, W/2
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# stage 3: 残差块 64 -> 64
ResidualBlock(64),
# stage 4: 64 -> 128, H/2, W/2
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# stage 5: 残差块 128 -> 128
ResidualBlock(128),
# stage 6: 128 -> 128, H/2, W不变
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)),
)
# 高度方向自适应压到 1宽度保持
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None))
# ---- RNN 序列建模 ----
self.rnn_input_size = 128
self.rnn_hidden = 128
self.rnn = nn.LSTM(
input_size=self.rnn_input_size,
hidden_size=self.rnn_hidden,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.2,
)
# ---- 输出层 ----
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, H, W) 灰度图
Returns:
logits: (seq_len, batch, num_classes) CTC 格式 (T, B, C)
"""
conv = self.backbone(x) # (B, 128, H', W')
conv = self.adaptive_pool(conv) # (B, 128, 1, W')
conv = conv.squeeze(2) # (B, 128, W')
conv = conv.permute(0, 2, 1) # (B, W', 128)
rnn_out, _ = self.rnn(conv) # (B, W', 256)
logits = self.fc(rnn_out) # (B, W', num_classes)
logits = logits.permute(1, 0, 2) # (T, B, C)
return logits
@property
def seq_len(self) -> int:
"""CTC 序列长度 = 输入宽度经过 2 次 W/2 pool 后的宽度。"""
return self.img_w // 4
# ----------------------------------------------------------
# CTC 贪心解码
# ----------------------------------------------------------
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
"""
CTC 贪心解码。
Args:
logits: (T, B, C) 模型原始输出
Returns:
解码后的字符串列表
"""
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
results = []
for pred in preds:
chars = []
prev = -1
for idx in pred.tolist():
if idx != 0 and idx != prev:
chars.append(self.chars[idx - 1])
prev = idx
results.append("".join(chars))
return results

25
pyproject.toml Normal file
View File

@@ -0,0 +1,25 @@
[project]
name = "captchbreaker"
version = "0.1.0"
description = "验证码识别多模型系统 - 调度模型 + 多专家模型两级架构"
requires-python = ">=3.10"
dependencies = [
"torch>=2.0.0",
"torchvision>=0.15.0",
"onnx>=1.14.0",
"onnxscript>=0.6.0",
"onnxruntime>=1.15.0",
"pillow>=10.0.0",
"numpy>=1.24.0",
"tqdm>=4.65.0",
]
[project.optional-dependencies]
server = [
"fastapi>=0.100.0",
"uvicorn>=0.23.0",
"python-multipart>=0.0.6",
]
[project.scripts]
captcha = "cli:main"

3
tests/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
测试包
"""

10
training/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
训练脚本包
- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
- train_math.py: 训练算式识别 (LiteCRNN - math)
- train_3d.py: 训练 3D 立体识别 (ThreeDCNN)
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
"""

159
training/dataset.py Normal file
View File

@@ -0,0 +1,159 @@
"""
通用 Dataset 类
提供两种数据集:
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
文件名格式约定: {label}_{任意}.png
- 分类器: label 可为任意字符,所在子目录名即为类别
- 识别器: label 即标注内容 (如 "A3B8""3+8")
"""
import os
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from config import AUGMENT_CONFIG
# ============================================================
# 增强 / 推理 transform 工厂函数
# ============================================================
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
"""训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
"""验证 / 推理时 transform (无增强)。"""
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# ============================================================
# CRNN / CTC 识别用数据集
# ============================================================
class CRNNDataset(Dataset):
"""
CTC 识别数据集。
从目录中读取 {label}_{xxx}.png 文件,
将 label 编码为整数序列 (CTC target)。
"""
def __init__(
self,
dirs: list[str | Path],
chars: str,
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
chars: 字符集字符串 (不含 CTC blank)
transform: 图片预处理/增强
"""
self.chars = chars
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
self.transform = transform
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
self.samples.append((str(f), label))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
# 编码标签为整数序列
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
return img, target, label
@staticmethod
def collate_fn(batch):
"""自定义 collate: 图片堆叠为 tensor标签拼接为 1D tensor。"""
import torch
images, targets, labels = zip(*batch)
images = torch.stack(images, 0)
target_lengths = torch.IntTensor([len(t) for t in targets])
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
return images, targets_flat, target_lengths, list(labels)
# ============================================================
# 分类器用数据集
# ============================================================
class CaptchaDataset(Dataset):
"""
分类器训练数据集。
每个子目录名为类别名 (如 "normal", "math", "3d")
目录内所有 .png 文件属于该类。
"""
def __init__(
self,
root_dir: str | Path,
class_names: list[str],
transform: transforms.Compose | None = None,
):
"""
Args:
root_dir: 根目录,包含以类别名命名的子文件夹
class_names: 类别名列表 (顺序即标签索引)
transform: 图片预处理/增强
"""
self.class_names = class_names
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
self.transform = transform
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
root = Path(root_dir)
for cls_name in class_names:
cls_dir = root / cls_name
if not cls_dir.exists():
continue
for f in sorted(cls_dir.glob("*.png")):
self.samples.append((str(f), self.class_to_idx[cls_name]))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label

40
training/train_3d.py Normal file
View File

@@ -0,0 +1,40 @@
"""
训练 3D 立体验证码识别模型 (ThreeDCNN)
用法: python -m training.train_3d
"""
from config import (
THREED_CHARS,
IMAGE_SIZE,
SYNTHETIC_3D_DIR,
REAL_3D_DIR,
)
from generators.threed_gen import ThreeDCaptchaGenerator
from models.threed_cnn import ThreeDCNN
from training.train_utils import train_ctc_model
def main():
img_h, img_w = IMAGE_SIZE["3d"]
model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练 3D 立体验证码识别模型 (ThreeDCNN)")
print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_ctc_model(
model_name="threed",
model=model,
chars=THREED_CHARS,
synthetic_dir=SYNTHETIC_3D_DIR,
real_dir=REAL_3D_DIR,
generator_cls=ThreeDCaptchaGenerator,
config_key="threed",
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,232 @@
"""
训练调度分类器 (CaptchaClassifier)
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。
数据来源: data/classifier/ 目录 (按类型子目录组织)
用法: python -m training.train_classifier
"""
import os
import shutil
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CAPTCHA_TYPES,
NUM_CAPTCHA_TYPES,
IMAGE_SIZE,
TRAIN_CONFIG,
CLASSIFIER_DIR,
SYNTHETIC_NORMAL_DIR,
SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
get_device,
)
from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
from models.classifier import CaptchaClassifier
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
def _prepare_classifier_data():
"""
准备分类器训练数据。
策略:从各类型的合成数据目录中软链接 / 复制到 data/classifier/{type}/ 下,
每类取相同数量,保证类别平衡。
如果各类型合成数据不存在,先自动生成。
"""
cfg = TRAIN_CONFIG["classifier"]
per_class = cfg["synthetic_samples"] // NUM_CAPTCHA_TYPES
# 各类型: (类名, 合成目录, 生成器类)
type_info = [
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator),
]
for cls_name, syn_dir, gen_cls in type_info:
syn_dir = Path(syn_dir)
existing = sorted(syn_dir.glob("*.png"))
# 如果合成数据不够,生成一些
if len(existing) < per_class:
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
gen = gen_cls()
gen.generate_dataset(per_class, str(syn_dir))
existing = sorted(syn_dir.glob("*.png"))
# 复制到 classifier 目录
cls_dir = CLASSIFIER_DIR / cls_name
cls_dir.mkdir(parents=True, exist_ok=True)
already = len(list(cls_dir.glob("*.png")))
if already >= per_class:
print(f"[数据] {cls_name} 分类器数据已就绪: {already}")
continue
# 清空后重新链接
for f in cls_dir.glob("*.png"):
f.unlink()
selected = existing[:per_class]
for f in tqdm(selected, desc=f"准备 {cls_name}", leave=False):
dst = cls_dir / f.name
# 使用符号链接节省空间,失败则复制
try:
dst.symlink_to(f.resolve())
except OSError:
shutil.copy2(f, dst)
print(f"[数据] {cls_name} 分类器数据就绪: {len(selected)}")
def main():
cfg = TRAIN_CONFIG["classifier"]
img_h, img_w = IMAGE_SIZE["classifier"]
device = get_device()
print("=" * 60)
print("训练调度分类器 (CaptchaClassifier)")
print(f" 类别: {CAPTCHA_TYPES}")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
# ---- 1. 准备数据 ----
_prepare_classifier_data()
# ---- 2. 构建数据集 ----
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = CaptchaDataset(
root_dir=CLASSIFIER_DIR,
class_names=CAPTCHA_TYPES,
transform=train_transform,
)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集无增强
val_ds_clean = CaptchaDataset(
root_dir=CLASSIFIER_DIR,
class_names=CAPTCHA_TYPES,
transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 模型 / 优化器 / 调度器 ----
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
criterion = nn.CrossEntropyLoss()
best_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / "classifier.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
correct = 0
total_val = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total_val += labels.size(0)
val_acc = correct / max(total_val, 1)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={val_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if val_acc > best_acc:
best_acc = val_acc
torch.save({
"model_state_dict": model.state_dict(),
"class_names": CAPTCHA_TYPES,
"best_acc": best_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
onnx_path = ONNX_DIR / "classifier.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
return best_acc
if __name__ == "__main__":
main()

40
training/train_math.py Normal file
View File

@@ -0,0 +1,40 @@
"""
训练算式识别模型 (LiteCRNN - math 模式)
用法: python -m training.train_math
"""
from config import (
MATH_CHARS,
IMAGE_SIZE,
SYNTHETIC_MATH_DIR,
REAL_MATH_DIR,
)
from generators.math_gen import MathCaptchaGenerator
from models.lite_crnn import LiteCRNN
from training.train_utils import train_ctc_model
def main():
img_h, img_w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=MATH_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练算式识别模型 (LiteCRNN - math)")
print(f" 字符集: {MATH_CHARS} ({len(MATH_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_ctc_model(
model_name="math",
model=model,
chars=MATH_CHARS,
synthetic_dir=SYNTHETIC_MATH_DIR,
real_dir=REAL_MATH_DIR,
generator_cls=MathCaptchaGenerator,
config_key="math",
)
if __name__ == "__main__":
main()

40
training/train_normal.py Normal file
View File

@@ -0,0 +1,40 @@
"""
训练普通字符识别模型 (LiteCRNN - normal 模式)
用法: python -m training.train_normal
"""
from config import (
NORMAL_CHARS,
IMAGE_SIZE,
SYNTHETIC_NORMAL_DIR,
REAL_NORMAL_DIR,
)
from generators.normal_gen import NormalCaptchaGenerator
from models.lite_crnn import LiteCRNN
from training.train_utils import train_ctc_model
def main():
img_h, img_w = IMAGE_SIZE["normal"]
model = LiteCRNN(chars=NORMAL_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练普通字符识别模型 (LiteCRNN - normal)")
print(f" 字符集: {NORMAL_CHARS} ({len(NORMAL_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_ctc_model(
model_name="normal",
model=model,
chars=NORMAL_CHARS,
synthetic_dir=SYNTHETIC_NORMAL_DIR,
real_dir=REAL_NORMAL_DIR,
generator_cls=NormalCaptchaGenerator,
config_key="normal",
)
if __name__ == "__main__":
main()

232
training/train_utils.py Normal file
View File

@@ -0,0 +1,232 @@
"""
CTC 训练通用逻辑
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 Dataset / DataLoader含真实数据混合
3. CTC 训练循环 + cosine scheduler
4. 输出日志: epoch, loss, 整体准确率, 字符级准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束导出 ONNX
"""
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
get_device,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
# ============================================================
# 准确率计算
# ============================================================
def _calc_accuracy(preds: list[str], labels: list[str]):
"""返回 (整体准确率, 字符级准确率)。"""
total_samples = len(preds)
correct_samples = 0
total_chars = 0
correct_chars = 0
for pred, label in zip(preds, labels):
if pred == label:
correct_samples += 1
# 字符级: 逐位比较 (取较短长度)
max_len = max(len(pred), len(label))
if max_len == 0:
continue
for i in range(max_len):
total_chars += 1
if i < len(pred) and i < len(label) and pred[i] == label[i]:
correct_chars += 1
sample_acc = correct_samples / max(total_samples, 1)
char_acc = correct_chars / max(total_chars, 1)
return sample_acc, char_acc
# ============================================================
# ONNX 导出
# ============================================================
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {1: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
# ============================================================
# 核心训练函数
# ============================================================
def train_ctc_model(
model_name: str,
model: nn.Module,
chars: str,
synthetic_dir: str | Path,
real_dir: str | Path,
generator_cls,
config_key: str,
):
"""
通用 CTC 训练流程。
Args:
model_name: 模型名称 (用于保存文件: normal / math / threed)
model: PyTorch 模型实例 (LiteCRNN 或 ThreeDCNN)
chars: 字符集字符串
synthetic_dir: 合成数据目录
real_dir: 真实数据目录
generator_cls: 生成器类 (用于自动生成数据)
config_key: TRAIN_CONFIG 中的键名
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
device = get_device()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_path = Path(real_dir)
if real_path.exists() and list(real_path.glob("*.png")):
data_dirs.append(str(real_path))
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))}")
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = CRNNDataset(dirs=data_dirs, chars=chars, transform=train_transform)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集使用无增强 transform
val_ds_clean = CRNNDataset(dirs=data_dirs, chars=chars, transform=val_transform)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 优化器 / 调度器 / 损失 ----
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
best_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar:
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C)
T, B, C = logits.shape
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, _, _, labels in val_loader:
images = images.to(device)
logits = model(images)
preds = model.greedy_decode(logits)
all_preds.extend(preds)
all_labels.extend(labels)
sample_acc, char_acc = _calc_accuracy(all_preds, all_labels)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={sample_acc:.4f} "
f"char_acc={char_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if sample_acc >= best_acc:
best_acc = sample_acc
torch.save({
"model_state_dict": model.state_dict(),
"chars": chars,
"best_acc": best_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
# 加载最佳权重再导出
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
return best_acc

1236
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff