Support external FunCaptcha ONNX fallback
This commit is contained in:
@@ -69,18 +69,21 @@ class _FakeSessionOptions:
|
||||
|
||||
|
||||
class _FakeInput:
|
||||
def __init__(self, name):
|
||||
def __init__(self, name, shape=None):
|
||||
self.name = name
|
||||
self.shape = shape
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.path = path
|
||||
self.last_feed_dict = None
|
||||
|
||||
def get_inputs(self):
|
||||
return [_FakeInput("candidate"), _FakeInput("reference")]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
self.last_feed_dict = feed_dict
|
||||
batch_size = next(iter(feed_dict.values())).shape[0]
|
||||
logits = np.full((batch_size, 1), 0.1, dtype=np.float32)
|
||||
if batch_size >= 3:
|
||||
@@ -93,6 +96,29 @@ class _FakeOrt:
|
||||
InferenceSession = _FakeSession
|
||||
|
||||
|
||||
class _Batch1FakeSession(_FakeSession):
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
super().__init__(path, *args, **kwargs)
|
||||
self.run_calls = 0
|
||||
|
||||
def get_inputs(self):
|
||||
shape = [1, 3, 48, 48]
|
||||
return [_FakeInput("candidate", shape=shape), _FakeInput("reference", shape=shape)]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
self.run_calls += 1
|
||||
candidate = feed_dict["candidate"]
|
||||
reference = feed_dict["reference"]
|
||||
assert candidate.shape == (1, 3, 48, 48)
|
||||
assert reference.shape == (1, 3, 48, 48)
|
||||
return super().run(output_names, feed_dict)
|
||||
|
||||
|
||||
class _Batch1FakeOrt:
|
||||
SessionOptions = _FakeSessionOptions
|
||||
InferenceSession = _Batch1FakeSession
|
||||
|
||||
|
||||
class TestFunCaptchaPipeline:
|
||||
def test_pipeline_returns_best_object_index(self, tmp_path, monkeypatch):
|
||||
model_path = tmp_path / "funcaptcha_rollball_animals.onnx"
|
||||
@@ -102,6 +128,7 @@ class TestFunCaptchaPipeline:
|
||||
{
|
||||
"model_name": "funcaptcha_rollball_animals",
|
||||
"task": "funcaptcha_siamese",
|
||||
"preprocess": "rgb_centered",
|
||||
"question": "4_3d_rollball_animals",
|
||||
"num_candidates": 4,
|
||||
"tile_size": [200, 200],
|
||||
@@ -121,3 +148,40 @@ class TestFunCaptchaPipeline:
|
||||
assert result["objects"] == [2]
|
||||
assert result["result"] == "2"
|
||||
assert len(result["scores"]) == 4
|
||||
assert pipeline.preprocess_mode == "rgb_centered"
|
||||
|
||||
def test_pipeline_uses_external_model_env_without_metadata(self, tmp_path, monkeypatch):
|
||||
external_model = tmp_path / "external_rollball.onnx"
|
||||
external_model.touch()
|
||||
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
||||
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
image = Image.new("RGB", (800, 400), color=(128, 128, 128))
|
||||
sample_path = tmp_path / "0_demo.png"
|
||||
image.save(sample_path)
|
||||
|
||||
empty_models_dir = tmp_path / "missing_models"
|
||||
pipeline = FunCaptchaRollballPipeline(models_dir=empty_models_dir)
|
||||
result = pipeline.solve(sample_path)
|
||||
|
||||
assert result["objects"] == [2]
|
||||
assert pipeline.model_path == external_model
|
||||
assert pipeline.preprocess_mode == "rgb_255"
|
||||
candidate = pipeline.session.last_feed_dict["candidate"]
|
||||
assert candidate.shape == (4, 3, 48, 48)
|
||||
assert candidate[0, 0, 0, 0] == pytest.approx(128 / 255.0, abs=1e-6)
|
||||
|
||||
def test_pipeline_handles_external_fixed_batch_model(self, tmp_path, monkeypatch):
|
||||
external_model = tmp_path / "external_rollball.onnx"
|
||||
external_model.touch()
|
||||
monkeypatch.setenv("FUNCAPTCHA_ROLLBALL_MODEL_PATH", str(external_model))
|
||||
monkeypatch.setattr(fun_module, "_try_import_ort", lambda: _Batch1FakeOrt)
|
||||
|
||||
sample_path = tmp_path / "0_demo.png"
|
||||
_build_rollball_image(sample_path, answer_idx=0)
|
||||
|
||||
pipeline = FunCaptchaRollballPipeline(models_dir=tmp_path / "missing_models")
|
||||
result = pipeline.solve(sample_path)
|
||||
|
||||
assert result["objects"] == [0]
|
||||
assert pipeline.session.run_calls == 4
|
||||
|
||||
Reference in New Issue
Block a user