Numba 命名元组签名

问题描述 投票:0回答:2

我正在尝试在 Numba 中指定命名元组的返回类型,但我无法这样做。有人可以帮忙吗?考虑以下最少代码:

import numba as nb
from   collections import namedtuple

NT = namedtuple('NT',['sum','sum2'])

@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64],NT))(nb.int64,nb.float64[:,:]),fastmath=True)
def arrsum_njit(nn,xx):
    arraysum = 0.0
    out = NT(sum=arraysum,sum2=arraysum)
    return out

我收到错误

No conversion from NT(float64 x 2) to NT(float64, float64) for '$20return_value.7', defined at None

File "numbanamedtuple.py", line 10:
def arrsum_njit(nn,xx):
    <source elided>
    out = NT(sum=arraysum,sum2=arraysum)
    return out
    ^

During: typing of assignment at numbanamedtuple.py (10)

File "numbanamedtuple.py", line 10:
def arrsum_njit(nn,xx):
    <source elided>
    out = NT(sum=arraysum,sum2=arraysum)
    return out
python-3.x signature numba namedtuple
2个回答
3
投票

问题是“过度优化”的 numba 编译器(bug)。将不同类型的变量添加到元组中以告诉编译器使用异构元组(内部类)。

import numba as nb
import numpy as np
from collections import namedtuple

NT = namedtuple('NT',['sum','sum2','dummy'])

@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64,nb.int64],NT))(nb.int64,nb.float64[:,:]))
def arrsum_njit(nn,xx):
    arraysum = 0.0
    out = NT(sum=arraysum,sum2=arraysum,dummy=1)
    return out

arrsum_njit(1, np.array([[1.], [2.]]))
# >>> NT(sum=0.0, sum2=0.0, dummy=1)

2
投票

请使用

NamedUniTuple
代替。它是同质命名元组的 numba 规范类型。

© www.soinside.com 2019 - 2024. All rights reserved.