160 lines
5.2 KiB
Python
160 lines
5.2 KiB
Python
"""
|
||
通用 Dataset 类
|
||
|
||
提供两种数据集:
|
||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||
|
||
文件名格式约定: {label}_{任意}.png
|
||
- 分类器: label 可为任意字符,所在子目录名即为类别
|
||
- 识别器: label 即标注内容 (如 "A3B8" 或 "3+8")
|
||
"""
|
||
|
||
import os
|
||
from pathlib import Path
|
||
|
||
from PIL import Image
|
||
from torch.utils.data import Dataset
|
||
from torchvision import transforms
|
||
|
||
from config import AUGMENT_CONFIG
|
||
|
||
|
||
# ============================================================
|
||
# 增强 / 推理 transform 工厂函数
|
||
# ============================================================
|
||
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||
"""训练时数据增强 transform。"""
|
||
aug = AUGMENT_CONFIG
|
||
return transforms.Compose([
|
||
transforms.Grayscale(),
|
||
transforms.Resize((img_h, img_w)),
|
||
transforms.RandomAffine(
|
||
degrees=aug["degrees"],
|
||
translate=aug["translate"],
|
||
scale=aug["scale"],
|
||
),
|
||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.5], [0.5]),
|
||
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
|
||
])
|
||
|
||
|
||
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||
"""验证 / 推理时 transform (无增强)。"""
|
||
return transforms.Compose([
|
||
transforms.Grayscale(),
|
||
transforms.Resize((img_h, img_w)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.5], [0.5]),
|
||
])
|
||
|
||
|
||
# ============================================================
|
||
# CRNN / CTC 识别用数据集
|
||
# ============================================================
|
||
class CRNNDataset(Dataset):
|
||
"""
|
||
CTC 识别数据集。
|
||
|
||
从目录中读取 {label}_{xxx}.png 文件,
|
||
将 label 编码为整数序列 (CTC target)。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dirs: list[str | Path],
|
||
chars: str,
|
||
transform: transforms.Compose | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
|
||
chars: 字符集字符串 (不含 CTC blank)
|
||
transform: 图片预处理/增强
|
||
"""
|
||
self.chars = chars
|
||
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
|
||
self.transform = transform
|
||
|
||
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
|
||
for d in dirs:
|
||
d = Path(d)
|
||
if not d.exists():
|
||
continue
|
||
for f in sorted(d.glob("*.png")):
|
||
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
|
||
self.samples.append((str(f), label))
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx: int):
|
||
path, label = self.samples[idx]
|
||
img = Image.open(path).convert("RGB")
|
||
if self.transform:
|
||
img = self.transform(img)
|
||
|
||
# 编码标签为整数序列
|
||
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
|
||
return img, target, label
|
||
|
||
@staticmethod
|
||
def collate_fn(batch):
|
||
"""自定义 collate: 图片堆叠为 tensor,标签拼接为 1D tensor。"""
|
||
import torch
|
||
images, targets, labels = zip(*batch)
|
||
images = torch.stack(images, 0)
|
||
target_lengths = torch.IntTensor([len(t) for t in targets])
|
||
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
|
||
return images, targets_flat, target_lengths, list(labels)
|
||
|
||
|
||
# ============================================================
|
||
# 分类器用数据集
|
||
# ============================================================
|
||
class CaptchaDataset(Dataset):
|
||
"""
|
||
分类器训练数据集。
|
||
|
||
每个子目录名为类别名 (如 "normal", "math", "3d"),
|
||
目录内所有 .png 文件属于该类。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
root_dir: str | Path,
|
||
class_names: list[str],
|
||
transform: transforms.Compose | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
root_dir: 根目录,包含以类别名命名的子文件夹
|
||
class_names: 类别名列表 (顺序即标签索引)
|
||
transform: 图片预处理/增强
|
||
"""
|
||
self.class_names = class_names
|
||
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
|
||
self.transform = transform
|
||
|
||
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
|
||
root = Path(root_dir)
|
||
for cls_name in class_names:
|
||
cls_dir = root / cls_name
|
||
if not cls_dir.exists():
|
||
continue
|
||
for f in sorted(cls_dir.glob("*.png")):
|
||
self.samples.append((str(f), self.class_to_idx[cls_name]))
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx: int):
|
||
path, label = self.samples[idx]
|
||
img = Image.open(path).convert("RGB")
|
||
if self.transform:
|
||
img = self.transform(img)
|
||
return img, label
|