""" 测试推理流水线组件。 - math_eval: 加减乘除正确性 + 异常输入 - CTC greedy decode (构造 logits) - SlideSolver (合成图 → OpenCV 检测) - generate_slide_track 轨迹合理性 """ import math from pathlib import Path import numpy as np import pytest from PIL import Image from inference.math_eval import eval_captcha_math from inference.model_metadata import write_model_metadata import inference.pipeline as pipeline_module from inference.pipeline import CaptchaPipeline # ============================================================ # math_eval 测试 # ============================================================ class TestMathEval: def test_addition(self): assert eval_captcha_math("3+8=?") == "11" assert eval_captcha_math("12+5") == "17" assert eval_captcha_math("0+0=?") == "0" def test_subtraction(self): assert eval_captcha_math("15-7=?") == "8" assert eval_captcha_math("20-20") == "0" def test_multiplication(self): assert eval_captcha_math("12×3=?") == "36" assert eval_captcha_math("5*4") == "20" assert eval_captcha_math("6x7") == "42" assert eval_captcha_math("6X7") == "42" def test_division(self): assert eval_captcha_math("20÷4=?") == "5" assert eval_captcha_math("9÷3") == "3" def test_division_by_zero(self): with pytest.raises(ValueError, match="除数为零"): eval_captcha_math("5÷0=?") def test_invalid_expression(self): with pytest.raises(ValueError, match="无法解析"): eval_captcha_math("abc") def test_with_spaces(self): assert eval_captcha_math("3 + 8 = ?") == "11" # ============================================================ # CTC greedy decode 测试 # ============================================================ class TestCTCGreedyDecode: """Test the static _ctc_greedy_decode method from CaptchaPipeline.""" def test_simple_decode(self): chars = "ABC" # index 0=blank, 1=A, 2=B, 3=C T = 6 C = 4 # blank + 3 chars logits = np.full((T, 1, C), -10.0, dtype=np.float32) # Spell out "AB": A, A, blank, B, B, B logits[0, 0, 1] = 10.0 # A logits[1, 0, 1] = 10.0 # A (dup, collapsed) logits[2, 0, 0] = 10.0 # blank logits[3, 0, 2] = 10.0 # B logits[4, 0, 2] = 10.0 # B (dup) logits[5, 0, 2] = 10.0 # B (dup) result = CaptchaPipeline._ctc_greedy_decode(logits, chars) assert result == "AB" def test_all_blank(self): chars = "ABC" T = 5 C = 4 logits = np.full((T, 1, C), -10.0, dtype=np.float32) for t in range(T): logits[t, 0, 0] = 10.0 result = CaptchaPipeline._ctc_greedy_decode(logits, chars) assert result == "" def test_repeated_chars_with_blank_separator(self): chars = "ABC" T = 5 C = 4 logits = np.full((T, 1, C), -10.0, dtype=np.float32) # Spell "AA": A, blank, A, blank, blank logits[0, 0, 1] = 10.0 # A logits[1, 0, 0] = 10.0 # blank logits[2, 0, 1] = 10.0 # A logits[3, 0, 0] = 10.0 # blank logits[4, 0, 0] = 10.0 # blank result = CaptchaPipeline._ctc_greedy_decode(logits, chars) assert result == "AA" class _FakeInput: name = "input" class _FakeSession: def __init__(self, path, *args, **kwargs): self.model_name = Path(path).name def get_inputs(self): return [_FakeInput()] def run(self, output_names, feed_dict): if self.model_name == "classifier.onnx": return [np.array([[0.1, 0.9]], dtype=np.float32)] if self.model_name == "normal.onnx": logits = np.full((2, 1, 4), -10.0, dtype=np.float32) logits[0, 0, 2] = 10.0 logits[1, 0, 0] = 10.0 return [logits] if self.model_name == "threed_rotate.onnx": return [np.array([[0.25]], dtype=np.float32)] raise AssertionError(f"unexpected fake session: {self.model_name}") class _FakeSessionOptions: def __init__(self): self.inter_op_num_threads = 0 self.intra_op_num_threads = 0 class _FakeOrt: SessionOptions = _FakeSessionOptions InferenceSession = _FakeSession class TestPipelineMetadata: def test_classifier_uses_metadata_class_order(self, tmp_path, monkeypatch): (tmp_path / "classifier.onnx").touch() write_model_metadata( tmp_path / "classifier.onnx", { "model_name": "classifier", "task": "classifier", "class_names": ["math", "normal"], "input_shape": [1, 64, 128], }, ) monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt) pipeline = CaptchaPipeline(models_dir=tmp_path) captcha_type = pipeline.classify(Image.new("RGB", (32, 32), color="white")) assert captcha_type == "normal" def test_solve_uses_ctc_chars_metadata(self, tmp_path, monkeypatch): (tmp_path / "normal.onnx").touch() write_model_metadata( tmp_path / "normal.onnx", { "model_name": "normal", "task": "ctc", "chars": "XYZ", "input_shape": [1, 40, 120], }, ) monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt) pipeline = CaptchaPipeline(models_dir=tmp_path) result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="normal") assert result["raw"] == "Y" assert result["result"] == "Y" def test_solve_uses_regression_label_range_metadata(self, tmp_path, monkeypatch): (tmp_path / "threed_rotate.onnx").touch() write_model_metadata( tmp_path / "threed_rotate.onnx", { "model_name": "threed_rotate", "task": "regression", "label_range": [100, 200], "input_shape": [1, 80, 80], }, ) monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt) pipeline = CaptchaPipeline(models_dir=tmp_path) result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="3d_rotate") assert result["raw"] == "125.0" assert result["result"] == "125" # ============================================================ # SlideSolver 测试 # ============================================================ class TestSlideSolver: def test_solve_with_synthetic_image(self): """Generate a synthetic slide image and verify the solver detects a gap.""" try: import cv2 except ImportError: pytest.skip("OpenCV not installed") from generators.slide_gen import SlideDataGenerator from solvers.slide_solver import SlideSolver gen = SlideDataGenerator(seed=42) img, label = gen.generate() expected_gap_x = int(label) solver = SlideSolver() result = solver.solve(img) assert "gap_x" in result assert "gap_x_percent" in result assert "confidence" in result assert "method" in result assert isinstance(result["gap_x"], int) assert 0.0 <= result["gap_x_percent"] <= 1.0 # ============================================================ # generate_slide_track 测试 # ============================================================ class TestSlideTrack: def test_track_basic(self): from utils.slide_utils import generate_slide_track track = generate_slide_track(100, seed=42) assert isinstance(track, list) assert len(track) >= 10 def test_track_point_structure(self): from utils.slide_utils import generate_slide_track track = generate_slide_track(150, seed=0) for pt in track: assert "x" in pt assert "y" in pt assert "t" in pt def test_track_starts_at_origin(self): from utils.slide_utils import generate_slide_track track = generate_slide_track(100, seed=1) assert track[0]["x"] == 0.0 or abs(track[0]["x"]) < 1e-6 def test_track_ends_near_distance(self): from utils.slide_utils import generate_slide_track distance = 120 track = generate_slide_track(distance, seed=2) final_x = track[-1]["x"] assert abs(final_x - distance) < 1.0, f"final x={final_x}, expected ~{distance}" def test_track_time_increases(self): from utils.slide_utils import generate_slide_track track = generate_slide_track(100, seed=3) for i in range(1, len(track)): assert track[i]["t"] >= track[i - 1]["t"] def test_track_y_has_jitter(self): from utils.slide_utils import generate_slide_track track = generate_slide_track(200, seed=4) y_vals = [pt["y"] for pt in track] # At least some y values should be non-zero (jitter) assert any(abs(y) > 0 for y in y_vals)