使用 Pyorch 很难通过unet 修复 DDPM 的代码

问题描述 投票:0回答:1

我正在学习深度学习,并有一项任务是在 MNIST 手写数字数据集上使用 UNet 训练 DDPM。三个 ipynb 文件(model、unet 和 train_mnist)和一份 unet 图的 pdf 保存在 my github 中。我在“model.ipynb”和“unet.ipynb”中填写了#Your Code Here#中缺少的部分

当我运行“unet.ipynb”时,出现如下警告提示。

NotImplementedError                       Traceback (most recent call last)
<ipython-input-21-36a7bf177d9a> in <cell line: 60>()
     62     t=torch.randint(0,1000,(3,))
     63     model=Unet(1000,128)
---> 64     y=model(x,t)
     65     print(y.shape)

5 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
    350         registered hooks while the latter silently ignores them.
    351     """
--> 352     raise NotImplementedError(f'Module [{type(self).__name__}] is missing the required "forward" function')
    353 
    354 

NotImplementedError: Module [ModuleList] is missing the required "forward" function

这是unet架构:

class Unet(nn.Module):

    def __init__(self,timesteps,time_embedding_dim,in_channels=3,out_channels=2,base_dim=32,dim_mults=[2,4,8,16]):
        super().__init__()
        assert isinstance(dim_mults,(list,tuple))
        assert base_dim%2==0

        channels=self._cal_channels(base_dim,dim_mults)

        self.init_conv=ConvBnSiLu(in_channels,base_dim,3,1,1)
        self.time_embedding=nn.Embedding(timesteps,time_embedding_dim)

        self.encoder_blocks=nn.ModuleList([EncoderBlock(c[0],c[1],time_embedding_dim) for c in channels])
        self.decoder_blocks=nn.ModuleList([DecoderBlock(c[1],c[0],time_embedding_dim) for c in channels[::-1]])

        self.mid_block=nn.Sequential(*[ResidualBottleneck(channels[-1][1],channels[-1][1]) for i in range(2)],
                                        ResidualBottleneck(channels[-1][1],channels[-1][1]//2))

        self.final_conv=nn.Conv2d(in_channels=channels[0][0]//2,out_channels=out_channels,kernel_size=1)

    def forward(self,x,t=None):
        '''
            Implement the data flow of the UNet architecture
        '''
        # ---------- **** ---------- #
        # YOUR CODE HERE
        t = self.time_embedding

        #initial conv
        x1 = self.init_conv(x)
        #Down
        x2 = self.encoder_blocks(x1,t)
        x3 = self.encoder_blocks(x2[0],t)
        x4 = self.encoder_blocks(x3[0],t)
        x5 = self.encoder_blocks(x4[0],t)
        #Middle
        x6 = self.mid_block(x5[0])
        #Up
        x = self.decoder_blocks(x6,x5[1],t)
        x = self.decoder_blocks(x,x4[1],t)
        x = self.decoder_blocks(x,x3[1],t)
        x = self.decoder_blocks(x,x2[1],t)
        x = self.decoder_blocks(x,x1[1],t)
        #final
        x = self.final_conv(x)

        # ---------- **** ---------- #
        return x


    def _cal_channels(self,base_dim,dim_mults):
        dims=[base_dim*x for x in dim_mults]
        dims.insert(0,base_dim)
        channels=[]
        for i in range(len(dims)-1):
            channels.append((dims[i],dims[i+1])) # in_channel, out_channel

        return channels

if __name__=="__main__":
    x=torch.randn(3,3,224,224)
    t=torch.randint(0,1000,(3,))
    model=Unet(1000,128)
    y=model(x,t)
    print(y.shape)

我尝试过使用 poe 和困惑,但由于错误,事情变得更糟。我还查看了 Pytorch 中的文档,但不知道如何修复代码。 有人可以发现问题并解释如何实施unet吗?很高兴讨论。谢谢。

pytorch unet-neural-network
1个回答
0
投票

你的

encoder_blocks
decoder_blocks
是ModuleList,与
nn.Sequential
不同,它们只是模块的集合,没有
forward
功能。您需要手动迭代列表中的模块:

x1 = self.init_conv(x)
for module in self.encoder_blocks:
    x1,_ = module(x1,t)
# Same for decoder ...

nn.ModuleList
的文档显示了类似的示例:https://pytorch.org/docs/stable/ generated/torch.nn.ModuleList.html

© www.soinside.com 2019 - 2024. All rights reserved.