有一个神经网络将两个变量作为输入:
net(x, t)
,其中x
通常是d-dim,而t
是标量。神经网络输出长度为 d
的向量。 x
和t
可能是批次,所以x
的形状是(b, d)
,t
是(b, 1)
,输出是(b,d)
。我需要找到
d out/dt
。它应该是 d 暗向量(或 (batch, d)
);的导数
d out/dx
x
,神经网络输出的散度梯度,它仍然应该是(batch, d)
向量由于神经网络不输出标量,我认为 Jax grad 在这里不会有帮助。我知道如何执行我在 torch 中描述的操作,但我对 JAX 完全陌生。我非常感谢您对这个问题的帮助!
有一个例子:
import jaxlib
import jax
from jax import numpy as jnp
import flax.linen as nn
from flax.training import train_state
class NN(nn.Module):
hid_dim : int # Number of hidden neurons
output_dim : int # Number of output neurons
@nn.compact
def __call__(self, x, t):
out = jnp.hstack((x, t))
out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
out = nn.tanh(nn.Dense(features=self.hid_dim)(out))
out = nn.Dense(features=self.output_dim)(out)
return out
d = 3
batch_size = 10
net = NN(hid_dim=100, output_dim=d)
rng_nn, rng_inp1, rng_inp2 = jax.random.split(jax.random.PRNGKey(100), 3)
inp_x = jax.random.normal(rng_inp1, (1, d)) # batch, d
inp_t = jax.random.normal(rng_inp2, (1, 1))
params_net = net.init(rng_nn, inp_x, inp_t)
x = jax.random.normal(rng_inp2, (batch_size, d)) # batch, d
t = jax.random.normal(rng_inp1, (batxh_size, 1))
out_net = net.apply(params_net, x, t)
optimizer = optax.adam(1e-3)
model_state = train_state.TrainState.create(apply_fn=net.apply,
params= params_net,
tx=optimizer)
我想根据神经网络输出根据其输入的一些导数来计算 $L_2$ 损失。例如,我想要
d f/dx
或 d f/dt
,其中 f
是神经网络。还有 x 的散度梯度。我想它会是这样的
def find_derivatives(net, params, X, t):
d_dt = lambda net, params, X, t: jax.jvp(lambda time: net(params, X, t), (t, ), (jnp.ones_like(t), ))
d_dx = lambda net, params, X, t: jax.jvp(lambda X: net(params, X, t), (Xs_all, ), (jnp.ones_like(X), ))
out_f, df_dt = d_dt(net.apply, params, X, t)
d_ddx = lambda net, params, X, t: d_dx(lambda params, X, t: d_dx(net, params, X, t)[1], params, X, t)
df_dx, df_ddx = d_ddx(net.apply, params, X, t)
return out_f, df_dt, df_dx, df_ddx
out_f, df_dt, df_dx, df_ddx = find_derivatives(net, params_net, x, t)
我会避免在这里使用
jax.jvp
,因为这是一个较低级别的 API。您可以使用 jax.jacobian
计算雅可比行列式(因为您的函数有多个输出),并使用 vmap
进行批处理。例如:
df_dx = jax.vmap(
jax.jacobian(net.apply, argnums=1),
in_axes=(None, 0, 0)
)(params_net, x, t)
print(df_dx.shape) # (10, 3, 3)
df_dt = jax.vmap(
jax.jacobian(net.apply, argnums=2),
in_axes=(None,0, 0)
)(params_net, x, t).reshape(10, 3)
print(df_dt.shape) # (10, 3)
这里
df_dx
是 3 维输出向量相对于 3 维 x 输入向量的批量雅可比行列式,df_dt
是 3 维输出向量相对于 3 维 x 输入向量的批量梯度输入t
。