""" FunCaptcha 专项 Siamese 模型。 用于 `4_3d_rollball_animals` 这类 challenge: - 输入 1: 候选块 candidate (RGB) - 输入 2: 参考块 reference (RGB) - 输出: 单个匹配 logit,值越大表示越可能为正确候选 """ from __future__ import annotations import torch import torch.nn as nn class _SharedEncoder(nn.Module): def __init__(self, in_channels: int = 3, embedding_dim: int = 128): super().__init__() self.features = nn.Sequential( nn.Conv2d(in_channels, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 96, kernel_size=3, padding=1), nn.BatchNorm2d(96), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(96, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1), ) self.proj = nn.Linear(128, embedding_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = x.view(x.size(0), -1) return self.proj(x) class FunCaptchaSiamese(nn.Module): """ 共享编码器 + 特征对比头。 输出 raw logits,训练时配合 `BCEWithLogitsLoss` 使用。 """ def __init__(self, in_channels: int = 3, embedding_dim: int = 128): super().__init__() self.encoder = _SharedEncoder(in_channels=in_channels, embedding_dim=embedding_dim) self.head = nn.Sequential( nn.Linear(embedding_dim * 4, 128), nn.ReLU(inplace=True), nn.Dropout(p=0.1), nn.Linear(128, 1), ) def forward(self, candidate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor: candidate_feat = self.encoder(candidate) reference_feat = self.encoder(reference) diff = torch.abs(candidate_feat - reference_feat) prod = candidate_feat * reference_feat features = torch.cat([candidate_feat, reference_feat, diff, prod], dim=1) return self.head(features)