Files
CaptchBreaker/tests/test_pipeline.py
Hua 788ddcae1a Add tests, server, resume training, and project cleanup
- Add 57 unit tests covering generators, models, and pipeline components
- Implement FastAPI HTTP service (server.py) with POST /solve and GET /health
- Add checkpoint resume (断点续训) to both CTC and regression training utils
- Fix device mismatch bug in CTC training (targets/input_lengths on GPU)
- Add pytest dev dependency to pyproject.toml
- Update .gitignore with data/solver/, data/real/, *.log
- Remove PyCharm template main.py
- Update training/__init__.py docs for solver training scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 19:05:47 +08:00

177 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
测试推理流水线组件。
- math_eval: 加减乘除正确性 + 异常输入
- CTC greedy decode (构造 logits)
- SlideSolver (合成图 → OpenCV 检测)
- generate_slide_track 轨迹合理性
"""
import math
import numpy as np
import pytest
from inference.math_eval import eval_captcha_math
from inference.pipeline import CaptchaPipeline
# ============================================================
# math_eval 测试
# ============================================================
class TestMathEval:
def test_addition(self):
assert eval_captcha_math("3+8=?") == "11"
assert eval_captcha_math("12+5") == "17"
assert eval_captcha_math("0+0=?") == "0"
def test_subtraction(self):
assert eval_captcha_math("15-7=?") == "8"
assert eval_captcha_math("20-20") == "0"
def test_multiplication(self):
assert eval_captcha_math("12×3=?") == "36"
assert eval_captcha_math("5*4") == "20"
assert eval_captcha_math("6x7") == "42"
assert eval_captcha_math("6X7") == "42"
def test_division(self):
assert eval_captcha_math("20÷4=?") == "5"
assert eval_captcha_math("9÷3") == "3"
def test_division_by_zero(self):
with pytest.raises(ValueError, match="除数为零"):
eval_captcha_math("5÷0=?")
def test_invalid_expression(self):
with pytest.raises(ValueError, match="无法解析"):
eval_captcha_math("abc")
def test_with_spaces(self):
assert eval_captcha_math("3 + 8 = ?") == "11"
# ============================================================
# CTC greedy decode 测试
# ============================================================
class TestCTCGreedyDecode:
"""Test the static _ctc_greedy_decode method from CaptchaPipeline."""
def test_simple_decode(self):
chars = "ABC" # index 0=blank, 1=A, 2=B, 3=C
T = 6
C = 4 # blank + 3 chars
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
# Spell out "AB": A, A, blank, B, B, B
logits[0, 0, 1] = 10.0 # A
logits[1, 0, 1] = 10.0 # A (dup, collapsed)
logits[2, 0, 0] = 10.0 # blank
logits[3, 0, 2] = 10.0 # B
logits[4, 0, 2] = 10.0 # B (dup)
logits[5, 0, 2] = 10.0 # B (dup)
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == "AB"
def test_all_blank(self):
chars = "ABC"
T = 5
C = 4
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
for t in range(T):
logits[t, 0, 0] = 10.0
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == ""
def test_repeated_chars_with_blank_separator(self):
chars = "ABC"
T = 5
C = 4
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
# Spell "AA": A, blank, A, blank, blank
logits[0, 0, 1] = 10.0 # A
logits[1, 0, 0] = 10.0 # blank
logits[2, 0, 1] = 10.0 # A
logits[3, 0, 0] = 10.0 # blank
logits[4, 0, 0] = 10.0 # blank
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == "AA"
# ============================================================
# SlideSolver 测试
# ============================================================
class TestSlideSolver:
def test_solve_with_synthetic_image(self):
"""Generate a synthetic slide image and verify the solver detects a gap."""
try:
import cv2
except ImportError:
pytest.skip("OpenCV not installed")
from generators.slide_gen import SlideDataGenerator
from solvers.slide_solver import SlideSolver
gen = SlideDataGenerator(seed=42)
img, label = gen.generate()
expected_gap_x = int(label)
solver = SlideSolver()
result = solver.solve(img)
assert "gap_x" in result
assert "gap_x_percent" in result
assert "confidence" in result
assert "method" in result
assert isinstance(result["gap_x"], int)
assert 0.0 <= result["gap_x_percent"] <= 1.0
# ============================================================
# generate_slide_track 测试
# ============================================================
class TestSlideTrack:
def test_track_basic(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=42)
assert isinstance(track, list)
assert len(track) >= 10
def test_track_point_structure(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(150, seed=0)
for pt in track:
assert "x" in pt
assert "y" in pt
assert "t" in pt
def test_track_starts_at_origin(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=1)
assert track[0]["x"] == 0.0 or abs(track[0]["x"]) < 1e-6
def test_track_ends_near_distance(self):
from utils.slide_utils import generate_slide_track
distance = 120
track = generate_slide_track(distance, seed=2)
final_x = track[-1]["x"]
assert abs(final_x - distance) < 1.0, f"final x={final_x}, expected ~{distance}"
def test_track_time_increases(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=3)
for i in range(1, len(track)):
assert track[i]["t"] >= track[i - 1]["t"]
def test_track_y_has_jitter(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(200, seed=4)
y_vals = [pt["y"] for pt in track]
# At least some y values should be non-zero (jitter)
assert any(abs(y) > 0 for y in y_vals)