Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -29,7 +29,9 @@ from config import (
get_device,
)
from generators.rotate_solver_gen import RotateSolverDataGenerator
from inference.model_metadata import write_model_metadata
from models.rotation_regressor import RotationRegressor
from training.data_fingerprint import build_dataset_spec, ensure_synthetic_dataset
from training.dataset import RotateSolverDataset
@@ -85,6 +87,15 @@ def _export_onnx(model: nn.Module, img_h: int, img_w: int):
if ONNX_CONFIG["dynamic_batch"]
else None,
)
write_model_metadata(
onnx_path,
{
"model_name": "rotation_regressor",
"task": "rotation_solver",
"output_encoding": "sin_cos",
"input_shape": [3, img_h, img_w],
},
)
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
@@ -107,13 +118,30 @@ def main():
# ---- 1. 检查 / 生成合成数据 ----
syn_path = ROTATE_SOLVER_DATA_DIR
syn_path.mkdir(parents=True, exist_ok=True)
existing = list(syn_path.glob("*.png"))
if len(existing) < cfg["synthetic_samples"]:
print(f"[数据] 合成数据不足 ({len(existing)}/{cfg['synthetic_samples']}),开始生成...")
gen = RotateSolverDataGenerator()
gen.generate_dataset(cfg["synthetic_samples"], str(syn_path))
dataset_spec = build_dataset_spec(
RotateSolverDataGenerator,
config_key="rotate_solver",
config_snapshot={
"solver_config": SOLVER_CONFIG["rotate"],
"train_config": {
"synthetic_samples": cfg["synthetic_samples"],
},
},
)
dataset_state = ensure_synthetic_dataset(
syn_path,
generator_cls=RotateSolverDataGenerator,
spec=dataset_spec,
gen_count=cfg["synthetic_samples"],
exact_count=cfg["synthetic_samples"],
)
if dataset_state["refreshed"]:
print(f"[数据] 合成数据已刷新: {dataset_state['sample_count']}")
elif dataset_state["adopted"]:
print(f"[数据] 现有合成数据已采纳并写入指纹: {dataset_state['sample_count']}")
else:
print(f"[数据] 合成数据已就绪: {len(existing)}")
print(f"[数据] 合成数据已就绪: {dataset_state['sample_count']}")
current_data_spec_hash = dataset_state["manifest"]["spec_hash"]
# ---- 2. 构建数据集 ----
data_dirs = [str(syn_path)]
@@ -229,6 +257,7 @@ def main():
"best_mae": best_mae,
"best_tol_acc": best_tol_acc,
"epoch": epoch,
"synthetic_data_spec_hash": current_data_spec_hash,
}, ckpt_path)
print(f" → 保存最佳模型 tol_acc={best_tol_acc:.4f} MAE={best_mae:.2f}° {ckpt_path}")