我有一个 4 维的 ndarray 结构:(K,N,2,2)。您可以将其想象为 K 个不同的堆栈,每个堆栈包含 N 个维度为 2x2 的矩阵。 对于每个堆栈,我尝试计算其 N 个矩阵的矩阵乘积(没有 For 循环)。所以最后,我应该有 K 个 2x2 维度的矩阵(N 个 2x2 矩阵的乘积仍然是 2x2 矩阵)。
如果我的数组的维度为 (N,2,2),则使用以下函数沿第一个轴执行矩阵乘积非常简单:
A_total = np.linalg.multi_dot(A) # A being a (N,2,2) array
但是对于 (K,N,2,2) 结构,我无法在忽略第一个轴的情况下沿第二个轴执行相同的操作。
您对此有何建议?我尝试过 np.einsum() 和 np.tensordot() 但不太明白如何正确使用这些函数。
让我们尝试一下小 K=3、N=5 的 2 个明显替代方案
In [95]: A = np.random.rand(3,5,2,2)
首先迭代 K,堆栈:
In [96]: res = np.array([np.linalg.multi_dot(A[i,...]) for i in range(A.shape[0])])
In [97]: res.shape
Out[97]: (3, 2, 2)
现在链上N:
In [98]: res1=A[:,0,:,:].copy()
...: for j in range(1,A.shape[1]):
...: res1 = res1@A[:,j,:,:]
...:
他们匹配:
In [99]: np.allclose(res, res1)
Out[99]: True
迭代链的时机要好得多:
In [100]: timeit res = np.array([np.linalg.multi_dot(A[i,...]) for i in range(A.shape[0])])
329 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [101]: %%timeit
...: res1=A[:,0,:,:].copy()
...: for j in range(1,A.shape[1]):
...: res1 = res1@A[:,j,:,:]
...:
33.7 µs ± 767 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
因此,除非 K/N 比率非常不同,否则我会坚持重复的
matmul
。 matmul
在(3+)的引导尺寸上是线性的,但该迭代是在编译代码中。