如何修复 numba-scipy 以与 scipy.special 一起使用

问题描述 投票:0回答:1

我正在尝试编写一个涉及非常大的矩阵乘法和 for 循环的模拟。为了加快这个过程,我考虑使用 numba-scipy。使用 conda 安装软件包后,我尝试运行以下代码

import scipy.special as sp
from numba import jit
@jit(nopython=True)
def Grh_gtr(J2,t1,t2):
t = t1-t2
if t == 0:
    return -1j/2
else:
    return (sp.jv(1,J2*t)-1j*sp.struve(1,J2*t))/(2j*J2*t)

我得到了以下错误

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<ufunc 'jv'>) found for signature:
 
 >>> jv(Literal[int](1), int64)

我也尝试使用 pip 安装

numba-special
包,但它给了我以下错误

ERROR: Failed building wheel for numba-special
  Running setup.py clean for numba-special
Failed to build numba-special
ERROR: ERROR: Failed to build installable wheels for some pyproject.toml based projects (numba-special)

非常感谢任何帮助。

python python-3.x scipy numba
1个回答
0
投票

假设您已成功安装 numba-scipy。

第一个错误抱怨类型不匹配。

根据文档,jv支持的签名如下。

Supported signature(s): float64(float64,float64)

另一方面,您对 jv 的使用解释如下。

 >>> jv(Literal[int](1), int64)

第一个参数被解释为文字 int,第二个参数被解释为 int64 变量。 因此,一个简单的解决方法是将两者都转换为浮动。

@jit(nopython=True)
def Grh_gtr(J2, t1, t2):
    t = t1 - t2
    if t == 0:
        return -1j / 2
    return (sp.jv(float(1), float(J2 * t)) - 1j * sp.struve(float(1), float(J2 * t))) / (2j * J2 * t)

请注意,从 int 到 float 的转换很便宜,但它不是免费的。 通过将其设为变量来最小化它可能会更好。

import scipy.special as sp
from numba import jit

@jit(nopython=True)
def Grh_gtr(J2, t1, t2):
    t = t1 - t2
    if t == 0:
        return -1j / 2
    one = float(1.0)
    j2t = float(J2 * t)
    return (sp.jv(one, j2t) - 1j * sp.struve(one, j2t)) / (2j * j2t)


print(Grh_gtr(1.0, 2.0, 1.0))  # Also better to call with the float arguments if possible.
© www.soinside.com 2019 - 2024. All rights reserved.