Files
CaptchBreaker/models/fun_captcha_siamese.py

73 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FunCaptcha 专项 Siamese 模型。
用于 `4_3d_rollball_animals` 这类 challenge
- 输入 1: 候选块 candidate (RGB)
- 输入 2: 参考块 reference (RGB)
- 输出: 单个匹配 logit值越大表示越可能为正确候选
"""
from __future__ import annotations
import torch
import torch.nn as nn
class _SharedEncoder(nn.Module):
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 96, kernel_size=3, padding=1),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(96, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
)
self.proj = nn.Linear(128, embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = x.view(x.size(0), -1)
return self.proj(x)
class FunCaptchaSiamese(nn.Module):
"""
共享编码器 + 特征对比头。
输出 raw logits训练时配合 `BCEWithLogitsLoss` 使用。
"""
def __init__(self, in_channels: int = 3, embedding_dim: int = 128):
super().__init__()
self.encoder = _SharedEncoder(in_channels=in_channels, embedding_dim=embedding_dim)
self.head = nn.Sequential(
nn.Linear(embedding_dim * 4, 128),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Linear(128, 1),
)
def forward(self, candidate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor:
candidate_feat = self.encoder(candidate)
reference_feat = self.encoder(reference)
diff = torch.abs(candidate_feat - reference_feat)
prod = candidate_feat * reference_feat
features = torch.cat([candidate_feat, reference_feat, diff, prod], dim=1)
return self.head(features)