如何在 JAX 中使用具有多输出(向量值)损失函数的梯度下降来训练模型?

问题描述 投票:0回答:1

我正在尝试训练一个具有两个梯度下降输出的模型。因此,我的成本函数返回两个错误。处理这个问题的典型方法是什么?

我到处都看到提到这个问题,但我还没有想出一个令人满意的解决方案。

这是一个重现我的问题的玩具示例:

from jax import jit, random, grad
import optax


@jit
def my_model(forz, params):
    a, b = params

    a_vect = a + forz**b
    b_vect = b + forz**a

    return a_vect, b_vect*50.


@jit
def rmse(predictions, targets):

    rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
    return rmse


@jit
def my_loss(forz, params, true_a, true_b):

    sim_a, sim_b = my_model(forz, params)

    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)

    return loss_a, loss_b


grad_myloss = jit(grad(my_loss, argnums=1))

# synthetic true data
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))

true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)

# Train
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)

for i in range(1000):

    grads = grad_myloss(forz, model_params, true_a, true_b)  # this fails
    updates, opt_state = optimizer.update(grads, opt_state)
    model_params = optax.apply_updates(model_params, updates)

我知道这两个错误必须以某种方式聚合为一个错误,以实现某种损失的标准化(我的输出向量具有不可比较的单位),

@jit
def normalized_rmse(predictions, targets):
   std_dev_targets = jnp.std(targets)
   rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
   return rmse/std_dev_targets


@jit
def my_loss_single(forz, params, true_a, true_b):

   sim_a, sim_b = my_model(forz, params)

   loss_a = normalized_rmse(sim_a, true_a)
   loss_b = normalized_rmse(sim_b, true_b)

   return jnp.sqrt((loss_a ** 2) + (loss_b * 2)) 

或者我应该以某种方式使用雅可比矩阵(

jacrev
)?

python machine-learning loss-function jax
1个回答
0
投票
与大多数优化框架一样,

optax
只能优化单值损失函数。您应该决定什么单值损失对您的特定问题有意义。考虑到个人损失的 RMS 形式,一个不错的选择可能是平方和:

@jit
def my_loss(forz, params, true_a, true_b):

    sim_a, sim_b = my_model(forz, params)

    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)

    return loss_a ** 2 + loss_b ** 2

进行此更改后,您的代码执行时不会出现错误。

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.