227 lines
6.0 KiB
Python
227 lines
6.0 KiB
Python
"""
|
|
合成数据集指纹与清单辅助工具。
|
|
|
|
用于识别“样本数量足够但生成规则已变化”的情况,避免静默复用过期数据。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import inspect
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
MANIFEST_NAME = ".dataset_meta.json"
|
|
|
|
|
|
def _stable_json(data: dict) -> str:
|
|
return json.dumps(data, ensure_ascii=True, sort_keys=True, separators=(",", ":"))
|
|
|
|
|
|
def _sha256_text(text: str) -> str:
|
|
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _source_hash(obj) -> str:
|
|
try:
|
|
source = inspect.getsource(obj)
|
|
except (OSError, TypeError):
|
|
source = repr(obj)
|
|
return _sha256_text(source)
|
|
|
|
|
|
def dataset_manifest_path(dataset_dir: str | Path) -> Path:
|
|
return Path(dataset_dir) / MANIFEST_NAME
|
|
|
|
|
|
def dataset_spec_hash(spec: dict) -> str:
|
|
return _sha256_text(_stable_json(spec))
|
|
|
|
|
|
def build_dataset_spec(
|
|
generator_cls,
|
|
*,
|
|
config_key: str,
|
|
config_snapshot: dict,
|
|
) -> dict:
|
|
"""构造可稳定哈希的数据集规格说明。"""
|
|
return {
|
|
"config_key": config_key,
|
|
"generator": f"{generator_cls.__module__}.{generator_cls.__name__}",
|
|
"generator_source_hash": _source_hash(generator_cls),
|
|
"config_snapshot": config_snapshot,
|
|
}
|
|
|
|
|
|
def load_dataset_manifest(dataset_dir: str | Path) -> dict | None:
|
|
path = dataset_manifest_path(dataset_dir)
|
|
if not path.exists():
|
|
return None
|
|
with path.open("r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def write_dataset_manifest(
|
|
dataset_dir: str | Path,
|
|
*,
|
|
spec: dict,
|
|
sample_count: int,
|
|
adopted_existing: bool,
|
|
) -> dict:
|
|
path = dataset_manifest_path(dataset_dir)
|
|
manifest = {
|
|
"version": 1,
|
|
"spec": spec,
|
|
"spec_hash": dataset_spec_hash(spec),
|
|
"sample_count": sample_count,
|
|
"adopted_existing": adopted_existing,
|
|
}
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with path.open("w", encoding="utf-8") as f:
|
|
json.dump(manifest, f, ensure_ascii=True, indent=2, sort_keys=True)
|
|
f.write("\n")
|
|
return manifest
|
|
|
|
|
|
def labels_cover_tokens(files: list[Path], required_tokens: tuple[str, ...]) -> bool:
|
|
"""检查文件名标签中是否至少覆盖每个目标 token 一次。"""
|
|
remaining = set(required_tokens)
|
|
if not remaining:
|
|
return True
|
|
|
|
for path in files:
|
|
label = path.stem.rsplit("_", 1)[0]
|
|
matched = {token for token in remaining if token in label}
|
|
if matched:
|
|
remaining -= matched
|
|
if not remaining:
|
|
return True
|
|
return not remaining
|
|
|
|
|
|
def _count_matches(count: int, *, exact_count: int | None, min_count: int | None) -> bool:
|
|
if exact_count is not None and count != exact_count:
|
|
return False
|
|
if min_count is not None and count < min_count:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _dataset_valid(
|
|
files: list[Path],
|
|
*,
|
|
exact_count: int | None,
|
|
min_count: int | None,
|
|
validator: Callable[[list[Path]], bool] | None,
|
|
) -> bool:
|
|
counts_ok = _count_matches(len(files), exact_count=exact_count, min_count=min_count)
|
|
if not counts_ok:
|
|
return False
|
|
if validator is None:
|
|
return True
|
|
return validator(files)
|
|
|
|
|
|
def clear_generated_dataset(dataset_dir: str | Path) -> None:
|
|
dataset_dir = Path(dataset_dir)
|
|
for path in dataset_dir.glob("*.png"):
|
|
path.unlink()
|
|
manifest = dataset_manifest_path(dataset_dir)
|
|
if manifest.exists():
|
|
manifest.unlink()
|
|
|
|
|
|
def ensure_synthetic_dataset(
|
|
dataset_dir: str | Path,
|
|
*,
|
|
generator_cls,
|
|
spec: dict,
|
|
gen_count: int,
|
|
exact_count: int | None = None,
|
|
min_count: int | None = None,
|
|
validator: Callable[[list[Path]], bool] | None = None,
|
|
adopt_if_missing: bool = False,
|
|
) -> dict:
|
|
"""
|
|
确保合成数据与当前生成规则一致。
|
|
|
|
返回:
|
|
{
|
|
"manifest": dict,
|
|
"sample_count": int,
|
|
"refreshed": bool,
|
|
"adopted": bool,
|
|
}
|
|
"""
|
|
dataset_dir = Path(dataset_dir)
|
|
dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
files = sorted(dataset_dir.glob("*.png"))
|
|
sample_count = len(files)
|
|
counts_ok = _count_matches(sample_count, exact_count=exact_count, min_count=min_count)
|
|
validator_ok = _dataset_valid(
|
|
files,
|
|
exact_count=exact_count,
|
|
min_count=min_count,
|
|
validator=validator,
|
|
)
|
|
manifest = load_dataset_manifest(dataset_dir)
|
|
spec_hash = dataset_spec_hash(spec)
|
|
|
|
manifest_ok = (
|
|
manifest is not None
|
|
and manifest.get("spec_hash") == spec_hash
|
|
and manifest.get("sample_count") == sample_count
|
|
)
|
|
|
|
if manifest_ok and counts_ok and validator_ok:
|
|
return {
|
|
"manifest": manifest,
|
|
"sample_count": sample_count,
|
|
"refreshed": False,
|
|
"adopted": False,
|
|
}
|
|
|
|
if manifest is None and adopt_if_missing and counts_ok and validator_ok:
|
|
manifest = write_dataset_manifest(
|
|
dataset_dir,
|
|
spec=spec,
|
|
sample_count=sample_count,
|
|
adopted_existing=True,
|
|
)
|
|
return {
|
|
"manifest": manifest,
|
|
"sample_count": sample_count,
|
|
"refreshed": False,
|
|
"adopted": True,
|
|
}
|
|
|
|
clear_generated_dataset(dataset_dir)
|
|
gen = generator_cls()
|
|
gen.generate_dataset(gen_count, str(dataset_dir))
|
|
|
|
files = sorted(dataset_dir.glob("*.png"))
|
|
sample_count = len(files)
|
|
if not _dataset_valid(
|
|
files,
|
|
exact_count=exact_count,
|
|
min_count=min_count,
|
|
validator=validator,
|
|
):
|
|
raise RuntimeError(
|
|
f"生成后的数据集不符合要求: {dataset_dir} (count={sample_count})"
|
|
)
|
|
manifest = write_dataset_manifest(
|
|
dataset_dir,
|
|
spec=spec,
|
|
sample_count=sample_count,
|
|
adopted_existing=False,
|
|
)
|
|
return {
|
|
"manifest": manifest,
|
|
"sample_count": sample_count,
|
|
"refreshed": True,
|
|
"adopted": False,
|
|
}
|