Add tests, server, resume training, and project cleanup
- Add 57 unit tests covering generators, models, and pipeline components - Implement FastAPI HTTP service (server.py) with POST /solve and GET /health - Add checkpoint resume (断点续训) to both CTC and regression training utils - Fix device mismatch bug in CTC training (targets/input_lengths on GPU) - Add pytest dev dependency to pyproject.toml - Update .gitignore with data/solver/, data/real/, *.log - Remove PyCharm template main.py - Update training/__init__.py docs for solver training scripts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
173
tests/test_generators.py
Normal file
173
tests/test_generators.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
测试所有验证码生成器。
|
||||
|
||||
每种生成器 generate() 1 张 → 验证返回类型、图片尺寸、标签格式。
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from config import GENERATE_CONFIG, NORMAL_CHARS, MATH_CHARS, THREED_CHARS, SOLVER_CONFIG
|
||||
from generators import (
|
||||
NormalCaptchaGenerator,
|
||||
MathCaptchaGenerator,
|
||||
ThreeDCaptchaGenerator,
|
||||
ThreeDRotateGenerator,
|
||||
ThreeDSliderGenerator,
|
||||
SlideDataGenerator,
|
||||
RotateSolverDataGenerator,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalCaptchaGenerator:
|
||||
def setup_method(self):
|
||||
self.gen = NormalCaptchaGenerator(seed=0)
|
||||
self.cfg = GENERATE_CONFIG["normal"]
|
||||
|
||||
def test_generate_returns_image_and_label(self):
|
||||
img, label = self.gen.generate()
|
||||
assert isinstance(img, Image.Image)
|
||||
assert isinstance(label, str)
|
||||
|
||||
def test_image_size(self):
|
||||
img, _ = self.gen.generate()
|
||||
w, h = self.cfg["image_size"]
|
||||
assert img.size == (w, h)
|
||||
|
||||
def test_label_chars_in_charset(self):
|
||||
img, label = self.gen.generate()
|
||||
assert len(label) >= 4
|
||||
for ch in label:
|
||||
assert ch in NORMAL_CHARS, f"char {ch!r} not in NORMAL_CHARS"
|
||||
|
||||
def test_generate_with_text(self):
|
||||
img, label = self.gen.generate(text="AB12")
|
||||
assert label == "AB12"
|
||||
|
||||
|
||||
class TestMathCaptchaGenerator:
|
||||
def setup_method(self):
|
||||
self.gen = MathCaptchaGenerator(seed=0)
|
||||
self.cfg = GENERATE_CONFIG["math"]
|
||||
|
||||
def test_generate_returns_image_and_label(self):
|
||||
img, label = self.gen.generate()
|
||||
assert isinstance(img, Image.Image)
|
||||
assert isinstance(label, str)
|
||||
|
||||
def test_image_size(self):
|
||||
img, _ = self.gen.generate()
|
||||
w, h = self.cfg["image_size"]
|
||||
assert img.size == (w, h)
|
||||
|
||||
def test_label_is_expression(self):
|
||||
"""Label should be like '3+8' (expression without =? and without result)."""
|
||||
img, label = self.gen.generate()
|
||||
assert re.match(r"^\d+[+\-×÷]\d+$", label), f"unexpected label format: {label!r}"
|
||||
|
||||
|
||||
class TestThreeDCaptchaGenerator:
|
||||
def setup_method(self):
|
||||
self.gen = ThreeDCaptchaGenerator(seed=0)
|
||||
self.cfg = GENERATE_CONFIG["3d_text"]
|
||||
|
||||
def test_generate_returns_image_and_label(self):
|
||||
img, label = self.gen.generate()
|
||||
assert isinstance(img, Image.Image)
|
||||
assert isinstance(label, str)
|
||||
|
||||
def test_image_size(self):
|
||||
img, _ = self.gen.generate()
|
||||
w, h = self.cfg["image_size"]
|
||||
assert img.size == (w, h)
|
||||
|
||||
def test_label_chars_in_charset(self):
|
||||
img, label = self.gen.generate()
|
||||
assert len(label) >= 4
|
||||
for ch in label:
|
||||
assert ch in THREED_CHARS, f"char {ch!r} not in THREED_CHARS"
|
||||
|
||||
|
||||
class TestThreeDRotateGenerator:
|
||||
def setup_method(self):
|
||||
self.gen = ThreeDRotateGenerator(seed=0)
|
||||
self.cfg = GENERATE_CONFIG["3d_rotate"]
|
||||
|
||||
def test_generate_returns_image_and_label(self):
|
||||
img, label = self.gen.generate()
|
||||
assert isinstance(img, Image.Image)
|
||||
assert isinstance(label, str)
|
||||
|
||||
def test_image_size(self):
|
||||
img, _ = self.gen.generate()
|
||||
w, h = self.cfg["image_size"]
|
||||
assert img.size == (w, h)
|
||||
|
||||
def test_label_is_angle(self):
|
||||
img, label = self.gen.generate()
|
||||
angle = int(label)
|
||||
assert 0 <= angle <= 359
|
||||
|
||||
|
||||
class TestThreeDSliderGenerator:
|
||||
def setup_method(self):
|
||||
self.gen = ThreeDSliderGenerator(seed=0)
|
||||
self.cfg = GENERATE_CONFIG["3d_slider"]
|
||||
|
||||
def test_generate_returns_image_and_label(self):
|
||||
img, label = self.gen.generate()
|
||||
assert isinstance(img, Image.Image)
|
||||
assert isinstance(label, str)
|
||||
|
||||
def test_image_size(self):
|
||||
img, _ = self.gen.generate()
|
||||
w, h = self.cfg["image_size"]
|
||||
assert img.size == (w, h)
|
||||
|
||||
def test_label_is_offset(self):
|
||||
img, label = self.gen.generate()
|
||||
offset = int(label)
|
||||
lo, hi = self.cfg["gap_x_range"]
|
||||
assert lo <= offset <= hi
|
||||
|
||||
|
||||
class TestSlideDataGenerator:
|
||||
def setup_method(self):
|
||||
self.gen = SlideDataGenerator(seed=0)
|
||||
|
||||
def test_generate_returns_image_and_label(self):
|
||||
img, label = self.gen.generate()
|
||||
assert isinstance(img, Image.Image)
|
||||
assert isinstance(label, str)
|
||||
|
||||
def test_image_size(self):
|
||||
img, _ = self.gen.generate()
|
||||
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
||||
assert img.size == (w, h)
|
||||
|
||||
def test_label_is_numeric(self):
|
||||
img, label = self.gen.generate()
|
||||
val = int(label)
|
||||
assert val >= 0
|
||||
|
||||
|
||||
class TestRotateSolverDataGenerator:
|
||||
def setup_method(self):
|
||||
self.gen = RotateSolverDataGenerator(seed=0)
|
||||
|
||||
def test_generate_returns_image_and_label(self):
|
||||
img, label = self.gen.generate()
|
||||
assert isinstance(img, Image.Image)
|
||||
assert isinstance(label, str)
|
||||
|
||||
def test_image_size(self):
|
||||
img, _ = self.gen.generate()
|
||||
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
||||
assert img.size == (w, h)
|
||||
|
||||
def test_label_is_angle(self):
|
||||
img, label = self.gen.generate()
|
||||
angle = int(label)
|
||||
assert 0 <= angle <= 359
|
||||
177
tests/test_models.py
Normal file
177
tests/test_models.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
测试所有模型前向传播和输出形状。
|
||||
|
||||
每种模型构造 → forward → 验证输出 shape。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from config import NORMAL_CHARS, MATH_CHARS, THREED_CHARS, IMAGE_SIZE, SOLVER_CONFIG
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
|
||||
|
||||
class TestCaptchaClassifier:
|
||||
def setup_method(self):
|
||||
self.model = CaptchaClassifier(num_types=5)
|
||||
self.model.eval()
|
||||
|
||||
def test_output_shape(self):
|
||||
h, w = IMAGE_SIZE["classifier"]
|
||||
x = torch.randn(2, 1, h, w)
|
||||
out = self.model(x)
|
||||
assert out.shape == (2, 5)
|
||||
|
||||
def test_single_batch(self):
|
||||
h, w = IMAGE_SIZE["classifier"]
|
||||
x = torch.randn(1, 1, h, w)
|
||||
out = self.model(x)
|
||||
assert out.shape == (1, 5)
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
n = sum(p.numel() for p in self.model.parameters())
|
||||
# Should be < 500KB ≈ 125K float32 params
|
||||
assert n < 200_000, f"too many params: {n}"
|
||||
|
||||
|
||||
class TestLiteCRNN:
|
||||
def setup_method(self):
|
||||
self.model = LiteCRNN(chars=NORMAL_CHARS)
|
||||
self.model.eval()
|
||||
|
||||
def test_output_shape(self):
|
||||
h, w = IMAGE_SIZE["normal"]
|
||||
x = torch.randn(2, 1, h, w)
|
||||
out = self.model(x)
|
||||
num_classes = len(NORMAL_CHARS) + 1 # +1 for blank
|
||||
seq_len = w // 4
|
||||
assert out.shape == (seq_len, 2, num_classes)
|
||||
|
||||
def test_greedy_decode(self):
|
||||
h, w = IMAGE_SIZE["normal"]
|
||||
x = torch.randn(1, 1, h, w)
|
||||
logits = self.model(x)
|
||||
decoded = self.model.greedy_decode(logits)
|
||||
assert isinstance(decoded, list)
|
||||
assert len(decoded) == 1
|
||||
assert isinstance(decoded[0], str)
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
n = sum(p.numel() for p in self.model.parameters())
|
||||
# Should be < 2MB ≈ 500K float32 params
|
||||
assert n < 600_000, f"too many params: {n}"
|
||||
|
||||
def test_math_mode(self):
|
||||
h, w = IMAGE_SIZE["math"]
|
||||
model = LiteCRNN(chars=MATH_CHARS, img_h=h, img_w=w)
|
||||
model.eval()
|
||||
x = torch.randn(1, 1, h, w)
|
||||
out = model(x)
|
||||
num_classes = len(MATH_CHARS) + 1
|
||||
seq_len = w // 4
|
||||
assert out.shape == (seq_len, 1, num_classes)
|
||||
|
||||
|
||||
class TestThreeDCNN:
|
||||
def setup_method(self):
|
||||
h, w = IMAGE_SIZE["3d_text"]
|
||||
self.model = ThreeDCNN(chars=THREED_CHARS, img_h=h, img_w=w)
|
||||
self.model.eval()
|
||||
|
||||
def test_output_shape(self):
|
||||
h, w = IMAGE_SIZE["3d_text"]
|
||||
x = torch.randn(2, 1, h, w)
|
||||
out = self.model(x)
|
||||
num_classes = len(THREED_CHARS) + 1
|
||||
seq_len = w // 4
|
||||
assert out.shape == (seq_len, 2, num_classes)
|
||||
|
||||
def test_greedy_decode(self):
|
||||
h, w = IMAGE_SIZE["3d_text"]
|
||||
x = torch.randn(1, 1, h, w)
|
||||
logits = self.model(x)
|
||||
decoded = self.model.greedy_decode(logits)
|
||||
assert isinstance(decoded, list)
|
||||
assert len(decoded) == 1
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
n = sum(p.numel() for p in self.model.parameters())
|
||||
# Should be < 5MB ≈ 1.25M float32 params
|
||||
assert n < 1_500_000, f"too many params: {n}"
|
||||
|
||||
|
||||
class TestRegressionCNN:
|
||||
def test_3d_rotate_shape(self):
|
||||
h, w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
model.eval()
|
||||
x = torch.randn(2, 1, h, w)
|
||||
out = model(x)
|
||||
assert out.shape == (2, 1)
|
||||
# Output should be sigmoid [0, 1]
|
||||
assert out.min() >= 0.0
|
||||
assert out.max() <= 1.0
|
||||
|
||||
def test_3d_slider_shape(self):
|
||||
h, w = IMAGE_SIZE["3d_slider"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
model.eval()
|
||||
x = torch.randn(2, 1, h, w)
|
||||
out = model(x)
|
||||
assert out.shape == (2, 1)
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
h, w = IMAGE_SIZE["3d_rotate"]
|
||||
model = RegressionCNN(img_h=h, img_w=w)
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
# Should be ~1MB ≈ 250K float32 params
|
||||
assert n < 400_000, f"too many params: {n}"
|
||||
|
||||
|
||||
class TestGapDetectorCNN:
|
||||
def setup_method(self):
|
||||
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
||||
self.model = GapDetectorCNN(img_h=h, img_w=w)
|
||||
self.model.eval()
|
||||
|
||||
def test_output_shape(self):
|
||||
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
||||
x = torch.randn(2, 1, h, w)
|
||||
out = self.model(x)
|
||||
assert out.shape == (2, 1)
|
||||
assert out.min() >= 0.0
|
||||
assert out.max() <= 1.0
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
n = sum(p.numel() for p in self.model.parameters())
|
||||
assert n < 400_000, f"too many params: {n}"
|
||||
|
||||
|
||||
class TestRotationRegressor:
|
||||
def setup_method(self):
|
||||
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
||||
self.model = RotationRegressor(img_h=h, img_w=w)
|
||||
self.model.eval()
|
||||
|
||||
def test_output_shape(self):
|
||||
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
||||
x = torch.randn(2, 3, h, w) # RGB, 3 channels
|
||||
out = self.model(x)
|
||||
assert out.shape == (2, 2) # (sin, cos)
|
||||
|
||||
def test_output_range_tanh(self):
|
||||
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
||||
x = torch.randn(4, 3, h, w)
|
||||
out = self.model(x)
|
||||
assert out.min() >= -1.0
|
||||
assert out.max() <= 1.0
|
||||
|
||||
def test_param_count_reasonable(self):
|
||||
n = sum(p.numel() for p in self.model.parameters())
|
||||
# Should be ~2MB ≈ 500K float32 params
|
||||
assert n < 600_000, f"too many params: {n}"
|
||||
176
tests/test_pipeline.py
Normal file
176
tests/test_pipeline.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
测试推理流水线组件。
|
||||
|
||||
- math_eval: 加减乘除正确性 + 异常输入
|
||||
- CTC greedy decode (构造 logits)
|
||||
- SlideSolver (合成图 → OpenCV 检测)
|
||||
- generate_slide_track 轨迹合理性
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from inference.math_eval import eval_captcha_math
|
||||
from inference.pipeline import CaptchaPipeline
|
||||
|
||||
|
||||
# ============================================================
|
||||
# math_eval 测试
|
||||
# ============================================================
|
||||
class TestMathEval:
|
||||
def test_addition(self):
|
||||
assert eval_captcha_math("3+8=?") == "11"
|
||||
assert eval_captcha_math("12+5") == "17"
|
||||
assert eval_captcha_math("0+0=?") == "0"
|
||||
|
||||
def test_subtraction(self):
|
||||
assert eval_captcha_math("15-7=?") == "8"
|
||||
assert eval_captcha_math("20-20") == "0"
|
||||
|
||||
def test_multiplication(self):
|
||||
assert eval_captcha_math("12×3=?") == "36"
|
||||
assert eval_captcha_math("5*4") == "20"
|
||||
assert eval_captcha_math("6x7") == "42"
|
||||
assert eval_captcha_math("6X7") == "42"
|
||||
|
||||
def test_division(self):
|
||||
assert eval_captcha_math("20÷4=?") == "5"
|
||||
assert eval_captcha_math("9÷3") == "3"
|
||||
|
||||
def test_division_by_zero(self):
|
||||
with pytest.raises(ValueError, match="除数为零"):
|
||||
eval_captcha_math("5÷0=?")
|
||||
|
||||
def test_invalid_expression(self):
|
||||
with pytest.raises(ValueError, match="无法解析"):
|
||||
eval_captcha_math("abc")
|
||||
|
||||
def test_with_spaces(self):
|
||||
assert eval_captcha_math("3 + 8 = ?") == "11"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CTC greedy decode 测试
|
||||
# ============================================================
|
||||
class TestCTCGreedyDecode:
|
||||
"""Test the static _ctc_greedy_decode method from CaptchaPipeline."""
|
||||
|
||||
def test_simple_decode(self):
|
||||
chars = "ABC" # index 0=blank, 1=A, 2=B, 3=C
|
||||
T = 6
|
||||
C = 4 # blank + 3 chars
|
||||
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
|
||||
# Spell out "AB": A, A, blank, B, B, B
|
||||
logits[0, 0, 1] = 10.0 # A
|
||||
logits[1, 0, 1] = 10.0 # A (dup, collapsed)
|
||||
logits[2, 0, 0] = 10.0 # blank
|
||||
logits[3, 0, 2] = 10.0 # B
|
||||
logits[4, 0, 2] = 10.0 # B (dup)
|
||||
logits[5, 0, 2] = 10.0 # B (dup)
|
||||
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
|
||||
assert result == "AB"
|
||||
|
||||
def test_all_blank(self):
|
||||
chars = "ABC"
|
||||
T = 5
|
||||
C = 4
|
||||
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
|
||||
for t in range(T):
|
||||
logits[t, 0, 0] = 10.0
|
||||
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
|
||||
assert result == ""
|
||||
|
||||
def test_repeated_chars_with_blank_separator(self):
|
||||
chars = "ABC"
|
||||
T = 5
|
||||
C = 4
|
||||
logits = np.full((T, 1, C), -10.0, dtype=np.float32)
|
||||
# Spell "AA": A, blank, A, blank, blank
|
||||
logits[0, 0, 1] = 10.0 # A
|
||||
logits[1, 0, 0] = 10.0 # blank
|
||||
logits[2, 0, 1] = 10.0 # A
|
||||
logits[3, 0, 0] = 10.0 # blank
|
||||
logits[4, 0, 0] = 10.0 # blank
|
||||
result = CaptchaPipeline._ctc_greedy_decode(logits, chars)
|
||||
assert result == "AA"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# SlideSolver 测试
|
||||
# ============================================================
|
||||
class TestSlideSolver:
|
||||
def test_solve_with_synthetic_image(self):
|
||||
"""Generate a synthetic slide image and verify the solver detects a gap."""
|
||||
try:
|
||||
import cv2
|
||||
except ImportError:
|
||||
pytest.skip("OpenCV not installed")
|
||||
|
||||
from generators.slide_gen import SlideDataGenerator
|
||||
from solvers.slide_solver import SlideSolver
|
||||
|
||||
gen = SlideDataGenerator(seed=42)
|
||||
img, label = gen.generate()
|
||||
expected_gap_x = int(label)
|
||||
|
||||
solver = SlideSolver()
|
||||
result = solver.solve(img)
|
||||
|
||||
assert "gap_x" in result
|
||||
assert "gap_x_percent" in result
|
||||
assert "confidence" in result
|
||||
assert "method" in result
|
||||
assert isinstance(result["gap_x"], int)
|
||||
assert 0.0 <= result["gap_x_percent"] <= 1.0
|
||||
|
||||
|
||||
# ============================================================
|
||||
# generate_slide_track 测试
|
||||
# ============================================================
|
||||
class TestSlideTrack:
|
||||
def test_track_basic(self):
|
||||
from utils.slide_utils import generate_slide_track
|
||||
|
||||
track = generate_slide_track(100, seed=42)
|
||||
assert isinstance(track, list)
|
||||
assert len(track) >= 10
|
||||
|
||||
def test_track_point_structure(self):
|
||||
from utils.slide_utils import generate_slide_track
|
||||
|
||||
track = generate_slide_track(150, seed=0)
|
||||
for pt in track:
|
||||
assert "x" in pt
|
||||
assert "y" in pt
|
||||
assert "t" in pt
|
||||
|
||||
def test_track_starts_at_origin(self):
|
||||
from utils.slide_utils import generate_slide_track
|
||||
|
||||
track = generate_slide_track(100, seed=1)
|
||||
assert track[0]["x"] == 0.0 or abs(track[0]["x"]) < 1e-6
|
||||
|
||||
def test_track_ends_near_distance(self):
|
||||
from utils.slide_utils import generate_slide_track
|
||||
|
||||
distance = 120
|
||||
track = generate_slide_track(distance, seed=2)
|
||||
final_x = track[-1]["x"]
|
||||
assert abs(final_x - distance) < 1.0, f"final x={final_x}, expected ~{distance}"
|
||||
|
||||
def test_track_time_increases(self):
|
||||
from utils.slide_utils import generate_slide_track
|
||||
|
||||
track = generate_slide_track(100, seed=3)
|
||||
for i in range(1, len(track)):
|
||||
assert track[i]["t"] >= track[i - 1]["t"]
|
||||
|
||||
def test_track_y_has_jitter(self):
|
||||
from utils.slide_utils import generate_slide_track
|
||||
|
||||
track = generate_slide_track(200, seed=4)
|
||||
y_vals = [pt["y"] for pt in track]
|
||||
# At least some y values should be non-zero (jitter)
|
||||
assert any(abs(y) > 0 for y in y_vals)
|
||||
Reference in New Issue
Block a user