Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -8,6 +8,7 @@
- RegressionCNN: 回归 CNN (3D 旋转 + 滑块, ~1MB)
- GapDetectorCNN: 滑块缺口检测 CNN (~1MB)
- RotationRegressor: 旋转角度回归 sin/cos 编码 (~2MB)
- FunCaptchaSiamese: FunCaptcha 专项 Siamese 匹配模型
"""
from models.classifier import CaptchaClassifier
@@ -16,6 +17,7 @@ from models.threed_cnn import ThreeDCNN
from models.regression_cnn import RegressionCNN
from models.gap_detector import GapDetectorCNN
from models.rotation_regressor import RotationRegressor
from models.fun_captcha_siamese import FunCaptchaSiamese
__all__ = [
"CaptchaClassifier",
@@ -24,4 +26,5 @@ __all__ = [
"RegressionCNN",
"GapDetectorCNN",
"RotationRegressor",
"FunCaptchaSiamese",
]

View File

@@ -22,7 +22,7 @@ class CaptchaClassifier(nn.Module):
→ 全局平均池化 → 全连接 → 输出类别数。
"""
def __init__(self, num_types: int = 3):
def __init__(self, num_types: int = 5):
super().__init__()
self.num_types = num_types

View File

@@ -0,0 +1,72 @@
"""
FunCaptcha 专项 Siamese 模型。
用于 `4_3d_rollball_animals` 这类 challenge
- 输入 1: 候选块 candidate (RGB)
- 输入 2: 参考块 reference (RGB)
- 输出: 单个匹配 logit值越大表示越可能为正确候选
"""
from __future__ import annotations
import torch
import torch.nn as nn
class _SharedEncoder(nn.Module):
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 96, kernel_size=3, padding=1),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(96, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
)
self.proj = nn.Linear(128, embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = x.view(x.size(0), -1)
return self.proj(x)
class FunCaptchaSiamese(nn.Module):
"""
共享编码器 + 特征对比头。
输出 raw logits训练时配合 `BCEWithLogitsLoss` 使用。
"""
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
super().__init__()
self.encoder = _SharedEncoder(in_channels=in_channels, embedding_dim=embedding_dim)
self.head = nn.Sequential(
nn.Linear(embedding_dim * 4, 128),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Linear(128, 1),
)
def forward(self, candidate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
candidate_feat = self.encoder(candidate)
reference_feat = self.encoder(reference)
diff = torch.abs(candidate_feat - reference_feat)
prod = candidate_feat * reference_feat
features = torch.cat([candidate_feat, reference_feat, diff, prod], dim=1)
return self.head(features)