我正在尝试在我的神经网络模型中实现 TFBertModel 或 TFDistilBertModel (还有其他层,例如密集层和批规范层)。我的理解是,两个 BERT 模型中都有隐藏层,我想解冻最后两层以进行微调。
但是,在我运行以下代码之后:
model = TFDistilBertModel.from_pretrained('distilbert-base-uncased', config=config)
print(len(model.layers))
它表明只有一层,我认为它指的是整个 TFDistilBertModel。有没有办法让我解冻和冻结一些隐藏层以微调模型?
首先,我们需要访问图层/参数及其名称,以便我们知道要冻结/解冻的内容:
from transformers import AutoModel
model = AutoModel.from_pretrained('distilbert-base-uncased')
for m in model.modules():
for name, params in m.named_parameters():
print(name, params.requires_grad)
[出]:
embeddings.word_embeddings.weight True
embeddings.position_embeddings.weight True
embeddings.LayerNorm.weight True
embeddings.LayerNorm.bias True
transformer.layer.0.attention.q_lin.weight True
transformer.layer.0.attention.q_lin.bias True
transformer.layer.0.attention.k_lin.weight True
transformer.layer.0.attention.k_lin.bias True
transformer.layer.0.attention.v_lin.weight True
transformer.layer.0.attention.v_lin.bias True
transformer.layer.0.attention.out_lin.weight True
transformer.layer.0.attention.out_lin.bias True
transformer.layer.0.sa_layer_norm.weight True
transformer.layer.0.sa_layer_norm.bias True
transformer.layer.0.ffn.lin1.weight True
transformer.layer.0.ffn.lin1.bias True
transformer.layer.0.ffn.lin2.weight True
transformer.layer.0.ffn.lin2.bias True
transformer.layer.0.output_layer_norm.weight True
transformer.layer.0.output_layer_norm.bias True
transformer.layer.1.attention.q_lin.weight True
transformer.layer.1.attention.q_lin.bias True
transformer.layer.1.attention.k_lin.weight True
transformer.layer.1.attention.k_lin.bias True
transformer.layer.1.attention.v_lin.weight True
transformer.layer.1.attention.v_lin.bias True
transformer.layer.1.attention.out_lin.weight True
transformer.layer.1.attention.out_lin.bias True
transformer.layer.1.sa_layer_norm.weight True
transformer.layer.1.sa_layer_norm.bias True
transformer.layer.1.ffn.lin1.weight True
transformer.layer.1.ffn.lin1.bias True
transformer.layer.1.ffn.lin2.weight True
transformer.layer.1.ffn.lin2.bias True
transformer.layer.1.output_layer_norm.weight True
transformer.layer.1.output_layer_norm.bias True
transformer.layer.2.attention.q_lin.weight True
transformer.layer.2.attention.q_lin.bias True
transformer.layer.2.attention.k_lin.weight True
transformer.layer.2.attention.k_lin.bias True
transformer.layer.2.attention.v_lin.weight True
transformer.layer.2.attention.v_lin.bias True
transformer.layer.2.attention.out_lin.weight True
transformer.layer.2.attention.out_lin.bias True
transformer.layer.2.sa_layer_norm.weight True
transformer.layer.2.sa_layer_norm.bias True
transformer.layer.2.ffn.lin1.weight True
transformer.layer.2.ffn.lin1.bias True
transformer.layer.2.ffn.lin2.weight True
transformer.layer.2.ffn.lin2.bias True
transformer.layer.2.output_layer_norm.weight True
transformer.layer.2.output_layer_norm.bias True
transformer.layer.3.attention.q_lin.weight True
transformer.layer.3.attention.q_lin.bias True
transformer.layer.3.attention.k_lin.weight True
transformer.layer.3.attention.k_lin.bias True
transformer.layer.3.attention.v_lin.weight True
transformer.layer.3.attention.v_lin.bias True
transformer.layer.3.attention.out_lin.weight True
transformer.layer.3.attention.out_lin.bias True
transformer.layer.3.sa_layer_norm.weight True
transformer.layer.3.sa_layer_norm.bias True
transformer.layer.3.ffn.lin1.weight True
transformer.layer.3.ffn.lin1.bias True
transformer.layer.3.ffn.lin2.weight True
transformer.layer.3.ffn.lin2.bias True
transformer.layer.3.output_layer_norm.weight True
transformer.layer.3.output_layer_norm.bias True
transformer.layer.4.attention.q_lin.weight True
transformer.layer.4.attention.q_lin.bias True
transformer.layer.4.attention.k_lin.weight True
transformer.layer.4.attention.k_lin.bias True
transformer.layer.4.attention.v_lin.weight True
transformer.layer.4.attention.v_lin.bias True
transformer.layer.4.attention.out_lin.weight True
transformer.layer.4.attention.out_lin.bias True
transformer.layer.4.sa_layer_norm.weight True
transformer.layer.4.sa_layer_norm.bias True
transformer.layer.4.ffn.lin1.weight True
transformer.layer.4.ffn.lin1.bias True
transformer.layer.4.ffn.lin2.weight True
transformer.layer.4.ffn.lin2.bias True
transformer.layer.4.output_layer_norm.weight True
transformer.layer.4.output_layer_norm.bias True
transformer.layer.5.attention.q_lin.weight True
transformer.layer.5.attention.q_lin.bias True
transformer.layer.5.attention.k_lin.weight True
transformer.layer.5.attention.k_lin.bias True
transformer.layer.5.attention.v_lin.weight True
transformer.layer.5.attention.v_lin.bias True
transformer.layer.5.attention.out_lin.weight True
transformer.layer.5.attention.out_lin.bias True
transformer.layer.5.sa_layer_norm.weight True
transformer.layer.5.sa_layer_norm.bias True
transformer.layer.5.ffn.lin1.weight True
transformer.layer.5.ffn.lin1.bias True
transformer.layer.5.ffn.lin2.weight True
transformer.layer.5.ffn.lin2.bias True
transformer.layer.5.output_layer_norm.weight True
transformer.layer.5.output_layer_norm.bias True
然后,如果您有一些想要冻结的图层列表,例如
from transformers import AutoModel
model = AutoModel.from_pretrained('distilbert-base-uncased')
# Prefix of layer names to freeze.
tofreeze= set(["transformer.layer.4", "transformer.layer.5"])
for m in model.modules():
for name, params in m.named_parameters():
if any(prefix for prefix in tofreeze and name.startswith(prefix)):
params.requires_grad = False
print(name, params.requires_grad)