Files
CaptchBreaker/tests/test_data_fingerprint.py

110 lines
3.6 KiB
Python

"""
测试合成数据指纹与 ONNX metadata 辅助逻辑。
"""
from pathlib import Path
from PIL import Image
from inference.model_metadata import load_model_metadata, write_model_metadata
from training.data_fingerprint import (
build_dataset_spec,
ensure_synthetic_dataset,
load_dataset_manifest,
)
class AdoptableGenerator:
def generate_dataset(self, num_samples: int, output_dir: str):
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for idx in range(num_samples):
Image.new("L", (2, 2), color=255).save(out_dir / f"kept_{idx:06d}.png")
class RefreshingGenerator:
def generate_dataset(self, num_samples: int, output_dir: str):
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for idx in range(num_samples):
Image.new("L", (2, 2), color=255).save(out_dir / f"20÷4_{idx:06d}.png")
class TestDataFingerprint:
def test_adopt_existing_dataset_when_manifest_missing(self, tmp_path):
dataset_dir = tmp_path / "normal"
dataset_dir.mkdir()
for idx, label in enumerate(["AB12", "CD34"]):
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
state = ensure_synthetic_dataset(
dataset_dir,
generator_cls=AdoptableGenerator,
spec=build_dataset_spec(
AdoptableGenerator,
config_key="normal",
config_snapshot={"chars": "ABCD1234"},
),
gen_count=2,
exact_count=2,
adopt_if_missing=True,
)
manifest = load_dataset_manifest(dataset_dir)
assert state["adopted"] is True
assert state["refreshed"] is False
assert manifest is not None
assert manifest["adopted_existing"] is True
assert manifest["sample_count"] == 2
def test_refresh_dataset_when_validator_fails(self, tmp_path):
dataset_dir = tmp_path / "math"
dataset_dir.mkdir()
for idx, label in enumerate(["1+1", "2-1"]):
Image.new("L", (2, 2), color=255).save(dataset_dir / f"{label}_{idx:06d}.png")
state = ensure_synthetic_dataset(
dataset_dir,
generator_cls=RefreshingGenerator,
spec=build_dataset_spec(
RefreshingGenerator,
config_key="math",
config_snapshot={"operators": ["+", "-", "÷"]},
),
gen_count=2,
exact_count=2,
validator=lambda files: any("÷" in path.stem for path in files),
adopt_if_missing=True,
)
files = sorted(dataset_dir.glob("*.png"))
manifest = load_dataset_manifest(dataset_dir)
assert state["refreshed"] is True
assert state["adopted"] is False
assert manifest is not None
assert manifest["adopted_existing"] is False
assert len(files) == 2
assert all("÷" in path.stem for path in files)
class TestModelMetadata:
def test_write_and_load_model_metadata(self, tmp_path):
model_path = tmp_path / "normal.onnx"
model_path.touch()
write_model_metadata(
model_path,
{
"model_name": "normal",
"task": "ctc",
"chars": "ABC",
"input_shape": [1, 40, 120],
},
)
metadata = load_model_metadata(model_path)
assert metadata is not None
assert metadata["version"] == 1
assert metadata["chars"] == "ABC"
assert metadata["task"] == "ctc"