From 90d6423551d8683d0fc85981fe3590df1d5fda77 Mon Sep 17 00:00:00 2001 From: Hua Date: Wed, 11 Mar 2026 13:58:41 +0800 Subject: [PATCH] Replace AdaptiveAvgPool2d with fixed-kernel AvgPool2d for ONNX compatibility AdaptiveAvgPool2d with None dimensions can cause issues with some ONNX runtimes. Use AvgPool2d with kernel=(img_h//16, 1) to achieve the same height-to-1 reduction with full ONNX compatibility. Co-Authored-By: Claude Opus 4.6 --- models/lite_crnn.py | 10 ++++++---- models/threed_cnn.py | 8 +++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/models/lite_crnn.py b/models/lite_crnn.py index 1b99c0c..60217e5 100644 --- a/models/lite_crnn.py +++ b/models/lite_crnn.py @@ -66,9 +66,11 @@ class LiteCRNN(nn.Module): nn.MaxPool2d(2, 2), # H/2, W/2 ) - # 经过 4 次高度 pool: img_h / 16 (如 40 → 2, 不够整除时用自适应) - # 用 AdaptiveAvgPool 把高度压到 1 - self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None)) # (B, 128, 1, W') + # 经过 4 次高度 pool 后高度 = img_h/16 (向下取整) + # 例: 40→20→10→5→2 + # 用固定 kernel 的 AvgPool 把剩余高度压到 1 (ONNX 兼容) + self._pool_h = img_h // 16 # 40→2, 若有余数也向下取整 + self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1)) # ---- RNN 序列建模 ---- self.rnn_input_size = 128 @@ -95,7 +97,7 @@ class LiteCRNN(nn.Module): """ # CNN conv = self.cnn(x) # (B, 128, H', W') - conv = self.adaptive_pool(conv) # (B, 128, 1, W') + conv = self.height_pool(conv) # (B, 128, 1, W') conv = conv.squeeze(2) # (B, 128, W') conv = conv.permute(0, 2, 1) # (B, W', 128) — batch_first 序列 diff --git a/models/threed_cnn.py b/models/threed_cnn.py index 3f374b3..8fa76d4 100644 --- a/models/threed_cnn.py +++ b/models/threed_cnn.py @@ -88,8 +88,10 @@ class ThreeDCNN(nn.Module): nn.MaxPool2d(kernel_size=(2, 1)), ) - # 高度方向自适应压到 1,宽度保持 - self.adaptive_pool = nn.AdaptiveAvgPool2d((1, None)) + # 高度方向压到 1,宽度保持 (用固定 kernel 替代 Adaptive,ONNX 兼容) + # 60→30→15→7→3 (4次高度pool后) + self._pool_h = img_h // 16 # 60//16=3 + self.height_pool = nn.AvgPool2d(kernel_size=(self._pool_h, 1)) # ---- RNN 序列建模 ---- self.rnn_input_size = 128 @@ -115,7 +117,7 @@ class ThreeDCNN(nn.Module): logits: (seq_len, batch, num_classes) CTC 格式 (T, B, C) """ conv = self.backbone(x) # (B, 128, H', W') - conv = self.adaptive_pool(conv) # (B, 128, 1, W') + conv = self.height_pool(conv) # (B, 128, 1, W') conv = conv.squeeze(2) # (B, 128, W') conv = conv.permute(0, 2, 1) # (B, W', 128)