Initialize repository
This commit is contained in:
159
training/dataset.py
Normal file
159
training/dataset.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
通用 Dataset 类
|
||||
|
||||
提供两种数据集:
|
||||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||||
|
||||
文件名格式约定: {label}_{任意}.png
|
||||
- 分类器: label 可为任意字符,所在子目录名即为类别
|
||||
- 识别器: label 即标注内容 (如 "A3B8" 或 "3+8")
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
from config import AUGMENT_CONFIG
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 增强 / 推理 transform 工厂函数
|
||||
# ============================================================
|
||||
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""训练时数据增强 transform。"""
|
||||
aug = AUGMENT_CONFIG
|
||||
return transforms.Compose([
|
||||
transforms.Grayscale(),
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.RandomAffine(
|
||||
degrees=aug["degrees"],
|
||||
translate=aug["translate"],
|
||||
scale=aug["scale"],
|
||||
),
|
||||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
|
||||
])
|
||||
|
||||
|
||||
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""验证 / 推理时 transform (无增强)。"""
|
||||
return transforms.Compose([
|
||||
transforms.Grayscale(),
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
])
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CRNN / CTC 识别用数据集
|
||||
# ============================================================
|
||||
class CRNNDataset(Dataset):
|
||||
"""
|
||||
CTC 识别数据集。
|
||||
|
||||
从目录中读取 {label}_{xxx}.png 文件,
|
||||
将 label 编码为整数序列 (CTC target)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
chars: str,
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
|
||||
chars: 字符集字符串 (不含 CTC blank)
|
||||
transform: 图片预处理/增强
|
||||
"""
|
||||
self.chars = chars
|
||||
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
|
||||
self.transform = transform
|
||||
|
||||
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
for f in sorted(d.glob("*.png")):
|
||||
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
|
||||
self.samples.append((str(f), label))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
path, label = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
# 编码标签为整数序列
|
||||
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
|
||||
return img, target, label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""自定义 collate: 图片堆叠为 tensor,标签拼接为 1D tensor。"""
|
||||
import torch
|
||||
images, targets, labels = zip(*batch)
|
||||
images = torch.stack(images, 0)
|
||||
target_lengths = torch.IntTensor([len(t) for t in targets])
|
||||
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
|
||||
return images, targets_flat, target_lengths, list(labels)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 分类器用数据集
|
||||
# ============================================================
|
||||
class CaptchaDataset(Dataset):
|
||||
"""
|
||||
分类器训练数据集。
|
||||
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d"),
|
||||
目录内所有 .png 文件属于该类。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_dir: str | Path,
|
||||
class_names: list[str],
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
root_dir: 根目录,包含以类别名命名的子文件夹
|
||||
class_names: 类别名列表 (顺序即标签索引)
|
||||
transform: 图片预处理/增强
|
||||
"""
|
||||
self.class_names = class_names
|
||||
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
|
||||
self.transform = transform
|
||||
|
||||
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
|
||||
root = Path(root_dir)
|
||||
for cls_name in class_names:
|
||||
cls_dir = root / cls_name
|
||||
if not cls_dir.exists():
|
||||
continue
|
||||
for f in sorted(cls_dir.glob("*.png")):
|
||||
self.samples.append((str(f), self.class_to_idx[cls_name]))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
path, label = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
Reference in New Issue
Block a user