Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -26,11 +26,16 @@ from config import (
ONNX_CONFIG,
TRAIN_CONFIG,
IMAGE_SIZE,
GENERATE_CONFIG,
REGRESSION_RANGE,
SOLVER_CONFIG,
SOLVER_REGRESSION_RANGE,
RANDOM_SEED,
get_device,
)
from inference.model_metadata import write_model_metadata
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
def _set_seed(seed: int = RANDOM_SEED):
@@ -65,7 +70,14 @@ def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
return float(np.mean(diff))
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
def _export_onnx(
model: nn.Module,
model_name: str,
img_h: int,
img_w: int,
*,
label_range: tuple[int, int] | tuple[float, float],
):
"""导出模型为 ONNX 格式。"""
model.eval()
onnx_path = ONNX_DIR / f"{model_name}.onnx"
@@ -81,6 +93,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": model_name,
"task": "regression",
"label_range": list(label_range),
"input_shape": [1, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
@@ -120,13 +141,37 @@ def train_regression_model(
# ---- 1. 检查 / 生成合成数据 ----
syn_path = Path(synthetic_dir)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = generator_cls()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
config_snapshot = {
"image_size": IMAGE_SIZE[config_key],
"label_range": label_range,
}
if config_key in GENERATE_CONFIG:
config_snapshot["generate_config"] = GENERATE_CONFIG[config_key]
elif config_key == "slide_cnn":
config_snapshot["solver_config"] = SOLVER_CONFIG["slide"]
config_snapshot["solver_regression_range"] = SOLVER_REGRESSION_RANGE["slide"]
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
config_snapshot["train_config"] = cfg
dataset_spec = build_dataset_spec(
generator_cls,
config_key=config_key,
config_snapshot=config_snapshot,
)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=generator_cls,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
@@ -181,17 +226,26 @@ def train_regression_model(
# ---- 3.5 断点续训 ----
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
best_mae = ckpt.get("best_mae", float("inf"))
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
)
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
if dataset_state["refreshed"]:
print("[续训] 合成数据已刷新,忽略旧 checkpoint从 epoch 1 重新训练")
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
else:
if ckpt_data_spec_hash is None:
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
model.load_state_dict(ckpt["model_state_dict"])
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
best_mae = ckpt.get("best_mae", float("inf"))
start_epoch = ckpt.get("epoch", 0) + 1
# 快进 scheduler 到对应 epoch
for _ in range(start_epoch - 1):
scheduler.step()
print(
f"[续训] 从 epoch {start_epoch} 继续, "
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
)
# ---- 4. 训练循环 ----
for epoch in range(start_epoch, cfg["epochs"] + 1):
@@ -268,6 +322,7 @@ def train_regression_model(
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
@@ -275,6 +330,6 @@ def train_regression_model(
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
_export_onnx(model, model_name, img_h, img_w)
_export_onnx(model, model_name, img_h, img_w, label_range=label_range)
return best_tol_acc