73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
"""
|
||
调度分类器模型
|
||
|
||
轻量 CNN 分类器,用于判断验证码类型 (normal / math / 3d)。
|
||
不同类型验证码视觉差异大,分类任务简单。
|
||
|
||
架构: 4 层卷积 + GAP + FC
|
||
输入: 灰度图 1×64×128
|
||
输出: softmax 概率分布 (num_types 个类别)
|
||
体积目标: < 500KB
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class CaptchaClassifier(nn.Module):
|
||
"""
|
||
轻量分类器。
|
||
|
||
4 层卷积 (每层 Conv + BN + ReLU + MaxPool)
|
||
→ 全局平均池化 → 全连接 → 输出类别数。
|
||
"""
|
||
|
||
def __init__(self, num_types: int = 5):
|
||
super().__init__()
|
||
self.num_types = num_types
|
||
|
||
self.features = nn.Sequential(
|
||
# block 1: 1 -> 16, 64x128 -> 32x64
|
||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(16),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2),
|
||
|
||
# block 2: 16 -> 32, 32x64 -> 16x32
|
||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(32),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2),
|
||
|
||
# block 3: 32 -> 64, 16x32 -> 8x16
|
||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(64),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2),
|
||
|
||
# block 4: 64 -> 64, 8x16 -> 4x8
|
||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(64),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2),
|
||
)
|
||
|
||
# 全局平均池化 → 输出 (batch, 64, 1, 1)
|
||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||
|
||
self.classifier = nn.Linear(64, num_types)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Args:
|
||
x: (batch, 1, 64, 128) 灰度图
|
||
|
||
Returns:
|
||
logits: (batch, num_types) 未经 softmax 的原始输出
|
||
"""
|
||
x = self.features(x)
|
||
x = self.gap(x) # (B, 64, 1, 1)
|
||
x = x.view(x.size(0), -1) # (B, 64)
|
||
x = self.classifier(x) # (B, num_types)
|
||
return x
|