Files
CaptchBreaker/training/data_fingerprint.py

227 lines
6.0 KiB
Python

"""
合成数据集指纹与清单辅助工具。
用于识别“样本数量足够但生成规则已变化”的情况,避免静默复用过期数据。
"""
from __future__ import annotations
import hashlib
import inspect
import json
from pathlib import Path
from typing import Callable
MANIFEST_NAME = ".dataset_meta.json"
def _stable_json(data: dict) -> str:
return json.dumps(data, ensure_ascii=True, sort_keys=True, separators=(",", ":"))
def _sha256_text(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _source_hash(obj) -> str:
try:
source = inspect.getsource(obj)
except (OSError, TypeError):
source = repr(obj)
return _sha256_text(source)
def dataset_manifest_path(dataset_dir: str | Path) -> Path:
return Path(dataset_dir) / MANIFEST_NAME
def dataset_spec_hash(spec: dict) -> str:
return _sha256_text(_stable_json(spec))
def build_dataset_spec(
generator_cls,
*,
config_key: str,
config_snapshot: dict,
) -> dict:
"""构造可稳定哈希的数据集规格说明。"""
return {
"config_key": config_key,
"generator": f"{generator_cls.__module__}.{generator_cls.__name__}",
"generator_source_hash": _source_hash(generator_cls),
"config_snapshot": config_snapshot,
}
def load_dataset_manifest(dataset_dir: str | Path) -> dict | None:
path = dataset_manifest_path(dataset_dir)
if not path.exists():
return None
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def write_dataset_manifest(
dataset_dir: str | Path,
*,
spec: dict,
sample_count: int,
adopted_existing: bool,
) -> dict:
path = dataset_manifest_path(dataset_dir)
manifest = {
"version": 1,
"spec": spec,
"spec_hash": dataset_spec_hash(spec),
"sample_count": sample_count,
"adopted_existing": adopted_existing,
}
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=True, indent=2, sort_keys=True)
f.write("\n")
return manifest
def labels_cover_tokens(files: list[Path], required_tokens: tuple[str, ...]) -> bool:
"""检查文件名标签中是否至少覆盖每个目标 token 一次。"""
remaining = set(required_tokens)
if not remaining:
return True
for path in files:
label = path.stem.rsplit("_", 1)[0]
matched = {token for token in remaining if token in label}
if matched:
remaining -= matched
if not remaining:
return True
return not remaining
def _count_matches(count: int, *, exact_count: int | None, min_count: int | None) -> bool:
if exact_count is not None and count != exact_count:
return False
if min_count is not None and count < min_count:
return False
return True
def _dataset_valid(
files: list[Path],
*,
exact_count: int | None,
min_count: int | None,
validator: Callable[[list[Path]], bool] | None,
) -> bool:
counts_ok = _count_matches(len(files), exact_count=exact_count, min_count=min_count)
if not counts_ok:
return False
if validator is None:
return True
return validator(files)
def clear_generated_dataset(dataset_dir: str | Path) -> None:
dataset_dir = Path(dataset_dir)
for path in dataset_dir.glob("*.png"):
path.unlink()
manifest = dataset_manifest_path(dataset_dir)
if manifest.exists():
manifest.unlink()
def ensure_synthetic_dataset(
dataset_dir: str | Path,
*,
generator_cls,
spec: dict,
gen_count: int,
exact_count: int | None = None,
min_count: int | None = None,
validator: Callable[[list[Path]], bool] | None = None,
adopt_if_missing: bool = False,
) -> dict:
"""
确保合成数据与当前生成规则一致。
返回:
{
"manifest": dict,
"sample_count": int,
"refreshed": bool,
"adopted": bool,
}
"""
dataset_dir = Path(dataset_dir)
dataset_dir.mkdir(parents=True, exist_ok=True)
files = sorted(dataset_dir.glob("*.png"))
sample_count = len(files)
counts_ok = _count_matches(sample_count, exact_count=exact_count, min_count=min_count)
validator_ok = _dataset_valid(
files,
exact_count=exact_count,
min_count=min_count,
validator=validator,
)
manifest = load_dataset_manifest(dataset_dir)
spec_hash = dataset_spec_hash(spec)
manifest_ok = (
manifest is not None
and manifest.get("spec_hash") == spec_hash
and manifest.get("sample_count") == sample_count
)
if manifest_ok and counts_ok and validator_ok:
return {
"manifest": manifest,
"sample_count": sample_count,
"refreshed": False,
"adopted": False,
}
if manifest is None and adopt_if_missing and counts_ok and validator_ok:
manifest = write_dataset_manifest(
dataset_dir,
spec=spec,
sample_count=sample_count,
adopted_existing=True,
)
return {
"manifest": manifest,
"sample_count": sample_count,
"refreshed": False,
"adopted": True,
}
clear_generated_dataset(dataset_dir)
gen = generator_cls()
gen.generate_dataset(gen_count, str(dataset_dir))
files = sorted(dataset_dir.glob("*.png"))
sample_count = len(files)
if not _dataset_valid(
files,
exact_count=exact_count,
min_count=min_count,
validator=validator,
):
raise RuntimeError(
f"生成后的数据集不符合要求: {dataset_dir} (count={sample_count})"
)
manifest = write_dataset_manifest(
dataset_dir,
spec=spec,
sample_count=sample_count,
adopted_existing=False,
)
return {
"manifest": manifest,
"sample_count": sample_count,
"refreshed": True,
"adopted": False,
}