Align task API and add FunCaptcha support
This commit is contained in:
@@ -26,11 +26,16 @@ from config import (
|
||||
ONNX_CONFIG,
|
||||
TRAIN_CONFIG,
|
||||
IMAGE_SIZE,
|
||||
GENERATE_CONFIG,
|
||||
REGRESSION_RANGE,
|
||||
SOLVER_CONFIG,
|
||||
SOLVER_REGRESSION_RANGE,
|
||||
RANDOM_SEED,
|
||||
get_device,
|
||||
)
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from training.dataset import RegressionDataset, build_train_transform, build_val_transform
|
||||
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
|
||||
|
||||
|
||||
def _set_seed(seed: int = RANDOM_SEED):
|
||||
@@ -65,7 +70,14 @@ def _circular_mae(pred: np.ndarray, target: np.ndarray) -> float:
|
||||
return float(np.mean(diff))
|
||||
|
||||
|
||||
def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
def _export_onnx(
|
||||
model: nn.Module,
|
||||
model_name: str,
|
||||
img_h: int,
|
||||
img_w: int,
|
||||
*,
|
||||
label_range: tuple[int, int] | tuple[float, float],
|
||||
):
|
||||
"""导出模型为 ONNX 格式。"""
|
||||
model.eval()
|
||||
onnx_path = ONNX_DIR / f"{model_name}.onnx"
|
||||
@@ -81,6 +93,15 @@ def _export_onnx(model: nn.Module, model_name: str, img_h: int, img_w: int):
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": model_name,
|
||||
"task": "regression",
|
||||
"label_range": list(label_range),
|
||||
"input_shape": [1, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
@@ -120,13 +141,37 @@ def train_regression_model(
|
||||
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = Path(synthetic_dir)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
config_snapshot = {
|
||||
"image_size": IMAGE_SIZE[config_key],
|
||||
"label_range": label_range,
|
||||
}
|
||||
if config_key in GENERATE_CONFIG:
|
||||
config_snapshot["generate_config"] = GENERATE_CONFIG[config_key]
|
||||
elif config_key == "slide_cnn":
|
||||
config_snapshot["solver_config"] = SOLVER_CONFIG["slide"]
|
||||
config_snapshot["solver_regression_range"] = SOLVER_REGRESSION_RANGE["slide"]
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
config_snapshot["train_config"] = cfg
|
||||
|
||||
dataset_spec = build_dataset_spec(
|
||||
generator_cls,
|
||||
config_key=config_key,
|
||||
config_snapshot=config_snapshot,
|
||||
)
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_path,
|
||||
generator_cls=generator_cls,
|
||||
spec=dataset_spec,
|
||||
gen_count=cfg["synthetic_samples"],
|
||||
exact_count=cfg["synthetic_samples"],
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
@@ -181,17 +226,26 @@ def train_regression_model(
|
||||
# ---- 3.5 断点续训 ----
|
||||
if ckpt_path.exists():
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
|
||||
best_mae = ckpt.get("best_mae", float("inf"))
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
|
||||
)
|
||||
ckpt_data_spec_hash = ckpt.get("synthetic_data_spec_hash")
|
||||
|
||||
if dataset_state["refreshed"]:
|
||||
print("[续训] 合成数据已刷新,忽略旧 checkpoint,从 epoch 1 重新训练")
|
||||
elif ckpt_data_spec_hash is not None and ckpt_data_spec_hash != current_data_spec_hash:
|
||||
print("[续训] checkpoint 与当前合成数据指纹不一致,从 epoch 1 重新训练")
|
||||
else:
|
||||
if ckpt_data_spec_hash is None:
|
||||
print("[续训] 旧 checkpoint 缺少数据指纹,沿用现有权重继续训练")
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
best_tol_acc = ckpt.get("best_tol_acc", 0.0)
|
||||
best_mae = ckpt.get("best_mae", float("inf"))
|
||||
start_epoch = ckpt.get("epoch", 0) + 1
|
||||
# 快进 scheduler 到对应 epoch
|
||||
for _ in range(start_epoch - 1):
|
||||
scheduler.step()
|
||||
print(
|
||||
f"[续训] 从 epoch {start_epoch} 继续, "
|
||||
f"best_tol_acc={best_tol_acc:.4f}, best_mae={best_mae:.2f}"
|
||||
)
|
||||
|
||||
# ---- 4. 训练循环 ----
|
||||
for epoch in range(start_epoch, cfg["epochs"] + 1):
|
||||
@@ -268,6 +322,7 @@ def train_regression_model(
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
"synthetic_data_spec_hash": current_data_spec_hash,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f} {ckpt_path}")
|
||||
|
||||
@@ -275,6 +330,6 @@ def train_regression_model(
|
||||
print(f"\n[训练完成] 最佳容差准确率: {best_tol_acc:.4f} 最佳 MAE: {best_mae:.2f}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
_export_onnx(model, model_name, img_h, img_w)
|
||||
_export_onnx(model, model_name, img_h, img_w, label_range=label_range)
|
||||
|
||||
return best_tol_acc
|
||||
|
||||
Reference in New Issue
Block a user