我有一个像这样的 Jax 数组 X:
[[[0. 0. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]]
[[0. 0. 0.]
[0. 0. 0.]]]
如何将此数组的值设置为 1,其索引由数组 Y 给出:
[[[1 2]
[1 2]]
[[0 2]
[0 1]]
[[1 0]
[1 0]]]
所需输出:
([[[0., 1., 1.],
[0., 1., 1.]],
[[1., 0., 1.],
[1., 1., 0.]],
[[1., 1., 0.],
[1., 1., 0.]]]
有几种方法可以解决这个问题。首先让我们定义数组:
import jax
import jax.numpy as jnp
x = jnp.zeros((3, 2, 3))
indices = jnp.array([[[1, 2],
[1, 2]],
[[0, 2],
[0, 1]],
[[1, 0],
[1, 0]]])
实现此目的的一种方法是使用典型的 numpy 风格的索引广播。它可能看起来像这样:
i = jnp.arange(3).reshape(3, 1, 1)
j = jnp.arange(2).reshape(2, 1)
x = x.at[i, j, indices].set(1)
print(x)
[[[0. 1. 1.]
[0. 1. 1.]]
[[1. 0. 1.]
[1. 1. 0.]]
[[1. 1. 0.]
[1. 1. 0.]]]
另一种选择是使用双
vmap
转换来计算批量索引:
f = jax.vmap(jax.vmap(lambda x, i: x.at[i].set(1)))
print(f(x, indices))
[[[0. 1. 1.]
[0. 1. 1.]]
[[1. 0. 1.]
[1. 1. 0.]]
[[1. 1. 0.]
[1. 1. 0.]]]