在我的代码中的某个时刻,我调用 Numba 函数,并且所有后续计算都使用 Numba jitted 函数进行,直到后处理步骤。
在过去的几天里,我一直在寻找一种有效的方法来将所有变量(主要是布尔值、整数、浮点数和浮点数组)发送到代码的 Numba 部分,同时尝试保持代码的可读性和清晰性。就我而言,这意味着限制参数的数量,并且如果可能的话,根据它们引用的系统重新组合一些变量。
我确定了四种方法来做到这一点:
cache=True
选项中受益。这对我来说是一个大问题,因为编译时间可能超过代码本身的执行时间。namedtuples
相同,如果来自 @jitclass
的对象在非内部初始化jitted 函数,我发现不可能从 cache=True
选项中受益(参见 this post)。最终,这四种选择都不允许我做我想做的事。我可能错过了一些东西......
这就是我最终所做的:我结合使用常规 python 类和 Numba
@jitclass
,以保持从 cache=True
选项中受益的可能性。
这是我的mwe:
import numba as nb
from numba import jit
from numba.experimental import jitclass
spec_cls = [
('a', nb.types.float64),
('b', nb.types.float64),
]
# python class
class ClsWear_py(object):
def __init__(self, a, b):
self.a = a
self.b = b
# mirror Numba class
@jitclass(spec_cls)
class ClsWear(object):
def __init__(self, a, b):
self.a = a
self.b = b
def function_python(obj):
print('from the python class :', obj.a)
# call of a Numba function => this is where I must list explicitly all the keys of the python class object
oa, ob = function_numba(obj.a, obj.b)
return obj, oa, ob
@jit(nopython=True)
def function_numba(oa, ob):
# at the beginning of the Numba function, the arguments are used to define the @jitclass object
obj_nb = ClsWear(oa, ob)
print('from the numba class :', obj_nb.a)
return obj_nb.a, obj_nb.b
# main code :
obj_py = ClsWear_py(11,22)
obj_rt, a, b = function_python(obj_py)
这段代码的输出是:
$ python mwe.py
from the python class : 11
from the numba class : 11.0
好的一面:
cache=True
正在工作但也有缺点:
我错过了什么吗?有没有更明显的方法来做到这一点?
根据您的研究(我同意),似乎没有一种直观的方法可以做到这一点。 所以我将建议最不可怕的解决方法。
想法是将其作为元组传递给 jit 函数,然后将其转换为 jit 函数中所需的类。
from typing import NamedTuple
from numba import njit
class Config(NamedTuple):
a: int
b: float
c: str
@njit(cache=True)
def f(config_values):
config = Config(*config_values) # Convert to a named tuple.
return config.a
def main():
config = Config(1, 2.0, "3")
f(tuple(config)) # Pass as a tuple.
print("Cache:", f.stats)
main()
结果(第二次运行):
Cache: _CompileStats(cache_path=..., cache_hits=Counter({(Tuple((int64, float64, unicode_type)),): 1}), cache_misses=Counter())
如您所见,它已正确缓存为元组。
此解决方法的问题之一是命名元组是只读的。所以你不能修改字段。 在这种情况下,您可以使用 jitclass 执行相同的操作。 另请注意,您可以从其 Python 类对应项为 numba 创建镜像类,因为
@jitclass
是常规 Python 装饰器。
from numba import njit
from numba.experimental import jitclass
class Config:
# These type hints are interpreted as the default specs for jitclass.
# If you only use primitive types, this should be sufficient.
a: int
b: float
c: str
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
def as_tuple(self):
return self.a, self.b, self.c
JitConfig = jitclass(Config)
# _JitConfig = jitclass(specs)(Config) if you need to specify the specs.
@njit(cache=True)
def f(config_values):
config = JitConfig(*config_values) # Convert to a jitclass.
return config.a
def main():
# Since we use a tuple as an argument, it is not mandatory to use a jitclass here.
config = Config(1, 2.0, "3")
f(config.as_tuple()) # Pass as a tuple.
print("Cache:", f.stats)
main()
您还可以使用
overload
完全隐藏 jitclass。
请注意,在下面的代码中,不再需要显式使用 jitclass。
from numba import njit
from numba.core.extending import overload
from numba.experimental import jitclass
class Config:
a: int
b: float
c: str
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
def as_tuple(self):
return self.a, self.b, self.c
_JitConfig = jitclass(Config)
@overload(Config, strict=False)
def overload_config_init(*args):
def jit_config_init(*args):
return _JitConfig(*args)
return jit_config_init
@njit(cache=True)
def f(config_values):
# Since it will be overloaded to a jitclass, you can use a Python class here.
config = Config(*config_values)
return config.a
def main():
config = Config(1, 2.0, "3")
f(config.as_tuple())
print("Cache:", f.stats)
main()
至于性能,应该可以忽略不计。 这是基准:
import timeit
from typing import NamedTuple
from numba import njit
from numba.core.extending import overload
from numba.experimental import jitclass
class NamedTupleContainer(NamedTuple):
a0: int
a1: int
a2: int
a3: int
a4: int
a5: float
a6: float
a7: float
a8: float
a9: float
class JitclassContainer:
a0: int
a1: int
a2: int
a3: int
a4: int
a5: float
a6: float
a7: float
a8: float
a9: float
def __init__(self, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
self.a0 = a0
self.a1 = a1
self.a2 = a2
self.a3 = a3
self.a4 = a4
self.a5 = a5
self.a6 = a6
self.a7 = a7
self.a8 = a8
self.a9 = a9
def as_tuple(self):
return (
self.a0,
self.a1,
self.a2,
self.a3,
self.a4,
self.a5,
self.a6,
self.a7,
self.a8,
self.a9,
)
_JitclassContainer = jitclass(JitclassContainer)
@overload(JitclassContainer)
def overload_container_init(*args):
def jit_container_init(*args):
return _JitclassContainer(*args)
return jit_container_init
@njit(cache=True)
def f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
return a0
@njit(cache=True)
def f_namedtuple(args):
c = NamedTupleContainer(*args)
return c.a0
@njit(cache=True)
def f_jitclass(args):
c = JitclassContainer(*args)
return c.a0
def main():
def benchmark(f):
n_runs = 10000
return min(timeit.repeat(f, repeat=100, number=n_runs)) / n_runs
values = 1, 2, 3, 4, 5, 6.0, 7.0, 8.0, 9.0, 10.0
a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = values
named_tuple_container = NamedTupleContainer(*values)
jitclass_container = JitclassContainer(*values)
t = benchmark(lambda: f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9))
print(f"f_multi_args: {t * 10 ** 9:.0f} ns")
t = benchmark(lambda: f_namedtuple(tuple(named_tuple_container)))
print(f"f_namedtuple: {t * 10 ** 9:.0f} ns")
t = benchmark(lambda: f_jitclass(jitclass_container.as_tuple()))
print(f"f_jitclass : {t * 10 ** 9:.0f} ns")
main()
结果:
f_multi_args: 525 ns
f_namedtuple: 695 ns
f_jitclass : 652 ns
在我的 PC 上,每个具有 10 个参数/字段的函数调用的差异不到 200 纳秒。