在一个项目中,我创建了一个类,我需要在这个新类和一个真实矩阵之间进行操作,所以我像这样重载了
__rmul__
函数
class foo(object):
aarg = 0
def __init__(self):
self.aarg = 1
def __rmul__(self,A):
print(A)
return 0
def __mul__(self,A):
print(A)
return 0
但是当我调用它时,结果并不是我所期望的
A = [[i*j for i in np.arange(2) ] for j in np.arange(3)]
A = np.array(A)
R = foo()
C = A * R
输出:
0
0
0
1
0
2
该函数似乎被调用了 6 次,每个元素调用一次。
相反,
__mul__
功能非常有效
C = R * A
输出:
[[0 0]
[0 1]
[0 2]]
如果
A
不是数组,而只是列表的列表,则两者都可以正常工作
A = [[i*j for i in np.arange(2) ] for j in np.arange(3)]
R = foo()
C = A * R
C = R * A
输出
[[0, 0], [0, 1], [0, 2]]
[[0, 0], [0, 1], [0, 2]]
我真的希望我的
__rmul__
函数也能在数组上工作(我原来的乘法函数是不可交换的)。我该如何解决?
该行为是预期的。
首先你必须了解像
x*y
这样的操作实际上是如何执行的。 python 解释器将first尝试计算x.__mul__(y)
。
如果此调用返回 NotImplemented
,它将 then 尝试计算 y.__rmul__(x)
。
除了,当y
是x
类型的真子类时,在这种情况下,解释器将首先考虑y.__rmul__(x)
,然后考虑x.__mul__(y)
。 (请参阅 __rmul__
的文档)
现在发生的事情是
numpy
根据他认为参数是标量还是数组来不同地对待参数。
处理数组时
*
进行逐个元素的乘法,而标量乘法将数组的所有条目乘以给定的标量。
在您的情况下,
foo()
被numpy视为标量,因此numpy将数组的所有元素乘以foo
。此外,由于 numpy 不知道类型 foo
它返回一个带有 dtype=object
的数组,所以返回的对象是:
array([[0, 0],
[0, 0],
[0, 0]], dtype=object)
注意:当您尝试计算乘积时,
numpy
的数组不会返回NotImplemented
,因此解释器调用numpy的数组
__mul__
方法,该方法执行我们所说的标量乘法。此时 numpy 将尝试将数组的每个条目乘以“标量”
foo()
,这就是调用
__rmul__
方法的地方,因为当数组中的数字
NotImplemented
为 时,数组中的数字将返回
__mul__
使用
foo
参数调用。显然,如果您将参数的顺序更改为初始乘法,您的
__mul__
方法将立即被调用,并且您不会遇到任何麻烦。因此,要回答您的问题,处理此问题的一种方法是让
foo
继承自
ndarray
,以便应用第二条规则:
class foo(np.ndarray):
def __new__(cls):
# you must implement __new__
# code as before
但是警告子类化ndarray
并不简单。 此外,您可能还会产生其他副作用,因为现在您的班级是
ndarray
。
__numpy_ufunc__
函数。它甚至可以在无需子类化
np.ndarray
的情况下工作。您可以在此处找到文档。 这是一个基于您的案例的示例:
class foo(object):
aarg = 0
def __init__(self):
self.aarg = 1
def __numpy_ufunc__(self, *args):
pass
def __rmul__(self,A):
print(A)
return 0
def __mul__(self,A):
print(A)
return 0
如果我们尝试一下,
A = [[i*j for i in np.arange(2) ] for j in np.arange(3)]
A = np.array(A)
R = foo()
C = A * R
输出:
[[0 0]
[0 1]
[0 2]]
有效!
您可以通过定义
__array_priority__
来强制 numpy 使用您的评估方法。正如 numpy 文档中here所解释的那样。 在您的情况下,您必须将类定义更改为:
MAGIC_NUMBER = 15.0
# for the necessary lowest values of MAGIC_NUMBER look into the numpy docs
class foo(object):
__array_priority__ = MAGIC_NUMBER
aarg = 0
def __init__(self):
self.aarg = 1
def __rmul__(self,A):
print(A)
return 0
def __mul__(self,A):
print(A)
return 0