出于性能原因,我想将函数转换为 Numba。我的 MWE 示例如下。如果我删除
@njit
装饰器,代码可以工作,但是对于 @njit
,我会收到运行时异常。异常很可能是由于 dtype=object
来定义 result_arr
但我也尝试使用 dtype=float64
,但我得到了类似的异常。
import numpy as np
from numba import njit
from timeit import timeit
######-----------Required NUMBA function----------###
#@njit #<----without this, the code works
def required_numba_function():
nRows = 151
nCols = 151
nFrames = 24
result_arr = np.empty((151* 151 * 24), dtype=object)
for frame in range(nFrames):
for row in range(nRows):
for col in range(nCols):
size_rows = np.random.randint(8, 15)
size_cols = np.random.randint(2, 6)
args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
flat_idx = frame * (nRows * nCols) + (row * nCols + col)
result_arr[flat_idx] = args
return result_arr
######------------------main()-------##################
if __name__ == "__main__":
required_numba_function()
print()
如何解决 Numba 异常?
正如您所说,数组列表很好,您可以将
result_array
的分配替换为 dtype=object
的空数组,并在每次迭代时附加到空列表 - 这与 numba 兼容:
@nb.njit
def required_numba_function2():
np.random.seed(0) # Just for testing, you seem to have to set the seed within the function for numba to be aware of it
nRows = 151
nCols = 151
nFrames = 24
result_arr = []
for frame in range(nFrames):
for row in range(nRows):
for col in range(nCols):
size_rows = np.random.randint(8, 15)
size_cols = np.random.randint(2, 6)
args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
result_arr.append(args)
return result_arr
测试
np.random.seed(0)
result = required_numba_function()
result2 = required_numba_function2()
for i, j in zip(result, result2):
assert np.allclose(i, j)
时间:
%timeit required_numba_function()
%timeit required_numba_function2()
2.08 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
606 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)