Files
CaptchBreaker/models/classifier.py
2026-03-10 18:47:29 +08:00

73 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
调度分类器模型
轻量 CNN 分类器,用于判断验证码类型 (normal / math / 3d)。
不同类型验证码视觉差异大,分类任务简单。
架构: 4 层卷积 + GAP + FC
输入: 灰度图 1×64×128
输出: softmax 概率分布 (num_types 个类别)
体积目标: < 500KB
"""
import torch
import torch.nn as nn
class CaptchaClassifier(nn.Module):
"""
轻量分类器。
4 层卷积 (每层 Conv + BN + ReLU + MaxPool)
→ 全局平均池化 → 全连接 → 输出类别数。
"""
def __init__(self, num_types: int = 3):
super().__init__()
self.num_types = num_types
self.features = nn.Sequential(
# block 1: 1 -> 16, 64x128 -> 32x64
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 2: 16 -> 32, 32x64 -> 16x32
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 3: 32 -> 64, 16x32 -> 8x16
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# block 4: 64 -> 64, 8x16 -> 4x8
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
# 全局平均池化 → 输出 (batch, 64, 1, 1)
self.gap = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(64, num_types)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, 1, 64, 128) 灰度图
Returns:
logits: (batch, num_types) 未经 softmax 的原始输出
"""
x = self.features(x)
x = self.gap(x) # (B, 64, 1, 1)
x = x.view(x.size(0), -1) # (B, 64)
x = self.classifier(x) # (B, num_types)
return x