如何复制
numba jitclass
实例(其变量是 numpy
标量和数组)?
除了创建一个新实例并在循环中复制所有变量之外,还有更好的方法吗?我也尝试了
copy.copy
和copy.deepcopy
,但都失败了TypeError: can't pickle MyJitClassName objects
我认为原则上这是
numba
可以支持的东西(考虑打开一个问题),但现在我认为唯一的选择是定义你自己的。
请注意,jitclasses 保存对数组的引用,因此如果您想复制基础数据,则需要
array.copy()
。
from numba import jitclass, float64
spec = [
('scalar', float64),
('array', float64[:]),
]
@jitclass(spec)
class MyJitClass:
def __init__(self, scalar, array):
self.scalar = scalar
self.array = array
def copy(self):
return MyJitClass(self.scalar, self.array.copy())
根据 chrisb 的答案和 nivniv 的评论,这是一种创建复制函数的方法,该函数循环遍历规范中的所有变量而无需重新输入它们的名称,并使用 copy.deepcopy() 来避免检查类变量是否是一个数组。
from numba import float64, jit, objmode
from numba.experimental import jitclass
import copy
spec = [
('scalar', float64),
('array', float64[:])
]
@jitclass(spec)
class MyJitClass:
def __init__(self, scalar, array):
self.scalar = scalar
self.array = array
def copy(self, jc_new):
with objmode():
for i in range(len(spec)):
setattr(self,spec[i][0],copy.deepcopy(getattr(jc_new,spec[i][0])))