Files
CaptchBreaker/training/dataset.py
2026-03-10 18:47:29 +08:00

160 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
通用 Dataset 类
提供两种数据集:
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
文件名格式约定: {label}_{任意}.png
- 分类器: label 可为任意字符,所在子目录名即为类别
- 识别器: label 即标注内容 (如 "A3B8""3+8")
"""
import os
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from config import AUGMENT_CONFIG
# ============================================================
# 增强 / 推理 transform 工厂函数
# ============================================================
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
"""训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
"""验证 / 推理时 transform (无增强)。"""
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
# ============================================================
# CRNN / CTC 识别用数据集
# ============================================================
class CRNNDataset(Dataset):
"""
CTC 识别数据集。
从目录中读取 {label}_{xxx}.png 文件,
将 label 编码为整数序列 (CTC target)。
"""
def __init__(
self,
dirs: list[str | Path],
chars: str,
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
chars: 字符集字符串 (不含 CTC blank)
transform: 图片预处理/增强
"""
self.chars = chars
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
self.transform = transform
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
self.samples.append((str(f), label))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
# 编码标签为整数序列
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
return img, target, label
@staticmethod
def collate_fn(batch):
"""自定义 collate: 图片堆叠为 tensor标签拼接为 1D tensor。"""
import torch
images, targets, labels = zip(*batch)
images = torch.stack(images, 0)
target_lengths = torch.IntTensor([len(t) for t in targets])
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
return images, targets_flat, target_lengths, list(labels)
# ============================================================
# 分类器用数据集
# ============================================================
class CaptchaDataset(Dataset):
"""
分类器训练数据集。
每个子目录名为类别名 (如 "normal", "math", "3d")
目录内所有 .png 文件属于该类。
"""
def __init__(
self,
root_dir: str | Path,
class_names: list[str],
transform: transforms.Compose | None = None,
):
"""
Args:
root_dir: 根目录,包含以类别名命名的子文件夹
class_names: 类别名列表 (顺序即标签索引)
transform: 图片预处理/增强
"""
self.class_names = class_names
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
self.transform = transform
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
root = Path(root_dir)
for cls_name in class_names:
cls_dir = root / cls_name
if not cls_dir.exists():
continue
for f in sorted(cls_dir.glob("*.png")):
self.samples.append((str(f), self.class_to_idx[cls_name]))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label