Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

123
tests/test_funcaptcha.py Normal file
View File

@@ -0,0 +1,123 @@
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from config import FUN_CAPTCHA_TASKS, IMAGE_SIZE
import inference.fun_captcha as fun_module
from inference.fun_captcha import FunCaptchaRollballPipeline
from inference.model_metadata import write_model_metadata
from models.fun_captcha_siamese import FunCaptchaSiamese
from training.dataset import FunCaptchaChallengeDataset, build_val_rgb_transform
def _build_rollball_image(path: Path, answer_idx: int = 2):
colors = [
(255, 80, 80),
(80, 255, 80),
(80, 80, 255),
(255, 220, 80),
]
image = Image.new("RGB", (800, 400), color=(245, 245, 245))
for idx, color in enumerate(colors):
tile = Image.new("RGB", (200, 200), color=color)
image.paste(tile, (idx * 200, 0))
reference = Image.new("RGB", (200, 200), color=colors[answer_idx])
image.paste(reference, (0, 200))
image.save(path)
class TestFunCaptchaChallengeDataset:
def test_dataset_splits_candidates_and_reference(self, tmp_path):
sample_path = tmp_path / "2_demo.png"
_build_rollball_image(sample_path, answer_idx=2)
dataset = FunCaptchaChallengeDataset(
dirs=[tmp_path],
task_config=FUN_CAPTCHA_TASKS["4_3d_rollball_animals"],
transform=build_val_rgb_transform(*IMAGE_SIZE["funcaptcha_rollball_animals"]),
)
candidates, reference, answer_idx = dataset[0]
assert candidates.shape == (4, 3, 48, 48)
assert reference.shape == (3, 48, 48)
assert int(answer_idx.item()) == 2
class TestFunCaptchaSiamese:
def test_forward_shape(self):
model = FunCaptchaSiamese()
model.eval()
candidate = torch.randn(5, 3, 48, 48)
reference = torch.randn(5, 3, 48, 48)
out = model(candidate, reference)
assert out.shape == (5, 1)
def test_param_count_reasonable(self):
model = FunCaptchaSiamese()
n = sum(p.numel() for p in model.parameters())
assert n < 450_000, f"too many params: {n}"
class _FakeSessionOptions:
def __init__(self):
self.inter_op_num_threads = 0
self.intra_op_num_threads = 0
class _FakeInput:
def __init__(self, name):
self.name = name
class _FakeSession:
def __init__(self, path, *args, **kwargs):
self.path = path
def get_inputs(self):
return [_FakeInput("candidate"), _FakeInput("reference")]
def run(self, output_names, feed_dict):
batch_size = next(iter(feed_dict.values())).shape[0]
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
if batch_size >= 3:
logits[2, 0] = 0.95
return [logits]
class _FakeOrt:
SessionOptions = _FakeSessionOptions
InferenceSession = _FakeSession
class TestFunCaptchaPipeline:
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
model_path.touch()
write_model_metadata(
model_path,
{
"model_name": "funcaptcha_rollball_animals",
"task": "funcaptcha_siamese",
"question": "4_3d_rollball_animals",
"num_candidates": 4,
"tile_size": [200, 200],
"reference_box": [0, 200, 200, 400],
"answer_index_base": 0,
"input_shape": [3, 48, 48],
},
)
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
sample_path = tmp_path / "1_demo.png"
_build_rollball_image(sample_path, answer_idx=1)
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path)
result = pipeline.solve(sample_path)
assert result["question"] == "4_3d_rollball_animals"
assert result["objects"] == [2]
assert result["result"] == "2"
assert len(result["scores"]) == 4