73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
"""
|
||
FunCaptcha 专项 Siamese 模型。
|
||
|
||
用于 `4_3d_rollball_animals` 这类 challenge:
|
||
- 输入 1: 候选块 candidate (RGB)
|
||
- 输入 2: 参考块 reference (RGB)
|
||
- 输出: 单个匹配 logit,值越大表示越可能为正确候选
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class _SharedEncoder(nn.Module):
|
||
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
|
||
super().__init__()
|
||
self.features = nn.Sequential(
|
||
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(32),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2),
|
||
|
||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(64),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2),
|
||
|
||
nn.Conv2d(64, 96, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(96),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(2, 2),
|
||
|
||
nn.Conv2d(96, 128, kernel_size=3, padding=1),
|
||
nn.BatchNorm2d(128),
|
||
nn.ReLU(inplace=True),
|
||
nn.AdaptiveAvgPool2d(1),
|
||
)
|
||
self.proj = nn.Linear(128, embedding_dim)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = self.features(x)
|
||
x = x.view(x.size(0), -1)
|
||
return self.proj(x)
|
||
|
||
|
||
class FunCaptchaSiamese(nn.Module):
|
||
"""
|
||
共享编码器 + 特征对比头。
|
||
|
||
输出 raw logits,训练时配合 `BCEWithLogitsLoss` 使用。
|
||
"""
|
||
|
||
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
|
||
super().__init__()
|
||
self.encoder = _SharedEncoder(in_channels=in_channels, embedding_dim=embedding_dim)
|
||
self.head = nn.Sequential(
|
||
nn.Linear(embedding_dim * 4, 128),
|
||
nn.ReLU(inplace=True),
|
||
nn.Dropout(p=0.1),
|
||
nn.Linear(128, 1),
|
||
)
|
||
|
||
def forward(self, candidate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
|
||
candidate_feat = self.encoder(candidate)
|
||
reference_feat = self.encoder(reference)
|
||
|
||
diff = torch.abs(candidate_feat - reference_feat)
|
||
prod = candidate_feat * reference_feat
|
||
features = torch.cat([candidate_feat, reference_feat, diff, prod], dim=1)
|
||
return self.head(features)
|