Initialize repository
This commit is contained in:
18
models/__init__.py
Normal file
18
models/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
模型定义包
|
||||
|
||||
提供三种模型:
|
||||
- CaptchaClassifier: 调度分类器 (轻量 CNN, < 500KB)
|
||||
- LiteCRNN: 轻量 CRNN (普通字符 + 算式, < 2MB)
|
||||
- ThreeDCNN: 3D 验证码专用模型 (ResNet-lite + BiLSTM, < 5MB)
|
||||
"""
|
||||
|
||||
from models.classifier import CaptchaClassifier
|
||||
from models.lite_crnn import LiteCRNN
|
||||
from models.threed_cnn import ThreeDCNN
|
||||
|
||||
__all__ = [
|
||||
"CaptchaClassifier",
|
||||
"LiteCRNN",
|
||||
"ThreeDCNN",
|
||||
]
|
||||
72
models/classifier.py
Normal file
72
models/classifier.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
调度分类器模型
|
||||
|
||||
轻量 CNN 分类器,用于判断验证码类型 (normal / math / 3d)。
|
||||
不同类型验证码视觉差异大,分类任务简单。
|
||||
|
||||
架构: 4 层卷积 + GAP + FC
|
||||
输入: 灰度图 1×64×128
|
||||
输出: softmax 概率分布 (num_types 个类别)
|
||||
体积目标: < 500KB
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CaptchaClassifier(nn.Module):
|
||||
"""
|
||||
轻量分类器。
|
||||
|
||||
4 层卷积 (每层 Conv + BN + ReLU + MaxPool)
|
||||
→ 全局平均池化 → 全连接 → 输出类别数。
|
||||
"""
|
||||
|
||||
def __init__(self, num_types: int = 3):
|
||||
super().__init__()
|
||||
self.num_types = num_types
|
||||
|
||||
self.features = nn.Sequential(
|
||||
# block 1: 1 -> 16, 64x128 -> 32x64
|
||||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(16),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 2: 16 -> 32, 32x64 -> 16x32
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 3: 32 -> 64, 16x32 -> 8x16
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# block 4: 64 -> 64, 8x16 -> 4x8
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
)
|
||||
|
||||
# 全局平均池化 → 输出 (batch, 64, 1, 1)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
|
||||
self.classifier = nn.Linear(64, num_types)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, 64, 128) 灰度图
|
||||
|
||||
Returns:
|
||||
logits: (batch, num_types) 未经 softmax 的原始输出
|
||||
"""
|
||||
x = self.features(x)
|
||||
x = self.gap(x) # (B, 64, 1, 1)
|
||||
x = x.view(x.size(0), -1) # (B, 64)
|
||||
x = self.classifier(x) # (B, num_types)
|
||||
return x
|
||||
141
models/lite_crnn.py
Normal file
141
models/lite_crnn.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
轻量 CRNN 模型 (Convolutional Recurrent Neural Network)
|
||||
|
||||
用于普通字符验证码和算式验证码的 OCR 识别。
|
||||
两种模式通过不同的字符集和输入尺寸区分,共享同一网络架构。
|
||||
|
||||
架构: CNN 特征提取 → 序列映射 → BiLSTM → 全连接 → CTC 解码
|
||||
CTC 输出长度 = 特征图宽度 (经过若干次宽度方向 pool 后)
|
||||
CTC blank 位于 index 0,字符从 index 1 开始映射。
|
||||
|
||||
- normal 模式: 输入 1×40×120, 字符集 30 字符, 体积 < 2MB
|
||||
- math 模式: 输入 1×40×160, 字符集 16 字符, 体积 < 2MB
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LiteCRNN(nn.Module):
|
||||
"""
|
||||
轻量 CRNN + CTC。
|
||||
|
||||
CNN 部分对高度做 4 次 pool (40→20→10→5→1 via AdaptivePool),
|
||||
宽度做 2 次 pool (保留足够序列长度给 CTC)。
|
||||
RNN 部分使用单层 BiLSTM。
|
||||
"""
|
||||
|
||||
def __init__(self, chars: str, img_h: int = 40, img_w: int = 120):
|
||||
"""
|
||||
Args:
|
||||
chars: 字符集字符串 (不含 CTC blank)
|
||||
img_h: 输入图片高度
|
||||
img_w: 输入图片宽度
|
||||
"""
|
||||
super().__init__()
|
||||
self.chars = chars
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
# CTC 类别数 = 字符数 + 1 (blank at index 0)
|
||||
self.num_classes = len(chars) + 1
|
||||
|
||||
# ---- CNN 特征提取 ----
|
||||
self.cnn = nn.Sequential(
|
||||
# block 1: 1 -> 32, H/2, W不变
|
||||
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
|
||||
|
||||
# block 2: 32 -> 64, H/2, W/2
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2), # H/2, W/2
|
||||
|
||||
# block 3: 64 -> 128, H/2, W不变
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)), # H/2, W不变
|
||||
|
||||
# block 4: 128 -> 128, H/2, W/2
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2), # H/2, W/2
|
||||
)
|
||||
|
||||
# 经过 4 次高度 pool: img_h / 16 (如 40 → 2, 不够整除时用自适应)
|
||||
# 用 AdaptiveAvgPool 把高度压到 1
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None)) # (B, 128, 1, W')
|
||||
|
||||
# ---- RNN 序列建模 ----
|
||||
self.rnn_input_size = 128
|
||||
self.rnn_hidden = 96
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=self.rnn_input_size,
|
||||
hidden_size=self.rnn_hidden,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
# ---- 输出层 ----
|
||||
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, H, W) 灰度图
|
||||
|
||||
Returns:
|
||||
logits: (seq_len, batch, num_classes)
|
||||
即 CTC 所需的 (T, B, C) 格式
|
||||
"""
|
||||
# CNN
|
||||
conv = self.cnn(x) # (B, 128, H', W')
|
||||
conv = self.adaptive_pool(conv) # (B, 128, 1, W')
|
||||
conv = conv.squeeze(2) # (B, 128, W')
|
||||
conv = conv.permute(0, 2, 1) # (B, W', 128) — batch_first 序列
|
||||
|
||||
# RNN
|
||||
rnn_out, _ = self.rnn(conv) # (B, W', 256)
|
||||
|
||||
# FC
|
||||
logits = self.fc(rnn_out) # (B, W', num_classes)
|
||||
logits = logits.permute(1, 0, 2) # (T, B, C) — CTC 格式
|
||||
|
||||
return logits
|
||||
|
||||
@property
|
||||
def seq_len(self) -> int:
|
||||
"""根据输入宽度计算 CTC 序列长度 (特征图宽度)。"""
|
||||
# 宽度经过 2 次 /2 的 pool
|
||||
return self.img_w // 4
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# CTC 贪心解码
|
||||
# ----------------------------------------------------------
|
||||
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
|
||||
"""
|
||||
CTC 贪心解码。
|
||||
|
||||
Args:
|
||||
logits: (T, B, C) 模型原始输出
|
||||
|
||||
Returns:
|
||||
解码后的字符串列表,长度 = batch size
|
||||
"""
|
||||
# (T, B, C) -> (B, T)
|
||||
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
|
||||
results = []
|
||||
for pred in preds:
|
||||
chars = []
|
||||
prev = -1
|
||||
for idx in pred.tolist():
|
||||
if idx != 0 and idx != prev: # 0 = blank
|
||||
chars.append(self.chars[idx - 1]) # 字符从 index 1 开始
|
||||
prev = idx
|
||||
results.append("".join(chars))
|
||||
return results
|
||||
155
models/threed_cnn.py
Normal file
155
models/threed_cnn.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
3D 立体验证码专用模型
|
||||
|
||||
采用更深的 CNN backbone(类 ResNet 残差块)+ CRNN 序列建模,
|
||||
以更强的特征提取能力处理透视变形和阴影效果。
|
||||
|
||||
架构: ResNet-lite backbone → 自适应池化 → BiLSTM → FC → CTC
|
||||
输入: 灰度图 1×60×160
|
||||
体积目标: < 5MB
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""简化残差块: Conv-BN-ReLU-Conv-BN + shortcut。"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(channels)
|
||||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = F.relu(out + residual, inplace=True)
|
||||
return out
|
||||
|
||||
|
||||
class ThreeDCNN(nn.Module):
|
||||
"""
|
||||
3D 验证码识别专用模型。
|
||||
|
||||
backbone 使用 5 层卷积(含 2 个残差块),通道数逐步增长:
|
||||
1 → 32 → 64 → 64(res) → 128 → 128(res)
|
||||
高度通过 pool 压缩后再用自适应池化归一,宽度保留序列长度。
|
||||
之后接 BiLSTM + FC 做 CTC 序列输出。
|
||||
"""
|
||||
|
||||
def __init__(self, chars: str, img_h: int = 60, img_w: int = 160):
|
||||
"""
|
||||
Args:
|
||||
chars: 字符集字符串 (不含 CTC blank)
|
||||
img_h: 输入图片高度
|
||||
img_w: 输入图片宽度
|
||||
"""
|
||||
super().__init__()
|
||||
self.chars = chars
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
self.num_classes = len(chars) + 1 # +1 for CTC blank
|
||||
|
||||
# ---- ResNet-lite backbone ----
|
||||
self.backbone = nn.Sequential(
|
||||
# stage 1: 1 -> 32, H/2, W不变
|
||||
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)),
|
||||
|
||||
# stage 2: 32 -> 64, H/2, W/2
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# stage 3: 残差块 64 -> 64
|
||||
ResidualBlock(64),
|
||||
|
||||
# stage 4: 64 -> 128, H/2, W/2
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
# stage 5: 残差块 128 -> 128
|
||||
ResidualBlock(128),
|
||||
|
||||
# stage 6: 128 -> 128, H/2, W不变
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1)),
|
||||
)
|
||||
|
||||
# 高度方向自适应压到 1,宽度保持
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None))
|
||||
|
||||
# ---- RNN 序列建模 ----
|
||||
self.rnn_input_size = 128
|
||||
self.rnn_hidden = 128
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=self.rnn_input_size,
|
||||
hidden_size=self.rnn_hidden,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
dropout=0.2,
|
||||
)
|
||||
|
||||
# ---- 输出层 ----
|
||||
self.fc = nn.Linear(self.rnn_hidden * 2, self.num_classes)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (batch, 1, H, W) 灰度图
|
||||
|
||||
Returns:
|
||||
logits: (seq_len, batch, num_classes) CTC 格式 (T, B, C)
|
||||
"""
|
||||
conv = self.backbone(x) # (B, 128, H', W')
|
||||
conv = self.adaptive_pool(conv) # (B, 128, 1, W')
|
||||
conv = conv.squeeze(2) # (B, 128, W')
|
||||
conv = conv.permute(0, 2, 1) # (B, W', 128)
|
||||
|
||||
rnn_out, _ = self.rnn(conv) # (B, W', 256)
|
||||
logits = self.fc(rnn_out) # (B, W', num_classes)
|
||||
logits = logits.permute(1, 0, 2) # (T, B, C)
|
||||
return logits
|
||||
|
||||
@property
|
||||
def seq_len(self) -> int:
|
||||
"""CTC 序列长度 = 输入宽度经过 2 次 W/2 pool 后的宽度。"""
|
||||
return self.img_w // 4
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# CTC 贪心解码
|
||||
# ----------------------------------------------------------
|
||||
def greedy_decode(self, logits: torch.Tensor) -> list[str]:
|
||||
"""
|
||||
CTC 贪心解码。
|
||||
|
||||
Args:
|
||||
logits: (T, B, C) 模型原始输出
|
||||
|
||||
Returns:
|
||||
解码后的字符串列表
|
||||
"""
|
||||
preds = logits.argmax(dim=2).permute(1, 0) # (B, T)
|
||||
results = []
|
||||
for pred in preds:
|
||||
chars = []
|
||||
prev = -1
|
||||
for idx in pred.tolist():
|
||||
if idx != 0 and idx != prev:
|
||||
chars.append(self.chars[idx - 1])
|
||||
prev = idx
|
||||
results.append("".join(chars))
|
||||
return results
|
||||
Reference in New Issue
Block a user