Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -12,4 +12,5 @@
- train_classifier.py: 训练调度分类器 (CaptchaClassifier)
- train_slide.py: 训练滑块缺口检测 (GapDetectorCNN)
- train_rotate_solver.py: 训练旋转角度回归 (RotationRegressor)
- train_funcaptcha_rollball.py: 训练 FunCaptcha 专项 Siamese 模型
"""

View File

@@ -0,0 +1,226 @@
"""
合成数据集指纹与清单辅助工具。
用于识别“样本数量足够但生成规则已变化”的情况,避免静默复用过期数据。
"""
from __future__ import annotations
import hashlib
import inspect
import json
from pathlib import Path
from typing import Callable
MANIFEST_NAME = ".dataset_meta.json"
def _stable_json(data: dict) -> str:
return json.dumps(data, ensure_ascii=True, sort_keys=True, separators=(",", ":"))
def _sha256_text(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _source_hash(obj) -> str:
try:
source = inspect.getsource(obj)
except (OSError, TypeError):
source = repr(obj)
return _sha256_text(source)
def dataset_manifest_path(dataset_dir: str | Path) -> Path:
return Path(dataset_dir) / MANIFEST_NAME
def dataset_spec_hash(spec: dict) -> str:
return _sha256_text(_stable_json(spec))
def build_dataset_spec(
generator_cls,
*,
config_key: str,
config_snapshot: dict,
) -> dict:
"""构造可稳定哈希的数据集规格说明。"""
return {
"config_key": config_key,
"generator": f"{generator_cls.__module__}.{generator_cls.__name__}",
"generator_source_hash": _source_hash(generator_cls),
"config_snapshot": config_snapshot,
}
def load_dataset_manifest(dataset_dir: str | Path) -> dict | None:
path = dataset_manifest_path(dataset_dir)
if not path.exists():
return None
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def write_dataset_manifest(
dataset_dir: str | Path,
*,
spec: dict,
sample_count: int,
adopted_existing: bool,
) -> dict:
path = dataset_manifest_path(dataset_dir)
manifest = {
"version": 1,
"spec": spec,
"spec_hash": dataset_spec_hash(spec),
"sample_count": sample_count,
"adopted_existing": adopted_existing,
}
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=True, indent=2, sort_keys=True)
f.write("\n")
return manifest
def labels_cover_tokens(files: list[Path], required_tokens: tuple[str, ...]) -> bool:
"""检查文件名标签中是否至少覆盖每个目标 token 一次。"""
remaining = set(required_tokens)
if not remaining:
return True
for path in files:
label = path.stem.rsplit("_", 1)[0]
matched = {token for token in remaining if token in label}
if matched:
remaining -= matched
if not remaining:
return True
return not remaining
def _count_matches(count: int, *, exact_count: int | None, min_count: int | None) -> bool:
if exact_count is not None and count != exact_count:
return False
if min_count is not None and count < min_count:
return False
return True
def _dataset_valid(
files: list[Path],
*,
exact_count: int | None,
min_count: int | None,
validator: Callable[[list[Path]], bool] | None,
) -> bool:
counts_ok = _count_matches(len(files), exact_count=exact_count, min_count=min_count)
if not counts_ok:
return False
if validator is None:
return True
return validator(files)
def clear_generated_dataset(dataset_dir: str | Path) -> None:
dataset_dir = Path(dataset_dir)
for path in dataset_dir.glob("*.png"):
path.unlink()
manifest = dataset_manifest_path(dataset_dir)
if manifest.exists():
manifest.unlink()
def ensure_synthetic_dataset(
dataset_dir: str | Path,
*,
generator_cls,
spec: dict,
gen_count: int,
exact_count: int | None = None,
min_count: int | None = None,
validator: Callable[[list[Path]], bool] | None = None,
adopt_if_missing: bool = False,
) -> dict:
"""
确保合成数据与当前生成规则一致。
返回:
{
"manifest": dict,
"sample_count": int,
"refreshed": bool,
"adopted": bool,
}
"""
dataset_dir = Path(dataset_dir)
dataset_dir.mkdir(parents=True, exist_ok=True)
files = sorted(dataset_dir.glob("*.png"))
sample_count = len(files)
counts_ok = _count_matches(sample_count, exact_count=exact_count, min_count=min_count)
validator_ok = _dataset_valid(
files,
exact_count=exact_count,
min_count=min_count,
validator=validator,
)
manifest = load_dataset_manifest(dataset_dir)
spec_hash = dataset_spec_hash(spec)
manifest_ok = (
manifest is not None
and manifest.get("spec_hash") == spec_hash
and manifest.get("sample_count") == sample_count
)
if manifest_ok and counts_ok and validator_ok:
return {
"manifest": manifest,
"sample_count": sample_count,
"refreshed": False,
"adopted": False,
}
if manifest is None and adopt_if_missing and counts_ok and validator_ok:
manifest = write_dataset_manifest(
dataset_dir,
spec=spec,
sample_count=sample_count,
adopted_existing=True,
)
return {
"manifest": manifest,
"sample_count": sample_count,
"refreshed": False,
"adopted": True,
}
clear_generated_dataset(dataset_dir)
gen = generator_cls()
gen.generate_dataset(gen_count, str(dataset_dir))
files = sorted(dataset_dir.glob("*.png"))
sample_count = len(files)
if not _dataset_valid(
files,
exact_count=exact_count,
min_count=min_count,
validator=validator,
):
raise RuntimeError(
f"生成后的数据集不符合要求: {dataset_dir} (count={sample_count})"
)
manifest = write_dataset_manifest(
dataset_dir,
spec=spec,
sample_count=sample_count,
adopted_existing=False,
)
return {
"manifest": manifest,
"sample_count": sample_count,
"refreshed": True,
"adopted": False,
}

View File

@@ -55,6 +55,33 @@ def build_val_transform(img_h: int, img_w: int) -> transforms.Compose:
])
def build_train_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 模型训练时数据增强 transform。"""
aug = AUGMENT_CONFIG
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.RandomAffine(
degrees=aug["degrees"],
translate=aug["translate"],
scale=aug["scale"],
),
transforms.ColorJitter(brightness=aug["brightness"], contrast=aug["contrast"]),
transforms.GaussianBlur(aug["blur_kernel"], sigma=aug["blur_sigma"]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing(p=aug["erasing_prob"], scale=aug["erasing_scale"]),
])
def build_val_rgb_transform(img_h: int, img_w: int) -> transforms.Compose:
"""RGB 模型验证 / 推理时 transform。"""
return transforms.Compose([
transforms.Resize((img_h, img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
# ============================================================
# CRNN / CTC 识别用数据集
# ============================================================
@@ -276,3 +303,82 @@ class RotateSolverDataset(Dataset):
img = self.transform(img)
return img, torch.tensor([sin_val, cos_val], dtype=torch.float32)
class FunCaptchaChallengeDataset(Dataset):
"""
FunCaptcha 专项 challenge 数据集。
输入为整张 challenge 图片,文件名标签表示正确候选索引:
`{answer_index}_{anything}.png/jpg/jpeg`
每个样本会被裁成:
- `candidates`: (K, C, H, W)
- `reference`: (C, H, W)
- `answer_idx`: LongTensor 标量
"""
def __init__(
self,
dirs: list[str | Path],
task_config: dict,
transform: transforms.Compose | None = None,
):
import warnings
self.transform = transform
self.tile_w, self.tile_h = task_config["tile_size"]
self.reference_box = tuple(task_config["reference_box"])
self.num_candidates = int(task_config["num_candidates"])
self.answer_index_base = int(task_config.get("answer_index_base", 0))
self.samples: list[tuple[str, int]] = [] # (路径, 0-based answer_idx)
for d in dirs:
d = Path(d)
if not d.exists():
continue
for pattern in ("*.png", "*.jpg", "*.jpeg"):
for f in sorted(d.glob(pattern)):
raw_label = f.stem.rsplit("_", 1)[0]
try:
answer_idx = int(raw_label) - self.answer_index_base
except ValueError:
continue
if not (0 <= answer_idx < self.num_candidates):
warnings.warn(
f"FunCaptcha 标签越界: file={f} label={raw_label} "
f"expect=[{self.answer_index_base}, {self.answer_index_base + self.num_candidates - 1}]",
stacklevel=2,
)
continue
self.samples.append((str(f), answer_idx))
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
import torch
path, answer_idx = self.samples[idx]
image = Image.open(path).convert("RGB")
candidates = []
for i in range(self.num_candidates):
left = i * self.tile_w
box = (left, 0, left + self.tile_w, self.tile_h)
candidate = image.crop(box)
if self.transform:
candidate = self.transform(candidate)
candidates.append(candidate)
reference = image.crop(self.reference_box)
if self.transform:
reference = self.transform(reference)
return (
torch.stack(candidates, dim=0),
reference,
torch.tensor(answer_idx, dtype=torch.long),
)

View File

@@ -23,6 +23,11 @@ from config import (
NUM_CAPTCHA_TYPES,
IMAGE_SIZE,
TRAIN_CONFIG,
GENERATE_CONFIG,
NORMAL_CHARS,
MATH_CHARS,
THREED_CHARS,
REGRESSION_RANGE,
CLASSIFIER_DIR,
SYNTHETIC_NORMAL_DIR,
SYNTHETIC_MATH_DIR,
@@ -40,7 +45,13 @@ from generators.math_gen import MathCaptchaGenerator
from generators.threed_gen import ThreeDCaptchaGenerator
from generators.threed_rotate_gen import ThreeDRotateGenerator
from generators.threed_slider_gen import ThreeDSliderGenerator
from inference.model_metadata import write_model_metadata
from models.classifier import CaptchaClassifier
from training.data_fingerprint import (
build_dataset_spec,
ensure_synthetic_dataset,
labels_cover_tokens,
)
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
@@ -63,27 +74,55 @@ def _prepare_classifier_data():
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
]
chars_map = {
"normal": NORMAL_CHARS,
"math": MATH_CHARS,
"3d_text": THREED_CHARS,
}
for cls_name, syn_dir, gen_cls in type_info:
syn_dir = Path(syn_dir)
existing = sorted(syn_dir.glob("*.png"))
config_snapshot = {
"generate_config": GENERATE_CONFIG[cls_name],
"image_size": IMAGE_SIZE[cls_name],
}
if cls_name in chars_map:
config_snapshot["chars"] = chars_map[cls_name]
if cls_name in REGRESSION_RANGE:
config_snapshot["label_range"] = REGRESSION_RANGE[cls_name]
# 如果合成数据不够,生成一些
if len(existing) < per_class:
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
gen = gen_cls()
gen.generate_dataset(per_class, str(syn_dir))
existing = sorted(syn_dir.glob("*.png"))
validator = None
if cls_name == "math":
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
validator = lambda files, tokens=required_ops: labels_cover_tokens(files, tokens)
dataset_state = ensure_synthetic_dataset(
syn_dir,
generator_cls=gen_cls,
spec=build_dataset_spec(
gen_cls,
config_key=cls_name,
config_snapshot=config_snapshot,
),
gen_count=per_class,
min_count=per_class,
validator=validator,
adopt_if_missing=cls_name in {"normal", "math"},
)
if dataset_state["refreshed"]:
print(f"[数据] {cls_name} 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] {cls_name} 合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] {cls_name} 合成数据已就绪: {dataset_state['sample_count']}")
existing = sorted(syn_dir.glob("*.png"))
# 复制到 classifier 目录
cls_dir = CLASSIFIER_DIR / cls_name
cls_dir.mkdir(parents=True, exist_ok=True)
already = len(list(cls_dir.glob("*.png")))
if already >= per_class:
print(f"[数据] {cls_name} 分类器数据已就绪: {already}")
continue
# 清空后重新链接
# classifier 数据是派生目录,每次重建以对齐当前源数据与指纹状态。
for f in cls_dir.glob("*.png"):
f.unlink()
@@ -239,6 +278,15 @@ def main():
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": "classifier",
"task": "classifier",
"class_names": list(CAPTCHA_TYPES),
"input_shape": [1, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
return best_acc

View File

@@ -0,0 +1,248 @@
"""
训练 FunCaptcha `4_3d_rollball_animals` 专项 Siamese 模型。
数据格式:
data/real/funcaptcha/4_3d_rollball_animals/
0_xxx.png
1_xxx.jpg
2_xxx.jpeg
文件名前缀表示正确候选索引。
"""
from __future__ import annotations
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from config import (
CHECKPOINTS_DIR,
FUN_CAPTCHA_TASKS,
IMAGE_SIZE,
RANDOM_SEED,
TRAIN_CONFIG,
get_device,
)
from inference.export_onnx import _load_and_export
from models.fun_captcha_siamese import FunCaptchaSiamese
from training.dataset import (
FunCaptchaChallengeDataset,
build_train_rgb_transform,
build_val_rgb_transform,
)
QUESTION = "4_3d_rollball_animals"
def _set_seed():
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(RANDOM_SEED)
def _flatten_pairs(
candidates: torch.Tensor,
reference: torch.Tensor,
answer_idx: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, num_candidates, channels, img_h, img_w = candidates.shape
references = reference.unsqueeze(1).expand(-1, num_candidates, -1, -1, -1)
targets = F.one_hot(answer_idx, num_classes=num_candidates).float()
return (
candidates.reshape(batch_size * num_candidates, channels, img_h, img_w),
references.reshape(batch_size * num_candidates, channels, img_h, img_w),
targets.reshape(batch_size * num_candidates, 1),
)
def _evaluate(
model: FunCaptchaSiamese,
loader: DataLoader,
device: torch.device,
) -> tuple[float, float]:
model.eval()
challenge_correct = 0
challenge_total = 0
pair_correct = 0
pair_total = 0
with torch.no_grad():
for candidates, reference, answer_idx in loader:
candidates = candidates.to(device)
reference = reference.to(device)
answer_idx = answer_idx.to(device)
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
candidates, reference, answer_idx
)
logits = model(pair_candidates, pair_reference).view(candidates.size(0), candidates.size(1))
preds = logits.argmax(dim=1)
challenge_correct += (preds == answer_idx).sum().item()
challenge_total += answer_idx.size(0)
pair_probs = torch.sigmoid(logits)
pair_preds = (pair_probs >= 0.5).float()
target_matrix = pair_targets.view(candidates.size(0), candidates.size(1))
pair_correct += (pair_preds == target_matrix).sum().item()
pair_total += target_matrix.numel()
return (
challenge_correct / max(challenge_total, 1),
pair_correct / max(pair_total, 1),
)
def main(question: str = QUESTION):
task_cfg = FUN_CAPTCHA_TASKS[question]
cfg = TRAIN_CONFIG["funcaptcha_rollball_animals"]
img_h, img_w = IMAGE_SIZE["funcaptcha_rollball_animals"]
device = get_device()
data_dir = Path(task_cfg["data_dir"])
ckpt_name = task_cfg["checkpoint_name"]
ckpt_path = CHECKPOINTS_DIR / f"{ckpt_name}.pth"
_set_seed()
print("=" * 60)
print(f"训练 FunCaptcha 专项模型 ({question})")
print(f" 数据目录: {data_dir}")
print(f" 候选数: {task_cfg['num_candidates']}")
print(f" 输入尺寸: {img_h}×{img_w}")
print("=" * 60)
train_transform = build_train_rgb_transform(img_h, img_w)
val_transform = build_val_rgb_transform(img_h, img_w)
full_dataset = FunCaptchaChallengeDataset(
dirs=[data_dir],
task_config=task_cfg,
transform=train_transform,
)
total = len(full_dataset)
if total == 0:
raise FileNotFoundError(
f"未找到任何 FunCaptcha 训练样本,请先准备数据: {data_dir}"
)
val_size = max(1, int(total * cfg["val_split"]))
train_size = total - val_size
if train_size <= 0:
raise ValueError(f"FunCaptcha 数据量过少,至少需要 2 张样本: {data_dir}")
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
val_ds_clean = FunCaptchaChallengeDataset(
dirs=[data_dir],
task_config=task_cfg,
transform=val_transform,
)
val_ds_clean.samples = [full_dataset.samples[i] for i in val_ds.indices]
train_loader = DataLoader(
train_ds,
batch_size=cfg["batch_size"],
shuffle=True,
num_workers=0,
pin_memory=True,
)
val_loader = DataLoader(
val_ds_clean,
batch_size=cfg["batch_size"],
shuffle=False,
num_workers=0,
pin_memory=True,
)
print(f"[数据] 训练: {train_size} 验证: {val_size}")
model = FunCaptchaSiamese(in_channels=task_cfg["channels"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
pos_weight = torch.tensor([task_cfg["num_candidates"] - 1], dtype=torch.float32, device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
best_acc = 0.0
start_epoch = 1
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_acc = float(ckpt.get("best_acc", 0.0))
start_epoch = int(ckpt.get("epoch", 0)) + 1
for _ in range(start_epoch - 1):
scheduler.step()
print(f"[续训] 从 epoch {start_epoch} 继续, best_acc={best_acc:.4f}")
for epoch in range(start_epoch, cfg["epochs"] + 1):
model.train()
total_loss = 0.0
num_batches = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg['epochs']}", leave=False)
for candidates, reference, answer_idx in pbar:
candidates = candidates.to(device)
reference = reference.to(device)
answer_idx = answer_idx.to(device)
pair_candidates, pair_reference, pair_targets = _flatten_pairs(
candidates, reference, answer_idx
)
logits = model(pair_candidates, pair_reference)
loss = criterion(logits, pair_targets)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
pbar.set_postfix(loss=f"{loss.item():.4f}")
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
challenge_acc, pair_acc = _evaluate(model, val_loader, device)
lr = scheduler.get_last_lr()[0]
print(
f"Epoch {epoch:3d}/{cfg['epochs']} "
f"loss={avg_loss:.4f} "
f"acc={challenge_acc:.4f} "
f"pair_acc={pair_acc:.4f} "
f"lr={lr:.6f}"
)
if challenge_acc >= best_acc:
best_acc = challenge_acc
torch.save(
{
"model_state_dict": model.state_dict(),
"best_acc": best_acc,
"epoch": epoch,
"question": question,
"num_candidates": task_cfg["num_candidates"],
"tile_size": list(task_cfg["tile_size"]),
"reference_box": list(task_cfg["reference_box"]),
"answer_index_base": task_cfg["answer_index_base"],
"input_shape": [task_cfg["channels"], img_h, img_w],
},
ckpt_path,
)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
print(f"\n[训练完成] 最佳 challenge acc: {best_acc:.4f}")
_load_and_export(task_cfg["artifact_name"])
return best_acc
if __name__ == "__main__":
main()

View File

@@ -26,11 +26,16 @@ from config import (
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
GENERATE_CONFIG,
REGRESSION_RANGE,
SOLVER_CONFIG,
SOLVER_REGRESSION_RANGE,
RANDOM_SEED,
get_device,
)
from inference.model_metadata import write_model_metadata
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
def _set_seed(seed: int = RANDOM_SEED):
@@ -65,7 +70,14 @@ def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
return float(np.mean(diff))
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
def _export_onnx(
model: nn.Module,
model_name: str,
img_h: int,
img_w: int,
*,
label_range: tuple[int, int] | tuple[float, float],
):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
@@ -81,6 +93,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": model_name,
"task": "regression",
"label_range": list(label_range),
"input_shape": [1, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
@@ -120,13 +141,37 @@ def train_regression_model(
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
config_snapshot = {
"image_size": IMAGE_SIZE[config_key],
"label_range": label_range,
}
if config_key in GENERATE_CONFIG:
config_snapshot["generate_config"] = GENERATE_CONFIG[config_key]
elif config_key == "slide_cnn":
config_snapshot["solver_config"] = SOLVER_CONFIG["slide"]
config_snapshot["solver_regression_range"] = SOLVER_REGRESSION_RANGE["slide"]
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
config_snapshot["train_config"] = cfg
dataset_spec = build_dataset_spec(
generator_cls,
config_key=config_key,
config_snapshot=config_snapshot,
)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=generator_cls,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
@@ -181,17 +226,26 @@ def train_regression_model(
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
best_mae = ckpt.get("best_mae", float("inf"))
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
)
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
if dataset_state["refreshed"]:
print("[续训] 合成数据已刷新,忽略旧 checkpoint从 epoch 1 重新训练")
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
else:
if ckpt_data_spec_hash is None:
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
model.load_state_dict(ckpt["model_state_dict"])
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
best_mae = ckpt.get("best_mae", float("inf"))
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
)
# ---- 4. 训练循环 ----
for epoch in range(start_epoch, cfg["epochs"] + 1):
@@ -268,6 +322,7 @@ def train_regression_model(
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
@@ -275,6 +330,6 @@ def train_regression_model(
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
_export_onnx(model, model_name, img_h, img_w, label_range=label_range)
return best_tol_acc

View File

@@ -29,7 +29,9 @@ from config import (
get_device,
)
from generators.rotate_solver_gen import RotateSolverDataGenerator
from inference.model_metadata import write_model_metadata
from models.rotation_regressor import RotationRegressor
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
from training.dataset import RotateSolverDataset
@@ -85,6 +87,15 @@ def _export_onnx(model: nn.Module, img_h: int, img_w: int):
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": "rotation_regressor",
"task": "rotation_solver",
"output_encoding": "sin_cos",
"input_shape": [3, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
@@ -107,13 +118,30 @@ def main():
# ---- 1. 检查 / 生成合成数据 ----
syn_path = ROTATE_SOLVER_DATA_DIR
syn_path.mkdir(parents=True, exist_ok=True)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = RotateSolverDataGenerator()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
dataset_spec = build_dataset_spec(
RotateSolverDataGenerator,
config_key="rotate_solver",
config_snapshot={
"solver_config": SOLVER_CONFIG["rotate"],
"train_config": {
"synthetic_samples": cfg["synthetic_samples"],
},
},
)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=RotateSolverDataGenerator,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
@@ -229,6 +257,7 @@ def main():
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")

View File

@@ -27,9 +27,16 @@ from config import (
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
GENERATE_CONFIG,
RANDOM_SEED,
get_device,
)
from inference.model_metadata import write_model_metadata
from training.data_fingerprint import (
build_dataset_spec,
ensure_synthetic_dataset,
labels_cover_tokens,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
@@ -72,7 +79,14 @@ def _calc_accuracy(preds: list[str], labels: list[str]):
# ============================================================
# ONNX 导出
# ============================================================
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
def _export_onnx(
model: nn.Module,
model_name: str,
img_h: int,
img_w: int,
*,
chars: str,
):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
@@ -88,6 +102,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": model_name,
"task": "ctc",
"chars": chars,
"input_shape": [1, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
@@ -124,13 +147,36 @@ def train_ctc_model(
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
dataset_spec = build_dataset_spec(
generator_cls,
config_key=config_key,
config_snapshot={
"generate_config": GENERATE_CONFIG[config_key],
"chars": chars,
"image_size": IMAGE_SIZE[config_key],
},
)
validator = None
if config_key == "math":
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
validator = lambda files: labels_cover_tokens(files, required_ops)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=generator_cls,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
validator=validator,
adopt_if_missing=config_key in {"normal", "math"},
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
@@ -176,16 +222,25 @@ def train_ctc_model(
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_acc = ckpt.get("best_acc", 0.0)
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_acc={best_acc:.4f}"
)
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
if dataset_state["refreshed"]:
print("[续训] 合成数据已刷新,忽略旧 checkpoint从 epoch 1 重新训练")
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
else:
if ckpt_data_spec_hash is None:
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
model.load_state_dict(ckpt["model_state_dict"])
best_acc = ckpt.get("best_acc", 0.0)
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_acc={best_acc:.4f}"
)
# ---- 4. 训练循环 ----
for epoch in range(start_epoch, cfg["epochs"] + 1):
@@ -249,6 +304,7 @@ def train_ctc_model(
"chars": chars,
"best_acc": best_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
@@ -257,6 +313,6 @@ def train_ctc_model(
# 加载最佳权重再导出
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
_export_onnx(model, model_name, img_h, img_w, chars=chars)
return best_acc