如何在自定义 keras 层中拥有未跟踪的权重?

问题描述 投票:0回答:1

我想创建一个自定义 keras 层(VQVAE 模型的代码本。)在训练时,我希望有一个

tf.Variable
来跟踪每个代码的使用情况,以便我可以重新启动未使用的代码。 所以我创建了我的 Codebook 图层,如下所示...

class Codebook(layers.Layer): 
     def __init__(self, num_codes, code_reset_limit = None, **kwargs): 
         super().__init__(**kwargs) 
         self.num_codes = num_codes 
         self.code_reset_limit = code_reset_limit 
         if self.code_reset_limit: 
             self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False) 
     def build(self, input_shape): 
         self.codes = self.add_weight(name = 'codes',  
                                      shape = (self.num_codes, input_shape[-1]), 
                                      initializer = 'random_uniform',  
                                      trainable = True) 
         super().build(input_shape) 
                                                                                                             

我遇到的问题是

Layer
类找到成员变量
self.code_counter
并将其添加到与图层一起保存的权重列表中。 它还期望在加载权重时出现
self.code_counter
,但当我在推理模式下运行时,情况并非如此。 我怎样才能让 keras 不跟踪我的层中的变量。 我不希望它持续存在或成为
layers.weights
的一部分。

python tensorflow keras
1个回答
0
投票

我的答案有点晚了,但我也遇到了同样的问题,并且遇到了没有答案的问题。现在,我找到了适用于 Keras 2 和 Keras 3 的答案,因此我在这里分享给遇到同样问题的其他人。

为了防止 TensorFlow 和 Keras 跟踪变量,需要将变量封装在 TensorFlow 和 Keras 在跟踪模块中不处理的类中。 Keras 3 自动跟踪的类列表为:

keras.Variable
list
dict
tuple
NamedTuple
(参见此处)。对于 keras 2,对象列表不太容易找到,但似乎包括
tf.Variable
(请参阅当前问题)、
dict
list

在我的上下文中适用于 keras.Variable 和 tf.Variable 的解决方案是创建封装变量的数据类。这是 TensorFlow 和 keras 2 的设置。

import tensorflow as tf
from dataclasses import dataclass

@dataclass
class Container:
    data: tf.Variable

然后像这样使用

 ...
 if self.code_reset_limit: 
     self.code_counter = Container(data=tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False) )
...
 

对于 Keras 3 这变成了

import keras
from dataclasses import dataclass

@dataclass
class Container:
    data: keras.Variable

© www.soinside.com 2019 - 2024. All rights reserved.