我正在尝试训练一个具有两个梯度下降输出的模型。因此,我的成本函数返回两个错误。处理这个问题的典型方法是什么?
我到处都看到提到这个问题,但我还没有想出一个令人满意的解决方案。
这是一个重现我的问题的玩具示例:
from jax import jit, random, grad
import optax
@jit
def my_model(forz, params):
a, b = params
a_vect = a + forz**b
b_vect = b + forz**a
return a_vect, b_vect*50.
@jit
def rmse(predictions, targets):
rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
return rmse
@jit
def my_loss(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = rmse(sim_a, true_a)
loss_b = rmse(sim_b, true_b)
return loss_a, loss_b
grad_myloss = jit(grad(my_loss, argnums=1))
# synthetic true data
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))
true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)
# Train
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)
for i in range(1000):
grads = grad_myloss(forz, model_params, true_a, true_b) # this fails
updates, opt_state = optimizer.update(grads, opt_state)
model_params = optax.apply_updates(model_params, updates)
我知道这两个错误必须以某种方式聚合为一个错误,以实现某种损失的标准化(我的输出向量具有不可比较的单位),
@jit
def normalized_rmse(predictions, targets):
std_dev_targets = jnp.std(targets)
rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
return rmse/std_dev_targets
@jit
def my_loss_single(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = normalized_rmse(sim_a, true_a)
loss_b = normalized_rmse(sim_b, true_b)
return jnp.sqrt((loss_a ** 2) + (loss_b * 2))
或者我应该以某种方式使用雅可比矩阵(
jacrev
)?
optax
只能优化单值损失函数。您应该决定什么单值损失对您的特定问题有意义。考虑到个人损失的 RMS 形式,一个不错的选择可能是平方和:
@jit
def my_loss(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = rmse(sim_a, true_a)
loss_b = rmse(sim_b, true_b)
return loss_a ** 2 + loss_b ** 2
进行此更改后,您的代码执行时不会出现错误。