Files
CaptchBreaker/models/threed_cnn.py
Hua 90d6423551 Replace AdaptiveAvgPool2d with fixed-kernel AvgPool2d for ONNX compatibility
AdaptiveAvgPool2d with None dimensions can cause issues with some ONNX
runtimes. Use AvgPool2d with kernel=(img_h//16, 1) to achieve the same
height-to-1 reduction with full ONNX compatibility.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 13:58:41 +08:00

158 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
3D 立体验证码专用模型
采用更深的 CNN backbone类 ResNet 残差块)+ CRNN 序列建模,
以更强的特征提取能力处理透视变形和阴影效果。
架构: ResNet-lite backbone → 自适应池化 → BiLSTM → FC → CTC
输入: 灰度图 1×60×160
体积目标: < 5MB
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
"""简化残差块: Conv-BN-ReLU-Conv-BN + shortcut。"""
def __init__(self, channels: int):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
out = self.bn2(self.conv2(out))
out = F.relu(out + residual, inplace=True)
return out
class ThreeDCNN(nn.Module):
"""
3D 验证码识别专用模型。
backbone 使用 5 层卷积(含 2 个残差块),通道数逐步增长:
1 → 32 → 64 → 64(res) → 128 → 128(res)
高度通过 pool 压缩后再用自适应池化归一,宽度保留序列长度。
之后接 BiLSTM + FC 做 CTC 序列输出。
"""
def __init__(self, chars: str, img_h: int = 60, img_w: int = 160):
"""
Args:
chars: 字符集字符串 (不含 CTC blank)
img_h: 输入图片高度
img_w: 输入图片宽度
"""
super().__init__()
self.chars = chars
self.img_h = img_h
self.img_w = img_w
self.num_classes = len(chars) + 1 # +1 for CTC blank
# ---- ResNet-lite backbone ----
self.backbone = nn.Sequential(
# stage 1: 1 -> 32, H/2, W不变
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)),
# stage 2: 32 -> 64, H/2, W/2
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# stage 3: 残差块 64 -> 64
ResidualBlock(64),
# stage 4: 64 -> 128, H/2, W/2
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# stage 5: 残差块 128 -> 128
ResidualBlock(128),
# stage 6: 128 -> 128, H/2, W不变
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1)),
)
# 高度方向压到 1宽度保持 (用固定 kernel 替代 AdaptiveONNX 兼容)
# 60→30→15→7→3 (4次高度pool后)
self._pool_h = img_h // 16 # 60//16=3
self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1))
# ---- RNN 序列建模 ----
self.rnn_input_size = 128
self.rnn_hidden = 128
self.rnn = nn.LSTM(
input_size=self.rnn_input_size,
hidden_size=self.rnn_hidden,
num_layers=2,
batch_first=True,
bidirectional=True,
dropout=0.2,
)
# ---- 输出层 ----
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, H, W) 灰度图
Returns:
logits: (seq_len, batch, num_classes) CTC 格式 (T, B, C)
"""
conv = self.backbone(x) # (B, 128, H', W')
conv = self.height_pool(conv) # (B, 128, 1, W')
conv = conv.squeeze(2) # (B, 128, W')
conv = conv.permute(0, 2, 1) # (B, W', 128)
rnn_out, _ = self.rnn(conv) # (B, W', 256)
logits = self.fc(rnn_out) # (B, W', num_classes)
logits = logits.permute(1, 0, 2) # (T, B, C)
return logits
@property
def seq_len(self) -> int:
"""CTC 序列长度 = 输入宽度经过 2 次 W/2 pool 后的宽度。"""
return self.img_w // 4
# ----------------------------------------------------------
# CTC 贪心解码
# ----------------------------------------------------------
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
"""
CTC 贪心解码。
Args:
logits: (T, B, C) 模型原始输出
Returns:
解码后的字符串列表
"""
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
results = []
for pred in preds:
chars = []
prev = -1
for idx in pred.tolist():
if idx != 0 and idx != prev:
chars.append(self.chars[idx - 1])
prev = idx
results.append("".join(chars))
return results