from pathlib import Path import numpy as np import pytest import torch from PIL import Image from config import FUN_CAPTCHA_TASKS, IMAGE_SIZE import inference.fun_captcha as fun_module from inference.fun_captcha import FunCaptchaRollballPipeline from inference.model_metadata import write_model_metadata from models.fun_captcha_siamese import FunCaptchaSiamese from training.dataset import FunCaptchaChallengeDataset, build_val_rgb_transform def _build_rollball_image(path: Path, answer_idx: int = 2): colors = [ (255, 80, 80), (80, 255, 80), (80, 80, 255), (255, 220, 80), ] image = Image.new("RGB", (800, 400), color=(245, 245, 245)) for idx, color in enumerate(colors): tile = Image.new("RGB", (200, 200), color=color) image.paste(tile, (idx * 200, 0)) reference = Image.new("RGB", (200, 200), color=colors[answer_idx]) image.paste(reference, (0, 200)) image.save(path) class TestFunCaptchaChallengeDataset: def test_dataset_splits_candidates_and_reference(self, tmp_path): sample_path = tmp_path / "2_demo.png" _build_rollball_image(sample_path, answer_idx=2) dataset = FunCaptchaChallengeDataset( dirs=[tmp_path], task_config=FUN_CAPTCHA_TASKS["4_3d_rollball_animals"], transform=build_val_rgb_transform(*IMAGE_SIZE["funcaptcha_rollball_animals"]), ) candidates, reference, answer_idx = dataset[0] assert candidates.shape == (4, 3, 48, 48) assert reference.shape == (3, 48, 48) assert int(answer_idx.item()) == 2 class TestFunCaptchaSiamese: def test_forward_shape(self): model = FunCaptchaSiamese() model.eval() candidate = torch.randn(5, 3, 48, 48) reference = torch.randn(5, 3, 48, 48) out = model(candidate, reference) assert out.shape == (5, 1) def test_param_count_reasonable(self): model = FunCaptchaSiamese() n = sum(p.numel() for p in model.parameters()) assert n < 450_000, f"too many params: {n}" class _FakeSessionOptions: def __init__(self): self.inter_op_num_threads = 0 self.intra_op_num_threads = 0 class _FakeInput: def __init__(self, name, shape=None): self.name = name self.shape = shape class _FakeSession: def __init__(self, path, *args, **kwargs): self.path = path self.last_feed_dict = None def get_inputs(self): return [_FakeInput("candidate"), _FakeInput("reference")] def run(self, output_names, feed_dict): self.last_feed_dict = feed_dict batch_size = next(iter(feed_dict.values())).shape[0] logits = np.full((batch_size, 1), 0.1, dtype=np.float32) if batch_size >= 3: logits[2, 0] = 0.95 return [logits] class _FakeOrt: SessionOptions = _FakeSessionOptions InferenceSession = _FakeSession class _Batch1FakeSession(_FakeSession): def __init__(self, path, *args, **kwargs): super().__init__(path, *args, **kwargs) self.run_calls = 0 def get_inputs(self): shape = [1, 3, 48, 48] return [_FakeInput("candidate", shape=shape), _FakeInput("reference", shape=shape)] def run(self, output_names, feed_dict): self.run_calls += 1 candidate = feed_dict["candidate"] reference = feed_dict["reference"] assert candidate.shape == (1, 3, 48, 48) assert reference.shape == (1, 3, 48, 48) return super().run(output_names, feed_dict) class _Batch1FakeOrt: SessionOptions = _FakeSessionOptions InferenceSession = _Batch1FakeSession class TestFunCaptchaPipeline: def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch): model_path = tmp_path / "funcaptcha_rollball_animals.onnx" model_path.touch() write_model_metadata( model_path, { "model_name": "funcaptcha_rollball_animals", "task": "funcaptcha_siamese", "preprocess": "rgb_centered", "question": "4_3d_rollball_animals", "num_candidates": 4, "tile_size": [200, 200], "reference_box": [0, 200, 200, 400], "answer_index_base": 0, "input_shape": [3, 48, 48], }, ) monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt) sample_path = tmp_path / "1_demo.png" _build_rollball_image(sample_path, answer_idx=1) pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path) result = pipeline.solve(sample_path) assert result["question"] == "4_3d_rollball_animals" assert result["objects"] == [2] assert result["result"] == "2" assert len(result["scores"]) == 4 assert pipeline.preprocess_mode == "rgb_centered" def test_pipeline_uses_external_model_env_without_metadata(self, tmp_path, monkeypatch): external_model = tmp_path / "external_rollball.onnx" external_model.touch() monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model)) monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt) image = Image.new("RGB", (800, 400), color=(128, 128, 128)) sample_path = tmp_path / "0_demo.png" image.save(sample_path) empty_models_dir = tmp_path / "missing_models" pipeline = FunCaptchaRollballPipeline(models_dir=empty_models_dir) result = pipeline.solve(sample_path) assert result["objects"] == [2] assert pipeline.model_path == external_model assert pipeline.preprocess_mode == "rgb_255" candidate = pipeline.session.last_feed_dict["candidate"] assert candidate.shape == (4, 3, 48, 48) assert candidate[0, 0, 0, 0] == pytest.approx(128 / 255.0, abs=1e-6) def test_pipeline_handles_external_fixed_batch_model(self, tmp_path, monkeypatch): external_model = tmp_path / "external_rollball.onnx" external_model.touch() monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model)) monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _Batch1FakeOrt) sample_path = tmp_path / "0_demo.png" _build_rollball_image(sample_path, answer_idx=0) pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path / "missing_models") result = pipeline.solve(sample_path) assert result["objects"] == [0] assert pipeline.session.run_calls == 4