Numba njit 使涉及长表达式的函数调用极其缓慢

问题描述 投票:0回答:1

我正在编写一个有限体积代码来求解无粘性、可压缩欧拉方程。作为其中的一部分,我正在执行所谓的柯西-科瓦列夫斯卡亚过程。下面给出了代码片段

def cauchyKovalevskaya(drho_x: npt.NDArray[np.float64], dq_x: npt.NDArray[np.float64], dE_x: npt.NDArray[np.float64], k: int, gamma: float) -> npt.NDArray:
    if k == 0:
        return np.array([drho_x[0], dq_x[0], dE_x[0]], dtype=np.float64)
    
    elif k == 1:
        rho_t = -dq_x[1]

        q_t = -dq_x[0]**2*drho_x[1]*gamma/(2*drho_x[0]**2) + 3*dq_x[0]**2*drho_x[1]/(2*drho_x[0]**2) + dq_x[0]*dq_x[1]*gamma/drho_x[0] - 3*dq_x[0]*dq_x[1]/drho_x[0] - dE_x[1]*gamma + dE_x[1]

        E_t = -dq_x[0]**3*drho_x[1]*gamma/drho_x[0]**3 + dq_x[0]**3*drho_x[1]/drho_x[0]**3 + 3*dq_x[0]**2*dq_x[1]*gamma/(2*drho_x[0]**2) - 3*dq_x[0]**2*dq_x[1]/(2*drho_x[0]**2) - dq_x[0]*dE_x[1]*gamma/drho_x[0] + dq_x[0]*drho_x[1]*dE_x[0]*gamma/drho_x[0]**2 - dq_x[1]*dE_x[0]*gamma/drho_x[0]

        return np.array([rho_t, q_t, E_t], dtype=np.float64)
    
    elif k == 2:
        

        rho_tt = dq_x[0]**2*drho_x[2]*gamma/(2*drho_x[0]**2) - 3*dq_x[0]**2*drho_x[2]/(2*drho_x[0]**2) - dq_x[0]**2*drho_x[1]**2*gamma/drho_x[0]**3 + 3*dq_x[0]**2*drho_x[1]**2/drho_x[0]**3 + 2*dq_x[0]*dq_x[1]*drho_x[1]*gamma/drho_x[0]**2 - 6*dq_x[0]*dq_x[1]*drho_x[1]/drho_x[0]**2 - dq_x[0]*dq_x[2]*gamma/drho_x[0] + 3*dq_x[0]*dq_x[2]/drho_x[0] - dq_x[1]**2*gamma/drho_x[0] + 3*dq_x[1]**2/drho_x[0] + dE_x[2]*gamma - dE_x[2]

        q_tt = dq_x[0]**3*drho_x[2]*gamma**2/(2*drho_x[0]**3) + dq_x[0]**3*drho_x[2]*gamma/drho_x[0]**3 - 7*dq_x[0]**3*drho_x[2]/(2*drho_x[0]**3) - 3*dq_x[0]**3*drho_x[1]**2*gamma**2/(2*drho_x[0]**4) - 3*dq_x[0]**3*drho_x[1]**2*gamma/drho_x[0]**4 + 21*dq_x[0]**3*drho_x[1]**2/(2*drho_x[0]**4) + 5*dq_x[0]**2*dq_x[1]*drho_x[1]*gamma**2/(2*drho_x[0]**3) + 8*dq_x[0]**2*dq_x[1]*drho_x[1]*gamma/drho_x[0]**3 - 45*dq_x[0]**2*dq_x[1]*drho_x[1]/(2*drho_x[0]**3) - dq_x[0]**2*dq_x[2]*gamma**2/(2*drho_x[0]**2) - 5*dq_x[0]**2*dq_x[2]*gamma/(2*drho_x[0]**2) + 6*dq_x[0]**2*dq_x[2]/drho_x[0]**2 - dq_x[0]*dq_x[1]**2*gamma**2/drho_x[0]**2 - 5*dq_x[0]*dq_x[1]**2*gamma/drho_x[0]**2 + 12*dq_x[0]*dq_x[1]**2/drho_x[0]**2 + 3*dq_x[0]*dE_x[2]*gamma/drho_x[0] - 3*dq_x[0]*dE_x[2]/drho_x[0] - dq_x[0]*drho_x[1]*dE_x[1]*gamma**2/drho_x[0]**2 - 2*dq_x[0]*drho_x[1]*dE_x[1]*gamma/drho_x[0]**2 + 3*dq_x[0]*drho_x[1]*dE_x[1]/drho_x[0]**2 - dq_x[0]*drho_x[2]*dE_x[0]*gamma**2/drho_x[0]**2 + dq_x[0]*drho_x[2]*dE_x[0]*gamma/drho_x[0]**2 + 2*dq_x[0]*drho_x[1]**2*dE_x[0]*gamma**2/drho_x[0]**3 - 2*dq_x[0]*drho_x[1]**2*dE_x[0]*gamma/drho_x[0]**3 + dq_x[1]*dE_x[1]*gamma**2/drho_x[0] + 2*dq_x[1]*dE_x[1]*gamma/drho_x[0] - 3*dq_x[1]*dE_x[1]/drho_x[0] - 2*dq_x[1]*drho_x[1]*dE_x[0]*gamma**2/drho_x[0]**2 + 2*dq_x[1]*drho_x[1]*dE_x[0]*gamma/drho_x[0]**2 + dq_x[2]*dE_x[0]*gamma**2/drho_x[0] - dq_x[2]*dE_x[0]*gamma/drho_x[0]

        E_tt = dq_x[0]**4*drho_x[2]*gamma**2/(4*drho_x[0]**4) + 2*dq_x[0]**4*drho_x[2]*gamma/drho_x[0]**4 - 9*dq_x[0]**4*drho_x[2]/(4*drho_x[0]**4) - dq_x[0]**4*drho_x[1]**2*gamma**2/drho_x[0]**5 - 8*dq_x[0]**4*drho_x[1]**2*gamma/drho_x[0]**5 + 9*dq_x[0]**4*drho_x[1]**2/drho_x[0]**5 + dq_x[0]**3*dq_x[1]*drho_x[1]*gamma**2/drho_x[0]**4 + 37*dq_x[0]**3*dq_x[1]*drho_x[1]*gamma/(2*drho_x[0]**4) - 39*dq_x[0]**3*dq_x[1]*drho_x[1]/(2*drho_x[0]**4) - 7*dq_x[0]**3*dq_x[2]*gamma/(2*drho_x[0]**3) + 7*dq_x[0]**3*dq_x[2]/(2*drho_x[0]**3) - 21*dq_x[0]**2*dq_x[1]**2*gamma/(2*drho_x[0]**3) + 21*dq_x[0]**2*dq_x[1]**2/(2*drho_x[0]**3) - dq_x[0]**2*dE_x[2]*gamma**2/(2*drho_x[0]**2) + 3*dq_x[0]**2*dE_x[2]*gamma/drho_x[0]**2 - 3*dq_x[0]**2*dE_x[2]/(2*drho_x[0]**2) + dq_x[0]**2*drho_x[1]*dE_x[1]*gamma**2/(2*drho_x[0]**3) - 15*dq_x[0]**2*drho_x[1]*dE_x[1]*gamma/(2*drho_x[0]**3) + 3*dq_x[0]**2*drho_x[1]*dE_x[1]/drho_x[0]**3 - dq_x[0]**2*drho_x[2]*dE_x[0]*gamma**2/(2*drho_x[0]**3) - 3*dq_x[0]**2*drho_x[2]*dE_x[0]*gamma/(2*drho_x[0]**3) + 3*dq_x[0]**2*drho_x[1]**2*dE_x[0]*gamma**2/(2*drho_x[0]**4) + 9*dq_x[0]**2*drho_x[1]**2*dE_x[0]*gamma/(2*drho_x[0]**4) - dq_x[0]*dq_x[1]*dE_x[1]*gamma**2/drho_x[0]**2 + 8*dq_x[0]*dq_x[1]*dE_x[1]*gamma/drho_x[0]**2 - 3*dq_x[0]*dq_x[1]*dE_x[1]/drho_x[0]**2 - dq_x[0]*dq_x[1]*drho_x[1]*dE_x[0]*gamma**2/drho_x[0]**3 - 7*dq_x[0]*dq_x[1]*drho_x[1]*dE_x[0]*gamma/drho_x[0]**3 + 2*dq_x[0]*dq_x[2]*dE_x[0]*gamma/drho_x[0]**2 + 2*dq_x[1]**2*dE_x[0]*gamma/drho_x[0]**2 + dE_x[0]*dE_x[2]*gamma**2/drho_x[0] - dE_x[0]*dE_x[2]*gamma/drho_x[0] + dE_x[1]**2*gamma**2/drho_x[0] - dE_x[1]**2*gamma/drho_x[0] - drho_x[1]*dE_x[0]*dE_x[1]*gamma**2/drho_x[0]**2 + drho_x[1]*dE_x[0]*dE_x[1]*gamma/drho_x[0]**2


        return np.array([rho_tt, q_tt, E_tt], dtype=np.float64)


    else:
        raise Exception("Invalid order")

为了简洁起见,我还排除了 k = 3、4 和 5 的语句。可以说,这些行变得越来越长(一行中超过 80000 个字符),并且具有与上面类似的内存访问、幂和乘法负载。

当我在这段代码中使用 njit 时,a)它比没有 njit 慢(对于 k = 5,差异是 500 到 600 倍),b)无论我包含 elif 语句的 k 越高,它变得越来越慢我正在评估其中 k 的函数。据我所知,该函数中应该没有任何内容导致使用 Numba 的性能如此糟糕。

numpy numpy-ndarray numba
1个回答
0
投票

首先,编译后的代码相当庞大,特别是如果应用相同的逻辑“对于 k = 3、4 和 5”。这是一个问题,原因有两个。首先,要编译的代码越大,构建速度就越慢。其次,大型编译代码仍然不是最佳选择,因为本机编译代码需要从缓存/RAM 中获取、解码然后执行。前两个步骤的成本很高,尤其是当代码很大时。因此,由于前两个开销,循环有时会更快。

为了确保编译时间在基准测试/应用程序中间不是问题(由于延迟编译),您可以通过提供函数的目标签名来急切地编译函数。请注意,Numba 目前不关心像

npt.NDArray[np.float64]
这样的类型注释。这是签名示例:

@nb.njit('(float64[::1], float64[::1], float64[::1], int64, float64)')

在我的机器上,函数的构建需要 3.74 秒。对于如此简单的计算函数来说,这是很多,并且应该花费大量时间在 Numba Python 代码解析和语义分析上。虽然由于急切的编译,在定义函数时只需要支付一次时间,但如此减慢程序的初始化对我来说是不合理的。关键是避免重复表达。请注意,这不会对编译函数的执行时间产生太大影响,因为编译器(如 Numba)可以分析重复的术语并避免重新计算。另请注意,这也使代码更具可读性(避免超过 1700 个字符的丑陋行)。这是一个更好的版本:

def cauchyKovalevskaya(drho_x: npt.NDArray[np.float64], dq_x: npt.NDArray[np.float64], dE_x: npt.NDArray[np.float64], k: int, gamma: float) -> npt.NDArray:
    g = gamma
    r0, r1, r2 = drho_x
    q0, q1, q2 = dq_x
    e0, e1, e2 = dE_x

    r0p2, r1p2 = r0**2, r1**2
    q0p2, q1p2 = q0**2, q1**2
    e0p2, e1p2 = e0**2, e1**2
    gp2 = g**2
    r0p3, r0p4 = r0**3, r0**4
    q0p3, q0p4 = q0**3, q0**4

    if k == 0:
        return np.array([r0, q0, e0], dtype=np.float64)
    
    elif k == 1:
        rho_t = -q1
        q_t = -q0p2*r1*g/(2*r0p2) + 3*q0p2*r1/(2*r0p2) + q0*q1*g/r0 - 3*q0*q1/r0 - e1*g + e1
        E_t = -q0p3*r1*g/r0p3 + q0p3*r1/r0p3 + 3*q0p2*q1*g/(2*r0p2) - 3*q0p2*q1/(2*r0p2) - q0*e1*g/r0 + q0*r1*e0*g/r0p2 - q1*e0*g/r0
        return np.array([rho_t, q_t, E_t], dtype=np.float64)
    
    elif k == 2:
        rho_tt = q0p2*r2*g/(2*r0p2) - 3*q0p2*r2/(2*r0p2) - q0p2*r1p2*g/r0p3 + 3*q0p2*r1p2/r0p3 + 2*q0*q1*r1*g/r0p2 - 6*q0*q1*r1/r0p2 - q0*q2*g/r0 + 3*q0*q2/r0 - q1p2*g/r0 + 3*q1p2/r0 + e2*g - e2
        q_tt = q0p3*r2*gp2/(2*r0p3) + q0p3*r2*g/r0p3 - 7*q0p3*r2/(2*r0p3) - 3*q0p3*r1p2*gp2/(2*r0p4) - 3*q0p3*r1p2*g/r0p4 + 21*q0p3*r1p2/(2*r0p4) + 5*q0p2*q1*r1*gp2/(2*r0p3) + 8*q0p2*q1*r1*g/r0p3 - 45*q0p2*q1*r1/(2*r0p3) - q0p2*q2*gp2/(2*r0p2) - 5*q0p2*q2*g/(2*r0p2) + 6*q0p2*q2/r0p2 - q0*q1p2*gp2/r0p2 - 5*q0*q1p2*g/r0p2 + 12*q0*q1p2/r0p2 + 3*q0*e2*g/r0 - 3*q0*e2/r0 - q0*r1*e1*gp2/r0p2 - 2*q0*r1*e1*g/r0p2 + 3*q0*r1*e1/r0p2 - q0*r2*e0*gp2/r0p2 + q0*r2*e0*g/r0p2 + 2*q0*r1p2*e0*gp2/r0p3 - 2*q0*r1p2*e0*g/r0p3 + q1*e1*gp2/r0 + 2*q1*e1*g/r0 - 3*q1*e1/r0 - 2*q1*r1*e0*gp2/r0p2 + 2*q1*r1*e0*g/r0p2 + q2*e0*gp2/r0 - q2*e0*g/r0
        E_tt = q0p4*r2*gp2/(4*r0p4) + 2*q0p4*r2*g/r0p4 - 9*q0p4*r2/(4*r0p4) - q0p4*r1p2*gp2/r0**5 - 8*q0p4*r1p2*g/r0**5 + 9*q0p4*r1p2/r0**5 + q0p3*q1*r1*gp2/r0p4 + 37*q0p3*q1*r1*g/(2*r0p4) - 39*q0p3*q1*r1/(2*r0p4) - 7*q0p3*q2*g/(2*r0p3) + 7*q0p3*q2/(2*r0p3) - 21*q0p2*q1p2*g/(2*r0p3) + 21*q0p2*q1p2/(2*r0p3) - q0p2*e2*gp2/(2*r0p2) + 3*q0p2*e2*g/r0p2 - 3*q0p2*e2/(2*r0p2) + q0p2*r1*e1*gp2/(2*r0p3) - 15*q0p2*r1*e1*g/(2*r0p3) + 3*q0p2*r1*e1/r0p3 - q0p2*r2*e0*gp2/(2*r0p3) - 3*q0p2*r2*e0*g/(2*r0p3) + 3*q0p2*r1p2*e0*gp2/(2*r0p4) + 9*q0p2*r1p2*e0*g/(2*r0p4) - q0*q1*e1*gp2/r0p2 + 8*q0*q1*e1*g/r0p2 - 3*q0*q1*e1/r0p2 - q0*q1*r1*e0*gp2/r0p3 - 7*q0*q1*r1*e0*g/r0p3 + 2*q0*q2*e0*g/r0p2 + 2*q1p2*e0*g/r0p2 + e0*e2*gp2/r0 - e0*e2*g/r0 + e1p2*gp2/r0 - e1p2*g/r0 - r1*e0*e1*gp2/r0p2 + r1*e0*e1*g/r0p2
        return np.array([rho_tt, q_tt, E_tt], dtype=np.float64)

    else:
        raise Exception("Invalid order")

它还远未达到完美,但我无法找到所有子表达式的通用模式(开发表达式比分解它更简单)。

这个函数在我的机器上构建需要 1.55 秒(执行时间相同)。这是 2.4 更快的构建

您可以使用编译标志进一步减少此开销

cache=True


这是用于测量代码性能的设置:

import numpy as np
import numpy.typing as npt
import numba as nb

%%time
@nb.njit('(float64[::1], float64[::1], float64[::1], int64, float64)')
def cauchyKovalevskaya(drho_x: npt.NDArray[np.float64], dq_x: npt.NDArray[np.float64], dE_x: npt.NDArray[np.float64], k: int, gamma: float) -> npt.NDArray:
    # [...]

drho_x = np.random.rand(3)
dq_x = np.random.rand(3)
dE_x = np.random.rand(3)
gamma = 42.0
%timeit -n 20_000 cauchyKovalevskaya(drho_x, dq_x, dE_x, 0, gamma)
%timeit -n 20_000 cauchyKovalevskaya(drho_x, dq_x, dE_x, 1, gamma)
%timeit -n 20_000 cauchyKovalevskaya(drho_x, dq_x, dE_x, 2, gamma)

关于执行时间,这是我机器上的计时(带有急切的编译):

1.27 µs ± 8.18 ns per loop (mean ± std. dev. of 7 runs, 20000 loops each)
1.28 µs ± 16 ns per loop (mean ± std. dev. of 7 runs, 20000 loops each)
1.37 µs ± 24.4 ns per loop (mean ± std. dev. of 7 runs, 20000 loops each)

>90% 的时间来自开销(例如创建输出数组、检查输入类型并将其转换为本机类型、从 CPython 调用本机函数)而不是计算。预分配输出有助于稍微减少执行时间。话虽这么说,只要从 CPython 调用此函数,就无法大幅减少执行时间。这是主要瓶颈。您可以使用 Numba 编译调用者函数,以显着减少此开销。 在我的机器上,看起来计算只需要 0.1 µs。

使用 Numba 编译调用方函数后,请注意,您可以添加编译标志

fastmath=True

以加速数学计算。话虽这么说,如果值可以是

NaN

-inf
+inf
、次法线,或者即使操作的关联性很重要,则此选项是有害的。有关这方面的更多信息,请阅读
这篇文章
。在这种情况下我不推荐它,因为我不希望修改后的代码有很大的加速。

© www.soinside.com 2019 - 2024. All rights reserved.