Align task API and add FunCaptcha support
This commit is contained in:
@@ -23,6 +23,11 @@ from config import (
|
||||
NUM_CAPTCHA_TYPES,
|
||||
IMAGE_SIZE,
|
||||
TRAIN_CONFIG,
|
||||
GENERATE_CONFIG,
|
||||
NORMAL_CHARS,
|
||||
MATH_CHARS,
|
||||
THREED_CHARS,
|
||||
REGRESSION_RANGE,
|
||||
CLASSIFIER_DIR,
|
||||
SYNTHETIC_NORMAL_DIR,
|
||||
SYNTHETIC_MATH_DIR,
|
||||
@@ -40,7 +45,13 @@ from generators.math_gen import MathCaptchaGenerator
|
||||
from generators.threed_gen import ThreeDCaptchaGenerator
|
||||
from generators.threed_rotate_gen import ThreeDRotateGenerator
|
||||
from generators.threed_slider_gen import ThreeDSliderGenerator
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from models.classifier import CaptchaClassifier
|
||||
from training.data_fingerprint import (
|
||||
build_dataset_spec,
|
||||
ensure_synthetic_dataset,
|
||||
labels_cover_tokens,
|
||||
)
|
||||
from training.dataset import CaptchaDataset, build_train_transform, build_val_transform
|
||||
|
||||
|
||||
@@ -63,27 +74,55 @@ def _prepare_classifier_data():
|
||||
("3d_rotate", SYNTHETIC_3D_ROTATE_DIR, ThreeDRotateGenerator),
|
||||
("3d_slider", SYNTHETIC_3D_SLIDER_DIR, ThreeDSliderGenerator),
|
||||
]
|
||||
chars_map = {
|
||||
"normal": NORMAL_CHARS,
|
||||
"math": MATH_CHARS,
|
||||
"3d_text": THREED_CHARS,
|
||||
}
|
||||
|
||||
for cls_name, syn_dir, gen_cls in type_info:
|
||||
syn_dir = Path(syn_dir)
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
config_snapshot = {
|
||||
"generate_config": GENERATE_CONFIG[cls_name],
|
||||
"image_size": IMAGE_SIZE[cls_name],
|
||||
}
|
||||
if cls_name in chars_map:
|
||||
config_snapshot["chars"] = chars_map[cls_name]
|
||||
if cls_name in REGRESSION_RANGE:
|
||||
config_snapshot["label_range"] = REGRESSION_RANGE[cls_name]
|
||||
|
||||
# 如果合成数据不够,生成一些
|
||||
if len(existing) < per_class:
|
||||
print(f"[数据] {cls_name} 合成数据不足 ({len(existing)}/{per_class}),开始生成...")
|
||||
gen = gen_cls()
|
||||
gen.generate_dataset(per_class, str(syn_dir))
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
validator = None
|
||||
if cls_name == "math":
|
||||
required_ops = tuple(GENERATE_CONFIG["math"]["operators"])
|
||||
validator = lambda files, tokens=required_ops: labels_cover_tokens(files, tokens)
|
||||
|
||||
dataset_state = ensure_synthetic_dataset(
|
||||
syn_dir,
|
||||
generator_cls=gen_cls,
|
||||
spec=build_dataset_spec(
|
||||
gen_cls,
|
||||
config_key=cls_name,
|
||||
config_snapshot=config_snapshot,
|
||||
),
|
||||
gen_count=per_class,
|
||||
min_count=per_class,
|
||||
validator=validator,
|
||||
adopt_if_missing=cls_name in {"normal", "math"},
|
||||
)
|
||||
if dataset_state["refreshed"]:
|
||||
print(f"[数据] {cls_name} 合成数据已刷新: {dataset_state['sample_count']} 张")
|
||||
elif dataset_state["adopted"]:
|
||||
print(f"[数据] {cls_name} 合成数据已采纳并写入指纹: {dataset_state['sample_count']} 张")
|
||||
else:
|
||||
print(f"[数据] {cls_name} 合成数据已就绪: {dataset_state['sample_count']} 张")
|
||||
|
||||
existing = sorted(syn_dir.glob("*.png"))
|
||||
|
||||
# 复制到 classifier 目录
|
||||
cls_dir = CLASSIFIER_DIR / cls_name
|
||||
cls_dir.mkdir(parents=True, exist_ok=True)
|
||||
already = len(list(cls_dir.glob("*.png")))
|
||||
if already >= per_class:
|
||||
print(f"[数据] {cls_name} 分类器数据已就绪: {already} 张")
|
||||
continue
|
||||
|
||||
# 清空后重新链接
|
||||
# classifier 数据是派生目录,每次重建以对齐当前源数据与指纹状态。
|
||||
for f in cls_dir.glob("*.png"):
|
||||
f.unlink()
|
||||
|
||||
@@ -239,6 +278,15 @@ def main():
|
||||
if ONNX_CONFIG["dynamic_batch"]
|
||||
else None,
|
||||
)
|
||||
write_model_metadata(
|
||||
onnx_path,
|
||||
{
|
||||
"model_name": "classifier",
|
||||
"task": "classifier",
|
||||
"class_names": list(CAPTCHA_TYPES),
|
||||
"input_shape": [1, img_h, img_w],
|
||||
},
|
||||
)
|
||||
print(f"[ONNX] 导出完成: {onnx_path} ({onnx_path.stat().st_size / 1024:.1f} KB)")
|
||||
|
||||
return best_acc
|
||||
|
||||
Reference in New Issue
Block a user