问题:
这是一个简单的函数,适用于 numpy 但不适用于 numba:
# @numba.jit(nopython=True, fastmath=False, parallel=False)
def testgetvalue(tgvarray, tgvindex):
tgvalue = tgvarray[tuple(tgvindex)]
return tgvalue
如何制作在 numba 中运行的此函数的版本?
我尝试过:
@numba.jit(nopython=True, fastmath=False, parallel=False)
def testgetvalue2(tgvarray, tgvindex):
tgvalue = tgvarray[tuple(tgvindex)]
currentdex = tgvindex[0]
tgvtemp = tgvarray[currentdex]
for idx in range(1, len(tgvindex)):
currentdex = tgvindex[idx]
tgvtemp = tgvtemp[currentdex]
return tgvalue
但这在 numba 中也失败了
我发现这个问题,其中答案说可以:
更一般地,您无法生成 N 元组,其中 N 在 Numba 函数中是变量。但是,您可以改为为特定 N 生成并编译函数。如果 N 非常小(例如 <15)
这可以解决我的问题,但答案并没有解释如何生成然后编译特定 N 的函数...除非建议我编写一个脚本来生成一个 .py 文件,该文件用 jit 定义一个函数然后我可以从 .py 文件中调用装饰器。鉴于尺寸变化的频率非常低,我想可能会起作用。我不确定这是否符合最佳实践,但我将开始编写一个脚本来生成 .py 文件,直到有人回答为止。
在建议我复制类似这个问题之类的内容之前,请继续阅读,因为我真正的问题不一定涉及元组,即括号中问题的版本。
我为什么问这个问题: 我有代码,其中数组对象的维数有时会随着 1 到 15 维之间的维数变化。但一旦发生变化,就会对该多维数组进行数以万计的重复操作。对于其中许多操作,我希望能够使用索引数组来修改多维数组中某个位置的值。
这引出了括号中的另一个问题:
在我的代码的早期版本中,我通过获取每个维度中的大小数组并执行以下操作,将多维索引转换为一维索引:
multipliers = np.cumprod(array_of_sizes_in_each_dimension)
multipliers = np.roll(multipliers, 1)
multipliers[0] = 1
然后,我可以将多维索引中的每个值乘以
multipliers
中的相应值以获得一维索引。当我从多维索引转到一维索引时,这工作得很好。但是,对于需要从一维索引转换为多维索引的情况,我无法想到 fast 函数。目前,我最快的版本是构建一个查找表,即大小为 two_dimensional_array
X np.prod(array_of_sizes_in_each_dimension)
的 len(array_of_sizes_in_each_dimension)
,其中列出了所有多维索引,以便 two_dimensional_array[one_dimensional_index]
返回相应的多维索引。当我的多维数组恰好只有几个维度并且每个维度都很短时,这种方法效果很好。然而,随着维数的增加,我最终会发现 two_dimensional_array
太大了,以至于出现了内存瓶颈,并且代码速度减慢了几个数量级(例如 3 维数组需要 8 分钟,一个 3 维数组需要 8 天) 11 维数组)。因此,如果有人有一个快速函数来替换 two_dimensional_array
查找表的想法,那也可以解决我的问题。
非常感谢您的帮助!只要向我指出一些相关文档的方向,我将不胜感激。谢谢!
如果使用 numba.np.unsafe.ndarray
在
compile时已知元组的大小,实际上可以在 numba 中创建元组。这个函数被标记为不安全,我在文档中找不到它的任何痕迹,所以我不知道它的长期可靠性如何,但我发现它对于 numba
0.58
到 0.60
非常有效。例如,您可以将其用作形状来初始化 ndim 数组。
import numpy as np
import numba
@numba.njit
def numba_make_tuple(old_tuple):
CONST = len(old_tuple)
a = np.empty((CONST,), dtype=np.int64)
for i in range(CONST):
a[i] = i
new_tuple = numba.np.unsafe.ndarray.to_fixed_tuple(a, CONST)
return new_tuple
t1 = (1,)
nt1 = numba_make_tuple(t1)
print(nt1, type(t1))
t2 = (2,3,1)
nt2 = numba_make_tuple(t2)
print(nt2, type(t2))
请注意,我使用元组作为输入来强制了解
CONST
的编译时知识。我在官方文档中找不到任何内容,但对此进行了讨论,例如在https://github.com/numba/numba/issues/8812