Align task API and add FunCaptcha support
This commit is contained in:
109
tests/test_data_fingerprint.py
Normal file
109
tests/test_data_fingerprint.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
测试合成数据指纹与 ONNX metadata 辅助逻辑。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from inference.model_metadata import load_model_metadata, write_model_metadata
|
||||
from training.data_fingerprint import (
|
||||
build_dataset_spec,
|
||||
ensure_synthetic_dataset,
|
||||
load_dataset_manifest,
|
||||
)
|
||||
|
||||
|
||||
class AdoptableGenerator:
|
||||
def generate_dataset(self, num_samples: int, output_dir: str):
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for idx in range(num_samples):
|
||||
Image.new("L", (2, 2), color=255).save(out_dir / f"kept_{idx:06d}.png")
|
||||
|
||||
|
||||
class RefreshingGenerator:
|
||||
def generate_dataset(self, num_samples: int, output_dir: str):
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for idx in range(num_samples):
|
||||
Image.new("L", (2, 2), color=255).save(out_dir / f"20÷4_{idx:06d}.png")
|
||||
|
||||
|
||||
class TestDataFingerprint:
|
||||
def test_adopt_existing_dataset_when_manifest_missing(self, tmp_path):
|
||||
dataset_dir = tmp_path / "normal"
|
||||
dataset_dir.mkdir()
|
||||
for idx, label in enumerate(["AB12", "CD34"]):
|
||||
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
|
||||
|
||||
state = ensure_synthetic_dataset(
|
||||
dataset_dir,
|
||||
generator_cls=AdoptableGenerator,
|
||||
spec=build_dataset_spec(
|
||||
AdoptableGenerator,
|
||||
config_key="normal",
|
||||
config_snapshot={"chars": "ABCD1234"},
|
||||
),
|
||||
gen_count=2,
|
||||
exact_count=2,
|
||||
adopt_if_missing=True,
|
||||
)
|
||||
|
||||
manifest = load_dataset_manifest(dataset_dir)
|
||||
assert state["adopted"] is True
|
||||
assert state["refreshed"] is False
|
||||
assert manifest is not None
|
||||
assert manifest["adopted_existing"] is True
|
||||
assert manifest["sample_count"] == 2
|
||||
|
||||
def test_refresh_dataset_when_validator_fails(self, tmp_path):
|
||||
dataset_dir = tmp_path / "math"
|
||||
dataset_dir.mkdir()
|
||||
for idx, label in enumerate(["1+1", "2-1"]):
|
||||
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
|
||||
|
||||
state = ensure_synthetic_dataset(
|
||||
dataset_dir,
|
||||
generator_cls=RefreshingGenerator,
|
||||
spec=build_dataset_spec(
|
||||
RefreshingGenerator,
|
||||
config_key="math",
|
||||
config_snapshot={"operators": ["+", "-", "÷"]},
|
||||
),
|
||||
gen_count=2,
|
||||
exact_count=2,
|
||||
validator=lambda files: any("÷" in path.stem for path in files),
|
||||
adopt_if_missing=True,
|
||||
)
|
||||
|
||||
files = sorted(dataset_dir.glob("*.png"))
|
||||
manifest = load_dataset_manifest(dataset_dir)
|
||||
assert state["refreshed"] is True
|
||||
assert state["adopted"] is False
|
||||
assert manifest is not None
|
||||
assert manifest["adopted_existing"] is False
|
||||
assert len(files) == 2
|
||||
assert all("÷" in path.stem for path in files)
|
||||
|
||||
|
||||
class TestModelMetadata:
|
||||
def test_write_and_load_model_metadata(self, tmp_path):
|
||||
model_path = tmp_path / "normal.onnx"
|
||||
model_path.touch()
|
||||
|
||||
write_model_metadata(
|
||||
model_path,
|
||||
{
|
||||
"model_name": "normal",
|
||||
"task": "ctc",
|
||||
"chars": "ABC",
|
||||
"input_shape": [1, 40, 120],
|
||||
},
|
||||
)
|
||||
|
||||
metadata = load_model_metadata(model_path)
|
||||
assert metadata is not None
|
||||
assert metadata["version"] == 1
|
||||
assert metadata["chars"] == "ABC"
|
||||
assert metadata["task"] == "ctc"
|
||||
Reference in New Issue
Block a user