Expand 3D captcha into three subtypes: 3d_text, 3d_rotate, 3d_slider
Split the single "3d" captcha type into three independent expert models: - 3d_text: 3D perspective text OCR (renamed from old "3d", CTC-based ThreeDCNN) - 3d_rotate: rotation angle regression (new RegressionCNN, circular loss) - 3d_slider: slider offset regression (new RegressionCNN, SmoothL1 loss) CAPTCHA_TYPES expanded from 3 to 5 classes. Classifier samples updated to 50000 (10000 per class). New generators, model, dataset, training utilities, and full pipeline/export/CLI support for all subtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
通用 Dataset 类
|
||||
|
||||
提供两种数据集:
|
||||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||||
提供三种数据集:
|
||||
- CaptchaDataset: 用于分类器训练 (图片 → 类别标签)
|
||||
- CRNNDataset: 用于 CRNN/CTC 识别训练 (图片 → 字符序列编码)
|
||||
- RegressionDataset: 用于回归模型训练 (图片 → 数值标签 [0,1])
|
||||
|
||||
文件名格式约定: {label}_{任意}.png
|
||||
- 分类器: label 可为任意字符,所在子目录名即为类别
|
||||
- 识别器: label 即标注内容 (如 "A3B8" 或 "3+8")
|
||||
- 回归器: label 为数值 (如 "135" 或 "87")
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
@@ -98,7 +101,15 @@ class CRNNDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
|
||||
# 编码标签为整数序列
|
||||
target = [self.char_to_idx[c] for c in label if c in self.char_to_idx]
|
||||
target = []
|
||||
for c in label:
|
||||
if c in self.char_to_idx:
|
||||
target.append(self.char_to_idx[c])
|
||||
else:
|
||||
warnings.warn(
|
||||
f"标签 '{label}' 含字符集外字符 '{c}',已跳过 (文件: {path})",
|
||||
stacklevel=2,
|
||||
)
|
||||
return img, target, label
|
||||
|
||||
@staticmethod
|
||||
@@ -119,7 +130,7 @@ class CaptchaDataset(Dataset):
|
||||
"""
|
||||
分类器训练数据集。
|
||||
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d"),
|
||||
每个子目录名为类别名 (如 "normal", "math", "3d_text"),
|
||||
目录内所有 .png 文件属于该类。
|
||||
"""
|
||||
|
||||
@@ -157,3 +168,59 @@ class CaptchaDataset(Dataset):
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 回归模型用数据集
|
||||
# ============================================================
|
||||
class RegressionDataset(Dataset):
|
||||
"""
|
||||
回归模型数据集 (3d_rotate / 3d_slider)。
|
||||
|
||||
从目录中读取 {value}_{xxx}.png 文件,
|
||||
将 value 解析为浮点数并归一化到 [0, 1]。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
label_range: tuple[float, float],
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dirs: 数据目录列表
|
||||
label_range: (min_val, max_val) 标签原始范围
|
||||
transform: 图片预处理/增强
|
||||
"""
|
||||
self.label_range = label_range
|
||||
self.lo, self.hi = label_range
|
||||
self.transform = transform
|
||||
|
||||
self.samples: list[tuple[str, float]] = [] # (文件路径, 归一化标签)
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
for f in sorted(d.glob("*.png")):
|
||||
raw_label = f.stem.rsplit("_", 1)[0]
|
||||
try:
|
||||
value = float(raw_label)
|
||||
except ValueError:
|
||||
continue
|
||||
# 归一化到 [0, 1]
|
||||
norm = (value - self.lo) / max(self.hi - self.lo, 1e-6)
|
||||
norm = max(0.0, min(1.0, norm))
|
||||
self.samples.append((str(f), norm))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
import torch
|
||||
path, label = self.samples[idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([label], dtype=torch.float32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user