我想使用nopython模式在Numba中运行递归函数。直到现在我才会收到错误。这是一个非常简单的代码,用户给出一个少于五个元素的元组,然后该函数创建另一个元组,并将新值添加到元组(在本例中为数字3)。这一过程一直重复,直到最后一个元组的长度为5.由于某些原因,这不起作用,不知道为什么。
@njit
def tup(a):
if len(a) == 5:
return a
else:
b = a + (3,)
b = tup(b)
return b
例如,如果a = (0,1)
,我希望最终的结果是元组(0,1,3,3,3)
。
编辑:我正在使用Numba 0.41.0,我得到的错误是内核死亡,'内核似乎已经死了。它会自动重启。
您不应该这样做有几个原因:
O(n)
操作,即使在numba。因此该功能的整体性能将是O(n**2)
。这可以通过使用支持O(1)
附加的数据结构或支持预分配大小的数据结构来改进。或者仅仅通过不使用“循环”或“递归”方法。njit
装饰器并传入包含6个元素的元组会发生什么? (提示:它将达到递归限制,因为它永远不会满足递归的结束条件)。在编写0.43.1时,Numba仅在递归之间参数类型不变时才支持简单递归。在你的情况下类型确实改变,你传入一个tuple(int64 x 2)
但递归调用试图传入一个不同类型的tuple(int64 x 3)
。奇怪的是,它在我的计算机上遇到了StackOverflow
- 这似乎是numba中的一个错误。
我的建议是使用它(没有numba,没有递归):
def tup(a):
if len(a) < 5:
a += (3, ) * (5 - len(a))
return a
这也返回了预期的结果:
>>> tup((1,))
(1, 3, 3, 3, 3)
>>> tup((1, 2))
(1, 2, 3, 3, 3)
根据this list of proposals在当前版本中的说法:
numba中的递归支持目前仅限于使用函数的显式类型注释的自递归。此限制来自无法确定递归调用的返回类型。
所以,请尝试:
from numba import jit
@jit()
def tup(a:tuple) -> tuple:
if len(a) == 5:
return a
return tup(a + (3,))
print(tup((0, 1)))
看看它是否适合你。