Pytorch 累积前一批的张量

问题描述 投票:0回答:1

我一直在尝试实现一个用于图像分割的 trasnformer-UNet 混合模型。每当我尝试训练模型时,它都会出现内存不足的情况。最初,我认为这是由于模型的大小所致,并尝试减少批量大小、注意力头数量、变压器层数等参数。

所有这些步骤只是推迟了不可避免的内存不足的情况。我什至尝试过使用云 GPU,但仍然没有成功。这是pytorch的内存快照工具的截图:

enter image description here

(我通过调用

torch.cuda.empty_cache
清空缓存)

我怀疑某些张量被保留,因为我使用列表来实现跳过连接(必要的张量已附加到列表中)。

这是代码链接

这是 pickle 文件(内存快照转储)的链接

python deep-learning pytorch
1个回答
0
投票

看起来您的代码正在构建每个批次之间未重置的跳过连接的列表

class convolutionalEncoder(torch.nn.Module):
    
    class conv_block(torch.nn.Module):
 
        ...
 
    def __init__(self, device) -> None:
        super(convolutionalEncoder, self).__init__()
        self.conv_block_list = []
        self.skip_conn = []
        for i in range(num_skip_conn):
            self.conv_block_list.append(self.conv_block(filters[i], device))
 
    def forward(self, X):
 
        for i in range(num_skip_conn):
            X = self.conv_block_list[i](X)
            self.skip_conn.append(X)
 
        return self.skip_conn

self.skip_conn
应该只是一个普通列表,而不是批次之间保留的属性。发生的情况是每个批次的张量都被添加到
self.skip_conn
,本质上将整个数据集存储在该列表中,直到您出现为止。

每次只需将其替换为新的正常列表即可

    def forward(self, X):
 
        skip_conn = []
        for i in range(num_skip_conn):
            X = self.conv_block_list[i](X)
           skip_conn.append(X)
 
        return skip_conn
© www.soinside.com 2019 - 2024. All rights reserved.