- Add 57 unit tests covering generators, models, and pipeline components - Implement FastAPI HTTP service (server.py) with POST /solve and GET /health - Add checkpoint resume (断点续训) to both CTC and regression training utils - Fix device mismatch bug in CTC training (targets/input_lengths on GPU) - Add pytest dev dependency to pyproject.toml - Update .gitignore with data/solver/, data/real/, *.log - Remove PyCharm template main.py - Update training/__init__.py docs for solver training scripts Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
178 lines
5.6 KiB
Python
178 lines
5.6 KiB
Python
"""
|
|
测试所有模型前向传播和输出形状。
|
|
|
|
每种模型构造 → forward → 验证输出 shape。
|
|
"""
|
|
|
|
import torch
|
|
import pytest
|
|
|
|
from config import NORMAL_CHARS, MATH_CHARS, THREED_CHARS, IMAGE_SIZE, SOLVER_CONFIG
|
|
from models.classifier import CaptchaClassifier
|
|
from models.lite_crnn import LiteCRNN
|
|
from models.threed_cnn import ThreeDCNN
|
|
from models.regression_cnn import RegressionCNN
|
|
from models.gap_detector import GapDetectorCNN
|
|
from models.rotation_regressor import RotationRegressor
|
|
|
|
|
|
class TestCaptchaClassifier:
|
|
def setup_method(self):
|
|
self.model = CaptchaClassifier(num_types=5)
|
|
self.model.eval()
|
|
|
|
def test_output_shape(self):
|
|
h, w = IMAGE_SIZE["classifier"]
|
|
x = torch.randn(2, 1, h, w)
|
|
out = self.model(x)
|
|
assert out.shape == (2, 5)
|
|
|
|
def test_single_batch(self):
|
|
h, w = IMAGE_SIZE["classifier"]
|
|
x = torch.randn(1, 1, h, w)
|
|
out = self.model(x)
|
|
assert out.shape == (1, 5)
|
|
|
|
def test_param_count_reasonable(self):
|
|
n = sum(p.numel() for p in self.model.parameters())
|
|
# Should be < 500KB ≈ 125K float32 params
|
|
assert n < 200_000, f"too many params: {n}"
|
|
|
|
|
|
class TestLiteCRNN:
|
|
def setup_method(self):
|
|
self.model = LiteCRNN(chars=NORMAL_CHARS)
|
|
self.model.eval()
|
|
|
|
def test_output_shape(self):
|
|
h, w = IMAGE_SIZE["normal"]
|
|
x = torch.randn(2, 1, h, w)
|
|
out = self.model(x)
|
|
num_classes = len(NORMAL_CHARS) + 1 # +1 for blank
|
|
seq_len = w // 4
|
|
assert out.shape == (seq_len, 2, num_classes)
|
|
|
|
def test_greedy_decode(self):
|
|
h, w = IMAGE_SIZE["normal"]
|
|
x = torch.randn(1, 1, h, w)
|
|
logits = self.model(x)
|
|
decoded = self.model.greedy_decode(logits)
|
|
assert isinstance(decoded, list)
|
|
assert len(decoded) == 1
|
|
assert isinstance(decoded[0], str)
|
|
|
|
def test_param_count_reasonable(self):
|
|
n = sum(p.numel() for p in self.model.parameters())
|
|
# Should be < 2MB ≈ 500K float32 params
|
|
assert n < 600_000, f"too many params: {n}"
|
|
|
|
def test_math_mode(self):
|
|
h, w = IMAGE_SIZE["math"]
|
|
model = LiteCRNN(chars=MATH_CHARS, img_h=h, img_w=w)
|
|
model.eval()
|
|
x = torch.randn(1, 1, h, w)
|
|
out = model(x)
|
|
num_classes = len(MATH_CHARS) + 1
|
|
seq_len = w // 4
|
|
assert out.shape == (seq_len, 1, num_classes)
|
|
|
|
|
|
class TestThreeDCNN:
|
|
def setup_method(self):
|
|
h, w = IMAGE_SIZE["3d_text"]
|
|
self.model = ThreeDCNN(chars=THREED_CHARS, img_h=h, img_w=w)
|
|
self.model.eval()
|
|
|
|
def test_output_shape(self):
|
|
h, w = IMAGE_SIZE["3d_text"]
|
|
x = torch.randn(2, 1, h, w)
|
|
out = self.model(x)
|
|
num_classes = len(THREED_CHARS) + 1
|
|
seq_len = w // 4
|
|
assert out.shape == (seq_len, 2, num_classes)
|
|
|
|
def test_greedy_decode(self):
|
|
h, w = IMAGE_SIZE["3d_text"]
|
|
x = torch.randn(1, 1, h, w)
|
|
logits = self.model(x)
|
|
decoded = self.model.greedy_decode(logits)
|
|
assert isinstance(decoded, list)
|
|
assert len(decoded) == 1
|
|
|
|
def test_param_count_reasonable(self):
|
|
n = sum(p.numel() for p in self.model.parameters())
|
|
# Should be < 5MB ≈ 1.25M float32 params
|
|
assert n < 1_500_000, f"too many params: {n}"
|
|
|
|
|
|
class TestRegressionCNN:
|
|
def test_3d_rotate_shape(self):
|
|
h, w = IMAGE_SIZE["3d_rotate"]
|
|
model = RegressionCNN(img_h=h, img_w=w)
|
|
model.eval()
|
|
x = torch.randn(2, 1, h, w)
|
|
out = model(x)
|
|
assert out.shape == (2, 1)
|
|
# Output should be sigmoid [0, 1]
|
|
assert out.min() >= 0.0
|
|
assert out.max() <= 1.0
|
|
|
|
def test_3d_slider_shape(self):
|
|
h, w = IMAGE_SIZE["3d_slider"]
|
|
model = RegressionCNN(img_h=h, img_w=w)
|
|
model.eval()
|
|
x = torch.randn(2, 1, h, w)
|
|
out = model(x)
|
|
assert out.shape == (2, 1)
|
|
|
|
def test_param_count_reasonable(self):
|
|
h, w = IMAGE_SIZE["3d_rotate"]
|
|
model = RegressionCNN(img_h=h, img_w=w)
|
|
n = sum(p.numel() for p in model.parameters())
|
|
# Should be ~1MB ≈ 250K float32 params
|
|
assert n < 400_000, f"too many params: {n}"
|
|
|
|
|
|
class TestGapDetectorCNN:
|
|
def setup_method(self):
|
|
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
|
self.model = GapDetectorCNN(img_h=h, img_w=w)
|
|
self.model.eval()
|
|
|
|
def test_output_shape(self):
|
|
h, w = SOLVER_CONFIG["slide"]["cnn_input_size"]
|
|
x = torch.randn(2, 1, h, w)
|
|
out = self.model(x)
|
|
assert out.shape == (2, 1)
|
|
assert out.min() >= 0.0
|
|
assert out.max() <= 1.0
|
|
|
|
def test_param_count_reasonable(self):
|
|
n = sum(p.numel() for p in self.model.parameters())
|
|
assert n < 400_000, f"too many params: {n}"
|
|
|
|
|
|
class TestRotationRegressor:
|
|
def setup_method(self):
|
|
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
|
self.model = RotationRegressor(img_h=h, img_w=w)
|
|
self.model.eval()
|
|
|
|
def test_output_shape(self):
|
|
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
|
x = torch.randn(2, 3, h, w) # RGB, 3 channels
|
|
out = self.model(x)
|
|
assert out.shape == (2, 2) # (sin, cos)
|
|
|
|
def test_output_range_tanh(self):
|
|
h, w = SOLVER_CONFIG["rotate"]["input_size"]
|
|
x = torch.randn(4, 3, h, w)
|
|
out = self.model(x)
|
|
assert out.min() >= -1.0
|
|
assert out.max() <= 1.0
|
|
|
|
def test_param_count_reasonable(self):
|
|
n = sum(p.numel() for p in self.model.parameters())
|
|
# Should be ~2MB ≈ 500K float32 params
|
|
assert n < 600_000, f"too many params: {n}"
|