为什么我在 U-Net 中收到有关通道数的错误消息

问题描述 投票:0回答:0

我正在尝试按照教程实施 U-Net,我遇到了预期输入通道数量的问题。

这是我收到的错误消息:

RuntimeError:给定 groups=1,权重大小 [64, 1, 3, 3],预期输入 [15, 3, 300, 300] 有 1 个通道,但得到的是 3 个通道。

这是我的

double_conv
功能:

def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c,3,1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c,3,1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

    )
    return conv

这是定义每个向下卷积:

self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_conv_1 = double_conv(1,64)
self.down_conv_2 = double_conv(64,128)
self.down_conv_3 = double_conv(128,256)
self.down_conv_4 = double_conv(256,512)
self.down_conv_5 = double_conv(512,1024)

这是我的前向函数:

def forward(self, image):
    #encoder
    x1 = self.down_conv_1(image)  #
    x2 = self.max_pool_2x2(x1)
    x3 = self.down_conv_2(x2)  #
    x4 = self.max_pool_2x2(x3)
    x5 = self.down_conv_3(x4)  #
    x6 = self.max_pool_2x2(x5)
    x7 = self.down_conv_4(x6)  #
    x8 = self.max_pool_2x2(x7)
    x9 = self.down_conv_5(x8)

    #decoder
    x = self.up_trans_1(x9)
    y = crop_img(x7, x)
    x = self.up_conv_1(torch.cat([x,y], 1))

    x = self.up_trans_2(x)
    y = crop_img(x5, x)
    x = self.up_conv_2(torch.cat([x,y], 1))

    x = self.up_trans_3(x)
    y = crop_img(x3, x)
    x = self.up_conv_3(torch.cat([x,y], 1))

    x = self.up_trans_4(x)
    y = crop_img(x1, x)
    x = self.up_conv_4(torch.cat([x,y], 1))

    x = self.out(x)
    return x

我一开始以为是第2个Conv2d改变输入通道的问题,但这并没有做任何改变。

完整的错误回溯:

Traceback (most recent call last):
  File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\train.py", line 148, in <module>
    main()

  File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\train.py", line 113, in main
    check_accuracy(val_loader, model, device=DEVICE)

  File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\utils.py", line 68, in check_accuracy
    preds = torch.sigmoid(model(x))

  File "E:\Anaconda\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)

  File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\model.py", line 76, in forward
    x1 = self.down_conv_1(image)  #

  File "E:\Anaconda\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)

  File "E:\Anaconda\lib\site-packages\torch\nn\modules\container.py", line 204, in forward
    input = module(input)

  File "E:\Anaconda\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)

  File "E:\Anaconda\lib\site-packages\torch\nn\modules\conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)

  File "E:\Anaconda\lib\site-packages\torch\nn\modules\conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[15, 3, 300, 300] to have 1 channels, but got 3 channels instead

Process finished with exit code 1
deep-learning pytorch unet-neural-network
© www.soinside.com 2019 - 2024. All rights reserved.