Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -55,6 +55,33 @@ def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
])
def build_train_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 模型训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 模型验证 / 推理时 transform。"""
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
# ============================================================
# CRNN / CTC 识别用数据集
# ============================================================
@@ -276,3 +303,82 @@ class RotateSolverDataset(Dataset):
img = self.transform(img)
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
class FunCaptchaChallengeDataset(Dataset):
"""
FunCaptcha 专项 challenge 数据集。
输入为整张 challenge 图片,文件名标签表示正确候选索引:
`{answer_index}_{anything}.png/jpg/jpeg`
每个样本会被裁成:
- `candidates`: (K, C, H, W)
- `reference`: (C, H, W)
- `answer_idx`: LongTensor 标量
"""
def __init__(
self,
dirs: list[str | Path],
task_config: dict,
transform: transforms.Compose | None = None,
):
import warnings
self.transform = transform
self.tile_w, self.tile_h = task_config["tile_size"]
self.reference_box = tuple(task_config["reference_box"])
self.num_candidates = int(task_config["num_candidates"])
self.answer_index_base = int(task_config.get("answer_index_base", 0))
self.samples: list[tuple[str, int]] = [] # (路径, 0-based answer_idx)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for pattern in ("*.png", "*.jpg", "*.jpeg"):
for f in sorted(d.glob(pattern)):
raw_label = f.stem.rsplit("_", 1)[0]
try:
answer_idx = int(raw_label) - self.answer_index_base
except ValueError:
continue
if not (0 <= answer_idx < self.num_candidates):
warnings.warn(
f"FunCaptcha 标签越界: file={f} label={raw_label} "
f"expect=[{self.answer_index_base}, {self.answer_index_base + self.num_candidates - 1}]",
stacklevel=2,
)
continue
self.samples.append((str(f), answer_idx))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, answer_idx = self.samples[idx]
image = Image.open(path).convert("RGB")
candidates = []
for i in range(self.num_candidates):
left = i * self.tile_w
box = (left, 0, left + self.tile_w, self.tile_h)
candidate = image.crop(box)
if self.transform:
candidate = self.transform(candidate)
candidates.append(candidate)
reference = image.crop(self.reference_box)
if self.transform:
reference = self.transform(reference)
return (
torch.stack(candidates, dim=0),
reference,
torch.tensor(answer_idx, dtype=torch.long),
)