我正在尝试彩票假说的PyTorch实现。
为此,我想冻结模型中的权重为零。以下是实现它的正确方法吗?
for name, p in model.named_parameters():
if 'weight' in name:
tensor = p.data.cpu().numpy()
grad_tensor = p.grad.data.cpu().numpy()
grad_tensor = np.where(tensor == 0, 0, grad_tensor)
p.grad.data = torch.from_numpy(grad_tensor).to(device)
[如果您在tensor.backward
之后且在optimizer.step()
之前进行了操作,那么您看起来将可以使用。那似乎有点令人费解。另外,如果您的权重是浮点值,然后是comparing them to exactly zero is potentially a bad idea,我们可以引入epsilon来解决这个问题。也就是说,我认为以下是更简洁的解决方案
# locate zero-value weights before training loop
EPS = 1e-6
locked_masks = {n: torch.abs(w) < EPS for n, w in n.named_parameters() if n.endswith('weight')}
...
for ... #training loop
...
# update optimizer
optimizer.zero_grad()
loss.backward()
# zero the gradients of interest
for n, w in m.named_parameters():
if w.grad is not None and n in locked_masks:
w.grad[locked_masks[n]] = 0
optimizer.step()