我是 Keras 新手,正在构建一个模型。我想在训练前面的层时冻结模型最后几层的权重。我尝试将横向模型的可训练属性设置为 False,但似乎不起作用。这是代码和模型摘要:
opt = optimizers.Adam(1e-3)
domain_layers = self._build_domain_regressor()
domain_layers.trainble = False
feature_extrator = self._build_common()
img_inputs = Input(shape=(160, 160, 3))
conv_out = feature_extrator(img_inputs)
domain_label = domain_layers(conv_out)
self.domain_regressor = Model(img_inputs, domain_label)
self.domain_regressor.compile(optimizer = opt, loss='binary_crossentropy', metrics=['accuracy'])
self.domain_regressor.summary()
如您所见,
model_1
是可训练的。但根据代码,它被设置为不可训练。
您可以简单地为图层属性分配一个布尔值
trainable
。
model.layers[n].trainable = False
您可以直观地看到哪一层是可训练的:
for l in model.layers:
print(l.name, l.trainable)
您也可以通过模型定义来传递它:
frozen_layer = Dense(32, trainable=False)
来自 Keras 文档:
“冻结”一个层意味着将其从训练中排除,即它的 权重永远不会更新。这在以下情况下很有用 微调模型,或使用固定嵌入进行文本输入。
您可以将可训练参数(布尔值)传递给层构造函数以 将图层设置为不可训练。 此外,您可以将图层的可训练属性设置为 True 或 实例化后为假。要使其生效,您需要 修改可训练属性后,在模型上调用compile()。
“trainble”一词有一个拼写错误(缺少“a”)。可悲的是,keras 没有警告我该模型没有“trainble”属性。问题可以结束了。
尽管原始问题的解决方案是一个拼写错误修复,但让我添加一些有关 keras 可训练的信息。
现代 Keras 包含以下工具来查看和操作可训练状态:
tf.keras.Layer._get_trainable_state()
函数 - 打印字典,其中键是模型组件,值是布尔值。请注意,tf.keras.Model
也是 tf.Keras.Layer
。tf.keras.Layer.trainable
属性 - 操纵各个层的可训练状态。所以典型的动作如下:
# Print current trainable map:
print(model._get_trainable_state())
# Set every layer to be non-trainable:
for k,v in model._get_trainable_state().items():
k.trainable = False
# Don't forget to re-compile the model
model.compile(...)
更改代码中的最后 3 行:
last_few_layers = 20 #number of the last few layers to freeze
self.domain_regressor = Model(img_inputs, domain_label)
for layer in model.layers[:-last_few_layers]:
layer.trainable = False
self.domain_regressor.compile(optimizer = opt, loss='binary_crossentropy', metrics=['accuracy'])
我知道使用 trainable = False,我可以冻结该层的所有权重。但我想为每一层的组件(kernel、recurrent_kernel 和bias)工作。我想将特定层的内核和 recurrent_kernel 冻结为可训练 = False 和可训练 = True。我该怎么做?我尝试了以下代码,但出现错误。谁能向我建议如何使用标准 Keras 使 kernel 和 recurrent_kernel 可训练 = False?
# transfer model layer
lstm_layer = modelTL.layers[0]
# kernel non-trainable
lstm_layer.cell.kernel.trainable = False
错误:
lstm_layer.cell.kernel.trainable = False
属性错误:无法设置属性