Numba抱怨输入-但正在提供所有类型

问题描述 投票:0回答:1

我在输入Numba时遇到问题-我阅读了手册,但最终碰到了砖墙。

有问题的功能是一个更大的项目的一部分-尽管它需要快速运行-Python列表是不可能的,因此我决定尝试Numba。可悲的是,该函数在nopython = True模式下失败,尽管事实上-根据我的理解-正在提供所有类型。

代码如下:

from Numba import jit, njit, uint8, int64, typeof

@jit(uint8[:,:,:](int64))
def findWhite(cropped):
    h1 = int64(0)
    for i in cropped:
        for j in i:
            if np.sum(j) == 765:
                h1 = h1 + int64(1)
            else:
                pass
    return h1

也另外:

print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64

在这种情况下,“裁剪”是一个很大的uint8 3D C矩阵(RGB tiff文件理解-PIL.Image)。有人可以向Numba新手解释我做错了什么吗?

python numba typing
1个回答
1
投票

您是否考虑过使用Numpy?这通常是Python列表和Numba之间的良好中介,例如:

h1 = (cropped.sum(axis=-1) == 765).sum()

h1 = (cropped == 255).all(axis=-1).sum()

您提供的示例代码不是有效的Numba。您的签名也不正确,因为输入是3D数组,而输出是整数,所以它可能应该是:

@njit(int64(uint8[:,:,:]))

像您一样遍历数组是无效代码。仔细翻译您的代码将是这样的:

@njit(int64(uint8[:,:,:]))
def findWhite(cropped):

    h1 = int64(0)    
    ys, xs, n_bands = cropped.shape

    for i in range(ys):
        for j in range(xs):
            if cropped[i, j, :].sum() == 765:
                h1 += 1

    return h1

但是速度不是很快,并且在我的机器上没有击败Numpy。使用Numba可以显式地遍历数组中的每个元素,这已经快得多了:

@njit(int64(uint8[:,:,:]))
def findWhite_numba(cropped):

    h1 = int64(0)    
    ys, xs, zs = cropped.shape

    for i in range(ys):
        for j in range(xs):

            incr = 1
            for k in range(zs):

                if cropped[i, j, k] != 255:
                    incr = 0
                    break

            h1 += incr

    return h1

对于5000x5000x3数组,这是我的结果:

数字键(h1 = (cropped == 255).all(axis=-1).sum()):

427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

findWhite:

612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

findWhite_numba:

31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numpy方法的好处是它可以推广到任意数量的尺寸。

© www.soinside.com 2019 - 2024. All rights reserved.