我的目标是建立一个生成对抗网络,该网络生成分类变量的真实外观序列,类似于[1]。为了使用生成器生成分类序列,我需要使用Gumbel_Softmax激活来确保反向传播仍然有效。我只能在tfp.distributions.RelaxedOneHotCategorical中找到无法在Tensorflow 2.1中找到预定义的Gumbel_softmax激活函数,该函数应该可以解决我的问题。
在我的示例中,我想生成一个二进制变量序列。您能给我一个如何在tensorflow功能API中实现此功能的代码示例。
也许您可以从当前代码中掌握我的目标:
generator():
inputs = Input(latent_dim,)
x = Dense(t_steps* no_states, activation='relu')(inputs)
x = Reshape((t_steps, no_states))(x)
x = tfpl.RelaxedOneHotCategorical(temperature=t, logits=no_states, Batch_shape=t_steps)
outputs=x
noise = Input(shape=(latent_dim,))
inp = model(noise)
return Model(noise, inp)
[1] GANS用于Kusner等人的具有Gumbel-softmax分布的离散元素序列。 2016
y * tau
而不是y / tau
。https://github.com/gugarosa/nalp/blob/master/nalp/models/layers/gumbel_softmax.py