Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -8,11 +8,15 @@
"""
import math
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
from inference.math_eval import eval_captcha_math
from inference.model_metadata import write_model_metadata
import inference.pipeline as pipeline_module
from inference.pipeline import CaptchaPipeline
@@ -97,6 +101,96 @@ class TestCTCGreedyDecode:
assert result == "AA"
class _FakeInput:
name = "input"
class _FakeSession:
def __init__(self, path, *args, **kwargs):
self.model_name = Path(path).name
def get_inputs(self):
return [_FakeInput()]
def run(self, output_names, feed_dict):
if self.model_name == "classifier.onnx":
return [np.array([[0.1, 0.9]], dtype=np.float32)]
if self.model_name == "normal.onnx":
logits = np.full((2, 1, 4), -10.0, dtype=np.float32)
logits[0, 0, 2] = 10.0
logits[1, 0, 0] = 10.0
return [logits]
if self.model_name == "threed_rotate.onnx":
return [np.array([[0.25]], dtype=np.float32)]
raise AssertionError(f"unexpected fake session: {self.model_name}")
class _FakeSessionOptions:
def __init__(self):
self.inter_op_num_threads = 0
self.intra_op_num_threads = 0
class _FakeOrt:
SessionOptions = _FakeSessionOptions
InferenceSession = _FakeSession
class TestPipelineMetadata:
def test_classifier_uses_metadata_class_order(self, tmp_path, monkeypatch):
(tmp_path / "classifier.onnx").touch()
write_model_metadata(
tmp_path / "classifier.onnx",
{
"model_name": "classifier",
"task": "classifier",
"class_names": ["math", "normal"],
"input_shape": [1, 64, 128],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
captcha_type = pipeline.classify(Image.new("RGB", (32, 32), color="white"))
assert captcha_type == "normal"
def test_solve_uses_ctc_chars_metadata(self, tmp_path, monkeypatch):
(tmp_path / "normal.onnx").touch()
write_model_metadata(
tmp_path / "normal.onnx",
{
"model_name": "normal",
"task": "ctc",
"chars": "XYZ",
"input_shape": [1, 40, 120],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="normal")
assert result["raw"] == "Y"
assert result["result"] == "Y"
def test_solve_uses_regression_label_range_metadata(self, tmp_path, monkeypatch):
(tmp_path / "threed_rotate.onnx").touch()
write_model_metadata(
tmp_path / "threed_rotate.onnx",
{
"model_name": "threed_rotate",
"task": "regression",
"label_range": [100, 200],
"input_shape": [1, 80, 80],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="3d_rotate")
assert result["raw"] == "125.0"
assert result["result"] == "125"
# ============================================================
# SlideSolver 测试
# ============================================================