Align task API and add FunCaptcha support
This commit is contained in:
@@ -8,11 +8,15 @@
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from inference.math_eval import eval_captcha_math
|
||||
from inference.model_metadata import write_model_metadata
|
||||
import inference.pipeline as pipeline_module
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
|
||||
|
||||
@@ -97,6 +101,96 @@ class TestCTCGreedyDecode:
|
||||
assert result == "AA"
|
||||
|
||||
|
||||
class _FakeInput:
|
||||
name = "input"
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, path, *args, **kwargs):
|
||||
self.model_name = Path(path).name
|
||||
|
||||
def get_inputs(self):
|
||||
return [_FakeInput()]
|
||||
|
||||
def run(self, output_names, feed_dict):
|
||||
if self.model_name == "classifier.onnx":
|
||||
return [np.array([[0.1, 0.9]], dtype=np.float32)]
|
||||
if self.model_name == "normal.onnx":
|
||||
logits = np.full((2, 1, 4), -10.0, dtype=np.float32)
|
||||
logits[0, 0, 2] = 10.0
|
||||
logits[1, 0, 0] = 10.0
|
||||
return [logits]
|
||||
if self.model_name == "threed_rotate.onnx":
|
||||
return [np.array([[0.25]], dtype=np.float32)]
|
||||
raise AssertionError(f"unexpected fake session: {self.model_name}")
|
||||
|
||||
|
||||
class _FakeSessionOptions:
|
||||
def __init__(self):
|
||||
self.inter_op_num_threads = 0
|
||||
self.intra_op_num_threads = 0
|
||||
|
||||
|
||||
class _FakeOrt:
|
||||
SessionOptions = _FakeSessionOptions
|
||||
InferenceSession = _FakeSession
|
||||
|
||||
|
||||
class TestPipelineMetadata:
|
||||
def test_classifier_uses_metadata_class_order(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "classifier.onnx").touch()
|
||||
write_model_metadata(
|
||||
tmp_path / "classifier.onnx",
|
||||
{
|
||||
"model_name": "classifier",
|
||||
"task": "classifier",
|
||||
"class_names": ["math", "normal"],
|
||||
"input_shape": [1, 64, 128],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||||
captcha_type = pipeline.classify(Image.new("RGB", (32, 32), color="white"))
|
||||
assert captcha_type == "normal"
|
||||
|
||||
def test_solve_uses_ctc_chars_metadata(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "normal.onnx").touch()
|
||||
write_model_metadata(
|
||||
tmp_path / "normal.onnx",
|
||||
{
|
||||
"model_name": "normal",
|
||||
"task": "ctc",
|
||||
"chars": "XYZ",
|
||||
"input_shape": [1, 40, 120],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||||
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="normal")
|
||||
assert result["raw"] == "Y"
|
||||
assert result["result"] == "Y"
|
||||
|
||||
def test_solve_uses_regression_label_range_metadata(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "threed_rotate.onnx").touch()
|
||||
write_model_metadata(
|
||||
tmp_path / "threed_rotate.onnx",
|
||||
{
|
||||
"model_name": "threed_rotate",
|
||||
"task": "regression",
|
||||
"label_range": [100, 200],
|
||||
"input_shape": [1, 80, 80],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||||
|
||||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||||
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="3d_rotate")
|
||||
assert result["raw"] == "125.0"
|
||||
assert result["result"] == "125"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# SlideSolver 测试
|
||||
# ============================================================
|
||||
|
||||
Reference in New Issue
Block a user