""" 3D 立体验证码专用模型 采用更深的 CNN backbone(类 ResNet 残差块)+ CRNN 序列建模, 以更强的特征提取能力处理透视变形和阴影效果。 架构: ResNet-lite backbone → 自适应池化 → BiLSTM → FC → CTC 输入: 灰度图 1×60×160 体积目标: < 5MB """ import torch import torch.nn as nn import torch.nn.functional as F class ResidualBlock(nn.Module): """简化残差块: Conv-BN-ReLU-Conv-BN + shortcut。""" def __init__(self, channels: int): super().__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x out = F.relu(self.bn1(self.conv1(x)), inplace=True) out = self.bn2(self.conv2(out)) out = F.relu(out + residual, inplace=True) return out class ThreeDCNN(nn.Module): """ 3D 验证码识别专用模型。 backbone 使用 5 层卷积(含 2 个残差块),通道数逐步增长: 1 → 32 → 64 → 64(res) → 128 → 128(res) 高度通过 pool 压缩后再用自适应池化归一,宽度保留序列长度。 之后接 BiLSTM + FC 做 CTC 序列输出。 """ def __init__(self, chars: str, img_h: int = 60, img_w: int = 160): """ Args: chars: 字符集字符串 (不含 CTC blank) img_h: 输入图片高度 img_w: 输入图片宽度 """ super().__init__() self.chars = chars self.img_h = img_h self.img_w = img_w self.num_classes = len(chars) + 1 # +1 for CTC blank # ---- ResNet-lite backbone ---- self.backbone = nn.Sequential( # stage 1: 1 -> 32, H/2, W不变 nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 1)), # stage 2: 32 -> 64, H/2, W/2 nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # stage 3: 残差块 64 -> 64 ResidualBlock(64), # stage 4: 64 -> 128, H/2, W/2 nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # stage 5: 残差块 128 -> 128 ResidualBlock(128), # stage 6: 128 -> 128, H/2, W不变 nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 1)), ) # 高度方向压到 1,宽度保持 (用固定 kernel 替代 Adaptive,ONNX 兼容) # 60→30→15→7→3 (4次高度pool后) self._pool_h = img_h // 16 # 60//16=3 self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1)) # ---- RNN 序列建模 ---- self.rnn_input_size = 128 self.rnn_hidden = 128 self.rnn = nn.LSTM( input_size=self.rnn_input_size, hidden_size=self.rnn_hidden, num_layers=2, batch_first=True, bidirectional=True, dropout=0.2, ) # ---- 输出层 ---- self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, 1, H, W) 灰度图 Returns: logits: (seq_len, batch, num_classes) CTC 格式 (T, B, C) """ conv = self.backbone(x) # (B, 128, H', W') conv = self.height_pool(conv) # (B, 128, 1, W') conv = conv.squeeze(2) # (B, 128, W') conv = conv.permute(0, 2, 1) # (B, W', 128) rnn_out, _ = self.rnn(conv) # (B, W', 256) logits = self.fc(rnn_out) # (B, W', num_classes) logits = logits.permute(1, 0, 2) # (T, B, C) return logits @property def seq_len(self) -> int: """CTC 序列长度 = 输入宽度经过 2 次 W/2 pool 后的宽度。""" return self.img_w // 4 # ---------------------------------------------------------- # CTC 贪心解码 # ---------------------------------------------------------- def greedy_decode(self, logits: torch.Tensor) -> list[str]: """ CTC 贪心解码。 Args: logits: (T, B, C) 模型原始输出 Returns: 解码后的字符串列表 """ preds = logits.argmax(dim=2).permute(1, 0) # (B, T) results = [] for pred in preds: chars = [] prev = -1 for idx in pred.tolist(): if idx != 0 and idx != prev: chars.append(self.chars[idx - 1]) prev = idx results.append("".join(chars)) return results