Align task API and add FunCaptcha support
This commit is contained in:
@@ -12,4 +12,5 @@
|
||||
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
|
||||
- train_slide.py: 训练滑块缺口检测 (GapDetectorCNN)
|
||||
- train_rotate_solver.py: 训练旋转角度回归 (RotationRegressor)
|
||||
- train_funcaptcha_rollball.py: 训练 FunCaptcha 专项 Siamese 模型
|
||||
"""
|
||||
|
||||
226
training/data_fingerprint.py
Normal file
226
training/data_fingerprint.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
合成数据集指纹与清单辅助工具。
|
||||
|
||||
用于识别“样本数量足够但生成规则已变化”的情况,避免静默复用过期数据。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
MANIFEST_NAME = ".dataset_meta.json"
|
||||
|
||||
|
||||
def _stable_json(data: dict) -> str:
|
||||
return json.dumps(data, ensure_ascii=True, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _sha256_text(text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _source_hash(obj) -> str:
|
||||
try:
|
||||
source = inspect.getsource(obj)
|
||||
except (OSError, TypeError):
|
||||
source = repr(obj)
|
||||
return _sha256_text(source)
|
||||
|
||||
|
||||
def dataset_manifest_path(dataset_dir: str | Path) -> Path:
|
||||
return Path(dataset_dir) / MANIFEST_NAME
|
||||
|
||||
|
||||
def dataset_spec_hash(spec: dict) -> str:
|
||||
return _sha256_text(_stable_json(spec))
|
||||
|
||||
|
||||
def build_dataset_spec(
|
||||
generator_cls,
|
||||
*,
|
||||
config_key: str,
|
||||
config_snapshot: dict,
|
||||
) -> dict:
|
||||
"""构造可稳定哈希的数据集规格说明。"""
|
||||
return {
|
||||
"config_key": config_key,
|
||||
"generator": f"{generator_cls.__module__}.{generator_cls.__name__}",
|
||||
"generator_source_hash": _source_hash(generator_cls),
|
||||
"config_snapshot": config_snapshot,
|
||||
}
|
||||
|
||||
|
||||
def load_dataset_manifest(dataset_dir: str | Path) -> dict | None:
|
||||
path = dataset_manifest_path(dataset_dir)
|
||||
if not path.exists():
|
||||
return None
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_dataset_manifest(
|
||||
dataset_dir: str | Path,
|
||||
*,
|
||||
spec: dict,
|
||||
sample_count: int,
|
||||
adopted_existing: bool,
|
||||
) -> dict:
|
||||
path = dataset_manifest_path(dataset_dir)
|
||||
manifest = {
|
||||
"version": 1,
|
||||
"spec": spec,
|
||||
"spec_hash": dataset_spec_hash(spec),
|
||||
"sample_count": sample_count,
|
||||
"adopted_existing": adopted_existing,
|
||||
}
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
json.dump(manifest, f, ensure_ascii=True, indent=2, sort_keys=True)
|
||||
f.write("\n")
|
||||
return manifest
|
||||
|
||||
|
||||
def labels_cover_tokens(files: list[Path], required_tokens: tuple[str, ...]) -> bool:
|
||||
"""检查文件名标签中是否至少覆盖每个目标 token 一次。"""
|
||||
remaining = set(required_tokens)
|
||||
if not remaining:
|
||||
return True
|
||||
|
||||
for path in files:
|
||||
label = path.stem.rsplit("_", 1)[0]
|
||||
matched = {token for token in remaining if token in label}
|
||||
if matched:
|
||||
remaining -= matched
|
||||
if not remaining:
|
||||
return True
|
||||
return not remaining
|
||||
|
||||
|
||||
def _count_matches(count: int, *, exact_count: int | None, min_count: int | None) -> bool:
|
||||
if exact_count is not None and count != exact_count:
|
||||
return False
|
||||
if min_count is not None and count < min_count:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _dataset_valid(
|
||||
files: list[Path],
|
||||
*,
|
||||
exact_count: int | None,
|
||||
min_count: int | None,
|
||||
validator: Callable[[list[Path]], bool] | None,
|
||||
) -> bool:
|
||||
counts_ok = _count_matches(len(files), exact_count=exact_count, min_count=min_count)
|
||||
if not counts_ok:
|
||||
return False
|
||||
if validator is None:
|
||||
return True
|
||||
return validator(files)
|
||||
|
||||
|
||||
def clear_generated_dataset(dataset_dir: str | Path) -> None:
|
||||
dataset_dir = Path(dataset_dir)
|
||||
for path in dataset_dir.glob("*.png"):
|
||||
path.unlink()
|
||||
manifest = dataset_manifest_path(dataset_dir)
|
||||
if manifest.exists():
|
||||
manifest.unlink()
|
||||
|
||||
|
||||
def ensure_synthetic_dataset(
|
||||
dataset_dir: str | Path,
|
||||
*,
|
||||
generator_cls,
|
||||
spec: dict,
|
||||
gen_count: int,
|
||||
exact_count: int | None = None,
|
||||
min_count: int | None = None,
|
||||
validator: Callable[[list[Path]], bool] | None = None,
|
||||
adopt_if_missing: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
确保合成数据与当前生成规则一致。
|
||||
|
||||
返回:
|
||||
{
|
||||
"manifest": dict,
|
||||
"sample_count": int,
|
||||
"refreshed": bool,
|
||||
"adopted": bool,
|
||||
}
|
||||
"""
|
||||
dataset_dir = Path(dataset_dir)
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
files = sorted(dataset_dir.glob("*.png"))
|
||||
sample_count = len(files)
|
||||
counts_ok = _count_matches(sample_count, exact_count=exact_count, min_count=min_count)
|
||||
validator_ok = _dataset_valid(
|
||||
files,
|
||||
exact_count=exact_count,
|
||||
min_count=min_count,
|
||||
validator=validator,
|
||||
)
|
||||
manifest = load_dataset_manifest(dataset_dir)
|
||||
spec_hash = dataset_spec_hash(spec)
|
||||
|
||||
manifest_ok = (
|
||||
manifest is not None
|
||||
and manifest.get("spec_hash") == spec_hash
|
||||
and manifest.get("sample_count") == sample_count
|
||||
)
|
||||
|
||||
if manifest_ok and counts_ok and validator_ok:
|
||||
return {
|
||||
"manifest": manifest,
|
||||
"sample_count": sample_count,
|
||||
"refreshed": False,
|
||||
"adopted": False,
|
||||
}
|
||||
|
||||
if manifest is None and adopt_if_missing and counts_ok and validator_ok:
|
||||
manifest = write_dataset_manifest(
|
||||
dataset_dir,
|
||||
spec=spec,
|
||||
sample_count=sample_count,
|
||||
adopted_existing=True,
|
||||
)
|
||||
return {
|
||||
"manifest": manifest,
|
||||
"sample_count": sample_count,
|
||||
"refreshed": False,
|
||||
"adopted": True,
|
||||
}
|
||||
|
||||
clear_generated_dataset(dataset_dir)
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(gen_count, str(dataset_dir))
|
||||
|
||||
files = sorted(dataset_dir.glob("*.png"))
|
||||
sample_count = len(files)
|
||||
if not _dataset_valid(
|
||||
files,
|
||||
exact_count=exact_count,
|
||||
min_count=min_count,
|
||||
validator=validator,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"生成后的数据集不符合要求: {dataset_dir} (count={sample_count})"
|
||||
)
|
||||
manifest = write_dataset_manifest(
|
||||
dataset_dir,
|
||||
spec=spec,
|
||||
sample_count=sample_count,
|
||||
adopted_existing=False,
|
||||
)
|
||||
return {
|
||||
"manifest": manifest,
|
||||
"sample_count": sample_count,
|
||||
"refreshed": True,
|
||||
"adopted": False,
|
||||
}
|
||||
@@ -55,6 +55,33 @@ def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
])
|
||||
|
||||
|
||||
def build_train_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 模型训练时数据增强 transform。"""
|
||||
aug = AUGMENT_CONFIG
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.RandomAffine(
|
||||
degrees=aug["degrees"],
|
||||
translate=aug["translate"],
|
||||
scale=aug["scale"],
|
||||
),
|
||||
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
|
||||
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
|
||||
])
|
||||
|
||||
|
||||
def build_val_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
|
||||
"""RGB 模型验证 / 推理时 transform。"""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((img_h, img_w)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CRNN / CTC 识别用数据集
|
||||
# ============================================================
|
||||
@@ -276,3 +303,82 @@ class RotateSolverDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
|
||||
|
||||
|
||||
class FunCaptchaChallengeDataset(Dataset):
|
||||
"""
|
||||
FunCaptcha 专项 challenge 数据集。
|
||||
|
||||
输入为整张 challenge 图片,文件名标签表示正确候选索引:
|
||||
`{answer_index}_{anything}.png/jpg/jpeg`
|
||||
|
||||
每个样本会被裁成:
|
||||
- `candidates`: (K, C, H, W)
|
||||
- `reference`: (C, H, W)
|
||||
- `answer_idx`: LongTensor 标量
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dirs: list[str | Path],
|
||||
task_config: dict,
|
||||
transform: transforms.Compose | None = None,
|
||||
):
|
||||
import warnings
|
||||
|
||||
self.transform = transform
|
||||
self.tile_w, self.tile_h = task_config["tile_size"]
|
||||
self.reference_box = tuple(task_config["reference_box"])
|
||||
self.num_candidates = int(task_config["num_candidates"])
|
||||
self.answer_index_base = int(task_config.get("answer_index_base", 0))
|
||||
|
||||
self.samples: list[tuple[str, int]] = [] # (路径, 0-based answer_idx)
|
||||
for d in dirs:
|
||||
d = Path(d)
|
||||
if not d.exists():
|
||||
continue
|
||||
|
||||
for pattern in ("*.png", "*.jpg", "*.jpeg"):
|
||||
for f in sorted(d.glob(pattern)):
|
||||
raw_label = f.stem.rsplit("_", 1)[0]
|
||||
try:
|
||||
answer_idx = int(raw_label) - self.answer_index_base
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not (0 <= answer_idx < self.num_candidates):
|
||||
warnings.warn(
|
||||
f"FunCaptcha 标签越界: file={f} label={raw_label} "
|
||||
f"expect=[{self.answer_index_base}, {self.answer_index_base + self.num_candidates - 1}]",
|
||||
stacklevel=2,
|
||||
)
|
||||
continue
|
||||
|
||||
self.samples.append((str(f), answer_idx))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
import torch
|
||||
|
||||
path, answer_idx = self.samples[idx]
|
||||
image = Image.open(path).convert("RGB")
|
||||
|
||||
candidates = []
|
||||
for i in range(self.num_candidates):
|
||||
left = i * self.tile_w
|
||||
box = (left, 0, left + self.tile_w, self.tile_h)
|
||||
candidate = image.crop(box)
|
||||
if self.transform:
|
||||
candidate = self.transform(candidate)
|
||||
candidates.append(candidate)
|
||||
|
||||
reference = image.crop(self.reference_box)
|
||||
if self.transform:
|
||||
reference = self.transform(reference)
|
||||
|
||||
return (
|
||||
torch.stack(candidates, dim=0),
|
||||
reference,
|
||||
torch.tensor(answer_idx, dtype=torch.long),
|
||||
)
|
||||
|
||||
@@ -23,6 +23,11 @@ from config import (
|
||||
NUM_CAPTCHA_TYPES,
|
||||
IMAGE_SIZE,
|
||||
TRAIN_CONFIG,
|
||||
GENERATE_CONFIG,
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
REGRESSION_RANGE,
|
||||
CLASSIFIER_DIR,
|
||||
SYNTHETIC_NORMAL_DIR,
|
||||
SYNTHETIC_MATH_DIR,
|
||||
@@ -40,7 +45,13 @@ from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from models.classifier import CaptchaClassifier
|
||||
from training.data_fingerprint import (
|
||||
build_dataset_spec,
|
||||
ensure_synthetic_dataset,
|
||||
labels_cover_tokens,
|
||||
)
|
||||
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
@@ -63,27 +74,55 @@ def _prepare_classifier_data():
|
||||
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
|
||||
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
|
||||
]
|
||||
chars_map = {
|
||||
"normal": NORMAL_CHARS,
|
||||
"math": MATH_CHARS,
|
||||
"3d_text": THREED_CHARS,
|
||||
}
|
||||
|
||||
for cls_name, syn_dir, gen_cls in type_info:
|
||||
syn_dir = Path(syn_dir)
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
config_snapshot = {
|
||||
"generate_config": GENERATE_CONFIG[cls_name],
|
||||
"image_size": IMAGE_SIZE[cls_name],
|
||||
}
|
||||
if cls_name in chars_map:
|
||||
config_snapshot["chars"] = chars_map[cls_name]
|
||||
if cls_name in REGRESSION_RANGE:
|
||||
config_snapshot["label_range"] = REGRESSION_RANGE[cls_name]
|
||||
|
||||
# 如果合成数据不够,生成一些
|
||||
if len(existing) < per_class:
|
||||
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
|
||||
gen = gen_cls()
|
||||
gen.generate_dataset(per_class, str(syn_dir))
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
validator = None
|
||||
if cls_name == "math":
|
||||
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
|
||||
validator = lambda files, tokens=required_ops: labels_cover_tokens(files, tokens)
|
||||
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_dir,
|
||||
generator_cls=gen_cls,
|
||||
spec=build_dataset_spec(
|
||||
gen_cls,
|
||||
config_key=cls_name,
|
||||
config_snapshot=config_snapshot,
|
||||
),
|
||||
gen_count=per_class,
|
||||
min_count=per_class,
|
||||
validator=validator,
|
||||
adopt_if_missing=cls_name in {"normal", "math"},
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] {cls_name} 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] {cls_name} 合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] {cls_name} 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
|
||||
# 复制到 classifier 目录
|
||||
cls_dir = CLASSIFIER_DIR / cls_name
|
||||
cls_dir.mkdir(parents=True, exist_ok=True)
|
||||
already = len(list(cls_dir.glob("*.png")))
|
||||
if already >= per_class:
|
||||
print(f"[数据] {cls_name} 分类器数据已就绪: {already} 张")
|
||||
continue
|
||||
|
||||
# 清空后重新链接
|
||||
# classifier 数据是派生目录,每次重建以对齐当前源数据与指纹状态。
|
||||
for f in cls_dir.glob("*.png"):
|
||||
f.unlink()
|
||||
|
||||
@@ -239,6 +278,15 @@ def main():
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": "classifier",
|
||||
"task": "classifier",
|
||||
"class_names": list(CAPTCHA_TYPES),
|
||||
"input_shape": [1, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
return best_acc
|
||||
|
||||
248
training/train_funcaptcha_rollball.py
Normal file
248
training/train_funcaptcha_rollball.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
训练 FunCaptcha `4_3d_rollball_animals` 专项 Siamese 模型。
|
||||
|
||||
数据格式:
|
||||
data/real/funcaptcha/4_3d_rollball_animals/
|
||||
0_xxx.png
|
||||
1_xxx.jpg
|
||||
2_xxx.jpeg
|
||||
|
||||
文件名前缀表示正确候选索引。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import (
|
||||
CHECKPOINTS_DIR,
|
||||
FUN_CAPTCHA_TASKS,
|
||||
IMAGE_SIZE,
|
||||
RANDOM_SEED,
|
||||
TRAIN_CONFIG,
|
||||
get_device,
|
||||
)
|
||||
from inference.export_onnx import _load_and_export
|
||||
from models.fun_captcha_siamese import FunCaptchaSiamese
|
||||
from training.dataset import (
|
||||
FunCaptchaChallengeDataset,
|
||||
build_train_rgb_transform,
|
||||
build_val_rgb_transform,
|
||||
)
|
||||
|
||||
|
||||
QUESTION = "4_3d_rollball_animals"
|
||||
|
||||
|
||||
def _set_seed():
|
||||
random.seed(RANDOM_SEED)
|
||||
np.random.seed(RANDOM_SEED)
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(RANDOM_SEED)
|
||||
|
||||
|
||||
def _flatten_pairs(
|
||||
candidates: torch.Tensor,
|
||||
reference: torch.Tensor,
|
||||
answer_idx: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_candidates, channels, img_h, img_w = candidates.shape
|
||||
references = reference.unsqueeze(1).expand(-1, num_candidates, -1, -1, -1)
|
||||
targets = F.one_hot(answer_idx, num_classes=num_candidates).float()
|
||||
return (
|
||||
candidates.reshape(batch_size * num_candidates, channels, img_h, img_w),
|
||||
references.reshape(batch_size * num_candidates, channels, img_h, img_w),
|
||||
targets.reshape(batch_size * num_candidates, 1),
|
||||
)
|
||||
|
||||
|
||||
def _evaluate(
|
||||
model: FunCaptchaSiamese,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
) -> tuple[float, float]:
|
||||
model.eval()
|
||||
challenge_correct = 0
|
||||
challenge_total = 0
|
||||
pair_correct = 0
|
||||
pair_total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for candidates, reference, answer_idx in loader:
|
||||
candidates = candidates.to(device)
|
||||
reference = reference.to(device)
|
||||
answer_idx = answer_idx.to(device)
|
||||
|
||||
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
|
||||
candidates, reference, answer_idx
|
||||
)
|
||||
logits = model(pair_candidates, pair_reference).view(candidates.size(0), candidates.size(1))
|
||||
preds = logits.argmax(dim=1)
|
||||
challenge_correct += (preds == answer_idx).sum().item()
|
||||
challenge_total += answer_idx.size(0)
|
||||
|
||||
pair_probs = torch.sigmoid(logits)
|
||||
pair_preds = (pair_probs >= 0.5).float()
|
||||
target_matrix = pair_targets.view(candidates.size(0), candidates.size(1))
|
||||
pair_correct += (pair_preds == target_matrix).sum().item()
|
||||
pair_total += target_matrix.numel()
|
||||
|
||||
return (
|
||||
challenge_correct / max(challenge_total, 1),
|
||||
pair_correct / max(pair_total, 1),
|
||||
)
|
||||
|
||||
|
||||
def main(question: str = QUESTION):
|
||||
task_cfg = FUN_CAPTCHA_TASKS[question]
|
||||
cfg = TRAIN_CONFIG["funcaptcha_rollball_animals"]
|
||||
img_h, img_w = IMAGE_SIZE["funcaptcha_rollball_animals"]
|
||||
device = get_device()
|
||||
data_dir = Path(task_cfg["data_dir"])
|
||||
ckpt_name = task_cfg["checkpoint_name"]
|
||||
ckpt_path = CHECKPOINTS_DIR / f"{ckpt_name}.pth"
|
||||
|
||||
_set_seed()
|
||||
|
||||
print("=" * 60)
|
||||
print(f"训练 FunCaptcha 专项模型 ({question})")
|
||||
print(f" 数据目录: {data_dir}")
|
||||
print(f" 候选数: {task_cfg['num_candidates']}")
|
||||
print(f" 输入尺寸: {img_h}×{img_w}")
|
||||
print("=" * 60)
|
||||
|
||||
train_transform = build_train_rgb_transform(img_h, img_w)
|
||||
val_transform = build_val_rgb_transform(img_h, img_w)
|
||||
|
||||
full_dataset = FunCaptchaChallengeDataset(
|
||||
dirs=[data_dir],
|
||||
task_config=task_cfg,
|
||||
transform=train_transform,
|
||||
)
|
||||
total = len(full_dataset)
|
||||
if total == 0:
|
||||
raise FileNotFoundError(
|
||||
f"未找到任何 FunCaptcha 训练样本,请先准备数据: {data_dir}"
|
||||
)
|
||||
|
||||
val_size = max(1, int(total * cfg["val_split"]))
|
||||
train_size = total - val_size
|
||||
if train_size <= 0:
|
||||
raise ValueError(f"FunCaptcha 数据量过少,至少需要 2 张样本: {data_dir}")
|
||||
|
||||
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
|
||||
val_ds_clean = FunCaptchaChallengeDataset(
|
||||
dirs=[data_dir],
|
||||
task_config=task_cfg,
|
||||
transform=val_transform,
|
||||
)
|
||||
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=cfg["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_ds_clean,
|
||||
batch_size=cfg["batch_size"],
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
print(f"[数据] 训练: {train_size} 验证: {val_size}")
|
||||
|
||||
model = FunCaptchaSiamese(in_channels=task_cfg["channels"]).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
|
||||
pos_weight = torch.tensor([task_cfg["num_candidates"] - 1], dtype=torch.float32, device=device)
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||
|
||||
best_acc = 0.0
|
||||
start_epoch = 1
|
||||
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = float(ckpt.get("best_acc", 0.0))
|
||||
start_epoch = int(ckpt.get("epoch", 0)) + 1
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(f"[续训] 从 epoch {start_epoch} 继续, best_acc={best_acc:.4f}")
|
||||
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
|
||||
for candidates, reference, answer_idx in pbar:
|
||||
candidates = candidates.to(device)
|
||||
reference = reference.to(device)
|
||||
answer_idx = answer_idx.to(device)
|
||||
|
||||
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
|
||||
candidates, reference, answer_idx
|
||||
)
|
||||
logits = model(pair_candidates, pair_reference)
|
||||
loss = criterion(logits, pair_targets)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
pbar.set_postfix(loss=f"{loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
challenge_acc, pair_acc = _evaluate(model, val_loader, device)
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(
|
||||
f"Epoch {epoch:3d}/{cfg['epochs']} "
|
||||
f"loss={avg_loss:.4f} "
|
||||
f"acc={challenge_acc:.4f} "
|
||||
f"pair_acc={pair_acc:.4f} "
|
||||
f"lr={lr:.6f}"
|
||||
)
|
||||
|
||||
if challenge_acc >= best_acc:
|
||||
best_acc = challenge_acc
|
||||
torch.save(
|
||||
{
|
||||
"model_state_dict": model.state_dict(),
|
||||
"best_acc": best_acc,
|
||||
"epoch": epoch,
|
||||
"question": question,
|
||||
"num_candidates": task_cfg["num_candidates"],
|
||||
"tile_size": list(task_cfg["tile_size"]),
|
||||
"reference_box": list(task_cfg["reference_box"]),
|
||||
"answer_index_base": task_cfg["answer_index_base"],
|
||||
"input_shape": [task_cfg["channels"], img_h, img_w],
|
||||
},
|
||||
ckpt_path,
|
||||
)
|
||||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||||
|
||||
print(f"\n[训练完成] 最佳 challenge acc: {best_acc:.4f}")
|
||||
_load_and_export(task_cfg["artifact_name"])
|
||||
return best_acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -26,11 +26,16 @@ from config import (
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
GENERATE_CONFIG,
|
||||
REGRESSION_RANGE,
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_REGRESSION_RANGE,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
|
||||
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
@@ -65,7 +70,14 @@ def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
|
||||
return float(np.mean(diff))
|
||||
|
||||
|
||||
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
def _export_onnx(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
img_h: int,
|
||||
img_w: int,
|
||||
*,
|
||||
label_range: tuple[int, int] | tuple[float, float],
|
||||
):
|
||||
"""导出模型为 ONNX 格式。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||||
@@ -81,6 +93,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": model_name,
|
||||
"task": "regression",
|
||||
"label_range": list(label_range),
|
||||
"input_shape": [1, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
@@ -120,13 +141,37 @@ def train_regression_model(
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
config_snapshot = {
|
||||
"image_size": IMAGE_SIZE[config_key],
|
||||
"label_range": label_range,
|
||||
}
|
||||
if config_key in GENERATE_CONFIG:
|
||||
config_snapshot["generate_config"] = GENERATE_CONFIG[config_key]
|
||||
elif config_key == "slide_cnn":
|
||||
config_snapshot["solver_config"] = SOLVER_CONFIG["slide"]
|
||||
config_snapshot["solver_regression_range"] = SOLVER_REGRESSION_RANGE["slide"]
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
config_snapshot["train_config"] = cfg
|
||||
|
||||
dataset_spec = build_dataset_spec(
|
||||
generator_cls,
|
||||
config_key=config_key,
|
||||
config_snapshot=config_snapshot,
|
||||
)
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_path,
|
||||
generator_cls=generator_cls,
|
||||
spec=dataset_spec,
|
||||
gen_count=cfg["synthetic_samples"],
|
||||
exact_count=cfg["synthetic_samples"],
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
@@ -181,17 +226,26 @@ def train_regression_model(
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
|
||||
best_mae = ckpt.get("best_mae", float("inf"))
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
|
||||
)
|
||||
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
|
||||
|
||||
if dataset_state["refreshed"]:
|
||||
print("[续训] 合成数据已刷新,忽略旧 checkpoint,从 epoch 1 重新训练")
|
||||
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
|
||||
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
|
||||
else:
|
||||
if ckpt_data_spec_hash is None:
|
||||
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
|
||||
best_mae = ckpt.get("best_mae", float("inf"))
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
@@ -268,6 +322,7 @@ def train_regression_model(
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
"synthetic_data_spec_hash": current_data_spec_hash,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
|
||||
|
||||
@@ -275,6 +330,6 @@ def train_regression_model(
|
||||
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, model_name, img_h, img_w)
|
||||
_export_onnx(model, model_name, img_h, img_w, label_range=label_range)
|
||||
|
||||
return best_tol_acc
|
||||
|
||||
@@ -29,7 +29,9 @@ from config import (
|
||||
get_device,
|
||||
)
|
||||
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
|
||||
from training.dataset import RotateSolverDataset
|
||||
|
||||
|
||||
@@ -85,6 +87,15 @@ def _export_onnx(model: nn.Module, img_h: int, img_w: int):
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": "rotation_regressor",
|
||||
"task": "rotation_solver",
|
||||
"output_encoding": "sin_cos",
|
||||
"input_shape": [3, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
@@ -107,13 +118,30 @@ def main():
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = ROTATE_SOLVER_DATA_DIR
|
||||
syn_path.mkdir(parents=True, exist_ok=True)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = RotateSolverDataGenerator()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
dataset_spec = build_dataset_spec(
|
||||
RotateSolverDataGenerator,
|
||||
config_key="rotate_solver",
|
||||
config_snapshot={
|
||||
"solver_config": SOLVER_CONFIG["rotate"],
|
||||
"train_config": {
|
||||
"synthetic_samples": cfg["synthetic_samples"],
|
||||
},
|
||||
},
|
||||
)
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_path,
|
||||
generator_cls=RotateSolverDataGenerator,
|
||||
spec=dataset_spec,
|
||||
gen_count=cfg["synthetic_samples"],
|
||||
exact_count=cfg["synthetic_samples"],
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
@@ -229,6 +257,7 @@ def main():
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
"synthetic_data_spec_hash": current_data_spec_hash,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")
|
||||
|
||||
|
||||
@@ -27,9 +27,16 @@ from config import (
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
GENERATE_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from training.data_fingerprint import (
|
||||
build_dataset_spec,
|
||||
ensure_synthetic_dataset,
|
||||
labels_cover_tokens,
|
||||
)
|
||||
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
@@ -72,7 +79,14 @@ def _calc_accuracy(preds: list[str], labels: list[str]):
|
||||
# ============================================================
|
||||
# ONNX 导出
|
||||
# ============================================================
|
||||
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
def _export_onnx(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
img_h: int,
|
||||
img_w: int,
|
||||
*,
|
||||
chars: str,
|
||||
):
|
||||
"""导出模型为 ONNX 格式。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||||
@@ -88,6 +102,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": model_name,
|
||||
"task": "ctc",
|
||||
"chars": chars,
|
||||
"input_shape": [1, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
@@ -124,13 +147,36 @@ def train_ctc_model(
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
dataset_spec = build_dataset_spec(
|
||||
generator_cls,
|
||||
config_key=config_key,
|
||||
config_snapshot={
|
||||
"generate_config": GENERATE_CONFIG[config_key],
|
||||
"chars": chars,
|
||||
"image_size": IMAGE_SIZE[config_key],
|
||||
},
|
||||
)
|
||||
validator = None
|
||||
if config_key == "math":
|
||||
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
|
||||
validator = lambda files: labels_cover_tokens(files, required_ops)
|
||||
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_path,
|
||||
generator_cls=generator_cls,
|
||||
spec=dataset_spec,
|
||||
gen_count=cfg["synthetic_samples"],
|
||||
exact_count=cfg["synthetic_samples"],
|
||||
validator=validator,
|
||||
adopt_if_missing=config_key in {"normal", "math"},
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
@@ -176,16 +222,25 @@ def train_ctc_model(
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = ckpt.get("best_acc", 0.0)
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_acc={best_acc:.4f}"
|
||||
)
|
||||
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
|
||||
|
||||
if dataset_state["refreshed"]:
|
||||
print("[续训] 合成数据已刷新,忽略旧 checkpoint,从 epoch 1 重新训练")
|
||||
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
|
||||
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
|
||||
else:
|
||||
if ckpt_data_spec_hash is None:
|
||||
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = ckpt.get("best_acc", 0.0)
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_acc={best_acc:.4f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
@@ -249,6 +304,7 @@ def train_ctc_model(
|
||||
"chars": chars,
|
||||
"best_acc": best_acc,
|
||||
"epoch": epoch,
|
||||
"synthetic_data_spec_hash": current_data_spec_hash,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||||
|
||||
@@ -257,6 +313,6 @@ def train_ctc_model(
|
||||
# 加载最佳权重再导出
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, model_name, img_h, img_w)
|
||||
_export_onnx(model, model_name, img_h, img_w, chars=chars)
|
||||
|
||||
return best_acc
|
||||
|
||||
Reference in New Issue
Block a user