我正在尝试按照教程实施 U-Net,我遇到了预期输入通道数量的问题。
这是我收到的错误消息:
RuntimeError:给定 groups=1,权重大小 [64, 1, 3, 3],预期输入 [15, 3, 300, 300] 有 1 个通道,但得到的是 3 个通道。
这是我的
double_conv
功能:
def double_conv(in_c, out_c):
conv = nn.Sequential(
nn.Conv2d(in_c, out_c,3,1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c,3,1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
)
return conv
这是定义每个向下卷积:
self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.down_conv_1 = double_conv(1,64)
self.down_conv_2 = double_conv(64,128)
self.down_conv_3 = double_conv(128,256)
self.down_conv_4 = double_conv(256,512)
self.down_conv_5 = double_conv(512,1024)
这是我的前向函数:
def forward(self, image):
#encoder
x1 = self.down_conv_1(image) #
x2 = self.max_pool_2x2(x1)
x3 = self.down_conv_2(x2) #
x4 = self.max_pool_2x2(x3)
x5 = self.down_conv_3(x4) #
x6 = self.max_pool_2x2(x5)
x7 = self.down_conv_4(x6) #
x8 = self.max_pool_2x2(x7)
x9 = self.down_conv_5(x8)
#decoder
x = self.up_trans_1(x9)
y = crop_img(x7, x)
x = self.up_conv_1(torch.cat([x,y], 1))
x = self.up_trans_2(x)
y = crop_img(x5, x)
x = self.up_conv_2(torch.cat([x,y], 1))
x = self.up_trans_3(x)
y = crop_img(x3, x)
x = self.up_conv_3(torch.cat([x,y], 1))
x = self.up_trans_4(x)
y = crop_img(x1, x)
x = self.up_conv_4(torch.cat([x,y], 1))
x = self.out(x)
return x
我一开始以为是第2个Conv2d改变输入通道的问题,但这并没有做任何改变。
完整的错误回溯:
Traceback (most recent call last):
File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\train.py", line 148, in <module>
main()
File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\train.py", line 113, in main
check_accuracy(val_loader, model, device=DEVICE)
File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\utils.py", line 68, in check_accuracy
preds = torch.sigmoid(model(x))
File "E:\Anaconda\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\rosha\Desktop\Final Year Project\Code\U-Net\model.py", line 76, in forward
x1 = self.down_conv_1(image) #
File "E:\Anaconda\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "E:\Anaconda\lib\site-packages\torch\nn\modules\container.py", line 204, in forward
input = module(input)
File "E:\Anaconda\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "E:\Anaconda\lib\site-packages\torch\nn\modules\conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "E:\Anaconda\lib\site-packages\torch\nn\modules\conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[15, 3, 300, 300] to have 1 channels, but got 3 channels instead
Process finished with exit code 1