我发现 numpy 数组索引与 ndrrray 和形状为
(1,)
的 PyTorch 张量的工作方式不同,并且想知道为什么。请看下面的案例:
import numpy as np
import torch as th
x = np.arange(10)
y = x[np.array([1])]
z = x[th.tensor([1])]
print(y, z)
y
将是 array[2]
,而 z
只是 2
。到底有什么区别?
请注意,单个元素的整数张量可以转换为索引:
>>> torch.tensor([1]).__index__()
1
>>> torch.tensor([1, 2]).__index__()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: only integer tensors of a single element can be converted to an index
当传入的索引是张量时,
ndarray
无法识别它,因此它尝试调用其__index__
方法。如果转换成功,则被视为整数:
if (PyLong_CheckExact(obj) || !PyArray_Check(obj)) {
npy_intp ind = PyArray_PyIntAsIntp(obj); // it calls PyNumber_Index() internally
if (error_converting(ind)) {
PyErr_Clear();
}
else {
index_type |= HAS_INTEGER;
indices[curr_idx].object = NULL;
indices[curr_idx].value = ind;
indices[curr_idx].type = HAS_INTEGER;
used_ndim += 1;
new_ndim += 0;
curr_idx += 1;
continue;
}
}