我正在尝试使用 numba 的 njit 装饰器来优化方法。我收到一个错误,该错误似乎是由于定义具有多个元素的集合,然后用具有单个元素的集合覆盖它而引起的(反之亦然)。 我正在运行 python 3.10.12,numba 版本 0.60.0。
以下最小示例重现了错误
import numba
@numba.njit
def foo(n: int):
s = {1, 2, 3}
if n == 1:
pass
else:
s = {2} # This causes an error
# s = {2, 3, 4} # This works
res = sum(list(s))
return res
def main():
res = foo(2)
print(res)
if __name__ == '__main__':
main()
根据 numba 文档
JIT 编译函数支持集合上的所有方法和操作。
所以我不明白在这种情况下导致错误的原因。即使在没有事先转换为列表的情况下执行
sum(s)
也会引发错误。
错误的原因是什么,如何使其工作?
我分析了这个问题,这就是我的发现。
问题似乎源于根据集合包含的元素数量来解释集合的方式。 具体来说,当集合只有一个元素时,它被视为文字类型,而有两个或多个元素时,它被解释为非文字类型。 您可以通过以下方式验证如何解释类型:
from numba import njit
@njit
def func():
a = {1}
b = {1, 2}
return a, b
func()
func.inspect_types()
重点来了:
# a = build_set(items=[Var($const4.0, temp.py:6)]) :: set(Literal[int](1))
# b = build_set(items=[Var($const10.2, temp.py:7), Var($const12.3, temp.py:7)]) :: set(int64)
每个变量的类型在
::
符号后面指示。
如您所见,a
是 set
的 Literal[int]
,而 b
是 set
的 int64
。
正如无法将 float64
值分配给 int64
变量一样,两者不兼容。
现在我们知道了原因,解决方法很简单:防止一组元素被解释为文字类型。
例如,你可以这样做:
@njit
def func():
a = set((1,))
b = {1, 2}
return a, b
这将确保
a
和 b
都被解释为 set(int64)
。
# a = call $4load_global.0($const14.2, func=$4load_global.0, args=[Var($const14.2, temp.py:6)], kws=(), vararg=None, varkwarg=None, target=None) :: (UniTuple(int64, 1),) -> set(int64)
# b = build_set(items=[Var($const26.4, temp.py:7), Var($const28.5, temp.py:7)]) :: set(int64)
我对您的代码进行了相同的更改,现在它可以按预期工作。
import numba
@numba.njit
def foo(n: int):
s = {1, 2, 3}
if n == 1:
pass
else:
s = set((2,))
res = sum(list(s))
return res
def main():
res = foo(1)
print(res)
res = foo(2)
print(res)
if __name__ == '__main__':
main()
或者,如果您愿意,也可以这样实现。
@numba.njit
def non_literal_set(*args):
return set(args)
@numba.njit
def foo(n: int):
s = non_literal_set(1, 2, 3)
if n == 1:
pass
else:
s = non_literal_set(2)
res = sum(list(s))
return res