Files
CaptchBreaker/models/lite_crnn.py
Hua 90d6423551 Replace AdaptiveAvgPool2d with fixed-kernel AvgPool2d for ONNX compatibility
AdaptiveAvgPool2d with None dimensions can cause issues with some ONNX
runtimes. Use AvgPool2d with kernel=(img_h//16, 1) to achieve the same
height-to-1 reduction with full ONNX compatibility.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 13:58:41 +08:00

144 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
轻量 CRNN 模型 (Convolutional Recurrent Neural Network)
用于普通字符验证码和算式验证码的 OCR 识别。
两种模式通过不同的字符集和输入尺寸区分,共享同一网络架构。
架构: CNN 特征提取 → 序列映射 → BiLSTM → 全连接 → CTC 解码
CTC 输出长度 = 特征图宽度 (经过若干次宽度方向 pool 后)
CTC blank 位于 index 0字符从 index 1 开始映射。
- normal 模式: 输入 1×40×120, 字符集 30 字符, 体积 < 2MB
- math 模式: 输入 1×40×160, 字符集 16 字符, 体积 < 2MB
"""
import torch
import torch.nn as nn
class LiteCRNN(nn.Module):
"""
轻量 CRNN + CTC。
CNN 部分对高度做 4 次 pool (40→20→10→5→1 via AdaptivePool)
宽度做 2 次 pool (保留足够序列长度给 CTC)。
RNN 部分使用单层 BiLSTM。
"""
def __init__(self, chars: str, img_h: int = 40, img_w: int = 120):
"""
Args:
chars: 字符集字符串 (不含 CTC blank)
img_h: 输入图片高度
img_w: 输入图片宽度
"""
super().__init__()
self.chars = chars
self.img_h = img_h
self.img_w = img_w
# CTC 类别数 = 字符数 + 1 (blank at index 0)
self.num_classes = len(chars) + 1
# ---- CNN 特征提取 ----
self.cnn = nn.Sequential(
# block 1: 1 -> 32, H/2, W不变
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
# block 2: 32 -> 64, H/2, W/2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # H/2, W/2
# block 3: 64 -> 128, H/2, W不变
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
# block 4: 128 -> 128, H/2, W/2
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # H/2, W/2
)
# 经过 4 次高度 pool 后高度 = img_h/16 (向下取整)
# 例: 40→20→10→5→2
# 用固定 kernel 的 AvgPool 把剩余高度压到 1 (ONNX 兼容)
self._pool_h = img_h // 16 # 40→2, 若有余数也向下取整
self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1))
# ---- RNN 序列建模 ----
self.rnn_input_size = 128
self.rnn_hidden = 96
self.rnn = nn.LSTM(
input_size=self.rnn_input_size,
hidden_size=self.rnn_hidden,
num_layers=1,
batch_first=True,
bidirectional=True,
)
# ---- 输出层 ----
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, H, W) 灰度图
Returns:
logits: (seq_len, batch, num_classes)
即 CTC 所需的 (T, B, C) 格式
"""
# CNN
conv = self.cnn(x) # (B, 128, H', W')
conv = self.height_pool(conv) # (B, 128, 1, W')
conv = conv.squeeze(2) # (B, 128, W')
conv = conv.permute(0, 2, 1) # (B, W', 128) — batch_first 序列
# RNN
rnn_out, _ = self.rnn(conv) # (B, W', 256)
# FC
logits = self.fc(rnn_out) # (B, W', num_classes)
logits = logits.permute(1, 0, 2) # (T, B, C) — CTC 格式
return logits
@property
def seq_len(self) -> int:
"""根据输入宽度计算 CTC 序列长度 (特征图宽度)。"""
# 宽度经过 2 次 /2 的 pool
return self.img_w // 4
# ----------------------------------------------------------
# CTC 贪心解码
# ----------------------------------------------------------
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
"""
CTC 贪心解码。
Args:
logits: (T, B, C) 模型原始输出
Returns:
解码后的字符串列表,长度 = batch size
"""
# (T, B, C) -> (B, T)
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
results = []
for pred in preds:
chars = []
prev = -1
for idx in pred.tolist():
if idx != 0 and idx != prev: # 0 = blank
chars.append(self.chars[idx - 1]) # 字符从 index 1 开始
prev = idx
results.append("".join(chars))
return results