在 JAX 中根据其输入批量获取神经网络的导数

问题描述 投票:0回答:1

有一个神经网络将两个变量作为输入:

net(x, t)
,其中
x
通常是d-dim,而
t
是标量。神经网络输出长度为
d
的向量。
x
t
可能是批次,所以
x
的形状是
(b, d)
t
(b, 1)
,输出是
(b,d)
。我需要找到

  • 神经网络输出的导数
    d out/dt
    。它应该是 d 暗向量(或
    (batch, d)
    );
  • 神经网络
    的导数
    d out/dx
  • 根据
    x
    ,神经网络输出的散度梯度,它仍然应该是
    (batch, d)
    向量

由于神经网络不输出标量,我认为 Jax grad 在这里不会有帮助。我知道如何执行我在 torch 中描述的操作,但我对 JAX 完全陌生。我非常感谢您对这个问题的帮助!

有一个例子:

import jaxlib
import jax
from jax import numpy as jnp
import flax.linen as nn
from flax.training import train_state



class NN(nn.Module):
    hid_dim : int # Number of hidden neurons
    output_dim : int # Number of output neurons

    @nn.compact  
    def __call__(self, x, t):
        out = jnp.hstack((x, t))
        out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
        out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
        out = nn.Dense(features=self.output_dim)(out)
        return out

d = 3
batch_size = 10
net = NN(hid_dim=100, output_dim=d)

rng_nn, rng_inp1, rng_inp2 = jax.random.split(jax.random.PRNGKey(100), 3)
inp_x = jax.random.normal(rng_inp1, (1, d)) # batch, d
inp_t = jax.random.normal(rng_inp2, (1, 1))
params_net = net.init(rng_nn, inp_x, inp_t)

x = jax.random.normal(rng_inp2, (batch_size, d)) # batch, d
t = jax.random.normal(rng_inp1, (batxh_size, 1))

out_net = net.apply(params_net, x, t)

optimizer = optax.adam(1e-3)

model_state = train_state.TrainState.create(apply_fn=net.apply,
                                            params= params_net,
                                            tx=optimizer)

我想根据神经网络输出根据其输入的一些导数来计算 $L_2$ 损失。例如,我想要

d f/dx
d f/dt
,其中
f
是神经网络。还有 x 的散度梯度。我想它会是这样的

def find_derivatives(net, params, X, t):
    d_dt = lambda net, params, X, t: jax.jvp(lambda time: net(params, X, t), (t, ), (jnp.ones_like(t), ))
    d_dx = lambda net, params, X, t: jax.jvp(lambda X: net(params, X, t), (Xs_all, ), (jnp.ones_like(X), ))
    out_f, df_dt = d_dt(net.apply, params, X, t)

    d_ddx = lambda net, params, X, t: d_dx(lambda params, X, t: d_dx(net, params, X, t)[1], params, X, t)
    df_dx, df_ddx = d_ddx(net.apply, params, X, t)
    
    return out_f, df_dt, df_dx, df_ddx


out_f, df_dt, df_dx, df_ddx = find_derivatives(net, params_net, x, t)
python deep-learning jax autograd flax
1个回答
0
投票

我会避免在这里使用

jax.jvp
,因为这是一个较低级别的 API。您可以使用
jax.jacobian
计算雅可比行列式(因为您的函数有多个输出),并使用
vmap
进行批处理。例如:

df_dx = jax.vmap(
    jax.jacobian(net.apply, argnums=1),
    in_axes=(None, 0, 0)
  )(params_net, x, t)
print(df_dx.shape)  # (10, 3, 3)

df_dt = jax.vmap(
    jax.jacobian(net.apply, argnums=2),
    in_axes=(None,0, 0)
  )(params_net, x, t).reshape(10, 3)
print(df_dt.shape)  # (10, 3)

这里

df_dx
是 3 维输出向量相对于 3 维 x 输入向量的批量雅可比行列式,
df_dt
是 3 维输出向量相对于 3 维 x 输入向量的批量梯度输入
t

© www.soinside.com 2019 - 2024. All rights reserved.