""" 旋转角度回归模型 (RotationRegressor) 用于预测旋转验证码的正确旋转角度。 使用 sin/cos 编码避免 0°/360° 边界问题。 RGB 输入,输出 (sin θ, cos θ) ∈ [-1,1]。 架构: Conv(3→32) + BN + ReLU + Pool Conv(32→64) + BN + ReLU + Pool Conv(64→128) + BN + ReLU + Pool Conv(128→256) + BN + ReLU + Pool AdaptiveAvgPool2d(1) → FC(256→128) → ReLU → FC(128→2) → Tanh 约 400K 参数,~2MB。 """ import torch import torch.nn as nn class RotationRegressor(nn.Module): """ 旋转角度回归模型。 RGB 输入 3x128x128,输出 (sin θ, cos θ)。 推理时用 atan2(sin, cos) 转换为角度。 """ def __init__(self, img_h: int = 128, img_w: int = 128): super().__init__() self.img_h = img_h self.img_w = img_w self.features = nn.Sequential( # block 1: 3 → 32, H/2, W/2 nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # block 2: 32 → 64, H/4, W/4 nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # block 3: 64 → 128, H/8, W/8 nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # block 4: 128 → 256, H/16, W/16 nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), ) self.pool = nn.AdaptiveAvgPool2d(1) self.regressor = nn.Sequential( nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Linear(128, 2), nn.Tanh(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, 3, H, W) RGB 图 Returns: output: (batch, 2) → (sin θ, cos θ) ∈ [-1, 1] """ feat = self.features(x) feat = self.pool(feat) # (B, 256, 1, 1) feat = feat.flatten(1) # (B, 256) out = self.regressor(feat) # (B, 2) return out