在Keras输出自定义损失

问题描述 投票:0回答:1

我知道在Keras中有很多关于处理自定义丢失功能的问题,但即使经过3小时的谷歌搜索,我也无法回答这个问题。

这是我的问题的一个非常简单的例子。我意识到这个例子毫无意义但是为了简单起见我提供它,我显然需要实现更复杂的东西。

from keras.backend import binary_crossentropy
from keras.backend import mean
def custom_loss(y_true, y_pred):

    zeros = tf.zeros_like(y_true)
    index_of_zeros = tf.where(tf.equal(zeros, y_true))
    ones = tf.ones_like(y_true)
    index_of_ones = tf.where(tf.equal(ones, y_true))

    zero = tf.gather(y_pred, index_of_zeros)
    one = tf.gather(y_pred, index_of_ones)

    loss_0 = binary_crossentropy(tf.zeros_like(zero), zero)
    loss_1 = binary_crossentropy(tf.ones_like(one), one)

    return mean(tf.concat([loss_0, loss_1], axis=0))

我不明白为什么在两类数据集上使用上述损失函数训练网络不会产生与内置binary-crossentropy损失函数训练相同的结果。谢谢!

编辑:我根据下面的评论编辑了代码片段以包含均值。然而,我仍然得到相同的行为。

python tensorflow keras loss-function cross-entropy
1个回答
0
投票

我终于弄明白了。当形状“未知”时,tf.where函数的表现非常不同。要修复上面的代码段,只需在声明函数后立即插入以下行:

y_pred = tf.reshape(y_pred, [-1])
y_true = tf.reshape(y_true, [-1])
© www.soinside.com 2019 - 2024. All rights reserved.