我一直在尝试实现一个用于图像分割的 trasnformer-UNet 混合模型。每当我尝试训练模型时,它都会出现内存不足的情况。最初,我认为这是由于模型的大小所致,并尝试减少批量大小、注意力头数量、变压器层数等参数。
所有这些步骤只是推迟了不可避免的内存不足的情况。我什至尝试过使用云 GPU,但仍然没有成功。这是pytorch的内存快照工具的截图:
(我通过调用
torch.cuda.empty_cache
清空缓存)
我怀疑某些张量被保留,因为我使用列表来实现跳过连接(必要的张量已附加到列表中)。
看起来您的代码正在构建每个批次之间未重置的跳过连接的列表
class convolutionalEncoder(torch.nn.Module):
class conv_block(torch.nn.Module):
...
def __init__(self, device) -> None:
super(convolutionalEncoder, self).__init__()
self.conv_block_list = []
self.skip_conn = []
for i in range(num_skip_conn):
self.conv_block_list.append(self.conv_block(filters[i], device))
def forward(self, X):
for i in range(num_skip_conn):
X = self.conv_block_list[i](X)
self.skip_conn.append(X)
return self.skip_conn
self.skip_conn
应该只是一个普通列表,而不是批次之间保留的属性。发生的情况是每个批次的张量都被添加到 self.skip_conn
,本质上将整个数据集存储在该列表中,直到您出现为止。
每次只需将其替换为新的正常列表即可
def forward(self, X):
skip_conn = []
for i in range(num_skip_conn):
X = self.conv_block_list[i](X)
skip_conn.append(X)
return skip_conn