在张量流中实现KL预热:回调中的tf.keras.backend.variable在一段时间内是不稳定的

问题描述 投票:0回答:1

我正在尝试在TensorFlow中实现带有KL预热的Variational AutoEncoder的变体(论文here)。这个想法是,在训练开始时,应在指定的时期内线性增加损失的KL项。我尝试的方式是使用一个回调,该回调在每次新的时期开始时都在K.variable中设置一个值,因为当前的时期数超过了预热的期望范围(例如,如果预热设置为持续10次在第6阶段,损失中的KL项应乘以0.6)。我还在KL中添加了add_metric()(作为图层的子类实现),以在训练期间控制kl_rate。问题是变量的值不稳定!它在每个新纪元开始接近所需的值,但是在每次迭代时它都会缓慢衰减,从而使过程不太容易控制。你知道我在做什么错吗?我也不确定这是回调本身(以及随后的实际使用值)还是所报告指标的问题。谢谢!

进口:

import tensorflow.keras.backend as K

回调(self.kl_warmup是模型类的参数,该参数设置为整​​数,对应于应增加kl速率的时期数):

kl_beta = K.variable(1.0, name="kl_beta")
if self.kl_warmup:

    kl_warmup_callback = LambdaCallback(
        on_epoch_begin=lambda epoch, logs: K.set_value(
            kl_beta, K.min([epoch / self.kl_warmup, 1])
        )
    )

z_mean, z_log_sigma = KLDivergenceLayer(beta=kl_beta)([z_mean, z_log_sigma])

KL层:

class KLDivergenceLayer(Layer):

""" Identity transform layer that adds KL divergence
to the final model loss.
"""

def __init__(self, beta=1.0, *args, **kwargs):
    self.is_placeholder = True
    self.beta = beta
    super(KLDivergenceLayer, self).__init__(*args, **kwargs)

def get_config(self):
    config = super().get_config().copy()
    config.update({"beta": self.beta})
    return config

def call(self, inputs, **kwargs):
    mu, log_var = inputs
    kL_batch = -0.5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)

    self.add_loss(self.beta * K.mean(kL_batch), inputs=inputs)
    self.add_metric(self.beta, aggregation="mean", name="kl_rate")

    return inputs

模型实例(整个模型构建在一个类中,该类返回编码器,生成器,全值和kl_rate回调):

encoder, generator, vae, kl_warmup_callback = SEQ_2_SEQ_VAE(pttest.shape,
                                                               loss='ELBO',
                                                               kl_warmup_epochs=10).build()

fit()调用:

history = vae.fit(x=pttrain, y=pttrain, epochs=100, batch_size=512, verbose=1,
              validation_data=(pttest, pttest),
              callbacks=[tensorboard_callback, kl_warmup_callback])

训练过程的摘要(请注意,kl_rate应该为零,并且已关闭):enter image description here

张量板上历时kl_rate的屏幕截图(跨度设置为10历时;在10历时后,它应该达到1,但收敛到约0.9)]

enter image description here

python keras deep-learning callback tensorflow2.0
1个回答
0
投票

[我经过更多研究后最终发现了它。

kl_beta._trainable = False

没有窍门:)谢谢!

© www.soinside.com 2019 - 2024. All rights reserved.