我在输入Numba时遇到问题-我阅读了手册,但最终碰到了砖墙。
有问题的功能是一个更大的项目的一部分-尽管它需要快速运行-Python列表是不可能的,因此我决定尝试Numba。可悲的是,该函数在nopython = True模式下失败,尽管事实上-根据我的理解-正在提供所有类型。
代码如下:
from Numba import jit, njit, uint8, int64, typeof
@jit(uint8[:,:,:](int64))
def findWhite(cropped):
h1 = int64(0)
for i in cropped:
for j in i:
if np.sum(j) == 765:
h1 = h1 + int64(1)
else:
pass
return h1
也另外:
print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64
在这种情况下,“裁剪”是一个很大的uint8 3D C矩阵(RGB tiff文件理解-PIL.Image)。有人可以向Numba新手解释我做错了什么吗?
您是否考虑过使用Numpy?这通常是Python列表和Numba之间的良好中介,例如:
h1 = (cropped.sum(axis=-1) == 765).sum()
或
h1 = (cropped == 255).all(axis=-1).sum()
您提供的示例代码不是有效的Numba。您的签名也不正确,因为输入是3D数组,而输出是整数,所以它可能应该是:
@njit(int64(uint8[:,:,:]))
像您一样遍历数组是无效代码。仔细翻译您的代码将是这样的:
@njit(int64(uint8[:,:,:]))
def findWhite(cropped):
h1 = int64(0)
ys, xs, n_bands = cropped.shape
for i in range(ys):
for j in range(xs):
if cropped[i, j, :].sum() == 765:
h1 += 1
return h1
但是速度不是很快,并且在我的机器上没有击败Numpy。使用Numba可以显式地遍历数组中的每个元素,这已经快得多了:
@njit(int64(uint8[:,:,:]))
def findWhite_numba(cropped):
h1 = int64(0)
ys, xs, zs = cropped.shape
for i in range(ys):
for j in range(xs):
incr = 1
for k in range(zs):
if cropped[i, j, k] != 255:
incr = 0
break
h1 += incr
return h1
对于5000x5000x3数组,这是我的结果:
数字键(h1 = (cropped == 255).all(axis=-1).sum()
):
427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
findWhite:
612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
findWhite_numba:
31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numpy方法的好处是它可以推广到任意数量的尺寸。