""" 轻量 CRNN 模型 (Convolutional Recurrent Neural Network) 用于普通字符验证码和算式验证码的 OCR 识别。 两种模式通过不同的字符集和输入尺寸区分,共享同一网络架构。 架构: CNN 特征提取 → 序列映射 → BiLSTM → 全连接 → CTC 解码 CTC 输出长度 = 特征图宽度 (经过若干次宽度方向 pool 后) CTC blank 位于 index 0,字符从 index 1 开始映射。 - normal 模式: 输入 1×40×120, 字符集 30 字符, 体积 < 2MB - math 模式: 输入 1×40×160, 字符集 16 字符, 体积 < 2MB """ import torch import torch.nn as nn class LiteCRNN(nn.Module): """ 轻量 CRNN + CTC。 CNN 部分对高度做 4 次 pool (40→20→10→5→1 via AdaptivePool), 宽度做 2 次 pool (保留足够序列长度给 CTC)。 RNN 部分使用单层 BiLSTM。 """ def __init__(self, chars: str, img_h: int = 40, img_w: int = 120): """ Args: chars: 字符集字符串 (不含 CTC blank) img_h: 输入图片高度 img_w: 输入图片宽度 """ super().__init__() self.chars = chars self.img_h = img_h self.img_w = img_w # CTC 类别数 = 字符数 + 1 (blank at index 0) self.num_classes = len(chars) + 1 # ---- CNN 特征提取 ---- self.cnn = nn.Sequential( # block 1: 1 -> 32, H/2, W不变 nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变 # block 2: 32 -> 64, H/2, W/2 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # H/2, W/2 # block 3: 64 -> 128, H/2, W不变 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变 # block 4: 128 -> 128, H/2, W/2 nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # H/2, W/2 ) # 经过 4 次高度 pool 后高度 = img_h/16 (向下取整) # 例: 40→20→10→5→2 # 用固定 kernel 的 AvgPool 把剩余高度压到 1 (ONNX 兼容) self._pool_h = img_h // 16 # 40→2, 若有余数也向下取整 self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1)) # ---- RNN 序列建模 ---- self.rnn_input_size = 128 self.rnn_hidden = 96 self.rnn = nn.LSTM( input_size=self.rnn_input_size, hidden_size=self.rnn_hidden, num_layers=1, batch_first=True, bidirectional=True, ) # ---- 输出层 ---- self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, 1, H, W) 灰度图 Returns: logits: (seq_len, batch, num_classes) 即 CTC 所需的 (T, B, C) 格式 """ # CNN conv = self.cnn(x) # (B, 128, H', W') conv = self.height_pool(conv) # (B, 128, 1, W') conv = conv.squeeze(2) # (B, 128, W') conv = conv.permute(0, 2, 1) # (B, W', 128) — batch_first 序列 # RNN rnn_out, _ = self.rnn(conv) # (B, W', 256) # FC logits = self.fc(rnn_out) # (B, W', num_classes) logits = logits.permute(1, 0, 2) # (T, B, C) — CTC 格式 return logits @property def seq_len(self) -> int: """根据输入宽度计算 CTC 序列长度 (特征图宽度)。""" # 宽度经过 2 次 /2 的 pool return self.img_w // 4 # ---------------------------------------------------------- # CTC 贪心解码 # ---------------------------------------------------------- def greedy_decode(self, logits: torch.Tensor) -> list[str]: """ CTC 贪心解码。 Args: logits: (T, B, C) 模型原始输出 Returns: 解码后的字符串列表,长度 = batch size """ # (T, B, C) -> (B, T) preds = logits.argmax(dim=2).permute(1, 0) # (B, T) results = [] for pred in preds: chars = [] prev = -1 for idx in pred.tolist(): if idx != 0 and idx != prev: # 0 = blank chars.append(self.chars[idx - 1]) # 字符从 index 1 开始 prev = idx results.append("".join(chars)) return results