我有以下简单的自动编码器:
class Autoencoder(nn.Module):
def __init__(self, input_shape, model_config):
super().__init__()
output_features = model_config["output_features"]
encode2_size = model_config["encode2_size"]
encode3_size = model_config["encode3_size"]
# define encoding layers
self.encode1 = nn.Linear(input_shape, output_features)
self.encode2 = nn.Linear(output_features, encode2_size)
self.encode3 = nn.Linear(encode2_size, encode3_size)
# define decoding layers
self.decode1 = nn.Linear(encode3_size, encode2_size)
self.decode2 = nn.Linear(encode2_size, output_features)
self.decode3 = nn.Linear(output_features, input_shape)
def encode(self, x: torch.Tensor):
x = relu(self.encode1(x))
x = relu(self.encode2(x))
x = relu(self.encode3(x))
return x
def decode(self, x: torch.Tensor):
x = relu(self.decode1(x))
x = relu(self.decode2(x))
x = relu(self.decode3(x))
return x
def forward(self, x: torch.Tensor):
x = self.encode(x)
x = self.decode(x)
return x
我正在使用如下所示的骰子损失:
class DiceLoss(nn.Module):
"""
The formula: 2*|X ∩ Y|/(|X|*|Y|)
"""
def __init__(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
# comment out if your model contains a sigmoid acitvation
inputs = torch.sigmoid(inputs)
# flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice
对于培训,我有这样的东西:
train_sampler = DistributedSampler(data_tensor, num_replicas=hvd.size(), rank=hvd.rank())
train_dataloader = DataLoader(data_tensor, batch_size=batch_size, shuffle=False, sampler=train_sampler)
# autoencoder params
epochs = model_config["epochs"]
net = Autoencoder(embedding_dim, model_config)
# loss function and optimizer
loss_function = DiceLoss()
optimizer = optim.Adagrad(net.parameters(), lr=model_config["lr"], weight_decay=model_config["weight_decay"])
for epoch in range(epochs):
for i, batch in enumerate(train_dataloader):
net.zero_grad()
# Pass batch through
output = net(batch)
# Get Loss + Backprop
loss = loss_function(output, batch)
loss.backward(retain_graph=True)
optimizer.step()
没有 retain_graph=True 我得到错误:
RuntimeError:尝试第二次向后遍历图形(或在张量已被释放后直接访问已保存的张量)。当您调用 .backward() 或 autograd.grad() 时,图形的已保存中间值将被释放。如果需要第二次向后遍历图形或者需要在向后调用后访问保存的张量,请指定 retain_graph=True。
我明白这个问题。为了减少内存使用,在 .backward() 调用期间,所有中间结果在不再需要时被删除。因此,如果您尝试再次调用 .backward() ,则中间结果不存在并且无法执行向后传递(并且您会收到错误消息)。您可以调用 .backward(retain_graph=True) 进行反向传递,不会删除中间结果,这样您就可以再次调用 .backward() 。除了最后一次调用 backward 之外的所有调用都应该有 retain_graph=True 选项。
如果我使用 retain_graph=True,我会遇到 OOM 问题。如何更改此处的代码以避免 retain_graph=True?谢谢。