from pathlib import Path import numpy as np import pytest import torch from PIL import Image from config import FUN_CAPTCHA_TASKS, IMAGE_SIZE import inference.fun_captcha as fun_module from inference.fun_captcha import FunCaptchaRollballPipeline from inference.model_metadata import write_model_metadata from models.fun_captcha_siamese import FunCaptchaSiamese from training.dataset import FunCaptchaChallengeDataset, build_val_rgb_transform def _build_rollball_image(path: Path, answer_idx: int = 2): colors = [ (255, 80, 80), (80, 255, 80), (80, 80, 255), (255, 220, 80), ] image = Image.new("RGB", (800, 400), color=(245, 245, 245)) for idx, color in enumerate(colors): tile = Image.new("RGB", (200, 200), color=color) image.paste(tile, (idx * 200, 0)) reference = Image.new("RGB", (200, 200), color=colors[answer_idx]) image.paste(reference, (0, 200)) image.save(path) class TestFunCaptchaChallengeDataset: def test_dataset_splits_candidates_and_reference(self, tmp_path): sample_path = tmp_path / "2_demo.png" _build_rollball_image(sample_path, answer_idx=2) dataset = FunCaptchaChallengeDataset( dirs=[tmp_path], task_config=FUN_CAPTCHA_TASKS["4_3d_rollball_animals"], transform=build_val_rgb_transform(*IMAGE_SIZE["funcaptcha_rollball_animals"]), ) candidates, reference, answer_idx = dataset[0] assert candidates.shape == (4, 3, 48, 48) assert reference.shape == (3, 48, 48) assert int(answer_idx.item()) == 2 class TestFunCaptchaSiamese: def test_forward_shape(self): model = FunCaptchaSiamese() model.eval() candidate = torch.randn(5, 3, 48, 48) reference = torch.randn(5, 3, 48, 48) out = model(candidate, reference) assert out.shape == (5, 1) def test_param_count_reasonable(self): model = FunCaptchaSiamese() n = sum(p.numel() for p in model.parameters()) assert n < 450_000, f"too many params: {n}" class _FakeSessionOptions: def __init__(self): self.inter_op_num_threads = 0 self.intra_op_num_threads = 0 class _FakeInput: def __init__(self, name): self.name = name class _FakeSession: def __init__(self, path, *args, **kwargs): self.path = path def get_inputs(self): return [_FakeInput("candidate"), _FakeInput("reference")] def run(self, output_names, feed_dict): batch_size = next(iter(feed_dict.values())).shape[0] logits = np.full((batch_size, 1), 0.1, dtype=np.float32) if batch_size >= 3: logits[2, 0] = 0.95 return [logits] class _FakeOrt: SessionOptions = _FakeSessionOptions InferenceSession = _FakeSession class TestFunCaptchaPipeline: def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch): model_path = tmp_path / "funcaptcha_rollball_animals.onnx" model_path.touch() write_model_metadata( model_path, { "model_name": "funcaptcha_rollball_animals", "task": "funcaptcha_siamese", "question": "4_3d_rollball_animals", "num_candidates": 4, "tile_size": [200, 200], "reference_box": [0, 200, 200, 400], "answer_index_base": 0, "input_shape": [3, 48, 48], }, ) monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt) sample_path = tmp_path / "1_demo.png" _build_rollball_image(sample_path, answer_idx=1) pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path) result = pipeline.solve(sample_path) assert result["question"] == "4_3d_rollball_animals" assert result["objects"] == [2] assert result["result"] == "2" assert len(result["scores"]) == 4