使用 1024 个线程时,我在 numba cuda 中的累计总和给出了错误的结果

问题描述 投票:0回答:1

我正在尝试通过并行计算在 numba cuda 中实现累积和。此函数将采用一维数组 ( a = [1, 1, 1, 1]) 并对其进行计算,以便数组的每个元素都是所有先前元素的总和 ( ans = [1, 2, 3, 4])。这是我实现的功能:

from numba import cuda
import math
import numpy as np

@cuda.jit
def c_sum(array):

   tidx = cuda.threadIdx.x
   nb_it = math.log2(cuda.blockDim.x)
   if tidx < array.size:
       for l in range(nb_it):
           if tidx > 2**l-1:
               array[tidx] += array[tidx-(2**l)]
           cuda.syncthreads()

input_ar = np.ones([1024]) 
c_sum[1,1024](input_ar)

该函数起作用的要求是线程数大于或等于数组的大小。我已经对大小为 16,32,64,128 到 512 的数组进行了测试,它工作得很好!由于某种原因,它不适用于达到 1024 的数组,不幸的是,这正是我的应用程序所需要的。 1024应该是线程块中允许的最大线程数。数组的最后 5 个元素应为 [1020,1021,1022,1023,1024]。这是我得到的:

print(input_ar[-5:])
[1352. 1348. 1344. 1348. 1344.]

它实际上每次都在变化。有人有想法吗?

python cuda numba cumulative-sum
1个回答
0
投票

您遇到了全局内存竞争状况。 线程并非全部同步执行。 因此,这一行会导致不明确的行为:

array[tidx] += array[tidx-(2**l)]

例如,如果线程 32 完全完成其工作,然后线程 31 运行怎么办? 相反,如果线程 31 完全完成其工作然后线程 32 运行会怎样? 对于这两种不同的情况,您会期望得到两种不同的结果。

您可以通过使用适当的屏障将货物与商店分开来解决此问题:

$ cat t3.py
from numba import cuda
import math
import numpy as np

@cuda.jit
def c_sum(array):

   tidx = cuda.threadIdx.x
   nb_it = math.log2(cuda.blockDim.x)
   if tidx < array.size:
       for l in range(nb_it):
           if tidx > 2**l-1:
               c = array[tidx-(2**l)]
           cuda.syncthreads()
           if tidx > 2**l-1:
               array[tidx] += c
           cuda.syncthreads()

input_ar = np.ones([1024])
c_sum[1,1024](input_ar)
print(input_ar[-5:])
$ python3 t3.py
/home/bob/.local/lib/python3.10/site-packages/numba/cuda/dispatcher.py:536: NumbaPerformanceWarning: Grid size 1 will likely result in GPU under-utilization due to low occupancy.
  warn(NumbaPerformanceWarning(msg))
/home/bob/.local/lib/python3.10/site-packages/numba/cuda/cudadrv/devicearray.py:888: NumbaPerformanceWarning: Host array used in CUDA kernel will incur copy overhead to/from device.
  warn(NumbaPerformanceWarning(msg))
[1020. 1021. 1022. 1023. 1024.]
$

这种方法不能直接扩展到单个线程块之外,因为

cuda.syncthreads()
只是块中线程的执行屏障。 要超越单个块,您需要使用网格范围(设备范围)的屏障,或者使用替代公式来获取前缀和。

© www.soinside.com 2019 - 2024. All rights reserved.