我正在尝试在 Keras 3 上构建自定义损失函数,该函数将在 jax 或 torch 后端中使用。 我想从 y_pred 和 y_true 中屏蔽 y_true 为特定值的所有索引。将剩余值传递给给定的 loss_function。
但是每次我尝试用 jax 后端或 torch 来拟合我的损失函数的模型时,它都会崩溃,几乎说它无法获取索引或进行掩蔽。因为为此它需要访问张量上的值。
我使用两种方式:
import keras
from keras import Loss, ops
class NanValueLossA(Loss):
def __init__(
self,
loss_to_use=None,
nan_value=None,
name="nan_value_loss",
**kwargs,
):
self.nan_value = nan_value
self.loss_to_use=loss_to_use
super().__init__(name=name, **kwargs)
def call(self, y_true, y_pred):
valid_mask = ops.not_equal(y_true, self.nan_value)
return self.loss_to_use(y_true[valid_mask], y_pred[valid_mask])
class NanValueLossB(Loss):
def __init__(
self,
loss_to_use=None,
nan_value=None,
name="nan_value_loss",
**kwargs,
):
self.nan_value = nan_value
self.loss_to_use=loss_to_use
super().__init__(name=name, **kwargs)
def call(self, y_true, y_pred):
valid_mask = ops.not_equal(y_true, self.nan_value)
valid_indices = ops.where(valid_mask)
masked_y_pred = ops.take(y_pred,valid_indices)
masked_y_true = ops.take(y_true,valid_indices)
return self.loss_to_use(masked_y_true, masked_y_pred)
我已经在 jax 和 torch 中尝试过这两种形式。我尝试了其他几种方法,但每次问题都是一样的。 以下是错误:
NaNValueLossA: 火炬:
File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
x = x.to(device)
^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!
贾克斯:
File "c:....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 6976, in _expand_bool_indices
raise errors.NonConcreteBooleanIndexError(abstract_i)
jax.errors.NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[32,1,128,128,1])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
NaNValueLossB: 火炬:
File "c:\....\Lib\site-packages\keras\src\backend\torch\core.py", line 162, in convert_to_tensor
x = x.to(device)
^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!
贾克斯:
File "C:....\advanced_losses.py", line 651, in call
valid_indices = ops.where(valid_mask)
^^^^^^^^^^^^^^^^^^^^^
File "....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 1946, in where
return nonzero(condition, size=size, fill_value=fill_value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".....\Lib\site-packages\jax\_src\numpy\lax_numpy.py", line 2378, in nonzero
calculated_size = core.concrete_dim_or_error(calculated_size,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function wrapped_fn at c:.....\Lib\site-packages\keras\src\backend\jax\core.py:153 for jit. This concrete value was not available in Python because it depends on the value of the argument args[1].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
在 keras 3 之前,我使用了基于张量流的损失函数并且它有效,但现在我想要一些可以与 torch 一起使用的东西。 这是我的张量流实现:
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
def nan_mean_squared_error_loss(nan_value=np.nan):
# Create a loss function
def loss(y_true, y_pred):
# if y_true.shape != y_pred.shape:
# y_true = y_true[:, :1]
indices = tf.where(tf.not_equal(y_true, nan_value))
return tf.keras.losses.mean_squared_error(
tf.gather_nd(y_true, indices), tf.gather_nd(y_pred, indices)
)
# Return a function
return loss
放弃吧兄弟...似乎无解!