numba和numpy.expand_dims

问题描述 投票:-1回答:2

我正在重写我的一些功能,以适合Numba。现在我有一个函数,我在我的脚本中多次调用不同维度的输入数组。

def FormHistMatrix2(x,Whc,Lm):
    if x.ndim == 1:
       x = np.expand_dims(x,axis=1)
    [N,Ncells] = x.shape

这是我的功能的开始,Numba抛出以下错误:

TypingError: Cannot unify array(float64, 2d, A) and array(float64, 3d, A) for 'x', defined at C:/Users/DNP_Student_3/Documents/Python Scripts/GCFuncsTests.py (332)

在这种情况下,'x'是2D阵列,但在其他情况下,它可以是1-D阵列。所以Numba不喜欢if循环吗?或者这里发生了什么?

python numba
2个回答
1
投票

在Numba中,与标准python不同,变量在执行函数期间不能更改其类型。您应该能够将调用结果分配给np.expand_dims到另一个变量,它将起作用。如果有时x是1d并且有时它是2d,只要在函数执行过程中所有变量的类型一致,这是可以的。


0
投票

JoshAdel所说的一般是正确的,但在这种情况下的问题是你需要根据输入类型对函数进行不同的实现/特化。

Numba有这个案件的@generated_jit-decorator

在您的情况下,您需要编写一个专门的expand-dims函数,该函数取决于输入数组的维度:

import numba as nb
@nb.generated_jit(nopython=True)
def nb_expander(x):
    if x.ndim == 1:
        return lambda x: np.expand_dims(x, axis=1)
    else:
        return lambda x: x

需要从其他函数中调用此函数:

@nb.njit
def FormHistMatrix2(x, Whc, Lm):
    x = nb_expander(x)
    [N, Ncells] = x.shape

这将适用于尺寸为1和2的x。对于x.ndim==3,您还需要为形状实现类似的方法。

© www.soinside.com 2019 - 2024. All rights reserved.