如何在 numba 中为小 N 创建长度为 N 的元组(或者如何快速计算一维索引和多维索引之间的双向变化)

问题描述 投票:0回答:1

问题:

这是一个简单的函数,适用于 numpy 但不适用于 numba:

    # @numba.jit(nopython=True, fastmath=False, parallel=False)
    def testgetvalue(tgvarray, tgvindex):
        tgvalue = tgvarray[tuple(tgvindex)]
        return tgvalue

如何制作在 numba 中运行的此函数的版本?

我尝试过:

    @numba.jit(nopython=True, fastmath=False, parallel=False)
    def testgetvalue2(tgvarray, tgvindex):
        tgvalue = tgvarray[tuple(tgvindex)]
        currentdex = tgvindex[0]
        tgvtemp = tgvarray[currentdex]
        for idx in range(1, len(tgvindex)):
            currentdex = tgvindex[idx]
            tgvtemp =  tgvtemp[currentdex]
        return tgvalue

但这在 numba 中也失败了

我发现这个问题,其中答案说可以:

更一般地,您无法生成 N 元组,其中 N 在 Numba 函数中是变量。但是,您可以改为为特定 N 生成并编译函数。如果 N 非常小(例如 <15)

这可以解决我的问题,但答案并没有解释如何生成然后编译特定 N 的函数...除非建议我编写一个脚本来生成一个 .py 文件,该文件用 jit 定义一个函数然后我可以从 .py 文件中调用装饰器。鉴于尺寸变化的频率非常低,我想可能会起作用。我不确定这是否符合最佳实践,但我将开始编写一个脚本来生成 .py 文件,直到有人回答为止。

在建议我复制类似这个问题之类的内容之前,请继续阅读,因为我真正的问题不一定涉及元组,即括号中问题的版本。

我为什么问这个问题: 我有代码,其中数组对象的维数有时会随着 1 到 15 维之间的维数变化。但一旦发生变化,就会对该多维数组进行数以万计的重复操作。对于其中许多操作,我希望能够使用索引数组来修改多维数组中某个位置的值。

这引出了括号中的另一个问题:

在我的代码的早期版本中,我通过获取每个维度中的大小数组并执行以下操作,将多维索引转换为一维索引:

    multipliers = np.cumprod(array_of_sizes_in_each_dimension)
    multipliers = np.roll(multipliers, 1)
    multipliers[0] = 1

然后,我可以将多维索引中的每个值乘以

multipliers
中的相应值以获得一维索引。当我从多维索引转到一维索引时,这工作得很好。但是,对于需要从一维索引转换为多维索引的情况,我无法想到 fast 函数。目前,我最快的版本是构建一个查找表,即大小为
two_dimensional_array
X
np.prod(array_of_sizes_in_each_dimension)
len(array_of_sizes_in_each_dimension)
,其中列出了所有多维索引,以便
two_dimensional_array[one_dimensional_index]
返回相应的多维索引。当我的多维数组恰好只有几个维度并且每个维度都很短时,这种方法效果很好。然而,随着维数的增加,我最终会发现
two_dimensional_array
太大了,以至于出现了内存瓶颈,并且代码速度减慢了几个数量级(例如 3 维数组需要 8 分钟,一个 3 维数组需要 8 天) 11 维数组)。因此,如果有人有一个快速函数来替换
two_dimensional_array
查找表的想法,那也可以解决我的问题。

非常感谢您的帮助!只要向我指出一些相关文档的方向,我将不胜感激。谢谢!

python numpy numba numpy-slicing
1个回答
0
投票

如果使用 numba.np.unsafe.ndarray

compile
时已知元组的大小,实际上可以在 numba 中创建元组。这个函数被标记为不安全,我在文档中找不到它的任何痕迹,所以我不知道它的长期可靠性如何,但我发现它对于 numba
0.58
0.60
非常有效。例如,您可以将其用作形状来初始化 ndim 数组。

import numpy as np
import numba

@numba.njit
def numba_make_tuple(old_tuple):
    CONST = len(old_tuple)
    a = np.empty((CONST,), dtype=np.int64)
    for i in range(CONST):
        a[i] = i

    new_tuple = numba.np.unsafe.ndarray.to_fixed_tuple(a, CONST)
    return new_tuple


t1 = (1,)
nt1 = numba_make_tuple(t1)
print(nt1, type(t1))

t2 = (2,3,1)
nt2 = numba_make_tuple(t2)
print(nt2, type(t2))

请注意,我使用元组作为输入来强制了解

CONST
的编译时知识。我在官方文档中找不到任何内容,但对此进行了讨论,例如在https://github.com/numba/numba/issues/8812

© www.soinside.com 2019 - 2024. All rights reserved.