根据另一个数组中的值更新 JAX 数组

问题描述 投票:0回答:1

我有一个像这样的 Jax 数组 X:

[[[0. 0. 0.]
 [0. 0. 0.]]

 [[0. 0. 0.]
 [0. 0. 0.]]

 [[0. 0. 0.]
 [0. 0. 0.]]]

如何将此数组的值设置为 1,其索引由数组 Y 给出:

[[[1 2]
 [1 2]]

 [[0 2]
 [0 1]]

 [[1 0]
 [1 0]]]

所需输出:

([[[0., 1., 1.],
        [0., 1., 1.]],

       [[1., 0., 1.],
        [1., 1., 0.]],

       [[1., 1., 0.],
        [1., 1., 0.]]]
jax
1个回答
0
投票

有几种方法可以解决这个问题。首先让我们定义数组:

import jax
import jax.numpy as jnp

x = jnp.zeros((3, 2, 3))
indices = jnp.array([[[1, 2],
                      [1, 2]],
                     [[0, 2],
                      [0, 1]],
                     [[1, 0],
                      [1, 0]]])

实现此目的的一种方法是使用典型的 numpy 风格的索引广播。它可能看起来像这样:

i = jnp.arange(3).reshape(3, 1, 1)
j = jnp.arange(2).reshape(2, 1)
x = x.at[i, j, indices].set(1)
print(x)
[[[0. 1. 1.]
  [0. 1. 1.]]

 [[1. 0. 1.]
  [1. 1. 0.]]

 [[1. 1. 0.]
  [1. 1. 0.]]]

另一种选择是使用双

vmap
转换来计算批量索引:

f = jax.vmap(jax.vmap(lambda x, i: x.at[i].set(1)))
print(f(x, indices))
[[[0. 1. 1.]
  [0. 1. 1.]]

 [[1. 0. 1.]
  [1. 1. 0.]]

 [[1. 1. 0.]
  [1. 1. 0.]]]
© www.soinside.com 2019 - 2024. All rights reserved.