用于具有 numba 函数的代码的 python 类和 numba jitclass

问题描述 投票:0回答:1

在我的代码中的某个时刻,我调用 Numba 函数,并且所有后续计算都使用 Numba jitted 函数进行,直到后处理步骤。

在过去的几天里,我一直在寻找一种有效的方法来将所有变量(主要是布尔值、整数、浮点数和浮点数组)发送到代码的 Numba 部分,同时尝试保持代码的可读性和清晰性。就我而言,这意味着限制参数的数量,并且如果可能的话,根据它们引用的系统重新组合一些变量。

我确定了四种方法来做到这一点:

  1. 蛮力:将所有变量作为第一个调用的 Numba 函数的参数一一发送。我发现这个解决方案是不可接受的,因为它使代码几乎不可读(非常大的参数列表)并且与我重新分组变量的愿望不一致,
  2. Numba 类型字典(例如,参见这篇文章):我发现这个解决方案不可接受,因为我知道给定的字典只能包含类似类型的变量(例如 float64 的字典),而给定的系统可能有不同类型的相关变量。此外,我观察到使用此选项会显着降低性能(~ +10% 计算时间),
  3. Numbanamedtuples:相当容易实现和使用,但据我了解,这些只有在 Numba 函数中定义时才能有效使用,因此不能从非 jitted 函数/代码发送到 jitted 函数而不使其有效不可能从
    cache=True
    选项中受益。这对我来说是一个大问题,因为编译时间可能超过代码本身的执行时间。
  4. Numba @jitclass:我最初不愿意在我的代码中使用类,但事实证明它非常实用......但与
    namedtuples
    相同,如果来自
    @jitclass
    的对象在非内部初始化jitted 函数,我发现不可能从
    cache=True
    选项中受益(参见 this post)。

最终,这四种选择都不允许我做我想做的事。我可能错过了一些东西......

这就是我最终所做的:我结合使用常规 python 类和 Numba

@jitclass
,以保持从
cache=True
选项中受益的可能性。

这是我的mwe:

import numba as nb
from numba import jit
from numba.experimental import jitclass

spec_cls = [
    ('a', nb.types.float64),
    ('b', nb.types.float64),
]

# python class
class ClsWear_py(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b

# mirror Numba class
@jitclass(spec_cls)
class ClsWear(object):
    def __init__(self, a, b):
        self.a = a
        self.b = b
   
def function_python(obj):
    print('from the python class :', obj.a)
    # call of a Numba function => this is where I must list explicitly all the keys of the python class object
    oa, ob = function_numba(obj.a, obj.b)
    return obj, oa, ob

@jit(nopython=True)
def function_numba(oa, ob):
# at the beginning of the Numba function, the arguments are used to define the @jitclass object
    obj_nb = ClsWear(oa, ob)
    print('from the numba class :', obj_nb.a)
    return obj_nb.a, obj_nb.b

# main code :  
obj_py = ClsWear_py(11,22)

obj_rt, a, b = function_python(obj_py)

这段代码的输出是:

$ python mwe.py 
from the python class : 11
from the numba class : 11.0

好的一面:

  • 我在 python 和 Numba 中有一个干净的数据结构(使用类)
  • 我保持快速运行的代码并且
    cache=True
    正在工作

但也有缺点:

  • 我必须在 Numba 中定义 python 类及其镜像
  • 代码中仍然有一个几乎难以阅读的部分:第一次调用 jitted 函数,其中必须显式列出我的对象的所有内容

我错过了什么吗?有没有更明显的方法来做到这一点?

python class numba
1个回答
0
投票

根据您的研究(我同意),似乎没有一种直观的方法可以做到这一点。 所以我将建议最不可怕的解决方法。

想法是将其作为元组传递给 jit 函数,然后将其转换为 jit 函数中所需的类。

from typing import NamedTuple
from numba import njit


class Config(NamedTuple):
    a: int
    b: float
    c: str


@njit(cache=True)
def f(config_values):
    config = Config(*config_values)  # Convert to a named tuple.
    return config.a


def main():
    config = Config(1, 2.0, "3")
    f(tuple(config))  # Pass as a tuple.
    print("Cache:", f.stats)


main()

结果(第二次运行):

Cache: _CompileStats(cache_path=..., cache_hits=Counter({(Tuple((int64, float64, unicode_type)),): 1}), cache_misses=Counter())

如您所见,它已正确缓存为元组。

此解决方法的问题之一是命名元组是只读的。所以你不能修改字段。 在这种情况下,您可以使用 jitclass 执行相同的操作。 另请注意,您可以从其 Python 类对应项为 numba 创建镜像类,因为

@jitclass
是常规 Python 装饰器。

from numba import njit
from numba.experimental import jitclass


class Config:
    # These type hints are interpreted as the default specs for jitclass.
    # If you only use primitive types, this should be sufficient.
    a: int
    b: float
    c: str

    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c

    def as_tuple(self):
        return self.a, self.b, self.c


JitConfig = jitclass(Config)
# _JitConfig = jitclass(specs)(Config) if you need to specify the specs.


@njit(cache=True)
def f(config_values):
    config = JitConfig(*config_values)  # Convert to a jitclass.
    return config.a


def main():
    # Since we use a tuple as an argument, it is not mandatory to use a jitclass here.
    config = Config(1, 2.0, "3")

    f(config.as_tuple())  # Pass as a tuple.
    print("Cache:", f.stats)


main()

您还可以使用

overload
完全隐藏 jitclass。 请注意,在下面的代码中,不再需要显式使用 jitclass。

from numba import njit
from numba.core.extending import overload
from numba.experimental import jitclass


class Config:
    a: int
    b: float
    c: str

    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c

    def as_tuple(self):
        return self.a, self.b, self.c


_JitConfig = jitclass(Config)


@overload(Config, strict=False)
def overload_config_init(*args):
    def jit_config_init(*args):
        return _JitConfig(*args)

    return jit_config_init


@njit(cache=True)
def f(config_values):

    # Since it will be overloaded to a jitclass, you can use a Python class here.
    config = Config(*config_values)

    return config.a


def main():
    config = Config(1, 2.0, "3")

    f(config.as_tuple())
    print("Cache:", f.stats)


main()

至于性能,应该可以忽略不计。 这是基准:

import timeit
from typing import NamedTuple

from numba import njit
from numba.core.extending import overload
from numba.experimental import jitclass


class NamedTupleContainer(NamedTuple):
    a0: int
    a1: int
    a2: int
    a3: int
    a4: int
    a5: float
    a6: float
    a7: float
    a8: float
    a9: float


class JitclassContainer:
    a0: int
    a1: int
    a2: int
    a3: int
    a4: int
    a5: float
    a6: float
    a7: float
    a8: float
    a9: float

    def __init__(self, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
        self.a0 = a0
        self.a1 = a1
        self.a2 = a2
        self.a3 = a3
        self.a4 = a4
        self.a5 = a5
        self.a6 = a6
        self.a7 = a7
        self.a8 = a8
        self.a9 = a9

    def as_tuple(self):
        return (
            self.a0,
            self.a1,
            self.a2,
            self.a3,
            self.a4,
            self.a5,
            self.a6,
            self.a7,
            self.a8,
            self.a9,
        )


_JitclassContainer = jitclass(JitclassContainer)


@overload(JitclassContainer)
def overload_container_init(*args):
    def jit_container_init(*args):
        return _JitclassContainer(*args)

    return jit_container_init


@njit(cache=True)
def f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
    return a0


@njit(cache=True)
def f_namedtuple(args):
    c = NamedTupleContainer(*args)
    return c.a0


@njit(cache=True)
def f_jitclass(args):
    c = JitclassContainer(*args)
    return c.a0


def main():
    def benchmark(f):
        n_runs = 10000
        return min(timeit.repeat(f, repeat=100, number=n_runs)) / n_runs

    values = 1, 2, 3, 4, 5, 6.0, 7.0, 8.0, 9.0, 10.0
    a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = values
    named_tuple_container = NamedTupleContainer(*values)
    jitclass_container = JitclassContainer(*values)

    t = benchmark(lambda: f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9))
    print(f"f_multi_args: {t * 10 ** 9:.0f} ns")

    t = benchmark(lambda: f_namedtuple(tuple(named_tuple_container)))
    print(f"f_namedtuple: {t * 10 ** 9:.0f} ns")

    t = benchmark(lambda: f_jitclass(jitclass_container.as_tuple()))
    print(f"f_jitclass  : {t * 10 ** 9:.0f} ns")


main()

结果:

f_multi_args: 525 ns
f_namedtuple: 695 ns
f_jitclass  : 652 ns

在我的 PC 上,每个具有 10 个参数/字段的函数调用的差异不到 200 纳秒。

© www.soinside.com 2019 - 2024. All rights reserved.