188 lines
6.4 KiB
Python
188 lines
6.4 KiB
Python
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from config import FUN_CAPTCHA_TASKS, IMAGE_SIZE
|
|
import inference.fun_captcha as fun_module
|
|
from inference.fun_captcha import FunCaptchaRollballPipeline
|
|
from inference.model_metadata import write_model_metadata
|
|
from models.fun_captcha_siamese import FunCaptchaSiamese
|
|
from training.dataset import FunCaptchaChallengeDataset, build_val_rgb_transform
|
|
|
|
|
|
def _build_rollball_image(path: Path, answer_idx: int = 2):
|
|
colors = [
|
|
(255, 80, 80),
|
|
(80, 255, 80),
|
|
(80, 80, 255),
|
|
(255, 220, 80),
|
|
]
|
|
image = Image.new("RGB", (800, 400), color=(245, 245, 245))
|
|
for idx, color in enumerate(colors):
|
|
tile = Image.new("RGB", (200, 200), color=color)
|
|
image.paste(tile, (idx * 200, 0))
|
|
|
|
reference = Image.new("RGB", (200, 200), color=colors[answer_idx])
|
|
image.paste(reference, (0, 200))
|
|
image.save(path)
|
|
|
|
|
|
class TestFunCaptchaChallengeDataset:
|
|
def test_dataset_splits_candidates_and_reference(self, tmp_path):
|
|
sample_path = tmp_path / "2_demo.png"
|
|
_build_rollball_image(sample_path, answer_idx=2)
|
|
|
|
dataset = FunCaptchaChallengeDataset(
|
|
dirs=[tmp_path],
|
|
task_config=FUN_CAPTCHA_TASKS["4_3d_rollball_animals"],
|
|
transform=build_val_rgb_transform(*IMAGE_SIZE["funcaptcha_rollball_animals"]),
|
|
)
|
|
|
|
candidates, reference, answer_idx = dataset[0]
|
|
assert candidates.shape == (4, 3, 48, 48)
|
|
assert reference.shape == (3, 48, 48)
|
|
assert int(answer_idx.item()) == 2
|
|
|
|
|
|
class TestFunCaptchaSiamese:
|
|
def test_forward_shape(self):
|
|
model = FunCaptchaSiamese()
|
|
model.eval()
|
|
candidate = torch.randn(5, 3, 48, 48)
|
|
reference = torch.randn(5, 3, 48, 48)
|
|
out = model(candidate, reference)
|
|
assert out.shape == (5, 1)
|
|
|
|
def test_param_count_reasonable(self):
|
|
model = FunCaptchaSiamese()
|
|
n = sum(p.numel() for p in model.parameters())
|
|
assert n < 450_000, f"too many params: {n}"
|
|
|
|
|
|
class _FakeSessionOptions:
|
|
def __init__(self):
|
|
self.inter_op_num_threads = 0
|
|
self.intra_op_num_threads = 0
|
|
|
|
|
|
class _FakeInput:
|
|
def __init__(self, name, shape=None):
|
|
self.name = name
|
|
self.shape = shape
|
|
|
|
|
|
class _FakeSession:
|
|
def __init__(self, path, *args, **kwargs):
|
|
self.path = path
|
|
self.last_feed_dict = None
|
|
|
|
def get_inputs(self):
|
|
return [_FakeInput("candidate"), _FakeInput("reference")]
|
|
|
|
def run(self, output_names, feed_dict):
|
|
self.last_feed_dict = feed_dict
|
|
batch_size = next(iter(feed_dict.values())).shape[0]
|
|
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
|
|
if batch_size >= 3:
|
|
logits[2, 0] = 0.95
|
|
return [logits]
|
|
|
|
|
|
class _FakeOrt:
|
|
SessionOptions = _FakeSessionOptions
|
|
InferenceSession = _FakeSession
|
|
|
|
|
|
class _Batch1FakeSession(_FakeSession):
|
|
def __init__(self, path, *args, **kwargs):
|
|
super().__init__(path, *args, **kwargs)
|
|
self.run_calls = 0
|
|
|
|
def get_inputs(self):
|
|
shape = [1, 3, 48, 48]
|
|
return [_FakeInput("candidate", shape=shape), _FakeInput("reference", shape=shape)]
|
|
|
|
def run(self, output_names, feed_dict):
|
|
self.run_calls += 1
|
|
candidate = feed_dict["candidate"]
|
|
reference = feed_dict["reference"]
|
|
assert candidate.shape == (1, 3, 48, 48)
|
|
assert reference.shape == (1, 3, 48, 48)
|
|
return super().run(output_names, feed_dict)
|
|
|
|
|
|
class _Batch1FakeOrt:
|
|
SessionOptions = _FakeSessionOptions
|
|
InferenceSession = _Batch1FakeSession
|
|
|
|
|
|
class TestFunCaptchaPipeline:
|
|
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
|
|
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
|
|
model_path.touch()
|
|
write_model_metadata(
|
|
model_path,
|
|
{
|
|
"model_name": "funcaptcha_rollball_animals",
|
|
"task": "funcaptcha_siamese",
|
|
"preprocess": "rgb_centered",
|
|
"question": "4_3d_rollball_animals",
|
|
"num_candidates": 4,
|
|
"tile_size": [200, 200],
|
|
"reference_box": [0, 200, 200, 400],
|
|
"answer_index_base": 0,
|
|
"input_shape": [3, 48, 48],
|
|
},
|
|
)
|
|
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
|
|
|
|
sample_path = tmp_path / "1_demo.png"
|
|
_build_rollball_image(sample_path, answer_idx=1)
|
|
|
|
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path)
|
|
result = pipeline.solve(sample_path)
|
|
assert result["question"] == "4_3d_rollball_animals"
|
|
assert result["objects"] == [2]
|
|
assert result["result"] == "2"
|
|
assert len(result["scores"]) == 4
|
|
assert pipeline.preprocess_mode == "rgb_centered"
|
|
|
|
def test_pipeline_uses_external_model_env_without_metadata(self, tmp_path, monkeypatch):
|
|
external_model = tmp_path / "external_rollball.onnx"
|
|
external_model.touch()
|
|
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
|
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
|
|
|
|
image = Image.new("RGB", (800, 400), color=(128, 128, 128))
|
|
sample_path = tmp_path / "0_demo.png"
|
|
image.save(sample_path)
|
|
|
|
empty_models_dir = tmp_path / "missing_models"
|
|
pipeline = FunCaptchaRollballPipeline(models_dir=empty_models_dir)
|
|
result = pipeline.solve(sample_path)
|
|
|
|
assert result["objects"] == [2]
|
|
assert pipeline.model_path == external_model
|
|
assert pipeline.preprocess_mode == "rgb_255"
|
|
candidate = pipeline.session.last_feed_dict["candidate"]
|
|
assert candidate.shape == (4, 3, 48, 48)
|
|
assert candidate[0, 0, 0, 0] == pytest.approx(128 / 255.0, abs=1e-6)
|
|
|
|
def test_pipeline_handles_external_fixed_batch_model(self, tmp_path, monkeypatch):
|
|
external_model = tmp_path / "external_rollball.onnx"
|
|
external_model.touch()
|
|
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
|
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _Batch1FakeOrt)
|
|
|
|
sample_path = tmp_path / "0_demo.png"
|
|
_build_rollball_image(sample_path, answer_idx=0)
|
|
|
|
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path / "missing_models")
|
|
result = pipeline.solve(sample_path)
|
|
|
|
assert result["objects"] == [0]
|
|
assert pipeline.session.run_calls == 4
|