记录类型上的 Numba 索引(numpy 中的结构化数组)

问题描述 投票:0回答:1

我有一个 numpy 结构化数组,并将其中的一个元素传递给一个函数,如下所示。

from numba import njit
import numpy as np

dtype = np.dtype([
    ("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4"),
])
a = np.array([(1, b"24q1", 1.0)], dtype=dtype)

@njit
def upsert_numba(a, sid, qtrnm, val):
    a[1] = qtrnm
    a[2] = val
    #i = 0
    #a[i+1] = qtrnm
    #a[i+2] = val
    return a

x = (1, b"24q2", 3.0)
print(upsert_numba(a[0].copy(), *x))

上面的代码运行没有问题。但如果通过注释掉的代码进行更新,即 i=0;a[i+1]=qtrnm;a[i+2]=val,numba 会给出以下错误。

No implementation of function Function(<built-in function setitem>) found for signature:

 >>> setitem(Record(id[type=int32;offset=0],qtrnm0[type=[char x 4];offset=4],qtr0[type=float32;offset=8];12;False), int64, readonly bytes(uint8, 1d, C))

似乎只有常量才允许索引,该常量可以是编译时已知的整数或 CharSeq,但不能是常量上的表达式,但常量在编译时也已知。我可以知道幕后发生了什么吗?

我尝试过其他常量作为索引,例如“j=i; a[j]”,这也有效。但不出所料,“j=i+1;a[j]”失败了。

numba
1个回答
0
投票

考虑以下功能。

from numba import njit


@njit
def func():
    i = 777
    t = i
    return t


@njit
def func2():
    i = 776
    t = i + 1
    return t

您可以使用以下方法检查如何推断每个变量的类型。

func()
func.inspect_types()

这是关键行:

    #   i = const(int, 777)  :: Literal[int](777)
    #   t = i  :: Literal[int](777)

::
之后的部分是变量的类型。 这表明
i
t
都是整数文字类型。

接下来,对于

func2

func2()
func2.inspect_types()
    #   i = const(int, 776)  :: Literal[int](776)
    #   t = i + $const10.2  :: int64

func
相比,您可以看到
t
被推断为
int64
而不是整数文字类型。 这意味着,numba 在优化之前对代码执行类型推断。 这是一个合理的选择。 优化需要类型化代码,但生成类型化代码需要类型推断。 所以首先对Python字节码进行类型推断,然后根据推断出的类型进行优化。 有关此流程的更准确和详细的信息,请参阅

官方文档

综上所述,在Python字节码阶段需要一个常量变量。

额外说明,numba 不支持使用非文字变量对记录进行索引。 然而,通过重载显式定义映射是可能的。

from operator import setitem import numpy as np from numba import njit, types from numba.core.extending import overload a_dtype = np.dtype([("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4")]) @overload(setitem) def setitem_overload_for_a(a, index, value): if getattr(a, "dtype", None) != a_dtype: return None if isinstance(value, (types.Integer, types.Float)): def numeric_impl(a, index, value): # You need to map these indexes correctly according to the dtype. if index == 0: a[0] = value elif index == 2: a[2] = value else: raise ValueError() return numeric_impl elif isinstance(value, (types.Bytes, types.CharSeq)): def bytes_impl(a, index, value): if index == 1: a[1] = value else: raise ValueError() return bytes_impl else: raise TypeError(f"Unsupported type: {index=}, {value=}, {a.dtype=}") @njit def upsert_numba(a, sid, qtrnm, val): i = 0 a[i + 1] = qtrnm a[i + 2] = val return a x = (1, b"24q2", 3.0) a = np.array([(1, b"24q1", 1.0)], dtype=a_dtype) print(upsert_numba(a[0].copy(), *x)) # (1, b'24q2', 3.)

请注意,这是一种临时策略,要求您为每种记录类型对 setitem 进行硬编码,并且在某些情况下可能不起作用。也就是说,它应该有效,除非你正在做一些非常棘手的事情。

© www.soinside.com 2019 - 2024. All rights reserved.