我注意到,当我将 0 和 1 的二维数组传递到 Numba njit 函数中,对其进行整形,然后将其转换为 np.int32 或 numba.int32 时,打印时生成的数组是不同的。
这是示例代码:
import numpy as np
from numba import njit
array_2d = np.array([[0, 1, 0, 1, 1, 0, 0, 1],
[0, 1, 1, 0, 1, 1, 0, 0],
[0, 1, 1, 0, 1, 1, 0, 0]]).T
num_cols = array_2d.shape[1]
num_rows = array_2d.shape[0]
@njit
def f(array, num_rows, num_cols):
pairs = array.reshape(num_rows // 2, 2, num_cols)
pairs_cast = pairs.astype(numba.int32)
return pairs, pairs_cast
pairs, pairs_cast = f(array_2d, num_rows, num_cols)
print("Pairs:")
print(pairs)
print("\nPairs cast to int32:")
print(pairs_cast)
输出为:
Pairs:
[[[0 0 0]
[1 1 1]]
[[0 1 1]
[1 0 0]]
[[1 1 1]
[0 1 1]]
[[0 0 0]
[1 0 0]]]
Pairs cast to int32:
[[[0 0 0]
[1 1 1]]
[[1 1 1]
[0 1 1]]
[[0 1 1]
[0 0 0]]
[[1 0 0]
[1 0 0]]]
很想知道这里发生了什么。
正如评论中提到的,reshape(在 numba 中)当前仅支持连续数组,但似乎不会在这里触发错误。 如果您在 numba 编译代码中将数组转换为不连续的数组,您会看到错误:
import numpy as np
import numba as nb
@nb.njit
def func(a):
return a.transpose(1, 0).reshape(-1)
func(np.ones((10, 10)))
加薪:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
reshape() supports contiguous array only
- Resolution failure for non-literal arguments:
reshape() supports contiguous array only