Align task API and add FunCaptcha support

This commit is contained in:
Hua
2026-03-12 19:32:59 +08:00
parent ef9518deeb
commit bc6776979e
33 changed files with 3446 additions and 672 deletions

View File

@@ -0,0 +1,72 @@
"""
FunCaptcha 专项 Siamese 模型。
用于 `4_3d_rollball_animals` 这类 challenge
- 输入 1: 候选块 candidate (RGB)
- 输入 2: 参考块 reference (RGB)
- 输出: 单个匹配 logit值越大表示越可能为正确候选
"""
from __future__ import annotations
import torch
import torch.nn as nn
class _SharedEncoder(nn.Module):
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 96, kernel_size=3, padding=1),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(96, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
)
self.proj = nn.Linear(128, embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = x.view(x.size(0), -1)
return self.proj(x)
class FunCaptchaSiamese(nn.Module):
"""
共享编码器 + 特征对比头。
输出 raw logits训练时配合 `BCEWithLogitsLoss` 使用。
"""
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
super().__init__()
self.encoder = _SharedEncoder(in_channels=in_channels, embedding_dim=embedding_dim)
self.head = nn.Sequential(
nn.Linear(embedding_dim * 4, 128),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Linear(128, 1),
)
def forward(self, candidate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
candidate_feat = self.encoder(candidate)
reference_feat = self.encoder(reference)
diff = torch.abs(candidate_feat - reference_feat)
prod = candidate_feat * reference_feat
features = torch.cat([candidate_feat, reference_feat, diff, prod], dim=1)
return self.head(features)