Add tests, server, resume training, and project cleanup

- Add 57 unit tests covering generators, models, and pipeline components
- Implement FastAPI HTTP service (server.py) with POST /solve and GET /health
- Add checkpoint resume (断点续训) to both CTC and regression training utils
- Fix device mismatch bug in CTC training (targets/input_lengths on GPU)
- Add pytest dev dependency to pyproject.toml
- Update .gitignore with data/solver/, data/real/, *.log
- Remove PyCharm template main.py
- Update training/__init__.py docs for solver training scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hua
2026-03-11 19:05:47 +08:00
parent 9b5f29083e
commit 788ddcae1a
11 changed files with 786 additions and 21 deletions

176
tests/test_pipeline.py Normal file
View File

@@ -0,0 +1,176 @@
"""
测试推理流水线组件。
- math_eval: 加减乘除正确性 + 异常输入
- CTC greedy decode (构造 logits)
- SlideSolver (合成图 → OpenCV 检测)
- generate_slide_track 轨迹合理性
"""
import math
import numpy as np
import pytest
from inference.math_eval import eval_captcha_math
from inference.pipeline import CaptchaPipeline
# ============================================================
# math_eval 测试
# ============================================================
class TestMathEval:
def test_addition(self):
assert eval_captcha_math("3+8=?") == "11"
assert eval_captcha_math("12+5") == "17"
assert eval_captcha_math("0+0=?") == "0"
def test_subtraction(self):
assert eval_captcha_math("15-7=?") == "8"
assert eval_captcha_math("20-20") == "0"
def test_multiplication(self):
assert eval_captcha_math("12×3=?") == "36"
assert eval_captcha_math("5*4") == "20"
assert eval_captcha_math("6x7") == "42"
assert eval_captcha_math("6X7") == "42"
def test_division(self):
assert eval_captcha_math("20÷4=?") == "5"
assert eval_captcha_math("9÷3") == "3"
def test_division_by_zero(self):
with pytest.raises(ValueError, match="除数为零"):
eval_captcha_math("5÷0=?")
def test_invalid_expression(self):
with pytest.raises(ValueError, match="无法解析"):
eval_captcha_math("abc")
def test_with_spaces(self):
assert eval_captcha_math("3 + 8 = ?") == "11"
# ============================================================
# CTC greedy decode 测试
# ============================================================
class TestCTCGreedyDecode:
"""Test the static _ctc_greedy_decode method from CaptchaPipeline."""
def test_simple_decode(self):
chars = "ABC" # index 0=blank, 1=A, 2=B, 3=C
T = 6
C = 4 # blank + 3 chars
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
# Spell out "AB": A, A, blank, B, B, B
logits[0, 0, 1] = 10.0 # A
logits[1, 0, 1] = 10.0 # A (dup, collapsed)
logits[2, 0, 0] = 10.0 # blank
logits[3, 0, 2] = 10.0 # B
logits[4, 0, 2] = 10.0 # B (dup)
logits[5, 0, 2] = 10.0 # B (dup)
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == "AB"
def test_all_blank(self):
chars = "ABC"
T = 5
C = 4
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
for t in range(T):
logits[t, 0, 0] = 10.0
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == ""
def test_repeated_chars_with_blank_separator(self):
chars = "ABC"
T = 5
C = 4
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
# Spell "AA": A, blank, A, blank, blank
logits[0, 0, 1] = 10.0 # A
logits[1, 0, 0] = 10.0 # blank
logits[2, 0, 1] = 10.0 # A
logits[3, 0, 0] = 10.0 # blank
logits[4, 0, 0] = 10.0 # blank
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
assert result == "AA"
# ============================================================
# SlideSolver 测试
# ============================================================
class TestSlideSolver:
def test_solve_with_synthetic_image(self):
"""Generate a synthetic slide image and verify the solver detects a gap."""
try:
import cv2
except ImportError:
pytest.skip("OpenCV not installed")
from generators.slide_gen import SlideDataGenerator
from solvers.slide_solver import SlideSolver
gen = SlideDataGenerator(seed=42)
img, label = gen.generate()
expected_gap_x = int(label)
solver = SlideSolver()
result = solver.solve(img)
assert "gap_x" in result
assert "gap_x_percent" in result
assert "confidence" in result
assert "method" in result
assert isinstance(result["gap_x"], int)
assert 0.0 <= result["gap_x_percent"] <= 1.0
# ============================================================
# generate_slide_track 测试
# ============================================================
class TestSlideTrack:
def test_track_basic(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=42)
assert isinstance(track, list)
assert len(track) >= 10
def test_track_point_structure(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(150, seed=0)
for pt in track:
assert "x" in pt
assert "y" in pt
assert "t" in pt
def test_track_starts_at_origin(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=1)
assert track[0]["x"] == 0.0 or abs(track[0]["x"]) < 1e-6
def test_track_ends_near_distance(self):
from utils.slide_utils import generate_slide_track
distance = 120
track = generate_slide_track(distance, seed=2)
final_x = track[-1]["x"]
assert abs(final_x - distance) < 1.0, f"final x={final_x}, expected ~{distance}"
def test_track_time_increases(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(100, seed=3)
for i in range(1, len(track)):
assert track[i]["t"] >= track[i - 1]["t"]
def test_track_y_has_jitter(self):
from utils.slide_utils import generate_slide_track
track = generate_slide_track(200, seed=4)
y_vals = [pt["y"] for pt in track]
# At least some y values should be non-zero (jitter)
assert any(abs(y) > 0 for y in y_vals)