我正在尝试在 PyTorch 中训练一个 UNet,它的输入是一个 3D 直方图(128 个箱子,64 x 64 传感器,所以 128 x 64 x 64),它的输出是一个 2D 图像(64x64)。我使用以下代码执行此操作:
`
class conv_block(nn.Module):
def __init__(self, in_channels, out_channels):
super(conv_block,self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self,inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
class encoder_block(nn.Module):
def __init__(self,in_channels,out_channels):
super(encoder_block,self).__init__()
self.conv = conv_block(in_channels,out_channels)
self.pool = nn.MaxPool2d((2,2))
def forward(self,inputs):
x = self.conv(inputs)
p = self.pool(x)
return x,p
class decoder_block(nn.Module):
def __init__(self, in_c, out_c):
super(decoder_block,self).__init__()
self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
self.conv = conv_block(out_c+out_c, out_c)
def forward(self, inputs, skip):
x = self.up(inputs)
x = torch.cat([x, skip], axis=1)
x = self.conv(x)
return x
class unet(nn.Module):
def __init__(self):
super(unet,self).__init__()
self.e1 = encoder_block(128, 256)
self.e2 = encoder_block(256, 512)
self.e3 = encoder_block(512, 1024)
self.e4 = encoder_block(1024, 2048)
self.b = conv_block(2048, 4096)
self.d1 = decoder_block(4096, 2048)
self.d2 = decoder_block(2048, 1024)
self.d3 = decoder_block(1024, 512)
self.d4 = decoder_block(512, 256)
self.cast = nn.Conv2d(256, 128, kernel_size=1, padding=0)
self.output = nn.Conv2d(128,1,kernel_size=1,padding=0)
def forward(self, inputs):
s1, p1 = self.e1(inputs)
#print(s1.size())
s2, p2 = self.e2(p1)
#print(s2.size())
s3, p3 = self.e3(p2)
#print(s3.size())
s4, p4 = self.e4(p3)
#print(s4.size())
b = self.b(p4)
d1 = self.d1(b, s4)
d2 = self.d2(d1, s3)
d3 = self.d3(d2, s2)
d4 = self.d4(d3, s1)
cast = self.cast(d4)
output = self.output(cast)
return output
`
这只是标准的 UNet(编码器,然后是 conv,然后是解码器块),但是我在最后添加了一个“cast”和“output”层,使图像成为 64 x 64。
你会注意到这使用了 Conv2D。如果输入是 [B x 128 x 64 x 64](B 是批量大小),则输出是 [B x 1 x 64 x 64]。然而,我的数据的性质是输入数据在时间上是相关的,所以使用 Conv3D 会比 Conv2D 更好。
我试图用 3d 对应物替换 Conv2D、BatchNorm2D、ConvTranspose2d 和 MaxPool2d 步骤。然而,无论我如何改变内核大小、填充、步幅等,我都会不断收到尺寸错误
有人可以帮我算出我需要在 UNet 中更改什么,以使其进行 3D 卷积而不是 2D 卷积吗?
谢谢,非常感谢!