类似的问题都没有解决。所以请不要标记为重复
Pytorch BatchNorm2d 需要 N C H W 格式的输入 哪里
N = Batchsize
C = Channels
H = Height
W = Width
正如文档中所示:https://pytorch.org/docs/stable/ generated/torch.nn.BatchNorm2d.html
如果我们使用随机张量测试它,我们会得到一个错误:
import torch
n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width
torch.nn.BatchNorm2d(h)(torch.rand(n,c,h,w))
以下代码“有效”,但输入格式为“NHWC”
import torch
n = 32 # N = Batch size
c = 1 # C = Channels
h = 64 # H = Height
w = 512 # W = Width
x = torch.rand(n,h,w,c)
x = torch.nn.BatchNorm2d(h)(x)
这里的问题是,如果您更改
N
、C
、H
或 W
变量的值,您实际上并没有更改 PyTorch 开发人员设置的内部内存格式;这只是一个变量名称,即,如果您在 (n,h,c,w)
中提供输入,如上所述,则内部为 N->N
、H->C
(H 将是通道数,而不是您所想的高度)、C->H
和W->W
。
回到问题,输入数据中的通道数应与
nn.BatchNorm2d
中的通道数匹配。
在您的情况下,您设置的通道数为 1,但 BatchNorm 期望用户提供 64 个通道。要解决此问题,您可以按照以下示例操作:
示例:
import torch
n, c, h, w = 32, 64, 64, 512
x = torch.rand(n,c,h,w)
x = torch.nn.BatchNorm2d(h)(x)
和
import torch
n, c, h, w = 32, 1, 64, 512
x = torch.rand(n,c,h,w)
x = torch.nn.BatchNorm2d(c)(x)
我希望这对您有帮助。谢谢!