AdaptiveAvgPool2d with None dimensions can cause issues with some ONNX runtimes. Use AvgPool2d with kernel=(img_h//16, 1) to achieve the same height-to-1 reduction with full ONNX compatibility. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
144 lines
4.8 KiB
Python
144 lines
4.8 KiB
Python
"""
|
||
轻量 CRNN 模型 (Convolutional Recurrent Neural Network)
|
||
|
||
用于普通字符验证码和算式验证码的 OCR 识别。
|
||
两种模式通过不同的字符集和输入尺寸区分,共享同一网络架构。
|
||
|
||
架构: CNN 特征提取 → 序列映射 → BiLSTM → 全连接 → CTC 解码
|
||
CTC 输出长度 = 特征图宽度 (经过若干次宽度方向 pool 后)
|
||
CTC blank 位于 index 0,字符从 index 1 开始映射。
|
||
|
||
- normal 模式: 输入 1×40×120, 字符集 30 字符, 体积 < 2MB
|
||
- math 模式: 输入 1×40×160, 字符集 16 字符, 体积 < 2MB
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class LiteCRNN(nn.Module):
|
||
"""
|
||
轻量 CRNN + CTC。
|
||
|
||
CNN 部分对高度做 4 次 pool (40→20→10→5→1 via AdaptivePool),
|
||
宽度做 2 次 pool (保留足够序列长度给 CTC)。
|
||
RNN 部分使用单层 BiLSTM。
|
||
"""
|
||
|
||
def __init__(self, chars: str, img_h: int = 40, img_w: int = 120):
|
||
"""
|
||
Args:
|
||
chars: 字符集字符串 (不含 CTC blank)
|
||
img_h: 输入图片高度
|
||
img_w: 输入图片宽度
|
||
"""
|
||
super().__init__()
|
||
self.chars = chars
|
||
self.img_h = img_h
|
||
self.img_w = img_w
|
||
# CTC 类别数 = 字符数 + 1 (blank at index 0)
|
||
self.num_classes = len(chars) + 1
|
||
|
||
# ---- CNN 特征提取 ----
|
||
self.cnn = nn.Sequential(
|
||
# block 1: 1 -> 32, H/2, W不变
|
||
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(32),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
|
||
|
||
# block 2: 32 -> 64, H/2, W/2
|
||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(64),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2), # H/2, W/2
|
||
|
||
# block 3: 64 -> 128, H/2, W不变
|
||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(128),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
|
||
|
||
# block 4: 128 -> 128, H/2, W/2
|
||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(128),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2), # H/2, W/2
|
||
)
|
||
|
||
# 经过 4 次高度 pool 后高度 = img_h/16 (向下取整)
|
||
# 例: 40→20→10→5→2
|
||
# 用固定 kernel 的 AvgPool 把剩余高度压到 1 (ONNX 兼容)
|
||
self._pool_h = img_h // 16 # 40→2, 若有余数也向下取整
|
||
self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1))
|
||
|
||
# ---- RNN 序列建模 ----
|
||
self.rnn_input_size = 128
|
||
self.rnn_hidden = 96
|
||
self.rnn = nn.LSTM(
|
||
input_size=self.rnn_input_size,
|
||
hidden_size=self.rnn_hidden,
|
||
num_layers=1,
|
||
batch_first=True,
|
||
bidirectional=True,
|
||
)
|
||
|
||
# ---- 输出层 ----
|
||
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x: (batch, 1, H, W) 灰度图
|
||
|
||
Returns:
|
||
logits: (seq_len, batch, num_classes)
|
||
即 CTC 所需的 (T, B, C) 格式
|
||
"""
|
||
# CNN
|
||
conv = self.cnn(x) # (B, 128, H', W')
|
||
conv = self.height_pool(conv) # (B, 128, 1, W')
|
||
conv = conv.squeeze(2) # (B, 128, W')
|
||
conv = conv.permute(0, 2, 1) # (B, W', 128) — batch_first 序列
|
||
|
||
# RNN
|
||
rnn_out, _ = self.rnn(conv) # (B, W', 256)
|
||
|
||
# FC
|
||
logits = self.fc(rnn_out) # (B, W', num_classes)
|
||
logits = logits.permute(1, 0, 2) # (T, B, C) — CTC 格式
|
||
|
||
return logits
|
||
|
||
@property
|
||
def seq_len(self) -> int:
|
||
"""根据输入宽度计算 CTC 序列长度 (特征图宽度)。"""
|
||
# 宽度经过 2 次 /2 的 pool
|
||
return self.img_w // 4
|
||
|
||
# ----------------------------------------------------------
|
||
# CTC 贪心解码
|
||
# ----------------------------------------------------------
|
||
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
|
||
"""
|
||
CTC 贪心解码。
|
||
|
||
Args:
|
||
logits: (T, B, C) 模型原始输出
|
||
|
||
Returns:
|
||
解码后的字符串列表,长度 = batch size
|
||
"""
|
||
# (T, B, C) -> (B, T)
|
||
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
|
||
results = []
|
||
for pred in preds:
|
||
chars = []
|
||
prev = -1
|
||
for idx in pred.tolist():
|
||
if idx != 0 and idx != prev: # 0 = blank
|
||
chars.append(self.chars[idx - 1]) # 字符从 index 1 开始
|
||
prev = idx
|
||
results.append("".join(chars))
|
||
return results
|