Align task API and add FunCaptcha support
This commit is contained in:
@@ -55,6 +55,33 @@ def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
])
|
||||
|
||||
|
||||
def build_train_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 模型训练时数据增强 transform。"""
|
||||
aug = AUGMENT_CONFIG
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.RandomAffine(
|
||||
degrees=aug["degrees"],
|
||||
translate=aug["translate"],
|
||||
scale=aug["scale"],
|
||||
),
|
||||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
|
||||
])
|
||||
|
||||
|
||||
def build_val_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 模型验证 / 推理时 transform。"""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CRNN / CTC 识别用数据集
|
||||
# ============================================================
|
||||
@@ -276,3 +303,82 @@ class RotateSolverDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
|
||||
|
||||
|
||||
class FunCaptchaChallengeDataset(Dataset):
|
||||
"""
|
||||
FunCaptcha 专项 challenge 数据集。
|
||||
|
||||
输入为整张 challenge 图片,文件名标签表示正确候选索引:
|
||||
`{answer_index}_{anything}.png/jpg/jpeg`
|
||||
|
||||
每个样本会被裁成:
|
||||
- `candidates`: (K, C, H, W)
|
||||
- `reference`: (C, H, W)
|
||||
- `answer_idx`: LongTensor 标量
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
task_config: dict,
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
import warnings
|
||||
|
||||
self.transform = transform
|
||||
self.tile_w, self.tile_h = task_config["tile_size"]
|
||||
self.reference_box = tuple(task_config["reference_box"])
|
||||
self.num_candidates = int(task_config["num_candidates"])
|
||||
self.answer_index_base = int(task_config.get("answer_index_base", 0))
|
||||
|
||||
self.samples: list[tuple[str, int]] = [] # (路径, 0-based answer_idx)
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
|
||||
for pattern in ("*.png", "*.jpg", "*.jpeg"):
|
||||
for f in sorted(d.glob(pattern)):
|
||||
raw_label = f.stem.rsplit("_", 1)[0]
|
||||
try:
|
||||
answer_idx = int(raw_label) - self.answer_index_base
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not (0 <= answer_idx < self.num_candidates):
|
||||
warnings.warn(
|
||||
f"FunCaptcha 标签越界: file={f} label={raw_label} "
|
||||
f"expect=[{self.answer_index_base}, {self.answer_index_base + self.num_candidates - 1}]",
|
||||
stacklevel=2,
|
||||
)
|
||||
continue
|
||||
|
||||
self.samples.append((str(f), answer_idx))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
import torch
|
||||
|
||||
path, answer_idx = self.samples[idx]
|
||||
image = Image.open(path).convert("RGB")
|
||||
|
||||
candidates = []
|
||||
for i in range(self.num_candidates):
|
||||
left = i * self.tile_w
|
||||
box = (left, 0, left + self.tile_w, self.tile_h)
|
||||
candidate = image.crop(box)
|
||||
if self.transform:
|
||||
candidate = self.transform(candidate)
|
||||
candidates.append(candidate)
|
||||
|
||||
reference = image.crop(self.reference_box)
|
||||
if self.transform:
|
||||
reference = self.transform(reference)
|
||||
|
||||
return (
|
||||
torch.stack(candidates, dim=0),
|
||||
reference,
|
||||
torch.tensor(answer_idx, dtype=torch.long),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user