Keras 3 自定义损失函数来掩盖 NaN

问题描述 投票:0回答:1

我正在尝试在 Keras 3 上构建自定义损失函数,该函数将在 jax 或 torch 后端中使用。 我想从 y_pred 和 y_true 中屏蔽 y_true 为特定值的所有索引。将剩余值传递给给定的 loss_function。

但是每次我尝试用 jax 后端或 torch 来拟合我的损失函数的模型时,它都会崩溃,几乎说它无法获取索引或进行掩蔽。因为为此它需要访问张量上的值。

我使用两种方式:



import keras
from keras import Loss, ops



class NanValueLossA(Loss):
    def __init__(
        self,
        loss_to_use=None,
        nan_value=None,
        name="nan_value_loss",
        **kwargs,
    ):
        self.nan_value = nan_value
        self.loss_to_use=loss_to_use
        super().__init__(name=name, **kwargs)

    def call(self, y_true, y_pred):

        valid_mask = ops.not_equal(y_true, self.nan_value)
        return self.loss_to_use(y_true[valid_mask], y_pred[valid_mask])
    


class NanValueLossB(Loss):
    def __init__(
        self,
        loss_to_use=None,
        nan_value=None,
        name="nan_value_loss",
        **kwargs,
    ):
        self.nan_value = nan_value
        self.loss_to_use=loss_to_use
        super().__init__(name=name, **kwargs)

    def call(self, y_true, y_pred):

        valid_mask = ops.not_equal(y_true, self.nan_value)
        valid_indices = ops.where(valid_mask)
        masked_y_pred = ops.take(y_pred,valid_indices)
        masked_y_true = ops.take(y_true,valid_indices)

        return self.loss_to_use(masked_y_true, masked_y_pred)

我已经在 jax 和 torch 中尝试过这两种形式。我尝试了其他几种方法,但每次问题都是一样的。 以下是错误:

NaNValueLossA: 火炬:

  File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
    x = x.to(device)
        ^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

贾克斯:

  File "c:....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6976, in _expand_bool_indices
    raise errors.NonConcreteBooleanIndexError(abstract_i)
jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[32,1,128,128,1])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

NaNValueLossB: 火炬:

  File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
    x = x.to(device)
        ^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

贾克斯:

  File "C:....\advanced_losses.py", line 651, in call
    valid_indices = ops.where(valid_mask)
                    ^^^^^^^^^^^^^^^^^^^^^
  File "....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 1946, in where
    return nonzero(condition, size=size, fill_value=fill_value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 2378, in nonzero
    calculated_size = core.concrete_dim_or_error(calculated_size,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function wrapped_fn at c:.....\Lib\site-packages\keras\src\backend\jax\core.py:153 for jit. This concrete value was not available in Python because it depends on the value of the argument args[1].

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

在 keras 3 之前,我使用了基于张量流的损失函数并且它有效,但现在我想要一些可以与 torch 一起使用的东西。 这是我的张量流实现:

import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K




def nan_mean_squared_error_loss(nan_value=np.nan):
    # Create a loss function
    def loss(y_true, y_pred):
        # if y_true.shape != y_pred.shape:
        #    y_true = y_true[:, :1]
        indices = tf.where(tf.not_equal(y_true, nan_value))
        return tf.keras.losses.mean_squared_error(
            tf.gather_nd(y_true, indices), tf.gather_nd(y_pred, indices)
        )

    # Return a function
    return loss
python keras pytorch loss-function keras-3
1个回答
0
投票

放弃吧兄弟...似乎无解!

© www.soinside.com 2019 - 2024. All rights reserved.