Add tests, server, resume training, and project cleanup

- Add 57 unit tests covering generators, models, and pipeline components
- Implement FastAPI HTTP service (server.py) with POST /solve and GET /health
- Add checkpoint resume (断点续训) to both CTC and regression training utils
- Fix device mismatch bug in CTC training (targets/input_lengths on GPU)
- Add pytest dev dependency to pyproject.toml
- Update .gitignore with data/solver/, data/real/, *.log
- Remove PyCharm template main.py
- Update training/__init__.py docs for solver training scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Hua
2026-03-11 19:05:47 +08:00
parent 9b5f29083e
commit 788ddcae1a
11 changed files with 786 additions and 21 deletions

177
tests/test_models.py Normal file
View File

@@ -0,0 +1,177 @@
"""
测试所有模型前向传播和输出形状。
每种模型构造 → forward → 验证输出 shape。
"""
import torch
import pytest
from config import NORMAL_CHARS, MATH_CHARS, THREED_CHARS, IMAGE_SIZE, SOLVER_CONFIG
from models.classifier import CaptchaClassifier
from models.lite_crnn import LiteCRNN
from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
from models.gap_detector import GapDetectorCNN
from models.rotation_regressor import RotationRegressor
class TestCaptchaClassifier:
def setup_method(self):
self.model = CaptchaClassifier(num_types=5)
self.model.eval()
def test_output_shape(self):
h, w = IMAGE_SIZE["classifier"]
x = torch.randn(2, 1, h, w)
out = self.model(x)
assert out.shape == (2, 5)
def test_single_batch(self):
h, w = IMAGE_SIZE["classifier"]
x = torch.randn(1, 1, h, w)
out = self.model(x)
assert out.shape == (1, 5)
def test_param_count_reasonable(self):
n = sum(p.numel() for p in self.model.parameters())
# Should be < 500KB ≈ 125K float32 params
assert n < 200_000, f"too many params: {n}"
class TestLiteCRNN:
def setup_method(self):
self.model = LiteCRNN(chars=NORMAL_CHARS)
self.model.eval()
def test_output_shape(self):
h, w = IMAGE_SIZE["normal"]
x = torch.randn(2, 1, h, w)
out = self.model(x)
num_classes = len(NORMAL_CHARS) + 1 # +1 for blank
seq_len = w // 4
assert out.shape == (seq_len, 2, num_classes)
def test_greedy_decode(self):
h, w = IMAGE_SIZE["normal"]
x = torch.randn(1, 1, h, w)
logits = self.model(x)
decoded = self.model.greedy_decode(logits)
assert isinstance(decoded, list)
assert len(decoded) == 1
assert isinstance(decoded[0], str)
def test_param_count_reasonable(self):
n = sum(p.numel() for p in self.model.parameters())
# Should be < 2MB ≈ 500K float32 params
assert n < 600_000, f"too many params: {n}"
def test_math_mode(self):
h, w = IMAGE_SIZE["math"]
model = LiteCRNN(chars=MATH_CHARS, img_h=h, img_w=w)
model.eval()
x = torch.randn(1, 1, h, w)
out = model(x)
num_classes = len(MATH_CHARS) + 1
seq_len = w // 4
assert out.shape == (seq_len, 1, num_classes)
class TestThreeDCNN:
def setup_method(self):
h, w = IMAGE_SIZE["3d_text"]
self.model = ThreeDCNN(chars=THREED_CHARS, img_h=h, img_w=w)
self.model.eval()
def test_output_shape(self):
h, w = IMAGE_SIZE["3d_text"]
x = torch.randn(2, 1, h, w)
out = self.model(x)
num_classes = len(THREED_CHARS) + 1
seq_len = w // 4
assert out.shape == (seq_len, 2, num_classes)
def test_greedy_decode(self):
h, w = IMAGE_SIZE["3d_text"]
x = torch.randn(1, 1, h, w)
logits = self.model(x)
decoded = self.model.greedy_decode(logits)
assert isinstance(decoded, list)
assert len(decoded) == 1
def test_param_count_reasonable(self):
n = sum(p.numel() for p in self.model.parameters())
# Should be < 5MB ≈ 1.25M float32 params
assert n < 1_500_000, f"too many params: {n}"
class TestRegressionCNN:
def test_3d_rotate_shape(self):
h, w = IMAGE_SIZE["3d_rotate"]
model = RegressionCNN(img_h=h, img_w=w)
model.eval()
x = torch.randn(2, 1, h, w)
out = model(x)
assert out.shape == (2, 1)
# Output should be sigmoid [0, 1]
assert out.min() >= 0.0
assert out.max() <= 1.0
def test_3d_slider_shape(self):
h, w = IMAGE_SIZE["3d_slider"]
model = RegressionCNN(img_h=h, img_w=w)
model.eval()
x = torch.randn(2, 1, h, w)
out = model(x)
assert out.shape == (2, 1)
def test_param_count_reasonable(self):
h, w = IMAGE_SIZE["3d_rotate"]
model = RegressionCNN(img_h=h, img_w=w)
n = sum(p.numel() for p in model.parameters())
# Should be ~1MB ≈ 250K float32 params
assert n < 400_000, f"too many params: {n}"
class TestGapDetectorCNN:
def setup_method(self):
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
self.model = GapDetectorCNN(img_h=h, img_w=w)
self.model.eval()
def test_output_shape(self):
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
x = torch.randn(2, 1, h, w)
out = self.model(x)
assert out.shape == (2, 1)
assert out.min() >= 0.0
assert out.max() <= 1.0
def test_param_count_reasonable(self):
n = sum(p.numel() for p in self.model.parameters())
assert n < 400_000, f"too many params: {n}"
class TestRotationRegressor:
def setup_method(self):
h, w = SOLVER_CONFIG["rotate"]["input_size"]
self.model = RotationRegressor(img_h=h, img_w=w)
self.model.eval()
def test_output_shape(self):
h, w = SOLVER_CONFIG["rotate"]["input_size"]
x = torch.randn(2, 3, h, w) # RGB, 3 channels
out = self.model(x)
assert out.shape == (2, 2) # (sin, cos)
def test_output_range_tanh(self):
h, w = SOLVER_CONFIG["rotate"]["input_size"]
x = torch.randn(4, 3, h, w)
out = self.model(x)
assert out.min() >= -1.0
assert out.max() <= 1.0
def test_param_count_reasonable(self):
n = sum(p.numel() for p in self.model.parameters())
# Should be ~2MB ≈ 500K float32 params
assert n < 600_000, f"too many params: {n}"