优化两个 4D 张量的像素值损失

问题描述 投票:0回答:1

我一直在尝试实现一个损失函数(使用tensorflow/keras)来根据我发现有用的特定论文来预测方向图。作者通过预测每个像素(在输出的每个通道上)的正弦和余弦值,然后使用以下函数获得距离测量来实现此目的

θ^(1+δ) = (arccos(cosα · cosβ + sinα · sinβ))^(1+δ)

为了完整起见,其梯度为 (1 + δ) · θ^δ,其中 δ=0.2

鉴于 alpha 是 y_true,beta 是 y_pred,并且这些张量具有形状(批量、高度、宽度、通道),我设计了一个使用嵌套 for 循环的实现,它可能有效,但尚未优化,我不确定 Keras 是否由于我对 ML 的经验很少,因此能够反向传播它。

我想知道是否有比当前代码(如下)更好的优化方法来实现这一点,因为我在任何地方都找不到它,所以我在这里写下我的第一个问题。 min 和 max 函数用于将值剪辑到区间 [10^-6, 1-10^-6]

def angle_distance_loss(y_true,y_pred):
    """
    Lproposed = (arccos(cosα · cosβ + sinα · sinβ))^(1+δ)
    """

    batch, height, width, channels = y_true.shape
    cos_c = 0
    sin_c = 1
    for batch_i in range(batch):
        for h_j in range(height):
            for w_k in range(width):
                yt_cos = y_true[batch_i][h_j][w_k][cos_c]
                yt_sin = y_true[batch_i][h_j][w_k][sin_c]

                yp_cos = y_pred[batch_i][h_j][w_k][cos_c]
                yp_sin = y_pred[batch_i][h_j][w_k][sin_c]

                l += math.acos(max(10**-6, min((yt_cos * yp_cos + yt_sin * yp_sin), 1-10**-6))) ** (1.2)
                
    return l / batch * width * height

欢迎对此提出任何意见:)

numpy tensorflow keras optimization tensor
1个回答
0
投票

这看起来是一个好的开始。但正如您所提到的,嵌套循环和您使用的函数并不是真正有效。相反,您应该使用 TensorFlow 的内置向量运算:

  • tf.clip_by_value
    而不是
    min
    /
    max
  • tf.acos
    而不是
    math.acos
  • tf.reduce_sum
    tf.reduce_mean
    而不是总结和 除以尺寸

这些操作是完全可微的,因此反向传播应该没有问题。

这是一个更优化的实现:

def angle_distance_loss(y_true: tf.Tensor, y_pred: tf.Tensor, delta: float = 0.2, epsilon: float = 1e-6) -> tf.Tensor:
    """
    Lproposed = (arccos(cosα · cosβ + sinα · sinβ))^(1+δ)
    """
    dot_product = tf.clip_by_value(
        tf.reduce_sum(y_true * y_pred, axis=-1), 
        clip_value_min=epsilon, 
        clip_value_max=1-epsilon
    )  # Compute the dot product of the cosine and sine components
    angle_distance = tf.acos(dot_product) ** (1 + delta)  # Compute the angle distance
    return tf.reduce_mean(angle_distance)  # Return the mean loss over the batch

我没有对此进行大量测试,因此在依赖它之前,您应该进行验证(例如,通过同时计算两个实现并通过调试或打印检查相等性)。

© www.soinside.com 2019 - 2024. All rights reserved.