Files
CaptchBreaker/training/dataset.py

385 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
通用 Dataset 类
提供三种数据集:
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1])
文件名格式约定: {label}_{任意}.png
- 分类器: label 可为任意字符,所在子目录名即为类别
- 识别器: label 即标注内容 (如 "A3B8""3+8")
- 回归器: label 为数值 (如 "135""87")
"""
import os
import warnings
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from config import AUGMENT_CONFIG
# ============================================================
# 增强 / 推理 transform 工厂函数
# ============================================================
def build_train_transform(img_h: int, img_w: int) -> transforms.Compose:
"""训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
"""验证 / 推理时 transform (无增强)。"""
return transforms.Compose([
transforms.Grayscale(),
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def build_train_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 模型训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 模型验证 / 推理时 transform。"""
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
# ============================================================
# CRNN / CTC 识别用数据集
# ============================================================
class CRNNDataset(Dataset):
"""
CTC 识别数据集。
从目录中读取 {label}_{xxx}.png 文件,
将 label 编码为整数序列 (CTC target)。
"""
def __init__(
self,
dirs: list[str | Path],
chars: str,
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表 (会合并所有目录下的 .png 文件)
chars: 字符集字符串 (不含 CTC blank)
transform: 图片预处理/增强
"""
self.chars = chars
self.char_to_idx = {c: i + 1 for i, c in enumerate(chars)} # blank=0
self.transform = transform
self.samples: list[tuple[str, str]] = [] # (文件路径, 标签文本)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
label = f.stem.rsplit("_", 1)[0] # "A3B8_000001" -> "A3B8"
self.samples.append((str(f), label))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
# 编码标签为整数序列
target = []
for c in label:
if c in self.char_to_idx:
target.append(self.char_to_idx[c])
else:
warnings.warn(
f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})",
stacklevel=2,
)
return img, target, label
@staticmethod
def collate_fn(batch):
"""自定义 collate: 图片堆叠为 tensor标签拼接为 1D tensor。"""
import torch
images, targets, labels = zip(*batch)
images = torch.stack(images, 0)
target_lengths = torch.IntTensor([len(t) for t in targets])
targets_flat = torch.IntTensor([idx for t in targets for idx in t])
return images, targets_flat, target_lengths, list(labels)
# ============================================================
# 分类器用数据集
# ============================================================
class CaptchaDataset(Dataset):
"""
分类器训练数据集。
每个子目录名为类别名 (如 "normal", "math", "3d_text")
目录内所有 .png 文件属于该类。
"""
def __init__(
self,
root_dir: str | Path,
class_names: list[str],
transform: transforms.Compose | None = None,
):
"""
Args:
root_dir: 根目录,包含以类别名命名的子文件夹
class_names: 类别名列表 (顺序即标签索引)
transform: 图片预处理/增强
"""
self.class_names = class_names
self.class_to_idx = {c: i for i, c in enumerate(class_names)}
self.transform = transform
self.samples: list[tuple[str, int]] = [] # (文件路径, 类别索引)
root = Path(root_dir)
for cls_name in class_names:
cls_dir = root / cls_name
if not cls_dir.exists():
continue
for f in sorted(cls_dir.glob("*.png")):
self.samples.append((str(f), self.class_to_idx[cls_name]))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label
# ============================================================
# 回归模型用数据集
# ============================================================
class RegressionDataset(Dataset):
"""
回归模型数据集 (3d_rotate / 3d_slider)。
从目录中读取 {value}_{xxx}.png 文件,
将 value 解析为浮点数并归一化到 [0, 1]。
"""
def __init__(
self,
dirs: list[str | Path],
label_range: tuple[float, float],
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表
label_range: (min_val, max_val) 标签原始范围
transform: 图片预处理/增强
"""
self.label_range = label_range
self.lo, self.hi = label_range
self.transform = transform
self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
raw_label = f.stem.rsplit("_", 1)[0]
try:
value = float(raw_label)
except ValueError:
continue
# 归一化到 [0, 1]
norm = (value - self.lo) / max(self.hi - self.lo, 1e-6)
norm = max(0.0, min(1.0, norm))
self.samples.append((str(f), norm))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, label = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, torch.tensor([label], dtype=torch.float32)
# ============================================================
# 旋转求解器用数据集 (sin/cos 编码)
# ============================================================
class RotateSolverDataset(Dataset):
"""
旋转求解器数据集。
从目录中读取 {angle}_{xxx}.png 文件,
将角度转换为 (sin θ, cos θ) 目标。
RGB 输入,不转灰度。
"""
def __init__(
self,
dirs: list[str | Path],
transform: transforms.Compose | None = None,
):
"""
Args:
dirs: 数据目录列表
transform: 图片预处理/增强 (RGB)
"""
import math
self.transform = transform
self.samples: list[tuple[str, float, float]] = [] # (路径, sin, cos)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for f in sorted(d.glob("*.png")):
raw_label = f.stem.rsplit("_", 1)[0]
try:
angle = float(raw_label)
except ValueError:
continue
rad = math.radians(angle)
self.samples.append((str(f), math.sin(rad), math.cos(rad)))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, sin_val, cos_val = self.samples[idx]
img = Image.open(path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
class FunCaptchaChallengeDataset(Dataset):
"""
FunCaptcha 专项 challenge 数据集。
输入为整张 challenge 图片,文件名标签表示正确候选索引:
`{answer_index}_{anything}.png/jpg/jpeg`
每个样本会被裁成:
- `candidates`: (K, C, H, W)
- `reference`: (C, H, W)
- `answer_idx`: LongTensor 标量
"""
def __init__(
self,
dirs: list[str | Path],
task_config: dict,
transform: transforms.Compose | None = None,
):
import warnings
self.transform = transform
self.tile_w, self.tile_h = task_config["tile_size"]
self.reference_box = tuple(task_config["reference_box"])
self.num_candidates = int(task_config["num_candidates"])
self.answer_index_base = int(task_config.get("answer_index_base", 0))
self.samples: list[tuple[str, int]] = [] # (路径, 0-based answer_idx)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for pattern in ("*.png", "*.jpg", "*.jpeg"):
for f in sorted(d.glob(pattern)):
raw_label = f.stem.rsplit("_", 1)[0]
try:
answer_idx = int(raw_label) - self.answer_index_base
except ValueError:
continue
if not (0 <= answer_idx < self.num_candidates):
warnings.warn(
f"FunCaptcha 标签越界: file={f} label={raw_label} "
f"expect=[{self.answer_index_base}, {self.answer_index_base + self.num_candidates - 1}]",
stacklevel=2,
)
continue
self.samples.append((str(f), answer_idx))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, answer_idx = self.samples[idx]
image = Image.open(path).convert("RGB")
candidates = []
for i in range(self.num_candidates):
left = i * self.tile_w
box = (left, 0, left + self.tile_w, self.tile_h)
candidate = image.crop(box)
if self.transform:
candidate = self.transform(candidate)
candidates.append(candidate)
reference = image.crop(self.reference_box)
if self.transform:
reference = self.transform(reference)
return (
torch.stack(candidates, dim=0),
reference,
torch.tensor(answer_idx, dtype=torch.long),
)