""" 测试所有模型前向传播和输出形状。 每种模型构造 → forward → 验证输出 shape。 """ import torch import pytest from config import NORMAL_CHARS, MATH_CHARS, THREED_CHARS, IMAGE_SIZE, SOLVER_CONFIG from models.classifier import CaptchaClassifier from models.lite_crnn import LiteCRNN from models.threed_cnn import ThreeDCNN from models.regression_cnn import RegressionCNN from models.gap_detector import GapDetectorCNN from models.rotation_regressor import RotationRegressor class TestCaptchaClassifier: def setup_method(self): self.model = CaptchaClassifier(num_types=5) self.model.eval() def test_output_shape(self): h, w = IMAGE_SIZE["classifier"] x = torch.randn(2, 1, h, w) out = self.model(x) assert out.shape == (2, 5) def test_single_batch(self): h, w = IMAGE_SIZE["classifier"] x = torch.randn(1, 1, h, w) out = self.model(x) assert out.shape == (1, 5) def test_param_count_reasonable(self): n = sum(p.numel() for p in self.model.parameters()) # Should be < 500KB ≈ 125K float32 params assert n < 200_000, f"too many params: {n}" class TestLiteCRNN: def setup_method(self): self.model = LiteCRNN(chars=NORMAL_CHARS) self.model.eval() def test_output_shape(self): h, w = IMAGE_SIZE["normal"] x = torch.randn(2, 1, h, w) out = self.model(x) num_classes = len(NORMAL_CHARS) + 1 # +1 for blank seq_len = w // 4 assert out.shape == (seq_len, 2, num_classes) def test_greedy_decode(self): h, w = IMAGE_SIZE["normal"] x = torch.randn(1, 1, h, w) logits = self.model(x) decoded = self.model.greedy_decode(logits) assert isinstance(decoded, list) assert len(decoded) == 1 assert isinstance(decoded[0], str) def test_param_count_reasonable(self): n = sum(p.numel() for p in self.model.parameters()) # Should be < 2MB ≈ 500K float32 params assert n < 600_000, f"too many params: {n}" def test_math_mode(self): h, w = IMAGE_SIZE["math"] model = LiteCRNN(chars=MATH_CHARS, img_h=h, img_w=w) model.eval() x = torch.randn(1, 1, h, w) out = model(x) num_classes = len(MATH_CHARS) + 1 seq_len = w // 4 assert out.shape == (seq_len, 1, num_classes) class TestThreeDCNN: def setup_method(self): h, w = IMAGE_SIZE["3d_text"] self.model = ThreeDCNN(chars=THREED_CHARS, img_h=h, img_w=w) self.model.eval() def test_output_shape(self): h, w = IMAGE_SIZE["3d_text"] x = torch.randn(2, 1, h, w) out = self.model(x) num_classes = len(THREED_CHARS) + 1 seq_len = w // 4 assert out.shape == (seq_len, 2, num_classes) def test_greedy_decode(self): h, w = IMAGE_SIZE["3d_text"] x = torch.randn(1, 1, h, w) logits = self.model(x) decoded = self.model.greedy_decode(logits) assert isinstance(decoded, list) assert len(decoded) == 1 def test_param_count_reasonable(self): n = sum(p.numel() for p in self.model.parameters()) # Should be < 5MB ≈ 1.25M float32 params assert n < 1_500_000, f"too many params: {n}" class TestRegressionCNN: def test_3d_rotate_shape(self): h, w = IMAGE_SIZE["3d_rotate"] model = RegressionCNN(img_h=h, img_w=w) model.eval() x = torch.randn(2, 1, h, w) out = model(x) assert out.shape == (2, 1) # Output should be sigmoid [0, 1] assert out.min() >= 0.0 assert out.max() <= 1.0 def test_3d_slider_shape(self): h, w = IMAGE_SIZE["3d_slider"] model = RegressionCNN(img_h=h, img_w=w) model.eval() x = torch.randn(2, 1, h, w) out = model(x) assert out.shape == (2, 1) def test_param_count_reasonable(self): h, w = IMAGE_SIZE["3d_rotate"] model = RegressionCNN(img_h=h, img_w=w) n = sum(p.numel() for p in model.parameters()) # Should be ~1MB ≈ 250K float32 params assert n < 400_000, f"too many params: {n}" class TestGapDetectorCNN: def setup_method(self): h, w = SOLVER_CONFIG["slide"]["cnn_input_size"] self.model = GapDetectorCNN(img_h=h, img_w=w) self.model.eval() def test_output_shape(self): h, w = SOLVER_CONFIG["slide"]["cnn_input_size"] x = torch.randn(2, 1, h, w) out = self.model(x) assert out.shape == (2, 1) assert out.min() >= 0.0 assert out.max() <= 1.0 def test_param_count_reasonable(self): n = sum(p.numel() for p in self.model.parameters()) assert n < 400_000, f"too many params: {n}" class TestRotationRegressor: def setup_method(self): h, w = SOLVER_CONFIG["rotate"]["input_size"] self.model = RotationRegressor(img_h=h, img_w=w) self.model.eval() def test_output_shape(self): h, w = SOLVER_CONFIG["rotate"]["input_size"] x = torch.randn(2, 3, h, w) # RGB, 3 channels out = self.model(x) assert out.shape == (2, 2) # (sin, cos) def test_output_range_tanh(self): h, w = SOLVER_CONFIG["rotate"]["input_size"] x = torch.randn(4, 3, h, w) out = self.model(x) assert out.min() >= -1.0 assert out.max() <= 1.0 def test_param_count_reasonable(self): n = sum(p.numel() for p in self.model.parameters()) # Should be ~2MB ≈ 500K float32 params assert n < 600_000, f"too many params: {n}"