冻结和解冻 TFDistilBertModel 中的某些层

问题描述 投票:0回答:1

我正在尝试在我的神经网络模型中实现 TFBertModel 或 TFDistilBertModel (还有其他层,例如密集层和批规范层)。我的理解是,两个 BERT 模型中都有隐藏层,我想解冻最后两层以进行微调。

但是,在我运行以下代码之后:

model = TFDistilBertModel.from_pretrained('distilbert-base-uncased', config=config)
print(len(model.layers))

它表明只有一层,我认为它指的是整个 TFDistilBertModel。有没有办法让我解冻和冻结一些隐藏层以微调模型?

python tensorflow machine-learning deep-learning huggingface-transformers
1个回答
0
投票

首先,我们需要访问图层/参数及其名称,以便我们知道要冻结/解冻的内容:

from transformers import AutoModel

model = AutoModel.from_pretrained('distilbert-base-uncased')

for m in model.modules():
  for name, params in m.named_parameters():
    print(name, params.requires_grad)

[出]:

embeddings.word_embeddings.weight True
embeddings.position_embeddings.weight True
embeddings.LayerNorm.weight True
embeddings.LayerNorm.bias True
transformer.layer.0.attention.q_lin.weight True
transformer.layer.0.attention.q_lin.bias True
transformer.layer.0.attention.k_lin.weight True
transformer.layer.0.attention.k_lin.bias True
transformer.layer.0.attention.v_lin.weight True
transformer.layer.0.attention.v_lin.bias True
transformer.layer.0.attention.out_lin.weight True
transformer.layer.0.attention.out_lin.bias True
transformer.layer.0.sa_layer_norm.weight True
transformer.layer.0.sa_layer_norm.bias True
transformer.layer.0.ffn.lin1.weight True
transformer.layer.0.ffn.lin1.bias True
transformer.layer.0.ffn.lin2.weight True
transformer.layer.0.ffn.lin2.bias True
transformer.layer.0.output_layer_norm.weight True
transformer.layer.0.output_layer_norm.bias True
transformer.layer.1.attention.q_lin.weight True
transformer.layer.1.attention.q_lin.bias True
transformer.layer.1.attention.k_lin.weight True
transformer.layer.1.attention.k_lin.bias True
transformer.layer.1.attention.v_lin.weight True
transformer.layer.1.attention.v_lin.bias True
transformer.layer.1.attention.out_lin.weight True
transformer.layer.1.attention.out_lin.bias True
transformer.layer.1.sa_layer_norm.weight True
transformer.layer.1.sa_layer_norm.bias True
transformer.layer.1.ffn.lin1.weight True
transformer.layer.1.ffn.lin1.bias True
transformer.layer.1.ffn.lin2.weight True
transformer.layer.1.ffn.lin2.bias True
transformer.layer.1.output_layer_norm.weight True
transformer.layer.1.output_layer_norm.bias True
transformer.layer.2.attention.q_lin.weight True
transformer.layer.2.attention.q_lin.bias True
transformer.layer.2.attention.k_lin.weight True
transformer.layer.2.attention.k_lin.bias True
transformer.layer.2.attention.v_lin.weight True
transformer.layer.2.attention.v_lin.bias True
transformer.layer.2.attention.out_lin.weight True
transformer.layer.2.attention.out_lin.bias True
transformer.layer.2.sa_layer_norm.weight True
transformer.layer.2.sa_layer_norm.bias True
transformer.layer.2.ffn.lin1.weight True
transformer.layer.2.ffn.lin1.bias True
transformer.layer.2.ffn.lin2.weight True
transformer.layer.2.ffn.lin2.bias True
transformer.layer.2.output_layer_norm.weight True
transformer.layer.2.output_layer_norm.bias True
transformer.layer.3.attention.q_lin.weight True
transformer.layer.3.attention.q_lin.bias True
transformer.layer.3.attention.k_lin.weight True
transformer.layer.3.attention.k_lin.bias True
transformer.layer.3.attention.v_lin.weight True
transformer.layer.3.attention.v_lin.bias True
transformer.layer.3.attention.out_lin.weight True
transformer.layer.3.attention.out_lin.bias True
transformer.layer.3.sa_layer_norm.weight True
transformer.layer.3.sa_layer_norm.bias True
transformer.layer.3.ffn.lin1.weight True
transformer.layer.3.ffn.lin1.bias True
transformer.layer.3.ffn.lin2.weight True
transformer.layer.3.ffn.lin2.bias True
transformer.layer.3.output_layer_norm.weight True
transformer.layer.3.output_layer_norm.bias True
transformer.layer.4.attention.q_lin.weight True
transformer.layer.4.attention.q_lin.bias True
transformer.layer.4.attention.k_lin.weight True
transformer.layer.4.attention.k_lin.bias True
transformer.layer.4.attention.v_lin.weight True
transformer.layer.4.attention.v_lin.bias True
transformer.layer.4.attention.out_lin.weight True
transformer.layer.4.attention.out_lin.bias True
transformer.layer.4.sa_layer_norm.weight True
transformer.layer.4.sa_layer_norm.bias True
transformer.layer.4.ffn.lin1.weight True
transformer.layer.4.ffn.lin1.bias True
transformer.layer.4.ffn.lin2.weight True
transformer.layer.4.ffn.lin2.bias True
transformer.layer.4.output_layer_norm.weight True
transformer.layer.4.output_layer_norm.bias True
transformer.layer.5.attention.q_lin.weight True
transformer.layer.5.attention.q_lin.bias True
transformer.layer.5.attention.k_lin.weight True
transformer.layer.5.attention.k_lin.bias True
transformer.layer.5.attention.v_lin.weight True
transformer.layer.5.attention.v_lin.bias True
transformer.layer.5.attention.out_lin.weight True
transformer.layer.5.attention.out_lin.bias True
transformer.layer.5.sa_layer_norm.weight True
transformer.layer.5.sa_layer_norm.bias True
transformer.layer.5.ffn.lin1.weight True
transformer.layer.5.ffn.lin1.bias True
transformer.layer.5.ffn.lin2.weight True
transformer.layer.5.ffn.lin2.bias True
transformer.layer.5.output_layer_norm.weight True
transformer.layer.5.output_layer_norm.bias True

然后,如果您有一些想要冻结的图层列表,例如


from transformers import AutoModel

model = AutoModel.from_pretrained('distilbert-base-uncased')

# Prefix of layer names to freeze.
tofreeze= set(["transformer.layer.4", "transformer.layer.5"])


for m in model.modules():
  for name, params in m.named_parameters():
    if any(prefix for prefix in tofreeze and name.startswith(prefix)):
        params.requires_grad = False
    print(name, params.requires_grad)


© www.soinside.com 2019 - 2024. All rights reserved.