在 numba njit 方法中用单个元素分配集合时出现断言错误

问题描述 投票:0回答:1

我正在尝试使用 numba 的 njit 装饰器来优化方法。我收到一个错误,该错误似乎是由于定义具有多个元素的集合,然后用具有单个元素的集合覆盖它而引起的(反之亦然)。 我正在运行 python 3.10.12,numba 版本 0.60.0。

以下最小示例重现了错误

import numba


@numba.njit
def foo(n: int):
    s = {1, 2, 3}
    if n == 1:
        pass
    else:
        s = {2}  # This causes an error
        # s = {2, 3, 4}  # This works

    res = sum(list(s))
    return res


def main():
    res = foo(2)
    print(res)


if __name__ == '__main__':
    main()

根据 numba 文档

JIT 编译函数支持集合上的所有方法和操作。

所以我不明白在这种情况下导致错误的原因。即使在没有事先转换为列表的情况下执行

sum(s)
也会引发错误。

错误的原因是什么,如何使其工作?

python numba
1个回答
0
投票

我分析了这个问题,这就是我的发现。

问题似乎源于根据集合包含的元素数量来解释集合的方式。 具体来说,当集合只有一个元素时,它被视为文字类型,而有两个或多个元素时,它被解释为非文字类型。 您可以通过以下方式验证如何解释类型:

from numba import njit


@njit
def func():
    a = {1}
    b = {1, 2}
    return a, b


func()
func.inspect_types()

重点来了:

#   a = build_set(items=[Var($const4.0, temp.py:6)])  :: set(Literal[int](1))
#   b = build_set(items=[Var($const10.2, temp.py:7), Var($const12.3, temp.py:7)])  :: set(int64)

每个变量的类型在

::
符号后面指示。 如您所见,
a
set
Literal[int]
,而
b
set
int64
。 正如无法将
float64
值分配给
int64
变量一样,两者不兼容。

现在我们知道了原因,解决方法很简单:防止一组元素被解释为文字类型。

例如,你可以这样做:

@njit
def func():
    a = set((1,))
    b = {1, 2}
    return a, b

这将确保

a
b
都被解释为
set(int64)

#   a = call $4load_global.0($const14.2, func=$4load_global.0, args=[Var($const14.2, temp.py:6)], kws=(), vararg=None, varkwarg=None, target=None)  :: (UniTuple(int64, 1),) -> set(int64)
#   b = build_set(items=[Var($const26.4, temp.py:7), Var($const28.5, temp.py:7)])  :: set(int64)

我对您的代码进行了相同的更改,现在它可以按预期工作。

import numba


@numba.njit
def foo(n: int):
    s = {1, 2, 3}
    if n == 1:
        pass
    else:
        s = set((2,))

    res = sum(list(s))
    return res


def main():
    res = foo(1)
    print(res)
    res = foo(2)
    print(res)


if __name__ == '__main__':
    main()

或者,如果您愿意,也可以这样实现。

@numba.njit
def non_literal_set(*args):
    return set(args)


@numba.njit
def foo(n: int):
    s = non_literal_set(1, 2, 3)
    if n == 1:
        pass
    else:
        s = non_literal_set(2)

    res = sum(list(s))
    return res
© www.soinside.com 2019 - 2024. All rights reserved.