JAX 与 NumPy 数组索引 - 越界行为

问题描述 投票:0回答:0

给出以下使用 NumPy (np) 和 JAX 的 jax.numpy 模块 (jnp) 的 Python 代码

import numpy as np
import jax.numpy as jnp

x = np.arange(10)
#print(x[13]) # this will throw an error
x = jnp.array(x)
# indexing for single index
print(x[13]) # this will not throw an error and returns 9
print(x[-13]) # this will not throw an error and returns 0

# indexing for range of values
print(x[13:15]) # this will not throw an error but returns and empty array
print(x[-13:-15]) # this will not throw an error but returns and empty array 

我想知道 JAX 以上述方式表现的原因是什么(当为单个索引建立索引时返回 9 表示上限超出范围,返回 0 表示下限超出范围。在为一系列索引编制索引时返回空数组)而不是抛出一个我们会在 NumPy 中收到“IndexError: index out of bounds”错误?

python-3.x numpy multidimensional-array jax
© www.soinside.com 2019 - 2024. All rights reserved.