Unet pytorch 尺寸不匹配

问题描述 投票:0回答:1

所以我得到了以下导致问题的 U-net 架构:

class UNet(nn.Module): 
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder1 = self.double_conv(in_channels, 64)
        self.encoder2 = self.down(64, 128)
        self.encoder3 = self.down(128, 256)
        self.encoder4 = self.down(256, 512)
        self.bottleneck = self.double_conv(512, 1024)
        self.decoder4 = self.up(1024, 512)
        self.decoder3 = self.up(512, 256)
        self.decoder2 = self.up(256, 128)
        self.decoder1 = self.up(128, 64)
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) # SAME convolution/padding

    def double_conv(self, in_channels, out_channels): # Convo Block
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def down(self, in_channels, out_channels):
        return nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            self.double_conv(in_channels, out_channels),
        )

    def up(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            self.double_conv(in_channels, out_channels),
        )

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)  # Output: [1, 64, 256, 256]
        print("enc1.shape",enc1.shape)
        enc2 = self.encoder2(enc1)  # Output: [1, 128, 128, 128]
        print("enc2.shape",enc2.shape)
        enc3 = self.encoder3(enc2)  # Output: [1, 256, 64, 64]
        print("enc3.shape",enc3.shape)
        enc4 = self.encoder4(enc3)  # Output: [1, 512, 32, 32]
        print("enc4.shape",enc4.shape)
        bottleneck_output = self.bottleneck(enc4)  # Output: [1, 1024, 32, 32]
        print("bottleneck_output",bottleneck_output.shape)
        
        # Decoder
        dec4 = self.decoder4(bottleneck_output)#bottleneck_output)  # Output: [1, 512, 64, 64]
        print(dec4.shape)
        dec4 = torch.cat((dec4, enc4), dim=1)  # skip connect, Concatenate: [1, 1024, 64, 64]
        dec4 = self.double_conv(1024, 512)(dec4)  # Corrected input channels to 1024

        dec3 = self.decoder3(dec4)  # Output: [1, 256, 128, 128]
        dec3 = torch.cat((dec3, enc3), dim=1)  # Concatenate: [1, 512, 128, 128]
        dec3 = self.double_conv(512, 256)(dec3)  # Corrected input channels to 512

        dec2 = self.decoder2(dec3)  # Output: [1, 128, 256, 256]
        dec2 = torch.cat((dec2, enc2), dim=1)  # Concatenate: [1, 256, 256, 256]
        dec2 = self.double_conv(256, 128)(dec2)  # Corrected input channels to 256

        dec1 = self.decoder1(dec2)  # Output: [1, 64, 512, 512]
        dec1 = torch.cat((dec1, enc1), dim=1)  # Concatenate: [1, 128, 512, 512]
        dec1 = self.double_conv(128, 64)(dec1)  # Corrected input channels to 128

        return self.final_conv(dec1)  # Output: [1, 1, 512, 512]```

When executing in a main method via 

unet = UNet(in_channels=1, out_channels=1) 样本输入 = torch.randn(1, 1, 256, 256) 输出=unet(sample_input)```

我得到:

enc1.shape torch.Size([1, 64, 256, 256])
enc2.shape torch.Size([1, 128, 128, 128])
enc3.shape torch.Size([1, 256, 64, 64])
enc4.shape torch.Size([1, 512, 32, 32])
bottleneck_output torch.Size([1, 1024, 32, 32])

并出现以下错误:

---> 55 dec4 = self.decoder4(bottleneck_output)

RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 512, 64, 64] to have 1024 channels, but got 512 channels instead

所以问题显然是有 1024 个通道的bottleneck_output 形状,但解码器4 似乎无法识别它或其他东西。像这样。

我尝试匹配尺寸和其他东西,例如对齐功能,但到目前为止没有任何效果。打印输出形状也没有真正的帮助。感谢您的任何提示。

pytorch unet-neural-network
1个回答
0
投票

您的问题在于

up
方法的定义:

def up(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            self.double_conv(in_channels, out_channels),
        )

ConvTranspose2d
输出具有
out_channels
个通道的张量,但
double_conv
期望具有
in_channels
个通道的输入张量。

你可能应该使用类似的东西:

def up(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            self.double_conv(out_channels, out_channels), # NOTE CHANGE HERE
        )

© www.soinside.com 2019 - 2024. All rights reserved.