271 lines
8.9 KiB
Python
271 lines
8.9 KiB
Python
"""
|
||
测试推理流水线组件。
|
||
|
||
- math_eval: 加减乘除正确性 + 异常输入
|
||
- CTC greedy decode (构造 logits)
|
||
- SlideSolver (合成图 → OpenCV 检测)
|
||
- generate_slide_track 轨迹合理性
|
||
"""
|
||
|
||
import math
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import pytest
|
||
from PIL import Image
|
||
|
||
from inference.math_eval import eval_captcha_math
|
||
from inference.model_metadata import write_model_metadata
|
||
import inference.pipeline as pipeline_module
|
||
from inference.pipeline import CaptchaPipeline
|
||
|
||
|
||
# ============================================================
|
||
# math_eval 测试
|
||
# ============================================================
|
||
class TestMathEval:
|
||
def test_addition(self):
|
||
assert eval_captcha_math("3+8=?") == "11"
|
||
assert eval_captcha_math("12+5") == "17"
|
||
assert eval_captcha_math("0+0=?") == "0"
|
||
|
||
def test_subtraction(self):
|
||
assert eval_captcha_math("15-7=?") == "8"
|
||
assert eval_captcha_math("20-20") == "0"
|
||
|
||
def test_multiplication(self):
|
||
assert eval_captcha_math("12×3=?") == "36"
|
||
assert eval_captcha_math("5*4") == "20"
|
||
assert eval_captcha_math("6x7") == "42"
|
||
assert eval_captcha_math("6X7") == "42"
|
||
|
||
def test_division(self):
|
||
assert eval_captcha_math("20÷4=?") == "5"
|
||
assert eval_captcha_math("9÷3") == "3"
|
||
|
||
def test_division_by_zero(self):
|
||
with pytest.raises(ValueError, match="除数为零"):
|
||
eval_captcha_math("5÷0=?")
|
||
|
||
def test_invalid_expression(self):
|
||
with pytest.raises(ValueError, match="无法解析"):
|
||
eval_captcha_math("abc")
|
||
|
||
def test_with_spaces(self):
|
||
assert eval_captcha_math("3 + 8 = ?") == "11"
|
||
|
||
|
||
# ============================================================
|
||
# CTC greedy decode 测试
|
||
# ============================================================
|
||
class TestCTCGreedyDecode:
|
||
"""Test the static _ctc_greedy_decode method from CaptchaPipeline."""
|
||
|
||
def test_simple_decode(self):
|
||
chars = "ABC" # index 0=blank, 1=A, 2=B, 3=C
|
||
T = 6
|
||
C = 4 # blank + 3 chars
|
||
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
|
||
# Spell out "AB": A, A, blank, B, B, B
|
||
logits[0, 0, 1] = 10.0 # A
|
||
logits[1, 0, 1] = 10.0 # A (dup, collapsed)
|
||
logits[2, 0, 0] = 10.0 # blank
|
||
logits[3, 0, 2] = 10.0 # B
|
||
logits[4, 0, 2] = 10.0 # B (dup)
|
||
logits[5, 0, 2] = 10.0 # B (dup)
|
||
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
|
||
assert result == "AB"
|
||
|
||
def test_all_blank(self):
|
||
chars = "ABC"
|
||
T = 5
|
||
C = 4
|
||
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
|
||
for t in range(T):
|
||
logits[t, 0, 0] = 10.0
|
||
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
|
||
assert result == ""
|
||
|
||
def test_repeated_chars_with_blank_separator(self):
|
||
chars = "ABC"
|
||
T = 5
|
||
C = 4
|
||
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
|
||
# Spell "AA": A, blank, A, blank, blank
|
||
logits[0, 0, 1] = 10.0 # A
|
||
logits[1, 0, 0] = 10.0 # blank
|
||
logits[2, 0, 1] = 10.0 # A
|
||
logits[3, 0, 0] = 10.0 # blank
|
||
logits[4, 0, 0] = 10.0 # blank
|
||
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
|
||
assert result == "AA"
|
||
|
||
|
||
class _FakeInput:
|
||
name = "input"
|
||
|
||
|
||
class _FakeSession:
|
||
def __init__(self, path, *args, **kwargs):
|
||
self.model_name = Path(path).name
|
||
|
||
def get_inputs(self):
|
||
return [_FakeInput()]
|
||
|
||
def run(self, output_names, feed_dict):
|
||
if self.model_name == "classifier.onnx":
|
||
return [np.array([[0.1, 0.9]], dtype=np.float32)]
|
||
if self.model_name == "normal.onnx":
|
||
logits = np.full((2, 1, 4), -10.0, dtype=np.float32)
|
||
logits[0, 0, 2] = 10.0
|
||
logits[1, 0, 0] = 10.0
|
||
return [logits]
|
||
if self.model_name == "threed_rotate.onnx":
|
||
return [np.array([[0.25]], dtype=np.float32)]
|
||
raise AssertionError(f"unexpected fake session: {self.model_name}")
|
||
|
||
|
||
class _FakeSessionOptions:
|
||
def __init__(self):
|
||
self.inter_op_num_threads = 0
|
||
self.intra_op_num_threads = 0
|
||
|
||
|
||
class _FakeOrt:
|
||
SessionOptions = _FakeSessionOptions
|
||
InferenceSession = _FakeSession
|
||
|
||
|
||
class TestPipelineMetadata:
|
||
def test_classifier_uses_metadata_class_order(self, tmp_path, monkeypatch):
|
||
(tmp_path / "classifier.onnx").touch()
|
||
write_model_metadata(
|
||
tmp_path / "classifier.onnx",
|
||
{
|
||
"model_name": "classifier",
|
||
"task": "classifier",
|
||
"class_names": ["math", "normal"],
|
||
"input_shape": [1, 64, 128],
|
||
},
|
||
)
|
||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||
|
||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||
captcha_type = pipeline.classify(Image.new("RGB", (32, 32), color="white"))
|
||
assert captcha_type == "normal"
|
||
|
||
def test_solve_uses_ctc_chars_metadata(self, tmp_path, monkeypatch):
|
||
(tmp_path / "normal.onnx").touch()
|
||
write_model_metadata(
|
||
tmp_path / "normal.onnx",
|
||
{
|
||
"model_name": "normal",
|
||
"task": "ctc",
|
||
"chars": "XYZ",
|
||
"input_shape": [1, 40, 120],
|
||
},
|
||
)
|
||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||
|
||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="normal")
|
||
assert result["raw"] == "Y"
|
||
assert result["result"] == "Y"
|
||
|
||
def test_solve_uses_regression_label_range_metadata(self, tmp_path, monkeypatch):
|
||
(tmp_path / "threed_rotate.onnx").touch()
|
||
write_model_metadata(
|
||
tmp_path / "threed_rotate.onnx",
|
||
{
|
||
"model_name": "threed_rotate",
|
||
"task": "regression",
|
||
"label_range": [100, 200],
|
||
"input_shape": [1, 80, 80],
|
||
},
|
||
)
|
||
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
|
||
|
||
pipeline = CaptchaPipeline(models_dir=tmp_path)
|
||
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="3d_rotate")
|
||
assert result["raw"] == "125.0"
|
||
assert result["result"] == "125"
|
||
|
||
|
||
# ============================================================
|
||
# SlideSolver 测试
|
||
# ============================================================
|
||
class TestSlideSolver:
|
||
def test_solve_with_synthetic_image(self):
|
||
"""Generate a synthetic slide image and verify the solver detects a gap."""
|
||
try:
|
||
import cv2
|
||
except ImportError:
|
||
pytest.skip("OpenCV not installed")
|
||
|
||
from generators.slide_gen import SlideDataGenerator
|
||
from solvers.slide_solver import SlideSolver
|
||
|
||
gen = SlideDataGenerator(seed=42)
|
||
img, label = gen.generate()
|
||
expected_gap_x = int(label)
|
||
|
||
solver = SlideSolver()
|
||
result = solver.solve(img)
|
||
|
||
assert "gap_x" in result
|
||
assert "gap_x_percent" in result
|
||
assert "confidence" in result
|
||
assert "method" in result
|
||
assert isinstance(result["gap_x"], int)
|
||
assert 0.0 <= result["gap_x_percent"] <= 1.0
|
||
|
||
|
||
# ============================================================
|
||
# generate_slide_track 测试
|
||
# ============================================================
|
||
class TestSlideTrack:
|
||
def test_track_basic(self):
|
||
from utils.slide_utils import generate_slide_track
|
||
|
||
track = generate_slide_track(100, seed=42)
|
||
assert isinstance(track, list)
|
||
assert len(track) >= 10
|
||
|
||
def test_track_point_structure(self):
|
||
from utils.slide_utils import generate_slide_track
|
||
|
||
track = generate_slide_track(150, seed=0)
|
||
for pt in track:
|
||
assert "x" in pt
|
||
assert "y" in pt
|
||
assert "t" in pt
|
||
|
||
def test_track_starts_at_origin(self):
|
||
from utils.slide_utils import generate_slide_track
|
||
|
||
track = generate_slide_track(100, seed=1)
|
||
assert track[0]["x"] == 0.0 or abs(track[0]["x"]) < 1e-6
|
||
|
||
def test_track_ends_near_distance(self):
|
||
from utils.slide_utils import generate_slide_track
|
||
|
||
distance = 120
|
||
track = generate_slide_track(distance, seed=2)
|
||
final_x = track[-1]["x"]
|
||
assert abs(final_x - distance) < 1.0, f"final x={final_x}, expected ~{distance}"
|
||
|
||
def test_track_time_increases(self):
|
||
from utils.slide_utils import generate_slide_track
|
||
|
||
track = generate_slide_track(100, seed=3)
|
||
for i in range(1, len(track)):
|
||
assert track[i]["t"] >= track[i - 1]["t"]
|
||
|
||
def test_track_y_has_jitter(self):
|
||
from utils.slide_utils import generate_slide_track
|
||
|
||
track = generate_slide_track(200, seed=4)
|
||
y_vals = [pt["y"] for pt in track]
|
||
# At least some y values should be non-zero (jitter)
|
||
assert any(abs(y) > 0 for y in y_vals)
|