更改np数组不会自动更改Torch Tensor?

问题描述 投票:6回答:1

我正在阅读PyTorch的基本教程,并遇到了NumPy数组和Torch张量之间的转换。文件说:

Torch Tensor和NumPy阵列将共享其底层内存位置,更改一个将改变另一个。

但是,在以下代码中似乎不是这种情况:

import numpy as np

a = np.ones((3,3))
b = torch.from_numpy(a)

np.add(a,1,out=a)
print(a)
print(b)

在上面的例子中,我看到输出中自动反映的变化:

[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], dtype=torch.float64)

但是当我写这样的东西时,同样不会发生:

a = np.ones((3,3))
b = torch.from_numpy(a)

a = a + 1
print(a)
print(b)

我得到以下输出:

[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)

我在这里错过了什么?

python numpy pytorch torch tensor
1个回答
4
投票

每当你用Python编写=符号时,你就会创建一个新对象。

因此,在第二种情况下表达式的右侧使用原始a,然后计算一个新对象,即a + 1,它取代了原始a。 b仍然指向原始a的内存位置,但现在指向内存中的新对象。

换句话说,在a = a + 1中,表达式a + 1创建一个新对象,然后Python将该新对象分配给名称a

然而,使用a += 1,Python使用参数1调用a的就地添加方法(__iadd__)。

numpy代码:np.add(a,1,out=a),在第一种情况下负责将该值添加到现有数组中。

(感谢@Engineero@Warren Weckesser在评论中指出这些解释)

© www.soinside.com 2019 - 2024. All rights reserved.