110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
"""
|
|
测试合成数据指纹与 ONNX metadata 辅助逻辑。
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
from PIL import Image
|
|
|
|
from inference.model_metadata import load_model_metadata, write_model_metadata
|
|
from training.data_fingerprint import (
|
|
build_dataset_spec,
|
|
ensure_synthetic_dataset,
|
|
load_dataset_manifest,
|
|
)
|
|
|
|
|
|
class AdoptableGenerator:
|
|
def generate_dataset(self, num_samples: int, output_dir: str):
|
|
out_dir = Path(output_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
for idx in range(num_samples):
|
|
Image.new("L", (2, 2), color=255).save(out_dir / f"kept_{idx:06d}.png")
|
|
|
|
|
|
class RefreshingGenerator:
|
|
def generate_dataset(self, num_samples: int, output_dir: str):
|
|
out_dir = Path(output_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
for idx in range(num_samples):
|
|
Image.new("L", (2, 2), color=255).save(out_dir / f"20÷4_{idx:06d}.png")
|
|
|
|
|
|
class TestDataFingerprint:
|
|
def test_adopt_existing_dataset_when_manifest_missing(self, tmp_path):
|
|
dataset_dir = tmp_path / "normal"
|
|
dataset_dir.mkdir()
|
|
for idx, label in enumerate(["AB12", "CD34"]):
|
|
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
|
|
|
|
state = ensure_synthetic_dataset(
|
|
dataset_dir,
|
|
generator_cls=AdoptableGenerator,
|
|
spec=build_dataset_spec(
|
|
AdoptableGenerator,
|
|
config_key="normal",
|
|
config_snapshot={"chars": "ABCD1234"},
|
|
),
|
|
gen_count=2,
|
|
exact_count=2,
|
|
adopt_if_missing=True,
|
|
)
|
|
|
|
manifest = load_dataset_manifest(dataset_dir)
|
|
assert state["adopted"] is True
|
|
assert state["refreshed"] is False
|
|
assert manifest is not None
|
|
assert manifest["adopted_existing"] is True
|
|
assert manifest["sample_count"] == 2
|
|
|
|
def test_refresh_dataset_when_validator_fails(self, tmp_path):
|
|
dataset_dir = tmp_path / "math"
|
|
dataset_dir.mkdir()
|
|
for idx, label in enumerate(["1+1", "2-1"]):
|
|
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
|
|
|
|
state = ensure_synthetic_dataset(
|
|
dataset_dir,
|
|
generator_cls=RefreshingGenerator,
|
|
spec=build_dataset_spec(
|
|
RefreshingGenerator,
|
|
config_key="math",
|
|
config_snapshot={"operators": ["+", "-", "÷"]},
|
|
),
|
|
gen_count=2,
|
|
exact_count=2,
|
|
validator=lambda files: any("÷" in path.stem for path in files),
|
|
adopt_if_missing=True,
|
|
)
|
|
|
|
files = sorted(dataset_dir.glob("*.png"))
|
|
manifest = load_dataset_manifest(dataset_dir)
|
|
assert state["refreshed"] is True
|
|
assert state["adopted"] is False
|
|
assert manifest is not None
|
|
assert manifest["adopted_existing"] is False
|
|
assert len(files) == 2
|
|
assert all("÷" in path.stem for path in files)
|
|
|
|
|
|
class TestModelMetadata:
|
|
def test_write_and_load_model_metadata(self, tmp_path):
|
|
model_path = tmp_path / "normal.onnx"
|
|
model_path.touch()
|
|
|
|
write_model_metadata(
|
|
model_path,
|
|
{
|
|
"model_name": "normal",
|
|
"task": "ctc",
|
|
"chars": "ABC",
|
|
"input_shape": [1, 40, 120],
|
|
},
|
|
)
|
|
|
|
metadata = load_model_metadata(model_path)
|
|
assert metadata is not None
|
|
assert metadata["version"] == 1
|
|
assert metadata["chars"] == "ABC"
|
|
assert metadata["task"] == "ctc"
|