385 lines
13 KiB
Python
385 lines
13 KiB
Python
"""
|
||
通用 Dataset 类
|
||
|
||
提供三种数据集:
|
||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||
- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1])
|
||
|
||
文件名格式约定: {label}_{任意}.png
|
||
- 分类器: label 可为任意字符,所在子目录名即为类别
|
||
- 识别器: label 即标注内容 (如 "A3B8" 或 "3+8")
|
||
- 回归器: label 为数值 (如 "135" 或 "87")
|
||
"""
|
||
|
||
import os
|
||
import warnings
|
||
from pathlib import Path
|
||
|
||
from PIL import Image
|
||
from torch.utils.data import Dataset
|
||
from torchvision import transforms
|
||
|
||
from config import AUGMENT_CONFIG
|
||
|
||
|
||
# ============================================================
|
||
# 增强 / 推理 transform 工厂函数
|
||
# ============================================================
|
||
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||
"""训练时数据增强 transform。"""
|
||
aug = AUGMENT_CONFIG
|
||
return transforms.Compose([
|
||
transforms.Grayscale(),
|
||
transforms.Resize((img_h, img_w)),
|
||
transforms.RandomAffine(
|
||
degrees=aug["degrees"],
|
||
translate=aug["translate"],
|
||
scale=aug["scale"],
|
||
),
|
||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.5], [0.5]),
|
||
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
|
||
])
|
||
|
||
|
||
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||
"""验证 / 推理时 transform (无增强)。"""
|
||
return transforms.Compose([
|
||
transforms.Grayscale(),
|
||
transforms.Resize((img_h, img_w)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.5], [0.5]),
|
||
])
|
||
|
||
|
||
def build_train_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||
"""RGB 模型训练时数据增强 transform。"""
|
||
aug = AUGMENT_CONFIG
|
||
return transforms.Compose([
|
||
transforms.Resize((img_h, img_w)),
|
||
transforms.RandomAffine(
|
||
degrees=aug["degrees"],
|
||
translate=aug["translate"],
|
||
scale=aug["scale"],
|
||
),
|
||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
|
||
])
|
||
|
||
|
||
def build_val_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||
"""RGB 模型验证 / 推理时 transform。"""
|
||
return transforms.Compose([
|
||
transforms.Resize((img_h, img_w)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||
])
|
||
|
||
|
||
# ============================================================
|
||
# CRNN / CTC 识别用数据集
|
||
# ============================================================
|
||
class CRNNDataset(Dataset):
|
||
"""
|
||
CTC 识别数据集。
|
||
|
||
从目录中读取 {label}_{xxx}.png 文件,
|
||
将 label 编码为整数序列 (CTC target)。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dirs: list[str | Path],
|
||
chars: str,
|
||
transform: transforms.Compose | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
|
||
chars: 字符集字符串 (不含 CTC blank)
|
||
transform: 图片预处理/增强
|
||
"""
|
||
self.chars = chars
|
||
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
|
||
self.transform = transform
|
||
|
||
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
|
||
for d in dirs:
|
||
d = Path(d)
|
||
if not d.exists():
|
||
continue
|
||
for f in sorted(d.glob("*.png")):
|
||
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
|
||
self.samples.append((str(f), label))
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx: int):
|
||
path, label = self.samples[idx]
|
||
img = Image.open(path).convert("RGB")
|
||
if self.transform:
|
||
img = self.transform(img)
|
||
|
||
# 编码标签为整数序列
|
||
target = []
|
||
for c in label:
|
||
if c in self.char_to_idx:
|
||
target.append(self.char_to_idx[c])
|
||
else:
|
||
warnings.warn(
|
||
f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})",
|
||
stacklevel=2,
|
||
)
|
||
return img, target, label
|
||
|
||
@staticmethod
|
||
def collate_fn(batch):
|
||
"""自定义 collate: 图片堆叠为 tensor,标签拼接为 1D tensor。"""
|
||
import torch
|
||
images, targets, labels = zip(*batch)
|
||
images = torch.stack(images, 0)
|
||
target_lengths = torch.IntTensor([len(t) for t in targets])
|
||
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
|
||
return images, targets_flat, target_lengths, list(labels)
|
||
|
||
|
||
# ============================================================
|
||
# 分类器用数据集
|
||
# ============================================================
|
||
class CaptchaDataset(Dataset):
|
||
"""
|
||
分类器训练数据集。
|
||
|
||
每个子目录名为类别名 (如 "normal", "math", "3d_text"),
|
||
目录内所有 .png 文件属于该类。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
root_dir: str | Path,
|
||
class_names: list[str],
|
||
transform: transforms.Compose | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
root_dir: 根目录,包含以类别名命名的子文件夹
|
||
class_names: 类别名列表 (顺序即标签索引)
|
||
transform: 图片预处理/增强
|
||
"""
|
||
self.class_names = class_names
|
||
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
|
||
self.transform = transform
|
||
|
||
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
|
||
root = Path(root_dir)
|
||
for cls_name in class_names:
|
||
cls_dir = root / cls_name
|
||
if not cls_dir.exists():
|
||
continue
|
||
for f in sorted(cls_dir.glob("*.png")):
|
||
self.samples.append((str(f), self.class_to_idx[cls_name]))
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx: int):
|
||
path, label = self.samples[idx]
|
||
img = Image.open(path).convert("RGB")
|
||
if self.transform:
|
||
img = self.transform(img)
|
||
return img, label
|
||
|
||
|
||
# ============================================================
|
||
# 回归模型用数据集
|
||
# ============================================================
|
||
class RegressionDataset(Dataset):
|
||
"""
|
||
回归模型数据集 (3d_rotate / 3d_slider)。
|
||
|
||
从目录中读取 {value}_{xxx}.png 文件,
|
||
将 value 解析为浮点数并归一化到 [0, 1]。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dirs: list[str | Path],
|
||
label_range: tuple[float, float],
|
||
transform: transforms.Compose | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
dirs: 数据目录列表
|
||
label_range: (min_val, max_val) 标签原始范围
|
||
transform: 图片预处理/增强
|
||
"""
|
||
self.label_range = label_range
|
||
self.lo, self.hi = label_range
|
||
self.transform = transform
|
||
|
||
self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签)
|
||
for d in dirs:
|
||
d = Path(d)
|
||
if not d.exists():
|
||
continue
|
||
for f in sorted(d.glob("*.png")):
|
||
raw_label = f.stem.rsplit("_", 1)[0]
|
||
try:
|
||
value = float(raw_label)
|
||
except ValueError:
|
||
continue
|
||
# 归一化到 [0, 1]
|
||
norm = (value - self.lo) / max(self.hi - self.lo, 1e-6)
|
||
norm = max(0.0, min(1.0, norm))
|
||
self.samples.append((str(f), norm))
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx: int):
|
||
import torch
|
||
path, label = self.samples[idx]
|
||
img = Image.open(path).convert("RGB")
|
||
if self.transform:
|
||
img = self.transform(img)
|
||
return img, torch.tensor([label], dtype=torch.float32)
|
||
|
||
|
||
# ============================================================
|
||
# 旋转求解器用数据集 (sin/cos 编码)
|
||
# ============================================================
|
||
class RotateSolverDataset(Dataset):
|
||
"""
|
||
旋转求解器数据集。
|
||
|
||
从目录中读取 {angle}_{xxx}.png 文件,
|
||
将角度转换为 (sin θ, cos θ) 目标。
|
||
RGB 输入,不转灰度。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dirs: list[str | Path],
|
||
transform: transforms.Compose | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
dirs: 数据目录列表
|
||
transform: 图片预处理/增强 (RGB)
|
||
"""
|
||
import math
|
||
|
||
self.transform = transform
|
||
self.samples: list[tuple[str, float, float]] = [] # (路径, sin, cos)
|
||
|
||
for d in dirs:
|
||
d = Path(d)
|
||
if not d.exists():
|
||
continue
|
||
for f in sorted(d.glob("*.png")):
|
||
raw_label = f.stem.rsplit("_", 1)[0]
|
||
try:
|
||
angle = float(raw_label)
|
||
except ValueError:
|
||
continue
|
||
rad = math.radians(angle)
|
||
self.samples.append((str(f), math.sin(rad), math.cos(rad)))
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx: int):
|
||
import torch
|
||
path, sin_val, cos_val = self.samples[idx]
|
||
img = Image.open(path).convert("RGB")
|
||
if self.transform:
|
||
img = self.transform(img)
|
||
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
|
||
|
||
|
||
class FunCaptchaChallengeDataset(Dataset):
|
||
"""
|
||
FunCaptcha 专项 challenge 数据集。
|
||
|
||
输入为整张 challenge 图片,文件名标签表示正确候选索引:
|
||
`{answer_index}_{anything}.png/jpg/jpeg`
|
||
|
||
每个样本会被裁成:
|
||
- `candidates`: (K, C, H, W)
|
||
- `reference`: (C, H, W)
|
||
- `answer_idx`: LongTensor 标量
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dirs: list[str | Path],
|
||
task_config: dict,
|
||
transform: transforms.Compose | None = None,
|
||
):
|
||
import warnings
|
||
|
||
self.transform = transform
|
||
self.tile_w, self.tile_h = task_config["tile_size"]
|
||
self.reference_box = tuple(task_config["reference_box"])
|
||
self.num_candidates = int(task_config["num_candidates"])
|
||
self.answer_index_base = int(task_config.get("answer_index_base", 0))
|
||
|
||
self.samples: list[tuple[str, int]] = [] # (路径, 0-based answer_idx)
|
||
for d in dirs:
|
||
d = Path(d)
|
||
if not d.exists():
|
||
continue
|
||
|
||
for pattern in ("*.png", "*.jpg", "*.jpeg"):
|
||
for f in sorted(d.glob(pattern)):
|
||
raw_label = f.stem.rsplit("_", 1)[0]
|
||
try:
|
||
answer_idx = int(raw_label) - self.answer_index_base
|
||
except ValueError:
|
||
continue
|
||
|
||
if not (0 <= answer_idx < self.num_candidates):
|
||
warnings.warn(
|
||
f"FunCaptcha 标签越界: file={f} label={raw_label} "
|
||
f"expect=[{self.answer_index_base}, {self.answer_index_base + self.num_candidates - 1}]",
|
||
stacklevel=2,
|
||
)
|
||
continue
|
||
|
||
self.samples.append((str(f), answer_idx))
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, idx: int):
|
||
import torch
|
||
|
||
path, answer_idx = self.samples[idx]
|
||
image = Image.open(path).convert("RGB")
|
||
|
||
candidates = []
|
||
for i in range(self.num_candidates):
|
||
left = i * self.tile_w
|
||
box = (left, 0, left + self.tile_w, self.tile_h)
|
||
candidate = image.crop(box)
|
||
if self.transform:
|
||
candidate = self.transform(candidate)
|
||
candidates.append(candidate)
|
||
|
||
reference = image.crop(self.reference_box)
|
||
if self.transform:
|
||
reference = self.transform(reference)
|
||
|
||
return (
|
||
torch.stack(candidates, dim=0),
|
||
reference,
|
||
torch.tensor(answer_idx, dtype=torch.long),
|
||
)
|