使用 Numba 加速的 Python 代码时排除 MPI4py 错误

问题描述 投票:0回答:1

我正在尝试使用 Numba 加速一段 Python 代码。在此代码中,我还使用 mpi4py 进行并行化。但是,我遇到了一个错误。我试图在下面提供一个最小的错误重现示例:

import numpy as np
from numba import njit, prange
from mpi4py import MPI

#####MPI setting#####
comm=MPI.COMM_WORLD
rank=comm.Get_rank()
size=comm.Get_size()
N_theta_to_scan=1000
n_per_proc=int(N_theta_to_scan/size)
n_more=int(N_theta_to_scan%size)
if rank<n_more:
    start=rank*(n_per_proc+1)
    number_to_cal=n_per_proc+1
else:
    start=n_more*(n_per_proc+1)+(rank-n_more)*n_per_proc
    number_to_cal=n_per_proc

#####function#####
@njit(parallel=True)
def func1(na,nb,nc):
    to_sum=np.zeros(na*nb*nc)
    for a in prange(0,na):
        for b in prange(0,nb):
            for c in prange(0,nc):
                to_sum[a*nb*nc+b*nc+c]=a*b*c
    out=np.sum(to_sum)
    return out

@njit(parallel=True)
def func2(start,number_to_cal):
    to_sum=np.zeros(number_to_cal)
    for i in prange(start,start+number_to_cal):
        to_sum[i-start]=func1(i,i*i,i*i*i)
    out2=np.sum(to_sum)
    return out2

#####main section#####
to_be_gather=np.array([func2(start,number_to_cal)])
gatheres=np.zeros(0)
comm.Reduce(to_be_gather,gatheres,op=MPI.SUM)

这一段代码出现错误:

MemoryError: Allocation failed (probably too large).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/share/workspace/wuliang/hanlin/test/test.py", line 38, in <module>
    to_be_gather=func2(start,number_to_cal)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
SystemError: CPUDispatcher(<function func2 at 0x2aed19d7b420>) returned a result with an exception set

接下来,我尝试删除

func1()
并修改
func2()
,将
to_sum[i-start]=func1(i,i*i,i*i*i)
替换为
to_sum[i-start]=i
,结果执行成功。另外,我尝试运行代码时只使用rank 0执行main函数,即在
if rank==0:
条件下运行代码,也运行成功。我想知道错误出在哪里,我应该进行哪些修改才能实现当前代码的功能?

python numpy mpi numba mpi4py
1个回答
0
投票

参考原始问题的代码以及有关耦合 Numba 和 MPI 的注释;通过

numba-mpi
包可以在 Numba 编译代码(也可以使用 parallel=True)中使用 MPI:https://pypi.org/project/numba-mpi/。这是 README 包中的并行计算 pi 示例,其中在 @numba.jit 修饰函数中执行缩减(使用 mpi4py,缩减需要在 JIT 编译的代码之外完成,如问题中所示) ):

import numba, numpy as np, numba_mpi

N_TIMES = 10000

@numba.jit
def get_pi_part(n_intervals=1000000, rank=0, size=1):
    h = 1 / n_intervals
    partial_sum = 0.0
    for i in range(rank + 1, n_intervals, size):
        x = h * (i - 0.5)
        partial_sum += 4 / (1 + x**2)
    return h * partial_sum

@numba.jit
def pi_numba_mpi(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        part[0] = get_pi_part(n_intervals, numba_mpi.rank(), numba_mpi.size())
        numba_mpi.allreduce(part, pi, numba_mpi.Operator.SUM)

自述文件包包含与 mpi4py 的性能比较:https://github.com/numba-mpi/numba-mpi?tab=readme-ov-file#example-comparing-numba-mpi-vs-mpi4py-performance

© www.soinside.com 2019 - 2024. All rights reserved.