Align task API and add FunCaptcha support
This commit is contained in:
123
tests/test_funcaptcha.py
Normal file
123
tests/test_funcaptcha.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from config import FUN_CAPTCHA_TASKS, IMAGE_SIZE
|
||||
import inference.fun_captcha as fun_module
|
||||
from inference.fun_captcha import FunCaptchaRollballPipeline
|
||||
from inference.model_metadata import write_model_metadata
|
||||
from models.fun_captcha_siamese import FunCaptchaSiamese
|
||||
from training.dataset import FunCaptchaChallengeDataset, build_val_rgb_transform
|
||||
|
||||
|
||||
def _build_rollball_image(path: Path, answer_idx: int = 2):
|
||||
colors = [
|
||||
(255, 80, 80),
|
||||
(80, 255, 80),
|
||||
(80, 80, 255),
|
||||
(255, 220, 80),
|
||||
]
|
||||
image = Image.new("RGB", (800, 400), color=(245, 245, 245))
|
||||
for idx, color in enumerate(colors):
|
||||
tile = Image.new("RGB", (200, 200), color=color)
|
||||
image.paste(tile, (idx * 200, 0))
|
||||
|
||||
reference = Image.new("RGB", (200, 200), color=colors[answer_idx])
|
||||
image.paste(reference, (0, 200))
|
||||
image.save(path)
|
||||
|
||||
|
||||
class TestFunCaptchaChallengeDataset:
|
||||
def test_dataset_splits_candidates_and_reference(self, tmp_path):
|
||||
sample_path = tmp_path / "2_demo.png"
|
||||
_build_rollball_image(sample_path, answer_idx=2)
|
||||
|
||||
dataset = FunCaptchaChallengeDataset(
|
||||
dirs=[tmp_path],
|
||||
task_config=FUN_CAPTCHA_TASKS["4_3d_rollball_animals"],
|
||||
transform=build_val_rgb_transform(*IMAGE_SIZE["funcaptcha_rollball_animals"]),
|
||||
)
|
||||
|
||||
candidates, reference, answer_idx = dataset[0]
|
||||
assert candidates.shape == (4, 3, 48, 48)
|
||||
assert reference.shape == (3, 48, 48)
|
||||
assert int(answer_idx.item()) == 2
|
||||
|
||||
|
||||
class TestFunCaptchaSiamese:
|
||||
def test_forward_shape(self):
|
||||
model = FunCaptchaSiamese()
|
||||
model.eval()
|
||||
candidate = torch.randn(5, 3, 48, 48)
|
||||
reference = torch.randn(5, 3, 48, 48)
|
||||
out = model(candidate, reference)
|
||||
assert out.shape == (5, 1)
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
model = FunCaptchaSiamese()
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
assert n < 450_000, f"too many params: {n}"
|
||||
|
||||
|
||||
class _FakeSessionOptions:
|
||||
def __init__(self):
|
||||
self.inter_op_num_threads = 0
|
||||
self.intra_op_num_threads = 0
|
||||
|
||||
|
||||
class _FakeInput:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.path = path
|
||||
|
||||
def get_inputs(self):
|
||||
return [_FakeInput("candidate"), _FakeInput("reference")]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
batch_size = next(iter(feed_dict.values())).shape[0]
|
||||
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
|
||||
if batch_size >= 3:
|
||||
logits[2, 0] = 0.95
|
||||
return [logits]
|
||||
|
||||
|
||||
class _FakeOrt:
|
||||
SessionOptions = _FakeSessionOptions
|
||||
InferenceSession = _FakeSession
|
||||
|
||||
|
||||
class TestFunCaptchaPipeline:
|
||||
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
|
||||
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
|
||||
model_path.touch()
|
||||
write_model_metadata(
|
||||
model_path,
|
||||
{
|
||||
"model_name": "funcaptcha_rollball_animals",
|
||||
"task": "funcaptcha_siamese",
|
||||
"question": "4_3d_rollball_animals",
|
||||
"num_candidates": 4,
|
||||
"tile_size": [200, 200],
|
||||
"reference_box": [0, 200, 200, 400],
|
||||
"answer_index_base": 0,
|
||||
"input_shape": [3, 48, 48],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
sample_path = tmp_path / "1_demo.png"
|
||||
_build_rollball_image(sample_path, answer_idx=1)
|
||||
|
||||
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path)
|
||||
result = pipeline.solve(sample_path)
|
||||
assert result["question"] == "4_3d_rollball_animals"
|
||||
assert result["objects"] == [2]
|
||||
assert result["result"] == "2"
|
||||
assert len(result["scores"]) == 4
|
||||
Reference in New Issue
Block a user