Align task API and add FunCaptcha support
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
|
||||
- GapDetectorCNN: 滑块缺口检测 CNN (~1MB)
|
||||
- RotationRegressor: 旋转角度回归 sin/cos 编码 (~2MB)
|
||||
- FunCaptchaSiamese: FunCaptcha 专项 Siamese 匹配模型
|
||||
"""
|
||||
|
||||
from models.classifier import CaptchaClassifier
|
||||
@@ -16,6 +17,7 @@ from models.threed_cnn import ThreeDCNN
|
||||
from models.regression_cnn import RegressionCNN
|
||||
from models.gap_detector import GapDetectorCNN
|
||||
from models.rotation_regressor import RotationRegressor
|
||||
from models.fun_captcha_siamese import FunCaptchaSiamese
|
||||
|
||||
__all__ = [
|
||||
"CaptchaClassifier",
|
||||
@@ -24,4 +26,5 @@ __all__ = [
|
||||
"RegressionCNN",
|
||||
"GapDetectorCNN",
|
||||
"RotationRegressor",
|
||||
"FunCaptchaSiamese",
|
||||
]
|
||||
|
||||
@@ -22,7 +22,7 @@ class CaptchaClassifier(nn.Module):
|
||||
→ 全局平均池化 → 全连接 → 输出类别数。
|
||||
"""
|
||||
|
||||
def __init__(self, num_types: int = 3):
|
||||
def __init__(self, num_types: int = 5):
|
||||
super().__init__()
|
||||
self.num_types = num_types
|
||||
|
||||
|
||||
72
models/fun_captcha_siamese.py
Normal file
72
models/fun_captcha_siamese.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
FunCaptcha 专项 Siamese 模型。
|
||||
|
||||
用于 `4_3d_rollball_animals` 这类 challenge:
|
||||
- 输入 1: 候选块 candidate (RGB)
|
||||
- 输入 2: 参考块 reference (RGB)
|
||||
- 输出: 单个匹配 logit,值越大表示越可能为正确候选
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class _SharedEncoder(nn.Module):
|
||||
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(64, 96, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(96),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(96, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
)
|
||||
self.proj = nn.Linear(128, embedding_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.features(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class FunCaptchaSiamese(nn.Module):
|
||||
"""
|
||||
共享编码器 + 特征对比头。
|
||||
|
||||
输出 raw logits,训练时配合 `BCEWithLogitsLoss` 使用。
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
|
||||
super().__init__()
|
||||
self.encoder = _SharedEncoder(in_channels=in_channels, embedding_dim=embedding_dim)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(embedding_dim * 4, 128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(128, 1),
|
||||
)
|
||||
|
||||
def forward(self, candidate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||||
candidate_feat = self.encoder(candidate)
|
||||
reference_feat = self.encoder(reference)
|
||||
|
||||
diff = torch.abs(candidate_feat - reference_feat)
|
||||
prod = candidate_feat * reference_feat
|
||||
features = torch.cat([candidate_feat, reference_feat, diff, prod], dim=1)
|
||||
return self.head(features)
|
||||
Reference in New Issue
Block a user