Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -27,9 +27,16 @@ from config import (
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
GENERATE_CONFIG,
RANDOM_SEED,
get_device,
)
from inference.model_metadata import write_model_metadata
from training.data_fingerprint import (
build_dataset_spec,
ensure_synthetic_dataset,
labels_cover_tokens,
)
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
@@ -72,7 +79,14 @@ def _calc_accuracy(preds: list[str], labels: list[str]):
# ============================================================
# ONNX 导出
# ============================================================
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
def _export_onnx(
model: nn.Module,
model_name: str,
img_h: int,
img_w: int,
*,
chars: str,
):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
@@ -88,6 +102,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": model_name,
"task": "ctc",
"chars": chars,
"input_shape": [1, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
@@ -124,13 +147,36 @@ def train_ctc_model(
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
dataset_spec = build_dataset_spec(
generator_cls,
config_key=config_key,
config_snapshot={
"generate_config": GENERATE_CONFIG[config_key],
"chars": chars,
"image_size": IMAGE_SIZE[config_key],
},
)
validator = None
if config_key == "math":
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
validator = lambda files: labels_cover_tokens(files, required_ops)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=generator_cls,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
validator=validator,
adopt_if_missing=config_key in {"normal", "math"},
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
@@ -176,16 +222,25 @@ def train_ctc_model(
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_acc = ckpt.get("best_acc", 0.0)
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_acc={best_acc:.4f}"
)
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
if dataset_state["refreshed"]:
print("[续训] 合成数据已刷新,忽略旧 checkpoint从 epoch 1 重新训练")
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
else:
if ckpt_data_spec_hash is None:
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
model.load_state_dict(ckpt["model_state_dict"])
best_acc = ckpt.get("best_acc", 0.0)
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_acc={best_acc:.4f}"
)
# ---- 4. 训练循环 ----
for epoch in range(start_epoch, cfg["epochs"] + 1):
@@ -249,6 +304,7 @@ def train_ctc_model(
"chars": chars,
"best_acc": best_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
@@ -257,6 +313,6 @@ def train_ctc_model(
# 加载最佳权重再导出
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
_export_onnx(model, model_name, img_h, img_w, chars=chars)
return best_acc