在这里使用np.min
有什么问题?为什么numba不喜欢在该函数中使用列表,是否还有其他方法可以让np.min
工作?
from numba import njit
import numpy as np
@njit
def availarray(length):
out=np.ones(14)
if length>0:
out[0:np.min([int(length),14])]=0
return out
availarray(3)
该功能与min
工作正常,但np.min
应该更快...
问题是np.min
的numba版本需要array
作为输入。
from numba import njit
import numpy as np
@njit
def test_numba_version_of_numpy_min(inp):
return np.min(inp)
>>> test_numba_version_of_numpy_min(np.array([1, 2])) # works
1
>>> test_numba_version_of_numpy_min([1, 2]) # doesn't work
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function amin at 0x000001B5DBDEE598>) with argument(s) of type(s): (reflected list(int64))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
更好的解决方案是使用numba版本的Pythons min
:
from numba import njit
import numpy as np
@njit
def availarray(length):
out = np.ones(14)
if length > 0:
out[0:min(length, 14)] = 0
return out
由于np.min
和min
实际上都是这些函数的Numba版本(至少在njit
ted函数中)min
在这种情况下也应该快得多。然而,它不太可能引人注意,因为数组的分配和将一些元素设置为零将成为这里的主要运行时贡献者。
请注意,这里甚至不需要min
调用 - 因为即使使用更大的停止索引,切片也会在数组末尾隐式停止:
from numba import njit
import numpy as np
@njit
def availarray(length):
out = np.ones(14)
if length > 0:
out[0:length] = 0
return out
要使代码与numba
一起使用,您必须在NumPy数组上应用np.min
,这意味着您必须将列表[int(length),14]
转换为NumPy数组,如下所示
from numba import njit
import numpy as np
@njit
def availarray(length):
out=np.ones(14)
if length>0:
out[0:np.min(np.array([int(length),14]))]=0
return out
availarray(3)
# array([0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])