我想创建一个自定义 keras 层(VQVAE 模型的代码本。)在训练时,我希望有一个
tf.Variable
来跟踪每个代码的使用情况,以便我可以重新启动未使用的代码。 所以我创建了我的 Codebook 图层,如下所示...
class Codebook(layers.Layer):
def __init__(self, num_codes, code_reset_limit = None, **kwargs):
super().__init__(**kwargs)
self.num_codes = num_codes
self.code_reset_limit = code_reset_limit
if self.code_reset_limit:
self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False)
def build(self, input_shape):
self.codes = self.add_weight(name = 'codes',
shape = (self.num_codes, input_shape[-1]),
initializer = 'random_uniform',
trainable = True)
super().build(input_shape)
我遇到的问题是
Layer
类找到成员变量 self.code_counter
并将其添加到与图层一起保存的权重列表中。 它还期望在加载权重时出现 self.code_counter
,但当我在推理模式下运行时,情况并非如此。 我怎样才能让 keras 不跟踪我的层中的变量。 我不希望它持续存在或成为 layers.weights
的一部分。
我的答案有点晚了,但我也遇到了同样的问题,并且遇到了没有答案的问题。现在,我找到了适用于 Keras 2 和 Keras 3 的答案,因此我在这里分享给遇到同样问题的其他人。
为了防止 TensorFlow 和 Keras 跟踪变量,需要将变量封装在 TensorFlow 和 Keras 在跟踪模块中不处理的类中。 Keras 3 自动跟踪的类列表为:
keras.Variable
、list
、dict
、tuple
和 NamedTuple
(参见此处)。对于 keras 2,对象列表不太容易找到,但似乎包括 tf.Variable
(请参阅当前问题)、dict
和 list
。
在我的上下文中适用于 keras.Variable 和 tf.Variable 的解决方案是创建封装变量的数据类。这是 TensorFlow 和 keras 2 的设置。
import tensorflow as tf
from dataclasses import dataclass
@dataclass
class Container:
data: tf.Variable
然后像这样使用
...
if self.code_reset_limit:
self.code_counter = Container(data=tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False) )
...
对于 Keras 3 这变成了
import keras
from dataclasses import dataclass
@dataclass
class Container:
data: keras.Variable