""" 通用 Dataset 类 提供三种数据集: - CaptchaDataset: 用于分类器训练 (图片 → 类别标签) - CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码) - RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1]) 文件名格式约定: {label}_{任意}.png - 分类器: label 可为任意字符,所在子目录名即为类别 - 识别器: label 即标注内容 (如 "A3B8" 或 "3+8") - 回归器: label 为数值 (如 "135" 或 "87") """ import os import warnings from pathlib import Path from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from config import AUGMENT_CONFIG # ============================================================ # 增强 / 推理 transform 工厂函数 # ============================================================ def build_train_transform(img_h: int, img_w: int) -> transforms.Compose: """训练时数据增强 transform。""" aug = AUGMENT_CONFIG return transforms.Compose([ transforms.Grayscale(), transforms.Resize((img_h, img_w)), transforms.RandomAffine( degrees=aug["degrees"], translate=aug["translate"], scale=aug["scale"], ), transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]), transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]), ]) def build_val_transform(img_h: int, img_w: int) -> transforms.Compose: """验证 / 推理时 transform (无增强)。""" return transforms.Compose([ transforms.Grayscale(), transforms.Resize((img_h, img_w)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) # ============================================================ # CRNN / CTC 识别用数据集 # ============================================================ class CRNNDataset(Dataset): """ CTC 识别数据集。 从目录中读取 {label}_{xxx}.png 文件, 将 label 编码为整数序列 (CTC target)。 """ def __init__( self, dirs: list[str | Path], chars: str, transform: transforms.Compose | None = None, ): """ Args: dirs: 数据目录列表 (会合并所有目录下的 .png 文件) chars: 字符集字符串 (不含 CTC blank) transform: 图片预处理/增强 """ self.chars = chars self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0 self.transform = transform self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本) for d in dirs: d = Path(d) if not d.exists(): continue for f in sorted(d.glob("*.png")): label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8" self.samples.append((str(f), label)) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): path, label = self.samples[idx] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) # 编码标签为整数序列 target = [] for c in label: if c in self.char_to_idx: target.append(self.char_to_idx[c]) else: warnings.warn( f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})", stacklevel=2, ) return img, target, label @staticmethod def collate_fn(batch): """自定义 collate: 图片堆叠为 tensor,标签拼接为 1D tensor。""" import torch images, targets, labels = zip(*batch) images = torch.stack(images, 0) target_lengths = torch.IntTensor([len(t) for t in targets]) targets_flat = torch.IntTensor([idx for t in targets for idx in t]) return images, targets_flat, target_lengths, list(labels) # ============================================================ # 分类器用数据集 # ============================================================ class CaptchaDataset(Dataset): """ 分类器训练数据集。 每个子目录名为类别名 (如 "normal", "math", "3d_text"), 目录内所有 .png 文件属于该类。 """ def __init__( self, root_dir: str | Path, class_names: list[str], transform: transforms.Compose | None = None, ): """ Args: root_dir: 根目录,包含以类别名命名的子文件夹 class_names: 类别名列表 (顺序即标签索引) transform: 图片预处理/增强 """ self.class_names = class_names self.class_to_idx = {c: i for i, c in enumerate(class_names)} self.transform = transform self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引) root = Path(root_dir) for cls_name in class_names: cls_dir = root / cls_name if not cls_dir.exists(): continue for f in sorted(cls_dir.glob("*.png")): self.samples.append((str(f), self.class_to_idx[cls_name])) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): path, label = self.samples[idx] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) return img, label # ============================================================ # 回归模型用数据集 # ============================================================ class RegressionDataset(Dataset): """ 回归模型数据集 (3d_rotate / 3d_slider)。 从目录中读取 {value}_{xxx}.png 文件, 将 value 解析为浮点数并归一化到 [0, 1]。 """ def __init__( self, dirs: list[str | Path], label_range: tuple[float, float], transform: transforms.Compose | None = None, ): """ Args: dirs: 数据目录列表 label_range: (min_val, max_val) 标签原始范围 transform: 图片预处理/增强 """ self.label_range = label_range self.lo, self.hi = label_range self.transform = transform self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签) for d in dirs: d = Path(d) if not d.exists(): continue for f in sorted(d.glob("*.png")): raw_label = f.stem.rsplit("_", 1)[0] try: value = float(raw_label) except ValueError: continue # 归一化到 [0, 1] norm = (value - self.lo) / max(self.hi - self.lo, 1e-6) norm = max(0.0, min(1.0, norm)) self.samples.append((str(f), norm)) def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): import torch path, label = self.samples[idx] img = Image.open(path).convert("RGB") if self.transform: img = self.transform(img) return img, torch.tensor([label], dtype=torch.float32)