Numpy“:”运营商广播问题

问题描述 投票:2回答:4

在下面的代码中,我写了两个方法,理论上(在我看来)应该做同样的事情。不幸的是他们没有,我无法找出他们为什么不按照numpy文档做同样的事情。

import numpy as np


dW = np.zeros((20, 10))
y = [1 for _ in range(100)]
X =  np.ones((100, 20))

# ===================
# Method 1  (works!)
# ===================
for i in range(len(y)):
  dW[:, y[i]] -=  X[i]


# ===================
# Method 2 (does not work)
# ===================
dW[:, y] -=  X.T
python numpy numpy-broadcasting
4个回答
2
投票

如上所述,原则上,由于NumPy中的缓冲工作方式,您无法在单个操作中多次对同一元素进行操作。为此目的,有at函数,可用于任何标准的NumPy函数(addsubtract等)。对于您的情况,您可以:

import numpy as np

dW = np.zeros((20, 10))
y = [1 for _ in range(100)]
X =  np.ones((100, 20))
# at modifies in place dW, does not return a new array
np.subtract.at(dW, (slice(None), y), X.T)

2
投票

这是this问题的列式版本。

那里的答案可以适用于列式工作如下:

方法1:np.<ufunc>.at

>>> np.subtract.at(dW, (slice(None), y), X.T)

方法2:np.bincount

>>> m, n = dW.shape
>>> dW -= np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n)

请注意,基于bincount的解决方案 - 即使涉及更多步骤 - 速度提高了约6倍。

>>> from timeit import repeat
>>> kwds = dict(globals=globals(), number=5000)
>>>
>>> repeat('np.subtract.at(dW, (slice(None), y), X.T); np.add.at(dW, (slice(None), y), X.T)', **kwds)
[1.590626839082688, 1.5769231889862567, 1.5802007300080732]
>>> repeat('_= dW; _ -= np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n); _ += np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n)', **kwds)
[0.2582490430213511, 0.25572817400097847, 0.25478115503210574]

1
投票

选项1:

for i in range(len(y)):
  dW[:, y[i]] -=  X[i]

这是有效的,因为您正在循环并更新上次更新的值。

选项2:

dW[:, [1,1,1,1,....1,1,1]] -=  [[1,1,1,1...1],
                                [1,1,1,1...1],
                                .
                                .
                                [1,1,1,1...1]]

它不起作用,因为更新同时发生在第一个索引并行而非串行方式。最初都是0,所以在-1中减去结果。


1
投票

我找到了第三个解决这个问题的方法。正常矩阵乘法:

ind = np.zeros((X.shape[0],dW.shape[1]))
ind[range(X.shape[0]),y] = -1
dW = X.T.dot(ind)

我在一些神经网络数据上使用上面提出的方法做了一些实验。在我的例子中X.shape = (500,3073)W.shape = (3073,10)ind.shape = (500,10)

减法版本大约需要0.2秒(最慢)。矩阵乘法方法0.01秒(最快)。正常循环0.015然后bincountmethod 0.04 s。请注意,问题y是一个向量。这不是我的情况。只有一个的情况可以通过一个简单的总和来解决。

© www.soinside.com 2019 - 2024. All rights reserved.