我正在尝试进一步加快一些使用Numba编译的以python编写的代码。当查看numba生成的程序集时,我注意到正在生成双精度运算,由于输入和输出都应该是float32,所以我觉得这很奇怪。
我在jitted循环之外将变量/数组类型声明为float32,并将它们传递给函数。奇怪的是,我发现在运行测试后,变量“ scalarout”被转换为python float,实际上是64位值。
我的代码:
from scipy import ndimage, misc
import matplotlib.pyplot as plt
import numpy.fft
from timeit import default_timer as timer
import numba
# numba.config.DUMP_ASSEMBLY = 1
from numba import float32
from numba import jit, njit, prange
from numba import cuda
import numpy as np
import scipy as sp
# import llvmlite.binding as llvm
# llvm.set_option('', '--debug-only=loop-vectorize')
@njit(fastmath=True, parallel=False)
def mydot(a, b, xlen, ylen, scalarout):
scalarout = (np.float32)(0.0)
for y in prange(ylen):
for x in prange(xlen):
scalarout += a[y, x] * b[y, x]
return scalarout
# ======================================== TESTS ========================================
print()
xlen = 100000
ylen = 16
a = np.random.rand(ylen, xlen).astype(np.float32)
b = np.random.rand(ylen, xlen).astype(np.float32)
print("a type = ", type(a[1,1]))
scalarout = (np.float32)(0.0)
print("scalarout type, before execution = ", type(scalarout))
iters=1000
time = 100.0
for n in range(iters):
start = timer()
scalarout = mydot(a, b, xlen, ylen, scalarout)
end = timer()
if(end-start < time):
time = end-start
print("Numba njit function time, in us = %16.10f" % ((end-start)*10**6))
print("function output = %f" % scalarout)
print("scalarout type, after execution = ", type(scalarout))
更多的是扩展评论,而不是答案。如果将scalarout
更改为长度为1的float32数组并对其进行修改,则输出为float32。
@njit(fastmath=True, parallel=False)
def mydot(a, b, xlen, ylen):
scalarout = np.array([0.0], dtype=np.float32)
for y in prange(ylen):
for x in prange(xlen):
scalarout[0] += a[y, x] * b[y, x]
return scalarout
如果将return scalarout
更改为return scalarout[0]
,则输出再次为python浮点数。
在您的mydot
原始代码中,即使您编写return np.float32(scalarout)
,结果也是python浮点数。