我正在尝试为我的简单 3 层前馈模型实现稀疏连接的权重矩阵。为此,我为每个层实现了一个具有一定百分比的零的掩码,其想法是,我希望在每个优化器步骤之后将同一组权重归零,以便我的层不完全连接。但我在这方面遇到了麻烦,因为当我对掩码与权重矩阵进行元素相乘时,权重在后续的向后传递中停止变化。为了查看是否是我的遮罩导致了问题,我只是将权重矩阵与标量 1.0 相乘,这又重现了问题。这里可能发生什么?我检查了一下,梯度仍然被计算出来。只是损失不再下降,权重也没有改变。进行这种乘法是否会以某种方式将权重与图表断开?
我的型号:
class TSP(nn.Module):
def __init__(self, input_size, hidden_size):
super(TSP, self).__init__()
self.sc1 = nn.Linear(input_size, hidden_size)
self.sc2 = nn.Linear(hidden_size, input_size)
torch.nn.init.normal_(self.sc1.weight, mean=0, std=0.1)
torch.nn.init.normal_(self.sc2.weight, mean=0, std=0.1)
def forward(self, x):
x = torch.relu(self.sc1(x))
x = (self.sc2(x))
return x
def predict_hidden(self, x):
x = torch.relu(self.sc1(x))
return x
要重现此问题,所需要做的就是以下操作,并且权重停止更新:
model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
model.sc2.weight = nn.Parameter(1. * model.sc2.weight)
当你跑步时
model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
model.sc2.weight = nn.Parameter(1. * model.sc2.weight)
你不是“乘以标量”。您正在创建一个全新的对象 (
nn.Parameter(1. * model.sc1.weight)
) 并将其分配给 .weight
属性。
我假设您正在使用标准 pytorch 优化器更新模型,例如:
model = TSP(...)
opt = torch.optim.SGD(model.parameters(), lr=1e-3)
当您运行
model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
时,您会在 model.sc1.weight
中创建一个全新的对象,但优化器仍然引用旧对象。
您可以按如下方式验证这一点:
# data pointer of weight
model.sc1.weight.data_ptr()
> 124805056
# data pointer of weight in the optimizer
opt.param_groups[0]['params'][0].data_ptr()
> 124805056
# now create new weight object
model.sc1.weight = nn.Parameter(1. * model.sc1.weight)
# data pointer of model weight has changed
model.sc1.weight.data_ptr()
> 139582720
# data pointer of optimizer has not
opt.param_groups[0]['params'][0].data_ptr()
> 124805056
为了避免这种情况,请更新对象而不是创建新对象
# data pointer of weight
model.sc1.weight.data_ptr()
> 124805056
# data pointer of weight in the optimizer
opt.param_groups[0]['params'][0].data_ptr()
> 124805056
# update data of weight tensor with in-place operation
model.sc1.weight.data.mul_(2.)
# weight and optimizer still have same data pointer
model.sc1.weight.data_ptr()
> 124805056
opt.param_groups[0]['params'][0].data_ptr()
> 124805056