Initialize repository

This commit is contained in:
Hua
2026-03-10 18:47:29 +08:00
commit 760b80ee5e
32 changed files with 4343 additions and 0 deletions

10
training/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
训练脚本包
- dataset.py: CRNNDataset / CaptchaDataset 通用数据集类
- train_utils.py: CTC 训练通用逻辑 (train_ctc_model)
- train_normal.py: 训练普通字符识别 (LiteCRNN - normal)
- train_math.py: 训练算式识别 (LiteCRNN - math)
- train_3d.py: 训练 3D 立体识别 (ThreeDCNN)
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
"""

159
training/dataset.py Normal file
View File

@@ -0,0 +1,159 @@
"""
通用 Dataset 类
提供两种数据集:
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
文件名格式约定: {label}_{任意}.png
- 分类器: label 可为任意字符,所在子目录名即为类别
- 识别器: label 即标注内容 (如 "A3B8""3+8")
"""
import os
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from config import AUGMENT_CONFIG
# ============================================================
# 增强 / 推理 transform 工厂函数
# ============================================================
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
"""训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
"""验证 / 推理时 transform (无增强)。"""
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# ============================================================
# CRNN / CTC 识别用数据集
# ============================================================
class CRNNDataset(Dataset):
"""
CTC 识别数据集。
从目录中读取 {label}_{xxx}.png 文件,
将 label 编码为整数序列 (CTC target)。
"""
def __init__(
self,
dirs: list[str | Path],
chars: str,
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
chars: 字符集字符串 (不含 CTC blank)
transform: 图片预处理/增强
"""
self.chars = chars
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
self.transform = transform
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
self.samples.append((str(f), label))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
# 编码标签为整数序列
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
return img, target, label
@staticmethod
def collate_fn(batch):
"""自定义 collate: 图片堆叠为 tensor标签拼接为 1D tensor。"""
import torch
images, targets, labels = zip(*batch)
images = torch.stack(images, 0)
target_lengths = torch.IntTensor([len(t) for t in targets])
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
return images, targets_flat, target_lengths, list(labels)
# ============================================================
# 分类器用数据集
# ============================================================
class CaptchaDataset(Dataset):
"""
分类器训练数据集。
每个子目录名为类别名 (如 "normal", "math", "3d")
目录内所有 .png 文件属于该类。
"""
def __init__(
self,
root_dir: str | Path,
class_names: list[str],
transform: transforms.Compose | None = None,
):
"""
Args:
root_dir: 根目录,包含以类别名命名的子文件夹
class_names: 类别名列表 (顺序即标签索引)
transform: 图片预处理/增强
"""
self.class_names = class_names
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
self.transform = transform
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
root = Path(root_dir)
for cls_name in class_names:
cls_dir = root / cls_name
if not cls_dir.exists():
continue
for f in sorted(cls_dir.glob("*.png")):
self.samples.append((str(f), self.class_to_idx[cls_name]))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label

40
training/train_3d.py Normal file
View File

@@ -0,0 +1,40 @@
"""
训练 3D 立体验证码识别模型 (ThreeDCNN)
用法: python -m training.train_3d
"""
from config import (
THREED_CHARS,
IMAGE_SIZE,
SYNTHETIC_3D_DIR,
REAL_3D_DIR,
)
from generators.threed_gen import ThreeDCaptchaGenerator
from models.threed_cnn import ThreeDCNN
from training.train_utils import train_ctc_model
def main():
img_h, img_w = IMAGE_SIZE["3d"]
model = ThreeDCNN(chars=THREED_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练 3D 立体验证码识别模型 (ThreeDCNN)")
print(f" 字符集: {THREED_CHARS} ({len(THREED_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_ctc_model(
model_name="threed",
model=model,
chars=THREED_CHARS,
synthetic_dir=SYNTHETIC_3D_DIR,
real_dir=REAL_3D_DIR,
generator_cls=ThreeDCaptchaGenerator,
config_key="threed",
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,232 @@
"""
训练调度分类器 (CaptchaClassifier)
从各类型验证码数据中混合采样,训练分类器区分 normal / math / 3d。
数据来源: data/classifier/ 目录 (按类型子目录组织)
用法: python -m training.train_classifier
"""
import os
import shutil
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CAPTCHA_TYPES,
NUM_CAPTCHA_TYPES,
IMAGE_SIZE,
TRAIN_CONFIG,
CLASSIFIER_DIR,
SYNTHETIC_NORMAL_DIR,
SYNTHETIC_MATH_DIR,
SYNTHETIC_3D_DIR,
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
get_device,
)
from generators.normal_gen import NormalCaptchaGenerator
from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
from models.classifier import CaptchaClassifier
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
def _prepare_classifier_data():
"""
准备分类器训练数据。
策略:从各类型的合成数据目录中软链接 / 复制到 data/classifier/{type}/ 下,
每类取相同数量,保证类别平衡。
如果各类型合成数据不存在,先自动生成。
"""
cfg = TRAIN_CONFIG["classifier"]
per_class = cfg["synthetic_samples"] // NUM_CAPTCHA_TYPES
# 各类型: (类名, 合成目录, 生成器类)
type_info = [
("normal", SYNTHETIC_NORMAL_DIR, NormalCaptchaGenerator),
("math", SYNTHETIC_MATH_DIR, MathCaptchaGenerator),
("3d", SYNTHETIC_3D_DIR, ThreeDCaptchaGenerator),
]
for cls_name, syn_dir, gen_cls in type_info:
syn_dir = Path(syn_dir)
existing = sorted(syn_dir.glob("*.png"))
# 如果合成数据不够,生成一些
if len(existing) < per_class:
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
gen = gen_cls()
gen.generate_dataset(per_class, str(syn_dir))
existing = sorted(syn_dir.glob("*.png"))
# 复制到 classifier 目录
cls_dir = CLASSIFIER_DIR / cls_name
cls_dir.mkdir(parents=True, exist_ok=True)
already = len(list(cls_dir.glob("*.png")))
if already >= per_class:
print(f"[数据] {cls_name} 分类器数据已就绪: {already}")
continue
# 清空后重新链接
for f in cls_dir.glob("*.png"):
f.unlink()
selected = existing[:per_class]
for f in tqdm(selected, desc=f"准备 {cls_name}", leave=False):
dst = cls_dir / f.name
# 使用符号链接节省空间,失败则复制
try:
dst.symlink_to(f.resolve())
except OSError:
shutil.copy2(f, dst)
print(f"[数据] {cls_name} 分类器数据就绪: {len(selected)}")
def main():
cfg = TRAIN_CONFIG["classifier"]
img_h, img_w = IMAGE_SIZE["classifier"]
device = get_device()
print("=" * 60)
print("训练调度分类器 (CaptchaClassifier)")
print(f" 类别: {CAPTCHA_TYPES}")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
# ---- 1. 准备数据 ----
_prepare_classifier_data()
# ---- 2. 构建数据集 ----
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = CaptchaDataset(
root_dir=CLASSIFIER_DIR,
class_names=CAPTCHA_TYPES,
transform=train_transform,
)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集无增强
val_ds_clean = CaptchaDataset(
root_dir=CLASSIFIER_DIR,
class_names=CAPTCHA_TYPES,
transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 模型 / 优化器 / 调度器 ----
model = CaptchaClassifier(num_types=NUM_CAPTCHA_TYPES).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
criterion = nn.CrossEntropyLoss()
best_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / "classifier.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
correct = 0
total_val = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total_val += labels.size(0)
val_acc = correct / max(total_val, 1)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={val_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if val_acc > best_acc:
best_acc = val_acc
torch.save({
"model_state_dict": model.state_dict(),
"class_names": CAPTCHA_TYPES,
"best_acc": best_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
onnx_path = ONNX_DIR / "classifier.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
return best_acc
if __name__ == "__main__":
main()

40
training/train_math.py Normal file
View File

@@ -0,0 +1,40 @@
"""
训练算式识别模型 (LiteCRNN - math 模式)
用法: python -m training.train_math
"""
from config import (
MATH_CHARS,
IMAGE_SIZE,
SYNTHETIC_MATH_DIR,
REAL_MATH_DIR,
)
from generators.math_gen import MathCaptchaGenerator
from models.lite_crnn import LiteCRNN
from training.train_utils import train_ctc_model
def main():
img_h, img_w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=MATH_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练算式识别模型 (LiteCRNN - math)")
print(f" 字符集: {MATH_CHARS} ({len(MATH_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_ctc_model(
model_name="math",
model=model,
chars=MATH_CHARS,
synthetic_dir=SYNTHETIC_MATH_DIR,
real_dir=REAL_MATH_DIR,
generator_cls=MathCaptchaGenerator,
config_key="math",
)
if __name__ == "__main__":
main()

40
training/train_normal.py Normal file
View File

@@ -0,0 +1,40 @@
"""
训练普通字符识别模型 (LiteCRNN - normal 模式)
用法: python -m training.train_normal
"""
from config import (
NORMAL_CHARS,
IMAGE_SIZE,
SYNTHETIC_NORMAL_DIR,
REAL_NORMAL_DIR,
)
from generators.normal_gen import NormalCaptchaGenerator
from models.lite_crnn import LiteCRNN
from training.train_utils import train_ctc_model
def main():
img_h, img_w = IMAGE_SIZE["normal"]
model = LiteCRNN(chars=NORMAL_CHARS, img_h=img_h, img_w=img_w)
print("=" * 60)
print("训练普通字符识别模型 (LiteCRNN - normal)")
print(f" 字符集: {NORMAL_CHARS} ({len(NORMAL_CHARS)} 字符)")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_ctc_model(
model_name="normal",
model=model,
chars=NORMAL_CHARS,
synthetic_dir=SYNTHETIC_NORMAL_DIR,
real_dir=REAL_NORMAL_DIR,
generator_cls=NormalCaptchaGenerator,
config_key="normal",
)
if __name__ == "__main__":
main()

232
training/train_utils.py Normal file
View File

@@ -0,0 +1,232 @@
"""
CTC 训练通用逻辑
提供 train_ctc_model() 函数,被 train_normal / train_math / train_3d 共用。
职责:
1. 检查合成数据,不存在则自动调用生成器
2. 构建 Dataset / DataLoader含真实数据混合
3. CTC 训练循环 + cosine scheduler
4. 输出日志: epoch, loss, 整体准确率, 字符级准确率
5. 保存最佳模型到 checkpoints/
6. 训练结束导出 ONNX
"""
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
ONNX_DIR,
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
get_device,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
# ============================================================
# 准确率计算
# ============================================================
def _calc_accuracy(preds: list[str], labels: list[str]):
"""返回 (整体准确率, 字符级准确率)。"""
total_samples = len(preds)
correct_samples = 0
total_chars = 0
correct_chars = 0
for pred, label in zip(preds, labels):
if pred == label:
correct_samples += 1
# 字符级: 逐位比较 (取较短长度)
max_len = max(len(pred), len(label))
if max_len == 0:
continue
for i in range(max_len):
total_chars += 1
if i < len(pred) and i < len(label) and pred[i] == label[i]:
correct_chars += 1
sample_acc = correct_samples / max(total_samples, 1)
char_acc = correct_chars / max(total_chars, 1)
return sample_acc, char_acc
# ============================================================
# ONNX 导出
# ============================================================
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
dummy = torch.randn(1, 1, img_h, img_w)
torch.onnx.export(
model.cpu(),
dummy,
str(onnx_path),
opset_version=ONNX_CONFIG["opset_version"],
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {1: "batch"}}
if ONNX_CONFIG["dynamic_batch"]
else None,
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
# ============================================================
# 核心训练函数
# ============================================================
def train_ctc_model(
model_name: str,
model: nn.Module,
chars: str,
synthetic_dir: str | Path,
real_dir: str | Path,
generator_cls,
config_key: str,
):
"""
通用 CTC 训练流程。
Args:
model_name: 模型名称 (用于保存文件: normal / math / threed)
model: PyTorch 模型实例 (LiteCRNN 或 ThreeDCNN)
chars: 字符集字符串
synthetic_dir: 合成数据目录
real_dir: 真实数据目录
generator_cls: 生成器类 (用于自动生成数据)
config_key: TRAIN_CONFIG 中的键名
"""
cfg = TRAIN_CONFIG[config_key]
img_h, img_w = IMAGE_SIZE[config_key if config_key != "threed" else "3d"]
device = get_device()
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
real_path = Path(real_dir)
if real_path.exists() and list(real_path.glob("*.png")):
data_dirs.append(str(real_path))
print(f"[数据] 混合真实数据: {len(list(real_path.glob('*.png')))}")
train_transform = build_train_transform(img_h, img_w)
val_transform = build_val_transform(img_h, img_w)
full_dataset = CRNNDataset(dirs=data_dirs, chars=chars, transform=train_transform)
total = len(full_dataset)
val_size = int(total * cfg["val_split"])
train_size = total - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
# 验证集使用无增强 transform
val_ds_clean = CRNNDataset(dirs=data_dirs, chars=chars, transform=val_transform)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds, batch_size=cfg["batch_size"], shuffle=True,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean, batch_size=cfg["batch_size"], shuffle=False,
num_workers=2, collate_fn=CRNNDataset.collate_fn, pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
# ---- 3. 优化器 / 调度器 / 损失 ----
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
best_acc = 0.0
ckpt_path = CHECKPOINTS_DIR / f"{model_name}.pth"
# ---- 4. 训练循环 ----
for epoch in range(1, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for images, targets, target_lengths, _ in pbar:
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
logits = model(images) # (T, B, C)
T, B, C = logits.shape
input_lengths = torch.full((B,), T, dtype=torch.int32, device=device)
log_probs = logits.log_softmax(2)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
# ---- 5. 验证 ----
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, _, _, labels in val_loader:
images = images.to(device)
logits = model(images)
preds = model.greedy_decode(logits)
all_preds.extend(preds)
all_labels.extend(labels)
sample_acc, char_acc = _calc_accuracy(all_preds, all_labels)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={sample_acc:.4f} "
f"char_acc={char_acc:.4f} "
f"lr={lr:.6f}"
)
# ---- 6. 保存最佳模型 ----
if sample_acc >= best_acc:
best_acc = sample_acc
torch.save({
"model_state_dict": model.state_dict(),
"chars": chars,
"best_acc": best_acc,
"epoch": epoch,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
# ---- 7. 导出 ONNX ----
print(f"\n[训练完成] 最佳准确率: {best_acc:.4f}")
# 加载最佳权重再导出
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
return best_acc