之前用 keras 用習慣了

要用 pytorch 訓練時後就卡住了

找一下發現在 pytorch 中 使用 3D Global average pooling 要自定義層出來

import torch
import torch.nn as nn

# 創建一個虛擬的三維特徵圖,假設大小為 (batch_size, channels, depth, height, width)
input_data = torch.randn(1, 512, 4, 4, 4)  # 假設特徵圖大小為 4x4x4,通道數為 512

# 定義 3D 全局平均池化
class GlobalAvgPool3d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool3d, self).__init__()

    def forward(self, x):
        # 在特徵圖的空間尺寸上執行平均池化
        return torch.mean(x, dim=(2, 3, 4))  # 在深度、高度和寬度上進行平均池化

# 創建 GlobalAvgPool3d 實例
global_avg_pool = GlobalAvgPool3d()

# 將輸入數據通過全局平均池化層
output_data = global_avg_pool(input_data)

# 輸出數據的形狀
print(output_data.shape)  # torch.Size([1, 512])

在 keras 中就有現成的

keras.layers.GlobalAveragePooling3D(data_format=None, keepdims=False, **kwargs)
雖然 pytorch 好像比較紅
但是 keras 確實比較方便用

而 3D Global average pooling 通常在哪邊使用呢?
通常在分類模型使用
因為特徵出來後可能是好幾維的陣列
但是要做分類之前要降維才好

而 Global average 就是直接針對通道降維