我正在学习如何使用Numba来加速jit和vectorize的功能。我对此代码的jit版本没有任何问题,但我在使用vectorize时遇到索引错误。我怀疑这个question's的答案是正确的想法,有一个类型错误,但我不知道改变索引的方向。下面包括我一直在玩的函数,它将斐波纳契数输出到序列的选定索引。索引出了什么问题,以及如何纠正我的代码以解决它?
from numba import vectorize
import numpy as np
from timeit import timeit
@vectorize
def fib(n):
'''
Adjusted from:
https://lectures.quantecon.org/py/numba.html
https://en.wikipedia.org/wiki/Fibonacci_number
https://www.geeksforgeeks.org/program-for-nth-fibonacci-number/
'''
if n == 1:
return np.ones(1)
elif n > 1:
x = np.empty(n)
x[0] = 1
x[1] = 1
for i in range(2,n):
x[i] = x[i-1] + x[i-2]
return x
else:
print('WARNING: Check validity of input.')
print(timeit('fib(10)', globals={'fib':fib}))
这导致以下错误输出。
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/llvmlite/ir/instructions.py", line 619, in __init__
typ = typ.elements[i]
AttributeError: 'PointerType' object has no attribute 'elements'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/galen/Projects/myjekyllblog/test_code/quantecon_2.py", line 27, in <module>
print(timeit('fib(10)', globals={'fib':fib}))
File "/usr/lib/python3.6/timeit.py", line 233, in timeit
return Timer(stmt, setup, timer, globals).timeit(number)
File "/usr/lib/python3.6/timeit.py", line 178, in timeit
timing = self.inner(it, self.timer)
File "<timeit-src>", line 6, in inner
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/dufunc.py", line 166, in _compile_for_args
return self._compile_for_argtys(tuple(argtys))
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/dufunc.py", line 188, in _compile_for_argtys
cres, actual_sig)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/ufuncbuilder.py", line 157, in _build_element_wise_ufunc_wrapper
cres.objectmode, cres)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 220, in build_ufunc_wrapper
env=envptr)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 130, in build_fast_loop_body
env=env)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 23, in _build_ufunc_loop_body
store(retval)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 126, in store
out.store_aligned(retval, ind)
File "/usr/local/lib/python3.6/dist-packages/numba/npyufunc/wrappers.py", line 276, in store_aligned
self.context.pack_value(self.builder, self.fe_type, value, ptr)
File "/usr/local/lib/python3.6/dist-packages/numba/targets/base.py", line 482, in pack_value
dataval = self.data_model_manager[ty].as_data(builder, value)
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 558, in as_data
elems = self._as("as_data", builder, value)
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 530, in _as
self.get(builder, value, i)))
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 558, in as_data
elems = self._as("as_data", builder, value)
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 530, in _as
self.get(builder, value, i)))
File "/usr/local/lib/python3.6/dist-packages/numba/datamodel/models.py", line 624, in get
name="extracted." + self._fields[pos])
File "/usr/local/lib/python3.6/dist-packages/llvmlite/ir/builder.py", line 911, in extract_value
instr = instructions.ExtractValue(self.block, agg, idx, name=name)
File "/usr/local/lib/python3.6/dist-packages/llvmlite/ir/instructions.py", line 622, in __init__
% (list(indices), agg.type))
TypeError: Can't index at [0] in i8*
错误是因为你试图vectorize
一个你可以说基本上不可矢量化的函数。我认为你混淆了@jit
和@vectorize
如何工作的功能。为了加速你的功能,你使用@jit
,而@vectorize
用于创建numpy通用函数。见official documentation here:
使用vectorize(),您可以将函数编写为通过输入标量而不是数组进行操作。 Numba将生成周围的循环(或内核),允许对实际输入进行有效迭代。
所以基本上不可能创建一个与你的斐波那契函数具有相同功能的numpy通用函数。如果您有兴趣,这里是official documentation on universal functions的链接。
因此,为了使用@vectorize
,您需要创建一个可以基本上用作numpy通用函数的函数。出于加速代码的目的,您只需使用@jit
。