namedtuple 作为 Numba 函数的参数

问题描述 投票:0回答:1

我一直在努力有效地处理 Numba 函数的长参数列表。虽然我最初考虑使用字典来提高代码的可读性,但我意识到

namedtuples
可能是更好的选择。根据我在网上找到的信息以及我的测试:

  1. namedtuples
    比字典轻得多(在我的代码中,字典的使用使性能下降了大约10%)
  2. namedtuples
    更通用,因为它们可以组合不同类型的变量(int、float、数组...)
  3. namedtuples
    也使(至少这是我的感觉)代码更容易阅读

但是...似乎有一个问题:我注意到使用

namedtuples
作为 Numba 函数的参数使得无法激活
cache=True
选项。每次运行代码时,都会重新编译它,这会增加计算时间。

我感觉解决方案的一部分在于

namedtuples
类型的明确定义。我发现关于该主题的帖子细节很少。

我有两个问题:

  • 要从 Numba 中的
    namedtuples
    选项中受益,需要对
    cache=True
    类型进行明确定义吗?
  • 有没有简单的方法?
python-3.x arguments numba namedtuple
1个回答
0
投票

我在网上到处找到了一些信息,我相信分享我的结论可能会很有趣。首先回答我的问题:

  • cache=True
    用作 Numba 函数的参数时,似乎不可能从
    namedtuples
    中受益。

现在,按照我最初帖子的评论中的建议,我用

namedtuples
替换了我的
jitclass
。它工作得很好,最重要的是,我看不到任何性能损失。

但是,如果该类未嵌套在 Numba 函数中(请参阅这篇文章),则将其作为 Numba 函数中的参数传递将导致无法从

cache=True
中受益。幸运的是,嵌套它非常简单。顺便说一句,如果我没记错的话,不可能在 Numba 函数中定义
namedtuple
(嵌套它),这可能就是我最初观察到
cache=True
没有效果的原因。

最后,我创建了一个包含以下内容的classes.py 文件:

import numpy as np
from numba import jit, int32, float32, float64, int64    # import the types
from numba.experimental import jitclass

spec_abr = [
    ('E', float64),
    ('K', float64),
    ('sigma_y', float64),
    ('nb_abr_elem', int64),
    ('sharpness', float64),
    ('amplitude', float64),
    ('ep_aub', float64),
    ('ep_aub_ax', float64),
    ('ep_abr', float64),
    ('biseau', float64[:])
]

spec_sim = [
    ('dt', float64),
    ('t99', float64),
    ('mu', float64),
    ('nb_lobe', int64),
    ('largeur_lobe', float64),
    ('hauteur_lobe', float64),
    ('nb_points', int64),
    ('omega', float64)
]

spec_mtx = [
    ('Kl', float64[:,:]),
    ('Ml', float64[:,:]),
    ('Dl', float64[:,:])
]

@jitclass(spec_abr)
class abradable(object):
    def __init__(self, E, K, sigma_y, nb_abr_elem, sharpness, amplitude, ep_aub, ep_aub_ax, ep_abr):
        self.E = E
        self.K = K
        self.sigma_y = sigma_y
        self.nb_abr_elem = nb_abr_elem
        self.sharpness = sharpness
        self.amplitude = amplitude
        self.ep_aub = ep_aub
        self.ep_aub_ax = ep_aub_ax
        self.ep_abr = ep_abr

@jitclass(spec_sim)
class simulation(object):
    def __init__(self, dt, t99, mu, nb_lobe, largeur_lobe, hauteur_lobe, nb_points, omega):
        self.dt = dt
        self.t99 = t99
        self.mu = mu
        self.nb_lobe = nb_lobe
        self.largeur_lobe = largeur_lobe
        self.hauteur_lobe = hauteur_lobe
        self.nb_points = nb_points
        self.omega = omega

@jitclass(spec_mtx)
class matrices(object):
    def __init__(self, Kl, Ml, Dl):
        self.Kl = Kl
        self.Ml = Ml
        self.Dl = Dl

我将它们全部导入到需要使用它们的地方(在编写我的 Numba 函数的 python 文件中):

from src.utils.classes import *
...

我相应地定义了我的对象:

# classes
mat_abr = abradable(Eab, Kab, sigma_y, nb_circ_element, sharpness, amplitude, ep_aub, ep_aub_ax, ep_abr)
par_sim = simulation(dt, t99, mu, nb_lobe, width_lobe, height_lobe, nb_points, omega)
matrix = matrices(K, M, D)    
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.