为什么我使用 numpy 进行矩阵乘法这么慢?

问题描述 投票:0回答:1

我正在尝试将 numpy 中两个维度相当大的矩阵相乘。 请参阅以下 3 种方法。我随机实现了 3 个矩阵来展示我的问题。第一个矩阵,即

Y1[:,:,0]
最初是更大的 3d 数组的一部分。第二个是该矩阵的
.copy()
,第三个是它自己的矩阵。

为什么第一个乘法比后两个乘法慢得多?

import numpy as np
from time import time

Y1 = np.random.uniform(-1, 1, (5000, 1093, 201))
Y2 = Y1[:,:,0].copy()
Y3 = np.random.uniform(-1, 1, (5000, 1093))

W = np.random.uniform(-1, 1, (1093, 30))

# method 1
START = time()
Y1[:,:,0].dot(W)
END = time()
print(f"Method 1 : {END - START}")

# method 2
START = time()
Y2.dot(W)
END = time()
print(f"Method 2 : {END - START}")

# method 3
START = time()
Y3.dot(W)
END = time()
print(f"Method 3 : {END - START}")

输出时间分别约为34、0.06、0.06秒。

我看到了区别:最后两个矩阵是“真正的”2d 数组,而第一个矩阵是更大的 3d 数组的一部分。

子集化

Y1[:,:,0]
是什么让它这么慢?另外,我注意到为矩阵 Y2 创建 Y1 的副本也非常慢。

毕竟,我得到了这个 3d 数组,并且必须重复计算 Y1 切片与(可能不同的)矩阵 W 的矩阵乘积。有没有更好/更快的方法来做到这一点?

提前致谢!

numpy matrix-multiplication
1个回答
0
投票

如果您想比较各种方法的性能,那么您需要原子地考虑操作。我会考虑两种情况:

  1. 切片复制点
Y2 = Y1[:,:,0]
# get slice time
Y2 = Y2.copy()
# get copy time
Y2 = Y2.dot(W)
# get dot time
  1. 切片点
Y2 = Y1[:,:,0]
# get slice time
Y2 = Y2.dot(W)
# get dot time

这会告诉你时间是花在切片、复制还是点上。我怀疑复制是大型数组中最昂贵的部分。 它还会告诉您

dot
在整个数组和切片上的表现是否不同。

一旦您知道瓶颈在哪里,您就可以查明您的问题,以使 that 部分更快。

© www.soinside.com 2019 - 2024. All rights reserved.