Align task API and add FunCaptcha support
This commit is contained in:
226
training/data_fingerprint.py
Normal file
226
training/data_fingerprint.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
合成数据集指纹与清单辅助工具。
|
||||
|
||||
用于识别“样本数量足够但生成规则已变化”的情况,避免静默复用过期数据。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
MANIFEST_NAME = ".dataset_meta.json"
|
||||
|
||||
|
||||
def _stable_json(data: dict) -> str:
|
||||
return json.dumps(data, ensure_ascii=True, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _sha256_text(text: str) -> str:
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _source_hash(obj) -> str:
|
||||
try:
|
||||
source = inspect.getsource(obj)
|
||||
except (OSError, TypeError):
|
||||
source = repr(obj)
|
||||
return _sha256_text(source)
|
||||
|
||||
|
||||
def dataset_manifest_path(dataset_dir: str | Path) -> Path:
|
||||
return Path(dataset_dir) / MANIFEST_NAME
|
||||
|
||||
|
||||
def dataset_spec_hash(spec: dict) -> str:
|
||||
return _sha256_text(_stable_json(spec))
|
||||
|
||||
|
||||
def build_dataset_spec(
|
||||
generator_cls,
|
||||
*,
|
||||
config_key: str,
|
||||
config_snapshot: dict,
|
||||
) -> dict:
|
||||
"""构造可稳定哈希的数据集规格说明。"""
|
||||
return {
|
||||
"config_key": config_key,
|
||||
"generator": f"{generator_cls.__module__}.{generator_cls.__name__}",
|
||||
"generator_source_hash": _source_hash(generator_cls),
|
||||
"config_snapshot": config_snapshot,
|
||||
}
|
||||
|
||||
|
||||
def load_dataset_manifest(dataset_dir: str | Path) -> dict | None:
|
||||
path = dataset_manifest_path(dataset_dir)
|
||||
if not path.exists():
|
||||
return None
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_dataset_manifest(
|
||||
dataset_dir: str | Path,
|
||||
*,
|
||||
spec: dict,
|
||||
sample_count: int,
|
||||
adopted_existing: bool,
|
||||
) -> dict:
|
||||
path = dataset_manifest_path(dataset_dir)
|
||||
manifest = {
|
||||
"version": 1,
|
||||
"spec": spec,
|
||||
"spec_hash": dataset_spec_hash(spec),
|
||||
"sample_count": sample_count,
|
||||
"adopted_existing": adopted_existing,
|
||||
}
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
json.dump(manifest, f, ensure_ascii=True, indent=2, sort_keys=True)
|
||||
f.write("\n")
|
||||
return manifest
|
||||
|
||||
|
||||
def labels_cover_tokens(files: list[Path], required_tokens: tuple[str, ...]) -> bool:
|
||||
"""检查文件名标签中是否至少覆盖每个目标 token 一次。"""
|
||||
remaining = set(required_tokens)
|
||||
if not remaining:
|
||||
return True
|
||||
|
||||
for path in files:
|
||||
label = path.stem.rsplit("_", 1)[0]
|
||||
matched = {token for token in remaining if token in label}
|
||||
if matched:
|
||||
remaining -= matched
|
||||
if not remaining:
|
||||
return True
|
||||
return not remaining
|
||||
|
||||
|
||||
def _count_matches(count: int, *, exact_count: int | None, min_count: int | None) -> bool:
|
||||
if exact_count is not None and count != exact_count:
|
||||
return False
|
||||
if min_count is not None and count < min_count:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _dataset_valid(
|
||||
files: list[Path],
|
||||
*,
|
||||
exact_count: int | None,
|
||||
min_count: int | None,
|
||||
validator: Callable[[list[Path]], bool] | None,
|
||||
) -> bool:
|
||||
counts_ok = _count_matches(len(files), exact_count=exact_count, min_count=min_count)
|
||||
if not counts_ok:
|
||||
return False
|
||||
if validator is None:
|
||||
return True
|
||||
return validator(files)
|
||||
|
||||
|
||||
def clear_generated_dataset(dataset_dir: str | Path) -> None:
|
||||
dataset_dir = Path(dataset_dir)
|
||||
for path in dataset_dir.glob("*.png"):
|
||||
path.unlink()
|
||||
manifest = dataset_manifest_path(dataset_dir)
|
||||
if manifest.exists():
|
||||
manifest.unlink()
|
||||
|
||||
|
||||
def ensure_synthetic_dataset(
|
||||
dataset_dir: str | Path,
|
||||
*,
|
||||
generator_cls,
|
||||
spec: dict,
|
||||
gen_count: int,
|
||||
exact_count: int | None = None,
|
||||
min_count: int | None = None,
|
||||
validator: Callable[[list[Path]], bool] | None = None,
|
||||
adopt_if_missing: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
确保合成数据与当前生成规则一致。
|
||||
|
||||
返回:
|
||||
{
|
||||
"manifest": dict,
|
||||
"sample_count": int,
|
||||
"refreshed": bool,
|
||||
"adopted": bool,
|
||||
}
|
||||
"""
|
||||
dataset_dir = Path(dataset_dir)
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
files = sorted(dataset_dir.glob("*.png"))
|
||||
sample_count = len(files)
|
||||
counts_ok = _count_matches(sample_count, exact_count=exact_count, min_count=min_count)
|
||||
validator_ok = _dataset_valid(
|
||||
files,
|
||||
exact_count=exact_count,
|
||||
min_count=min_count,
|
||||
validator=validator,
|
||||
)
|
||||
manifest = load_dataset_manifest(dataset_dir)
|
||||
spec_hash = dataset_spec_hash(spec)
|
||||
|
||||
manifest_ok = (
|
||||
manifest is not None
|
||||
and manifest.get("spec_hash") == spec_hash
|
||||
and manifest.get("sample_count") == sample_count
|
||||
)
|
||||
|
||||
if manifest_ok and counts_ok and validator_ok:
|
||||
return {
|
||||
"manifest": manifest,
|
||||
"sample_count": sample_count,
|
||||
"refreshed": False,
|
||||
"adopted": False,
|
||||
}
|
||||
|
||||
if manifest is None and adopt_if_missing and counts_ok and validator_ok:
|
||||
manifest = write_dataset_manifest(
|
||||
dataset_dir,
|
||||
spec=spec,
|
||||
sample_count=sample_count,
|
||||
adopted_existing=True,
|
||||
)
|
||||
return {
|
||||
"manifest": manifest,
|
||||
"sample_count": sample_count,
|
||||
"refreshed": False,
|
||||
"adopted": True,
|
||||
}
|
||||
|
||||
clear_generated_dataset(dataset_dir)
|
||||
gen = generator_cls()
|
||||
gen.generate_dataset(gen_count, str(dataset_dir))
|
||||
|
||||
files = sorted(dataset_dir.glob("*.png"))
|
||||
sample_count = len(files)
|
||||
if not _dataset_valid(
|
||||
files,
|
||||
exact_count=exact_count,
|
||||
min_count=min_count,
|
||||
validator=validator,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"生成后的数据集不符合要求: {dataset_dir} (count={sample_count})"
|
||||
)
|
||||
manifest = write_dataset_manifest(
|
||||
dataset_dir,
|
||||
spec=spec,
|
||||
sample_count=sample_count,
|
||||
adopted_existing=False,
|
||||
)
|
||||
return {
|
||||
"manifest": manifest,
|
||||
"sample_count": sample_count,
|
||||
"refreshed": True,
|
||||
"adopted": False,
|
||||
}
|
||||
Reference in New Issue
Block a user