所以我得到了以下导致问题的 U-net 架构:
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder1 = self.double_conv(in_channels, 64)
self.encoder2 = self.down(64, 128)
self.encoder3 = self.down(128, 256)
self.encoder4 = self.down(256, 512)
self.bottleneck = self.double_conv(512, 1024)
self.decoder4 = self.up(1024, 512)
self.decoder3 = self.up(512, 256)
self.decoder2 = self.up(256, 128)
self.decoder1 = self.up(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) # SAME convolution/padding
def double_conv(self, in_channels, out_channels): # Convo Block
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
def down(self, in_channels, out_channels):
return nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
def up(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
def forward(self, x):
# Encoder
enc1 = self.encoder1(x) # Output: [1, 64, 256, 256]
print("enc1.shape",enc1.shape)
enc2 = self.encoder2(enc1) # Output: [1, 128, 128, 128]
print("enc2.shape",enc2.shape)
enc3 = self.encoder3(enc2) # Output: [1, 256, 64, 64]
print("enc3.shape",enc3.shape)
enc4 = self.encoder4(enc3) # Output: [1, 512, 32, 32]
print("enc4.shape",enc4.shape)
bottleneck_output = self.bottleneck(enc4) # Output: [1, 1024, 32, 32]
print("bottleneck_output",bottleneck_output.shape)
# Decoder
dec4 = self.decoder4(bottleneck_output)#bottleneck_output) # Output: [1, 512, 64, 64]
print(dec4.shape)
dec4 = torch.cat((dec4, enc4), dim=1) # skip connect, Concatenate: [1, 1024, 64, 64]
dec4 = self.double_conv(1024, 512)(dec4) # Corrected input channels to 1024
dec3 = self.decoder3(dec4) # Output: [1, 256, 128, 128]
dec3 = torch.cat((dec3, enc3), dim=1) # Concatenate: [1, 512, 128, 128]
dec3 = self.double_conv(512, 256)(dec3) # Corrected input channels to 512
dec2 = self.decoder2(dec3) # Output: [1, 128, 256, 256]
dec2 = torch.cat((dec2, enc2), dim=1) # Concatenate: [1, 256, 256, 256]
dec2 = self.double_conv(256, 128)(dec2) # Corrected input channels to 256
dec1 = self.decoder1(dec2) # Output: [1, 64, 512, 512]
dec1 = torch.cat((dec1, enc1), dim=1) # Concatenate: [1, 128, 512, 512]
dec1 = self.double_conv(128, 64)(dec1) # Corrected input channels to 128
return self.final_conv(dec1) # Output: [1, 1, 512, 512]```
When executing in a main method via
unet = UNet(in_channels=1, out_channels=1) 样本输入 = torch.randn(1, 1, 256, 256) 输出=unet(sample_input)```
我得到:
enc1.shape torch.Size([1, 64, 256, 256])
enc2.shape torch.Size([1, 128, 128, 128])
enc3.shape torch.Size([1, 256, 64, 64])
enc4.shape torch.Size([1, 512, 32, 32])
bottleneck_output torch.Size([1, 1024, 32, 32])
并出现以下错误:
---> 55 dec4 = self.decoder4(bottleneck_output)
RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 512, 64, 64] to have 1024 channels, but got 512 channels instead
所以问题显然是有 1024 个通道的bottleneck_output 形状,但解码器4 似乎无法识别它或其他东西。像这样。
我尝试匹配尺寸和其他东西,例如对齐功能,但到目前为止没有任何效果。打印输出形状也没有真正的帮助。感谢您的任何提示。
您的问题在于
up
方法的定义:
def up(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
self.double_conv(in_channels, out_channels),
)
ConvTranspose2d
输出具有 out_channels
个通道的张量,但 double_conv
期望具有 in_channels
个通道的输入张量。
你可能应该使用类似的东西:
def up(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
self.double_conv(out_channels, out_channels), # NOTE CHANGE HERE
)