我正在尝试扩展numpy的“ tensordot”,例如:K_ijklm = A_ki * B_jml
可以这样清晰地写:K = mytensordot(A,B,[2,0],[1,4,3])
据我了解,numpy的张量点(带有可选参数0)将能够执行以下操作:K_kijml = A_ki * B_jml
,即保持索引的顺序。因此,我将不得不做很多np.swapaxes()
来获得矩阵“ K_ijklm”,在复杂的情况下,矩阵很容易成为错误的来源(可能很难调试)。
问题是我的实现速度很慢(比tensordot慢10倍[编辑:实际上比tensordot慢得多]),即使使用numba时也是如此。我想知道是否有人会对提高我的算法性能可以采取的措施有所了解。
import numpy as np
import numba as nb
import itertools
import timeit
@nb.jit()
def myproduct(dimN):
N=np.prod(dimN)
L=len(dimN)
Product=np.zeros((N,L),dtype=np.int32)
rn=0
for n in range(1,N):
for l in range(L):
if l==0:
rn=1
v=Product[n-1,L-1-l]+rn
rn = 0
if v == dimN[L-1-l]:
v = 0
rn = 1
Product[n,L-1-l]=v
return Product
@nb.jit()
def mytensordot(A,B,iA,iB):
iA,iB = np.array(iA,dtype=np.int32),np.array(iB,dtype=np.int32)
dimA,dimB = A.shape,B.shape
NdimA,NdimB=len(dimA),len(dimB)
if len(iA) != NdimA: raise ValueError("iA must be same size as dim A")
if len(iB) != NdimB: raise ValueError("iB must be same size as dim B")
NdimN = NdimA + NdimB
dimN=np.zeros(NdimN,dtype=np.int32)
dimN[iA]=dimA
dimN[iB]=dimB
Out=np.zeros(dimN)
indexes = myproduct(dimN)
for nidxs in indexes:
idxA = tuple(nidxs[iA])
idxB = tuple(nidxs[iB])
v=A[(idxA)]*B[(idxB)]
Out[tuple(nidxs)]=v
return Out
A=np.random.random((4,5,3))
B=np.random.random((6,4))
def runmytdot():
return mytensordot(A,B,[0,2,3],[1,4])
def runtensdot():
return np.tensordot(A,B,0).swapaxes(1,3).swapaxes(2,3)
print(np.all(runmytdot()==runtensdot()))
print(timeit.timeit(runmytdot,number=100))
print(timeit.timeit(runtensdot,number=100))
True
1.4962144780438393
0.003484356915578246
tensordot
(带标量轴的值)可能不明显。我在
How does numpy.tensordot function works step-by-step?
据我推断,np.tensordot(A, B, axes=0)
与axes=[[], []]
是等效的。
In [757]: A=np.random.random((4,5,3))
...: B=np.random.random((6,4))
In [758]: np.tensordot(A,B,0).shape
Out[758]: (4, 5, 3, 6, 4)
In [759]: np.tensordot(A,B,[[],[]]).shape
Out[759]: (4, 5, 3, 6, 4)
这等效于以新的大小为1的产品和维度调用dot
:
In [762]: np.dot(A[...,None],B[...,None,:]).shape
Out[762]: (4, 5, 3, 6, 4)
(4,5,3,1) * (6,1,4) # the 1 is the last of A and 2nd to the last of B
dot
快速,使用BLAS(或等效)代码。交换轴和重塑也相对较快。
einsum
使我们对轴有很多控制权
复制上述产品:
In [768]: np.einsum('jml,ki->jmlki',A,B).shape
Out[768]: (4, 5, 3, 6, 4)
和交换:
In [769]: np.einsum('jml,ki->ijklm',A,B).shape
Out[769]: (4, 4, 6, 3, 5)
次要点-双交换可以写为一个转置:
.swapaxes(1,3).swapaxes(2,3)
.transpose(0,3,1,2,4)
您已经遇到a known issue。创建多维数组时为numpy.zeros
requires a tuple。如果您传递的不是元组,则有时可以使用,但这仅是因为numpy
聪明地将对象首先转换为元组。
麻烦是numba
当前不支持conversion of arbitrary iterables into tuples。因此,当您尝试以nopython=True
模式进行编译时,此行将失败。 (其他几个也失败了,但这是第一个。)
Out=np.zeros(dimN)
从理论上讲,您可以调用np.prod(dimN)
,创建零的平面数组,然后对其进行整形,但是随后您遇到了同样的问题:reshape
数组的numpy
方法需要一个元组!
numba
这是一个非常令人头疼的问题-我以前从未遇到过。我真的怀疑我找到的解决方案是正确的,但这是一个可行的解决方案,它允许我们以nopython=True
模式编译版本。
核心思想是通过直接实现跟随数组strides
的索引器来避免使用元组进行索引:
@nb.jit(nopython=True)
def index_arr(a, ix_arr):
strides = np.array(a.strides) / a.itemsize
ix = int((ix_arr * strides).sum())
return a.ravel()[ix]
@nb.jit(nopython=True)
def index_set_arr(a, ix_arr, val):
strides = np.array(a.strides) / a.itemsize
ix = int((ix_arr * strides).sum())
a.ravel()[ix] = val
这使我们无需元组即可获取和设置值。
[我们还可以通过将输出缓冲区传递到jitted函数中并将该函数包装在帮助程序中来避免使用reshape
:
@nb.jit() # We can't use nopython mode here...
def mytensordot(A, B, iA, iB):
iA, iB = np.array(iA, dtype=np.int32), np.array(iB, dtype=np.int32)
dimA, dimB = A.shape, B.shape
NdimA, NdimB = len(dimA), len(dimB)
if len(iA) != NdimA:
raise ValueError("iA must be same size as dim A")
if len(iB) != NdimB:
raise ValueError("iB must be same size as dim B")
NdimN = NdimA + NdimB
dimN = np.zeros(NdimN, dtype=np.int32)
dimN[iA] = dimA
dimN[iB] = dimB
Out = np.zeros(dimN)
return mytensordot_jit(A, B, iA, iB, dimN, Out)
由于辅助程序不包含循环,因此增加了一些开销,但是开销非常小。这是最后的固定功能:
@nb.jit(nopython=True)
def mytensordot_jit(A, B, iA, iB, dimN, Out):
for i in range(np.prod(dimN)):
nidxs = int_to_idx(i, dimN)
a = index_arr(A, nidxs[iA])
b = index_arr(B, nidxs[iB])
index_set_arr(Out, nidxs, a * b)
return Out
不幸的是,这并没有像我们希望的那样产生尽可能多的加速。在较小的阵列上,它比tensordot
慢大约5倍;在较大的阵列上,速度仍然慢50倍。 (但至少不慢1000倍!)回想起来,这并不奇怪,因为dot
和tensordot
都在幕后使用BLAS,如@hpaulj reminds us。
完成此代码后,我看到einsum
解决了您的实际问题-很好!
但是您最初的问题所指向的潜在问题-在jitted代码中不可能使用任意长度的元组进行索引-仍然令人沮丧。因此希望这对其他人有用!