""" 通用 Dataset 类 提供两种数据集: - CaptchaDataset: 用于分类器训练 (图片 → 类别标签) - CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码) 文件名格式约定: {label}_{任意}.png - 分类器: label 可为任意字符,所在子目录名即为类别 - 识别器: label 即标注内容 (如 "A3B8" 或 "3+8") """ import os from pathlib import Path from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from config import AUGMENT_CONFIG # ============================================================ # 增强 / 推理 transform 工厂函数 # ============================================================ def build_train_transform(img_h: int, img_w: int) -> transforms.Compose: """训练时数据增强 transform。""" aug = AUGMENT_CONFIG return transforms.Compose([ transforms.Grayscale(), transforms.Resize((img_h, img_w)), transforms.RandomAffine( degrees=aug["degrees"], translate=aug["translate"], scale=aug["scale"], ), transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]), transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]), ]) def build_val_transform(img_h: int, img_w: int) -> transforms.Compose: """验证 / 推理时 transform (无增强)。""" return transforms.Compose([ transforms.Grayscale(), transforms.Resize((img_h, img_w)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) # ============================================================ # CRNN / CTC 识别用数据集 # ============================================================ class CRNNDataset(Dataset): """ CTC 识别数据集。 从目录中读取 {label}_{xxx}.png 文件, 将 label 编码为整数序列 (CTC target)。 """ def __init__( self, dirs: list[str | Path], chars: str, transform: transforms.Compose | None = None, ): """ Args: dirs: 数据目录列表 (会合并所有目录下的 .png 文件) chars: 字符集字符串 (不含 CTC blank) transform: 图片预处理/增强 """ self.chars = chars self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0 self.transform = transform self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本) for d in dirs: d = Path(d) if not d.exists(): continue for f in sorted(d.glob("*.png")): label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8" self.samples.append((str(f), label)) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): path, label = self.samples[idx] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) # 编码标签为整数序列 target = [self.char_to_idx[c] for c in label if c in self.char_to_idx] return img, target, label @staticmethod def collate_fn(batch): """自定义 collate: 图片堆叠为 tensor,标签拼接为 1D tensor。""" import torch images, targets, labels = zip(*batch) images = torch.stack(images, 0) target_lengths = torch.IntTensor([len(t) for t in targets]) targets_flat = torch.IntTensor([idx for t in targets for idx in t]) return images, targets_flat, target_lengths, list(labels) # ============================================================ # 分类器用数据集 # ============================================================ class CaptchaDataset(Dataset): """ 分类器训练数据集。 每个子目录名为类别名 (如 "normal", "math", "3d"), 目录内所有 .png 文件属于该类。 """ def __init__( self, root_dir: str | Path, class_names: list[str], transform: transforms.Compose | None = None, ): """ Args: root_dir: 根目录,包含以类别名命名的子文件夹 class_names: 类别名列表 (顺序即标签索引) transform: 图片预处理/增强 """ self.class_names = class_names self.class_to_idx = {c: i for i, c in enumerate(class_names)} self.transform = transform self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引) root = Path(root_dir) for cls_name in class_names: cls_dir = root / cls_name if not cls_dir.exists(): continue for f in sorted(cls_dir.glob("*.png")): self.samples.append((str(f), self.class_to_idx[cls_name])) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): path, label = self.samples[idx] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) return img, label