Align task API and add FunCaptcha support
This commit is contained in:
@@ -27,9 +27,16 @@ from config import (
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
GENERATE_CONFIG,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from training.data_fingerprint import (
|
||||
build_dataset_spec,
|
||||
ensure_synthetic_dataset,
|
||||
labels_cover_tokens,
|
||||
)
|
||||
from training.dataset import CRNNDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
@@ -72,7 +79,14 @@ def _calc_accuracy(preds: list[str], labels: list[str]):
|
||||
# ============================================================
|
||||
# ONNX 导出
|
||||
# ============================================================
|
||||
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
def _export_onnx(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
img_h: int,
|
||||
img_w: int,
|
||||
*,
|
||||
chars: str,
|
||||
):
|
||||
"""导出模型为 ONNX 格式。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||||
@@ -88,6 +102,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": model_name,
|
||||
"task": "ctc",
|
||||
"chars": chars,
|
||||
"input_shape": [1, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
@@ -124,13 +147,36 @@ def train_ctc_model(
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
dataset_spec = build_dataset_spec(
|
||||
generator_cls,
|
||||
config_key=config_key,
|
||||
config_snapshot={
|
||||
"generate_config": GENERATE_CONFIG[config_key],
|
||||
"chars": chars,
|
||||
"image_size": IMAGE_SIZE[config_key],
|
||||
},
|
||||
)
|
||||
validator = None
|
||||
if config_key == "math":
|
||||
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
|
||||
validator = lambda files: labels_cover_tokens(files, required_ops)
|
||||
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_path,
|
||||
generator_cls=generator_cls,
|
||||
spec=dataset_spec,
|
||||
gen_count=cfg["synthetic_samples"],
|
||||
exact_count=cfg["synthetic_samples"],
|
||||
validator=validator,
|
||||
adopt_if_missing=config_key in {"normal", "math"},
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
@@ -176,16 +222,25 @@ def train_ctc_model(
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = ckpt.get("best_acc", 0.0)
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_acc={best_acc:.4f}"
|
||||
)
|
||||
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
|
||||
|
||||
if dataset_state["refreshed"]:
|
||||
print("[续训] 合成数据已刷新,忽略旧 checkpoint,从 epoch 1 重新训练")
|
||||
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
|
||||
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
|
||||
else:
|
||||
if ckpt_data_spec_hash is None:
|
||||
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_acc = ckpt.get("best_acc", 0.0)
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_acc={best_acc:.4f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
@@ -249,6 +304,7 @@ def train_ctc_model(
|
||||
"chars": chars,
|
||||
"best_acc": best_acc,
|
||||
"epoch": epoch,
|
||||
"synthetic_data_spec_hash": current_data_spec_hash,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 acc={best_acc:.4f} {ckpt_path}")
|
||||
|
||||
@@ -257,6 +313,6 @@ def train_ctc_model(
|
||||
# 加载最佳权重再导出
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, model_name, img_h, img_w)
|
||||
_export_onnx(model, model_name, img_h, img_w, chars=chars)
|
||||
|
||||
return best_acc
|
||||
|
||||
Reference in New Issue
Block a user