Pytorch batchnorm2d:“运行时错误:running_mean 应包含 1 个元素而不是 64”

问题描述 投票:0回答:1

类似的问题都没有解决。所以请不要标记为重复

Pytorch BatchNorm2d 需要 N C H W 格式的输入 哪里

N = Batchsize
C = Channels
H = Height
W = Width

正如文档中所示:https://pytorch.org/docs/stable/ generated/torch.nn.BatchNorm2d.html

如果我们使用随机张量测试它,我们会得到一个错误:

import torch
        
n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width
        
torch.nn.BatchNorm2d(h)(torch.rand(n,c,h,w))

以下代码“有效”,但输入格式为“NHWC”

import torch

n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width

x = torch.rand(n,h,w,c)

x = torch.nn.BatchNorm2d(h)(x)
python python-3.x pytorch torch batch-normalization
1个回答
0
投票

这里的问题是,如果您更改

N
C
H
W
变量的值,您实际上并没有更改 PyTorch 开发人员设置的内部内存格式;这只是一个变量名称,即,如果您在
(n,h,c,w)
中提供输入,如上所述,则内部为
N->N
H->C
(H 将是通道数,而不是您所想的高度)、
C->H
W->W

回到问题,输入数据中的通道数应与

nn.BatchNorm2d
中的通道数匹配。

在您的情况下,您设置的通道数为 1,但 BatchNorm 期望用户提供 64 个通道。要解决此问题,您可以按照以下示例操作:

示例:

import torch
n, c, h, w = 32, 64, 64, 512
x = torch.rand(n,c,h,w)
x = torch.nn.BatchNorm2d(h)(x)

import torch
n, c, h, w = 32, 1, 64, 512
x = torch.rand(n,c,h,w)
x = torch.nn.BatchNorm2d(c)(x)

我希望这对您有帮助。谢谢!

© www.soinside.com 2019 - 2024. All rights reserved.