对使用非恒等向量 (JAX) 评估向量雅可比积感到困惑

问题描述 投票:0回答:0

当用于 VJP 的向量是非同一行向量时,我对评估向量-雅可比积的含义感到困惑。我的问题与矢量值函数有关,而不是像损失这样的标量函数。我将展示一个使用 Python 和 JAX 的具体示例,但这是一个关于反向模式自动微分的非常普遍的问题。

考虑这个简单的向量值函数,它的雅可比行列式很容易分析地写下来:

from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import vjp, jacrev

# Define a vector-valued function (3 inputs --> 2 outputs) 
def vector_func(args):
    x,y,z = args
    a = 2*x**2 + 3*y**2 + 4*z**2
    b = 4*x*y*z
    return jnp.array([a, b])

# Define the inputs
x = 2.0
y = 3.0
z = 4.0

# Compute the vector-Jacobian product at the fiducial input point (x,y,z)
val, func_vjp = vjp(vector_func, (x, y, z))

print(val) 
# [99,96]

# now evaluate the function returned by vjp along with basis row vectors to pull out gradient of 1st and 2nd output components 
v1 = jnp.array([1.0, 0.0])  # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., first row of Jacobian
v2 = jnp.array([0.0, 1.0])  # pulls out the gradient of the 1st component wrt the 3 inputs, i.e., second row of Jacobian 

gradient1 = func_vjp(v1)
print(gradient1)
# [8, 18, 32]

gradient2 = func_vjp(v2)
print(gradient2)
# [48,32,24]

这对我来说很有意义——我们分别将 [1,0] 和 [0,1] 馈送到 vjp_func 以分别获得在我们的基准点 (x,y,z) 评估的雅可比行列式的第一行和第二行=(2,3,4).

但是现在如果我们给 vjp_func 一个像 [2,0] 这样的非恒等行向量呢?这是在询问如何需要扰动基准 (x,y,z) 以使输出的第一个分量加倍吗?如果是这样,有没有办法通过在扰动参数值处评估 vector_func 来看到这一点?

我试过但我不确定:

# suppose I want to know what perturbations in (x,y,z) cause a doubling of the first output and no change in second output component 
print(func_vjp(jnp.array([2.0,0.0])))
# [16,36,64] 

### Attempts to use the output of vjp_func to verify that val becomes [99*2, 96]
### none of these work

print(vector_func([16,36,64]))
# [20784, 147456]

print(vector_func([x*16,y*36,z*64])
# [299184., 3538944.]

我在使用 func_vjp 的输出修改基准参数 (x,y,z) 并将其反馈给 vector_func 以验证这些参数扰动确实将原始输出的第一个分量加倍并保留第二个分量时做错了什么不变?

julia derivative jax automatic-differentiation autodiff
© www.soinside.com 2019 - 2024. All rights reserved.