Numba 中重塑和类型转换之间的奇怪交互

问题描述 投票:0回答:1

我注意到,当我将 0 和 1 的二维数组传递到 Numba njit 函数中,对其进行整形,然后将其转换为 np.int32 或 numba.int32 时,打印时生成的数组是不同的。

这是示例代码:

import numpy as np
from numba import njit

array_2d = np.array([[0, 1, 0, 1, 1, 0, 0, 1],
                     [0, 1, 1, 0, 1, 1, 0, 0],
                     [0, 1, 1, 0, 1, 1, 0, 0]]).T

num_cols = array_2d.shape[1]
num_rows = array_2d.shape[0]

@njit
def f(array, num_rows, num_cols):
    pairs = array.reshape(num_rows // 2, 2, num_cols)

    pairs_cast = pairs.astype(numba.int32)

    return pairs, pairs_cast

pairs, pairs_cast = f(array_2d, num_rows, num_cols)

print("Pairs:")
print(pairs)
print("\nPairs cast to int32:")
print(pairs_cast)

输出为:

Pairs:
[[[0 0 0]
  [1 1 1]]

 [[0 1 1]
  [1 0 0]]

 [[1 1 1]
  [0 1 1]]

 [[0 0 0]
  [1 0 0]]]

Pairs cast to int32:
[[[0 0 0]
  [1 1 1]]

 [[1 1 1]
  [0 1 1]]

 [[0 1 1]
  [0 0 0]]

 [[1 0 0]
  [1 0 0]]]

很想知道这里发生了什么。

python numba jit
1个回答
0
投票

正如评论中提到的,reshape(在 numba 中)当前仅支持连续数组,但似乎不会在这里触发错误。 如果您在 numba 编译代码中将数组转换为不连续的数组,您会看到错误:

import numpy as np
import numba as nb

@nb.njit
def func(a):
    return a.transpose(1, 0).reshape(-1)

func(np.ones((10, 10)))

加薪:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
reshape() supports contiguous array only
- Resolution failure for non-literal arguments:
reshape() supports contiguous array only
© www.soinside.com 2019 - 2024. All rights reserved.