我有一个 numpy 结构化数组,并将其中的一个元素传递给一个函数,如下所示。
from numba import njit
import numpy as np
dtype = np.dtype([
("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4"),
])
a = np.array([(1, b"24q1", 1.0)], dtype=dtype)
@njit
def upsert_numba(a, sid, qtrnm, val):
a[1] = qtrnm
a[2] = val
#i = 0
#a[i+1] = qtrnm
#a[i+2] = val
return a
x = (1, b"24q2", 3.0)
print(upsert_numba(a[0].copy(), *x))
上面的代码运行没有问题。但如果通过注释掉的代码进行更新,即 i=0;a[i+1]=qtrnm;a[i+2]=val,numba 会给出以下错误。
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(Record(id[type=int32;offset=0],qtrnm0[type=[char x 4];offset=4],qtr0[type=float32;offset=8];12;False), int64, readonly bytes(uint8, 1d, C))
似乎只有常量才允许索引,该常量可以是编译时已知的整数或 CharSeq,但不能是常量上的表达式,但常量在编译时也已知。我可以知道幕后发生了什么吗?
我尝试过其他常量作为索引,例如“j=i; a[j]”,这也有效。但不出所料,“j=i+1;a[j]”失败了。
考虑以下功能。
from numba import njit
@njit
def func():
i = 777
t = i
return t
@njit
def func2():
i = 776
t = i + 1
return t
您可以使用以下方法检查如何推断每个变量的类型。
func()
func.inspect_types()
这是关键行:
# i = const(int, 777) :: Literal[int](777)
# t = i :: Literal[int](777)
::
之后的部分是变量的类型。
这表明 i
和 t
都是整数文字类型。
接下来,对于
func2
:
func2()
func2.inspect_types()
# i = const(int, 776) :: Literal[int](776)
# t = i + $const10.2 :: int64
与
func
相比,您可以看到 t
被推断为 int64
而不是整数文字类型。
这意味着,numba 在优化之前对代码执行类型推断。
这是一个合理的选择。
优化需要类型化代码,但生成类型化代码需要类型推断。
所以首先对Python字节码进行类型推断,然后根据推断出的类型进行优化。
有关此流程的更准确和详细的信息,请参阅官方文档
额外说明,numba 不支持使用非文字变量对记录进行索引。 然而,通过重载显式定义映射是可能的。
from operator import setitem
import numpy as np
from numba import njit, types
from numba.core.extending import overload
a_dtype = np.dtype([("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4")])
@overload(setitem)
def setitem_overload_for_a(a, index, value):
if getattr(a, "dtype", None) != a_dtype:
return None
if isinstance(value, (types.Integer, types.Float)):
def numeric_impl(a, index, value):
# You need to map these indexes correctly according to the dtype.
if index == 0:
a[0] = value
elif index == 2:
a[2] = value
else:
raise ValueError()
return numeric_impl
elif isinstance(value, (types.Bytes, types.CharSeq)):
def bytes_impl(a, index, value):
if index == 1:
a[1] = value
else:
raise ValueError()
return bytes_impl
else:
raise TypeError(f"Unsupported type: {index=}, {value=}, {a.dtype=}")
@njit
def upsert_numba(a, sid, qtrnm, val):
i = 0
a[i + 1] = qtrnm
a[i + 2] = val
return a
x = (1, b"24q2", 3.0)
a = np.array([(1, b"24q1", 1.0)], dtype=a_dtype)
print(upsert_numba(a[0].copy(), *x)) # (1, b'24q2', 3.)
请注意,这是一种临时策略,要求您为每种记录类型对 setitem 进行硬编码,并且在某些情况下可能不起作用。也就是说,它应该有效,除非你正在做一些非常棘手的事情。