数组作为numba中的函数参数

问题描述 投票:4回答:1

以下简单示例失败并显示错误:

独立模块:

from numba.pycc import CC

cc = CC('foo')

@cc.export('product','float64(float64[:], float64[:])')
def product(a, b):
    prod = 0
    for i in range(a.size):
        prod += a[i] * b[i]
    return prod

if __name__ == "__main__":
    cc.compile()

测试程序:

import numpy as np
import foo

x = np.array([2,3,1,0])
y = np.array([2,3,1,0])

print(foo.product(x,y))

失败并显示错误消息:

Traceback (most recent call last):
  File "\temp\test.py", line 7, in <module>
    print(foo.product(x,y))
SystemError: exception RuntimeError<class 'BytesWarning'> not a BaseException subclass

在Windows上使用的numba版本是0.42.0和Python 3.7.2。任何提示?

python numba
1个回答
0
投票

所以,我终于让你的代码工作了:

from numba.pycc import CC

cc = CC('foo')
cc.verbose = True
@cc.export('producti','int64(int64[:], int64[:])')  #<--- Your data type was wrong
def product(a, b):
    prod = 0
    for i in range(a.size):
        y = a[i] * b[i]
        prod += y
    return prod

if __name__ == "__main__":
    cc.compile()

用于测试上述功能的代码:

import numpy as np
import foo

x = np.array([2, 3, 1, 0])
y = np.array([2, 3, 1, 0])

print(foo.producti(x, y))   # Output : 14

这里要注意一些要点:

  • 你创建xy数组的方式,dtype默认设置为int64,所以当你将它转换为float64时,它被错误地转换。

打印(x.dtype)

输出:dtype('int64')

  • 所以,简单地将你的类型修复到int64就可以了(或者你可以使用i8作为速记)。
  • 链接到谷歌Colab笔记本与运行代码:Notebook Link

参考文献:

© www.soinside.com 2019 - 2024. All rights reserved.