Initialize repository

This commit is contained in:
Hua
2026-03-10 18:47:29 +08:00
commit 760b80ee5e
32 changed files with 4343 additions and 0 deletions

159
training/dataset.py Normal file
View File

@@ -0,0 +1,159 @@
"""
通用 Dataset 类
提供两种数据集:
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
文件名格式约定: {label}_{任意}.png
- 分类器: label 可为任意字符,所在子目录名即为类别
- 识别器: label 即标注内容 (如 "A3B8""3+8")
"""
import os
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from config import AUGMENT_CONFIG
# ============================================================
# 增强 / 推理 transform 工厂函数
# ============================================================
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
"""训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
"""验证 / 推理时 transform (无增强)。"""
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# ============================================================
# CRNN / CTC 识别用数据集
# ============================================================
class CRNNDataset(Dataset):
"""
CTC 识别数据集。
从目录中读取 {label}_{xxx}.png 文件,
将 label 编码为整数序列 (CTC target)。
"""
def __init__(
self,
dirs: list[str | Path],
chars: str,
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
chars: 字符集字符串 (不含 CTC blank)
transform: 图片预处理/增强
"""
self.chars = chars
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
self.transform = transform
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
self.samples.append((str(f), label))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
# 编码标签为整数序列
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
return img, target, label
@staticmethod
def collate_fn(batch):
"""自定义 collate: 图片堆叠为 tensor标签拼接为 1D tensor。"""
import torch
images, targets, labels = zip(*batch)
images = torch.stack(images, 0)
target_lengths = torch.IntTensor([len(t) for t in targets])
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
return images, targets_flat, target_lengths, list(labels)
# ============================================================
# 分类器用数据集
# ============================================================
class CaptchaDataset(Dataset):
"""
分类器训练数据集。
每个子目录名为类别名 (如 "normal", "math", "3d")
目录内所有 .png 文件属于该类。
"""
def __init__(
self,
root_dir: str | Path,
class_names: list[str],
transform: transforms.Compose | None = None,
):
"""
Args:
root_dir: 根目录,包含以类别名命名的子文件夹
class_names: 类别名列表 (顺序即标签索引)
transform: 图片预处理/增强
"""
self.class_names = class_names
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
self.transform = transform
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
root = Path(root_dir)
for cls_name in class_names:
cls_dir = root / cls_name
if not cls_dir.exists():
continue
for f in sorted(cls_dir.glob("*.png")):
self.samples.append((str(f), self.class_to_idx[cls_name]))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label