Replace AdaptiveAvgPool2d with fixed-kernel AvgPool2d for ONNX compatibility
AdaptiveAvgPool2d with None dimensions can cause issues with some ONNX runtimes. Use AvgPool2d with kernel=(img_h//16, 1) to achieve the same height-to-1 reduction with full ONNX compatibility. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -88,8 +88,10 @@ class ThreeDCNN(nn.Module):
|
||||
nn.MaxPool2d(kernel_size=(2, 1)),
|
||||
)
|
||||
|
||||
# 高度方向自适应压到 1,宽度保持
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None))
|
||||
# 高度方向压到 1,宽度保持 (用固定 kernel 替代 Adaptive,ONNX 兼容)
|
||||
# 60→30→15→7→3 (4次高度pool后)
|
||||
self._pool_h = img_h // 16 # 60//16=3
|
||||
self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1))
|
||||
|
||||
# ---- RNN 序列建模 ----
|
||||
self.rnn_input_size = 128
|
||||
@@ -115,7 +117,7 @@ class ThreeDCNN(nn.Module):
|
||||
logits: (seq_len, batch, num_classes) CTC 格式 (T, B, C)
|
||||
"""
|
||||
conv = self.backbone(x) # (B, 128, H', W')
|
||||
conv = self.adaptive_pool(conv) # (B, 128, 1, W')
|
||||
conv = self.height_pool(conv) # (B, 128, 1, W')
|
||||
conv = conv.squeeze(2) # (B, 128, W')
|
||||
conv = conv.permute(0, 2, 1) # (B, W', 128)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user