Files
CaptchBreaker/tests/test_pipeline.py

271 lines
8.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
测试推理流水线组件。
- math_eval: 加减乘除正确性 + 异常输入
- CTC greedy decode (构造 logits)
- SlideSolver (合成图 → OpenCV 检测)
- generate_slide_track 轨迹合理性
"""
import math
from pathlib import Path
import numpy as np
import pytest
from PIL import Image
from inference.math_eval import eval_captcha_math
from inference.model_metadata import write_model_metadata
import inference.pipeline as pipeline_module
from inference.pipeline import CaptchaPipeline
# ============================================================
# math_eval 测试
# ============================================================
class TestMathEval:
def test_addition(self):
assert eval_captcha_math("3+8=?") == "11"
assert eval_captcha_math("12+5") == "17"
assert eval_captcha_math("0+0=?") == "0"
def test_subtraction(self):
assert eval_captcha_math("15-7=?") == "8"
assert eval_captcha_math("20-20") == "0"
def test_multiplication(self):
assert eval_captcha_math("12×3=?") == "36"
assert eval_captcha_math("5*4") == "20"
assert eval_captcha_math("6x7") == "42"
assert eval_captcha_math("6X7") == "42"
def test_division(self):
assert eval_captcha_math("20÷4=?") == "5"
assert eval_captcha_math("9÷3") == "3"
def test_division_by_zero(self):
with pytest.raises(ValueError, match="除数为零"):
eval_captcha_math("5÷0=?")
def test_invalid_expression(self):
with pytest.raises(ValueError, match="无法解析"):
eval_captcha_math("abc")
def test_with_spaces(self):
assert eval_captcha_math("3 + 8 = ?") == "11"
# ============================================================
# CTC greedy decode 测试
# ============================================================
class TestCTCGreedyDecode:
"""Test the static _ctc_greedy_decode method from CaptchaPipeline."""
def test_simple_decode(self):
chars = "ABC" # index 0=blank, 1=A, 2=B, 3=C
T = 6
C = 4 # blank + 3 chars
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
# Spell out "AB": A, A, blank, B, B, B
logits[0, 0, 1] = 10.0 # A
logits[1, 0, 1] = 10.0 # A (dup, collapsed)
logits[2, 0, 0] = 10.0 # blank
logits[3, 0, 2] = 10.0 # B
logits[4, 0, 2] = 10.0 # B (dup)
logits[5, 0, 2] = 10.0 # B (dup)
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == "AB"
def test_all_blank(self):
chars = "ABC"
T = 5
C = 4
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
for t in range(T):
logits[t, 0, 0] = 10.0
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == ""
def test_repeated_chars_with_blank_separator(self):
chars = "ABC"
T = 5
C = 4
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
# Spell "AA": A, blank, A, blank, blank
logits[0, 0, 1] = 10.0 # A
logits[1, 0, 0] = 10.0 # blank
logits[2, 0, 1] = 10.0 # A
logits[3, 0, 0] = 10.0 # blank
logits[4, 0, 0] = 10.0 # blank
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == "AA"
class _FakeInput:
name = "input"
class _FakeSession:
def __init__(self, path, *args, **kwargs):
self.model_name = Path(path).name
def get_inputs(self):
return [_FakeInput()]
def run(self, output_names, feed_dict):
if self.model_name == "classifier.onnx":
return [np.array([[0.1, 0.9]], dtype=np.float32)]
if self.model_name == "normal.onnx":
logits = np.full((2, 1, 4), -10.0, dtype=np.float32)
logits[0, 0, 2] = 10.0
logits[1, 0, 0] = 10.0
return [logits]
if self.model_name == "threed_rotate.onnx":
return [np.array([[0.25]], dtype=np.float32)]
raise AssertionError(f"unexpected fake session: {self.model_name}")
class _FakeSessionOptions:
def __init__(self):
self.inter_op_num_threads = 0
self.intra_op_num_threads = 0
class _FakeOrt:
SessionOptions = _FakeSessionOptions
InferenceSession = _FakeSession
class TestPipelineMetadata:
def test_classifier_uses_metadata_class_order(self, tmp_path, monkeypatch):
(tmp_path / "classifier.onnx").touch()
write_model_metadata(
tmp_path / "classifier.onnx",
{
"model_name": "classifier",
"task": "classifier",
"class_names": ["math", "normal"],
"input_shape": [1, 64, 128],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
captcha_type = pipeline.classify(Image.new("RGB", (32, 32), color="white"))
assert captcha_type == "normal"
def test_solve_uses_ctc_chars_metadata(self, tmp_path, monkeypatch):
(tmp_path / "normal.onnx").touch()
write_model_metadata(
tmp_path / "normal.onnx",
{
"model_name": "normal",
"task": "ctc",
"chars": "XYZ",
"input_shape": [1, 40, 120],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="normal")
assert result["raw"] == "Y"
assert result["result"] == "Y"
def test_solve_uses_regression_label_range_metadata(self, tmp_path, monkeypatch):
(tmp_path / "threed_rotate.onnx").touch()
write_model_metadata(
tmp_path / "threed_rotate.onnx",
{
"model_name": "threed_rotate",
"task": "regression",
"label_range": [100, 200],
"input_shape": [1, 80, 80],
},
)
monkeypatch.setattr(pipeline_module, "_try_import_ort", lambda: _FakeOrt)
pipeline = CaptchaPipeline(models_dir=tmp_path)
result = pipeline.solve(Image.new("RGB", (32, 32), color="white"), captcha_type="3d_rotate")
assert result["raw"] == "125.0"
assert result["result"] == "125"
# ============================================================
# SlideSolver 测试
# ============================================================
class TestSlideSolver:
def test_solve_with_synthetic_image(self):
"""Generate a synthetic slide image and verify the solver detects a gap."""
try:
import cv2
except ImportError:
pytest.skip("OpenCV not installed")
from generators.slide_gen import SlideDataGenerator
from solvers.slide_solver import SlideSolver
gen = SlideDataGenerator(seed=42)
img, label = gen.generate()
expected_gap_x = int(label)
solver = SlideSolver()
result = solver.solve(img)
assert "gap_x" in result
assert "gap_x_percent" in result
assert "confidence" in result
assert "method" in result
assert isinstance(result["gap_x"], int)
assert 0.0 <= result["gap_x_percent"] <= 1.0
# ============================================================
# generate_slide_track 测试
# ============================================================
class TestSlideTrack:
def test_track_basic(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=42)
assert isinstance(track, list)
assert len(track) >= 10
def test_track_point_structure(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(150, seed=0)
for pt in track:
assert "x" in pt
assert "y" in pt
assert "t" in pt
def test_track_starts_at_origin(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=1)
assert track[0]["x"] == 0.0 or abs(track[0]["x"]) < 1e-6
def test_track_ends_near_distance(self):
from utils.slide_utils import generate_slide_track
distance = 120
track = generate_slide_track(distance, seed=2)
final_x = track[-1]["x"]
assert abs(final_x - distance) < 1.0, f"final x={final_x}, expected ~{distance}"
def test_track_time_increases(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=3)
for i in range(1, len(track)):
assert track[i]["t"] >= track[i - 1]["t"]
def test_track_y_has_jitter(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(200, seed=4)
y_vals = [pt["y"] for pt in track]
# At least some y values should be non-zero (jitter)
assert any(abs(y) > 0 for y in y_vals)