如何将几个附加参数传递给 numba cfunc 作为 LowLevelCallable 传递给 scipy.integrate.quad

问题描述 投票:0回答:1

我正在尝试复制另一个问题所选答案的第 1 段中描述的技术:如何将附加参数传递给 numba cfunc 作为 LowLevelCallable 传递给 scipy.integrate.quad

但是,我不知道如何修改实现,使 xx[1] 是一个浮点数数组而不是唯一的浮点数。

python scipy numba
1个回答
2
投票

我通过将Jacques Gaudinhttps://stackoverflow.com/a/49732825/3925704中的代码修改为:

解决了这个问题
import numpy as np
import scipy.integrate as si
import numba
from numba import cfunc, carray
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable


def jit_integrand_function(integrand_function):
    jitted_function = numba.jit(integrand_function, nopython=True)
    
    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        values = carray(xx, n)
        return jitted_function(values)
    return LowLevelCallable(wrapped.ctypes)

@jit_integrand_function
def integrand(args):
    t = args[0]
    a = args[1]
    return np.exp(-t/a) / t**2

def do_integrate(func, a):
    """
    Integrate the given function from 1.0 to +inf with additional argument a.
    """
    return si.quad(func, 1, np.inf, args=(a,))
© www.soinside.com 2019 - 2024. All rights reserved.