我实际上想估计一个以高斯混合作为基本分布的归一化流,所以我有点被火炬困住了。但是,您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误。我的代码如下:
import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as datasets
import torch
from torch import nn
from torch import optim
import torch.distributions as D
num_layers = 8
weights = torch.ones(8,requires_grad=True).to(device)
means = torch.tensor(np.random.randn(8,2),requires_grad=True).to(device)#torch.randn(8,2,requires_grad=True).to(device)
stdevs = torch.tensor(np.abs(np.random.randn(8,2)),requires_grad=True).to(device)
mix = D.Categorical(weights)
comp = D.Independent(D.Normal(means,stdevs), 1)
gmm = D.MixtureSameFamily(mix, comp)
num_iter = 10001#30001
num_iter2 = 200001
loss_max1 = 100
for i in range(num_iter):
x = torch.randn(5000,2)#this can be an arbitrary x samples
loss2 = -gmm.log_prob(x).mean()#-densityflow.log_prob(inputs=x).mean()
optimizer1.zero_grad()
loss2.backward()
optimizer1.step()
我得到的错误是:
0
8.089411823514835
Traceback (most recent call last):
File "/home/cameron/AnacondaProjects/gmm.py", line 183, in <module>
loss2.backward()
File "/home/cameron/anaconda3/envs/torch/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/cameron/anaconda3/envs/torch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
如您所见,模型运行了 1 次迭代。
您的代码中存在排序问题,因为您在训练循环之外创建高斯混合模型,那么在计算损失时,高斯混合模型将尝试使用您在定义模型时设置的参数的初始值,但是
optimizer1.step()
已经修改了该值,所以即使你设置了loss2.backward(retain_graph=True)
,仍然会出现错误:RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
此问题的解决方案只需在更新参数时创建新的高斯混合模型,示例代码按预期运行:
import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as datasets
import torch
from torch import nn
from torch import optim
import torch.distributions as D
num_layers = 8
weights = torch.ones(8,requires_grad=True)
means = torch.tensor(np.random.randn(8,2),requires_grad=True)
stdevs = torch.tensor(np.abs(np.random.randn(8,2)),requires_grad=True)
parameters = [weights, means, stdevs]
optimizer1 = optim.SGD(parameters, lr=0.001, momentum=0.9)
num_iter = 10001
for i in range(num_iter):
mix = D.Categorical(weights)
comp = D.Independent(D.Normal(means,stdevs), 1)
gmm = D.MixtureSameFamily(mix, comp)
optimizer1.zero_grad()
x = torch.randn(5000,2)#this can be an arbitrary x samples
loss2 = -gmm.log_prob(x).mean()#-densityflow.log_prob(inputs=x).mean()
loss2.backward()
optimizer1.step()
print(i, loss2)
虽然可行,但我发现其他答案并不令人满意,因为良好的实现不应要求在每一步都进行冗余的重新实例化。
我在这里更深入地解释这个问题。不幸的是,这些类并未设置为开箱即用优化。然而,通过直接分配 mix.logits.requires_grad = True
并将
mix.logits
传递给优化器,我们可以让类按预期进行优化,而无需连续重新实例化。这是我的参考实现:
https://github.com/kylesayrs/GMMPytorch
class GmmFull(torch.nn.Module):
def __init__(
self,
num_components: int,
num_dims: int,
radius: float = 1.0,
):
super().__init__()
self.num_components = num_components
self.num_dims = num_dims
self.radius = radius
# learnable parameters (excluding self.mixture.logits)
self.mus = torch.nn.Parameter(torch.rand(num_components, num_dims).uniform_(-radius, radius))
self.scale_tril = torch.nn.Parameter(make_random_scale_trils(num_components, num_dims))
# mixture and components
self.mixture = Categorical(logits=torch.zeros(num_components, ))
self.components = MultivariateNormal(self.mus, scale_tril=self.scale_tril)
self.mixture_model = MixtureSameFamily(self.mixture, self.components)
# workaround, see https://github.com/pytorch/pytorch/issues/114417
self.mixture.logits.requires_grad = True
def component_parameters(self) -> Iterator[torch.nn.Parameter]:
return iter([self.mus, self.scale_tril])
def mixture_parameters(self) -> Iterator[torch.nn.Parameter]:
return iter([self.mixture.logits])
小心奇点问题(请参阅自述文件中的奇点缓解部分)。对于大多数数据集来说,它们几乎肯定会出现,因此必须通常通过重置来处理。
def forward(self, x: torch.Tensor) -> torch.Tensor:
nll_loss = -1 * self.mixture_model.log_prob(x).mean()
# detect singularity collapse and reset
if nll_loss.isnan():
with torch.no_grad():
self.mixture.logits.uniform_(0, 1)
self.mus.data.uniform_(-self.radius, self.radius)
self.scale_tril.data = make_random_scale_trils(self.num_components, self.num_dims)
nll_loss = -1 * self.mixture_model.log_prob(x).mean()
return nll_loss