Align task API and add FunCaptcha support
This commit is contained in:
@@ -29,7 +29,9 @@ from config import (
|
||||
get_device,
|
||||
)
|
||||
from generators.rotate_solver_gen import RotateSolverDataGenerator
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
|
||||
from training.dataset import RotateSolverDataset
|
||||
|
||||
|
||||
@@ -85,6 +87,15 @@ def _export_onnx(model: nn.Module, img_h: int, img_w: int):
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": "rotation_regressor",
|
||||
"task": "rotation_solver",
|
||||
"output_encoding": "sin_cos",
|
||||
"input_shape": [3, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
|
||||
@@ -107,13 +118,30 @@ def main():
|
||||
# ---- 1. 检查 / 生成合成数据 ----
|
||||
syn_path = ROTATE_SOLVER_DATA_DIR
|
||||
syn_path.mkdir(parents=True, exist_ok=True)
|
||||
existing = list(syn_path.glob("*.png"))
|
||||
if len(existing) < cfg["synthetic_samples"]:
|
||||
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
|
||||
gen = RotateSolverDataGenerator()
|
||||
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
|
||||
dataset_spec = build_dataset_spec(
|
||||
RotateSolverDataGenerator,
|
||||
config_key="rotate_solver",
|
||||
config_snapshot={
|
||||
"solver_config": SOLVER_CONFIG["rotate"],
|
||||
"train_config": {
|
||||
"synthetic_samples": cfg["synthetic_samples"],
|
||||
},
|
||||
},
|
||||
)
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_path,
|
||||
generator_cls=RotateSolverDataGenerator,
|
||||
spec=dataset_spec,
|
||||
gen_count=cfg["synthetic_samples"],
|
||||
exact_count=cfg["synthetic_samples"],
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] 合成数据已就绪: {len(existing)} 张")
|
||||
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
|
||||
|
||||
# ---- 2. 构建数据集 ----
|
||||
data_dirs = [str(syn_path)]
|
||||
@@ -229,6 +257,7 @@ def main():
|
||||
"best_mae": best_mae,
|
||||
"best_tol_acc": best_tol_acc,
|
||||
"epoch": epoch,
|
||||
"synthetic_data_spec_hash": current_data_spec_hash,
|
||||
}, ckpt_path)
|
||||
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user