给出以下使用 NumPy (np) 和 JAX 的 jax.numpy 模块 (jnp) 的 Python 代码
import numpy as np
import jax.numpy as jnp
x = np.arange(10)
#print(x[13]) # this will throw an error
x = jnp.array(x)
# indexing for single index
print(x[13]) # this will not throw an error and returns 9
print(x[-13]) # this will not throw an error and returns 0
# indexing for range of values
print(x[13:15]) # this will not throw an error but returns and empty array
print(x[-13:-15]) # this will not throw an error but returns and empty array
我想知道 JAX 以上述方式表现的原因是什么(当为单个索引建立索引时返回 9 表示上限超出范围,返回 0 表示下限超出范围。在为一系列索引编制索引时返回空数组)而不是抛出一个我们会在 NumPy 中收到“IndexError: index out of bounds”错误?