我一直在努力有效地处理 Numba 函数的长参数列表。虽然我最初考虑使用字典来提高代码的可读性,但我意识到
namedtuples
可能是更好的选择。根据我在网上找到的信息以及我的测试:
namedtuples
比字典轻得多(在我的代码中,字典的使用使性能下降了大约10%)namedtuples
更通用,因为它们可以组合不同类型的变量(int、float、数组...)namedtuples
也使(至少这是我的感觉)代码更容易阅读但是...似乎有一个问题:我注意到使用
namedtuples
作为 Numba 函数的参数使得无法激活 cache=True
选项。每次运行代码时,都会重新编译它,这会增加计算时间。
我感觉解决方案的一部分在于
namedtuples
类型的明确定义。我发现关于该主题的帖子和细节很少。
我有两个问题:
namedtuples
选项中受益,需要对 cache=True
类型进行明确定义吗?我在网上到处找到了一些信息,我相信分享我的结论可能会很有趣。首先回答我的问题:
cache=True
用作 Numba 函数的参数时,似乎不可能从 namedtuples
中受益。现在,按照我最初帖子的评论中的建议,我用
namedtuples
替换了我的 jitclass
。它工作得很好,最重要的是,我看不到任何性能损失。
但是,如果该类未嵌套在 Numba 函数中(请参阅这篇文章),则将其作为 Numba 函数中的参数传递将导致无法从
cache=True
中受益。幸运的是,嵌套它非常简单。顺便说一句,如果我没记错的话,不可能在 Numba 函数中定义 namedtuple
(嵌套它),这可能就是我最初观察到 cache=True
没有效果的原因。
最后,我创建了一个包含以下内容的classes.py 文件:
import numpy as np
from numba import jit, int32, float32, float64, int64 # import the types
from numba.experimental import jitclass
spec_abr = [
('E', float64),
('K', float64),
('sigma_y', float64),
('nb_abr_elem', int64),
('sharpness', float64),
('amplitude', float64),
('ep_aub', float64),
('ep_aub_ax', float64),
('ep_abr', float64),
('biseau', float64[:])
]
spec_sim = [
('dt', float64),
('t99', float64),
('mu', float64),
('nb_lobe', int64),
('largeur_lobe', float64),
('hauteur_lobe', float64),
('nb_points', int64),
('omega', float64)
]
spec_mtx = [
('Kl', float64[:,:]),
('Ml', float64[:,:]),
('Dl', float64[:,:])
]
@jitclass(spec_abr)
class abradable(object):
def __init__(self, E, K, sigma_y, nb_abr_elem, sharpness, amplitude, ep_aub, ep_aub_ax, ep_abr):
self.E = E
self.K = K
self.sigma_y = sigma_y
self.nb_abr_elem = nb_abr_elem
self.sharpness = sharpness
self.amplitude = amplitude
self.ep_aub = ep_aub
self.ep_aub_ax = ep_aub_ax
self.ep_abr = ep_abr
@jitclass(spec_sim)
class simulation(object):
def __init__(self, dt, t99, mu, nb_lobe, largeur_lobe, hauteur_lobe, nb_points, omega):
self.dt = dt
self.t99 = t99
self.mu = mu
self.nb_lobe = nb_lobe
self.largeur_lobe = largeur_lobe
self.hauteur_lobe = hauteur_lobe
self.nb_points = nb_points
self.omega = omega
@jitclass(spec_mtx)
class matrices(object):
def __init__(self, Kl, Ml, Dl):
self.Kl = Kl
self.Ml = Ml
self.Dl = Dl
我将它们全部导入到需要使用它们的地方(在编写我的 Numba 函数的 python 文件中):
from src.utils.classes import *
...
我相应地定义了我的对象:
# classes
mat_abr = abradable(Eab, Kab, sigma_y, nb_circ_element, sharpness, amplitude, ep_aub, ep_aub_ax, ep_abr)
par_sim = simulation(dt, t99, mu, nb_lobe, width_lobe, height_lobe, nb_points, omega)
matrix = matrices(K, M, D)