如何将keras中的参数设置为不可训练?

问题描述 投票:0回答:5

我是 Keras 新手,正在构建一个模型。我想在训练前面的层时冻结模型最后几层的权重。我尝试将横向模型的可训练属性设置为 False,但似乎不起作用。这是代码和模型摘要:

opt = optimizers.Adam(1e-3)
domain_layers = self._build_domain_regressor()
domain_layers.trainble = False
feature_extrator = self._build_common()
img_inputs = Input(shape=(160, 160, 3))
conv_out = feature_extrator(img_inputs)
domain_label = domain_layers(conv_out)
self.domain_regressor = Model(img_inputs, domain_label)
self.domain_regressor.compile(optimizer = opt, loss='binary_crossentropy', metrics=['accuracy'])
self.domain_regressor.summary()

模型总结:model summary.

如您所见,

model_1
是可训练的。但根据代码,它被设置为不可训练。

python keras deep-learning
5个回答
40
投票

您可以简单地为图层属性分配一个布尔值

trainable

model.layers[n].trainable = False

您可以直观地看到哪一层是可训练的:

for l in model.layers:
    print(l.name, l.trainable)

您也可以通过模型定义来传递它:

frozen_layer = Dense(32, trainable=False)

来自 Keras 文档

“冻结”一个层意味着将其从训练中排除,即它的 权重永远不会更新。这在以下情况下很有用 微调模型,或使用固定嵌入进行文本输入。
您可以将可训练参数(布尔值)传递给层构造函数以 将图层设置为不可训练。 此外,您可以将图层的可训练属性设置为 True 或 实例化后为假。要使其生效,您需要 修改可训练属性后,在模型上调用compile()。


12
投票

“trainble”一词有一个拼写错误(缺少“a”)。可悲的是,keras 没有警告我该模型没有“trainble”属性。问题可以结束了。


9
投票

尽管原始问题的解决方案是一个拼写错误修复,但让我添加一些有关 keras 可训练的信息。

现代 Keras 包含以下工具来查看和操作可训练状态:

  • tf.keras.Layer._get_trainable_state()
    函数 - 打印字典,其中键是模型组件,值是布尔值。请注意,
    tf.keras.Model
    也是
    tf.Keras.Layer
  • tf.keras.Layer.trainable
    属性 - 操纵各个层的可训练状态。

所以典型的动作如下:

# Print current trainable map:
print(model._get_trainable_state())

# Set every layer to be non-trainable:
for k,v in model._get_trainable_state().items():
    k.trainable = False

# Don't forget to re-compile the model
model.compile(...)

1
投票

更改代码中的最后 3 行:

last_few_layers = 20 #number of the last few layers to freeze
self.domain_regressor = Model(img_inputs, domain_label)
for layer in model.layers[:-last_few_layers]:
    layer.trainable = False
self.domain_regressor.compile(optimizer = opt, loss='binary_crossentropy', metrics=['accuracy'])

0
投票

我知道使用 trainable = False,我可以冻结该层的所有权重。但我想为每一层的组件(kernel、recurrent_kernel 和bias)工作。我想将特定层的内核和 recurrent_kernel 冻结为可训练 = False 和可训练 = True。我该怎么做?我尝试了以下代码,但出现错误。谁能向我建议如何使用标准 Keras 使 kernel 和 recurrent_kernel 可训练 = False?

# transfer model layer
lstm_layer = modelTL.layers[0]  

# kernel non-trainable
lstm_layer.cell.kernel.trainable = False

错误:
lstm_layer.cell.kernel.trainable = False 属性错误:无法设置属性

© www.soinside.com 2019 - 2024. All rights reserved.