如何仅冻结一些带有绑定嵌入的嵌入索引?

问题描述 投票:0回答:1

我在 Is it possible to freeze only certain embedding Weights in the embedding layer in pytorch? 中找到了一种仅冻结嵌入层的某些索引的好方法。 然而,虽然将其包含在 BERT 模型中,但我找不到绑定这些嵌入的方法。有人可以帮助我吗?

HF Transformers 使用解码器矩阵中嵌入层权重的副本。但是,如果我的嵌入层是 nn.Module 而不是 nn.Embedding,则无法执行此操作。用自定义模块替换解码器层并没有执行权重转置。

python pytorch nlp bert-language-model word-embedding
1个回答
0
投票

不可能直接这样做:Pytorch 处理每个张量的梯度计算,因此整个张量权重要么计算梯度,要么不计算梯度。你基本上有 3 个解决方案:

  1. 制作一个自定义

    Embedding
    模块,将冻结和非冻结权重存储在两个不同的矩阵中,如链接答案中的建议。这太复杂了,因为您还必须修改解码器层才能使用更改后的数据表示。

  2. 正如 dankal444 的评论所建议的那样,保留要冻结的嵌入的副本,并在每次优化之后将它们重置为副本的值。

  3. 正如 Karl 的评论所建议的那样,像平常一样进行向后传递,但然后将要冻结的嵌入的梯度归零。您可以为此创建一个 backwards hook,如下所示:


#for example
embeddings_to_keep = torch.tensor([1, 4, 5], dtype=torch.int64, device=device)

def set_grads_to_zero_hook(grad):
   grad = grad.clone()
   grad[embeddings_to_keep] = 0.
   return grad

embedding_module.weight.register_hook(set_grads_to_zero_hook)

© www.soinside.com 2019 - 2024. All rights reserved.