我正在使用内置的Tensorflow数据集模块读取一批MNIST数据。这给了一个numpy数组作为批处理。但是,如果我将数组复制到另一个变量并对第二个变量进行更改,则原始批处理数组也会更改。我怀疑为什么原始数组和复制的数组之间存在任何连接。
您可以在此CoLab链接上进行测试:
https://colab.research.google.com/drive/1DN4n5_YCO33LozxtidM7STqEAUWypNOv
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def test_reconstruction(mnist, h=28, w=28, batch_size=100):
# Test the trained model: reconstruction
batch = mnist.test.next_batch(batch_size)
batch_clean = batch[0]
print('before damage:', np.mean(batch_clean))
batch_damaged = np.reshape(batch_clean, (batch_size, 28, 28))
tmp = batch_damaged
tmp[:, 10:20, 10:20] = 0
print('after damage:', np.mean(batch_clean))
test_reconstruction(mnist)
预期:两个打印语句都应返回相同的平均值
实际:我得到两个打印语句的不同平均值
在您的行batch_damaged = np.reshape(batch_clean, (batch_size, 28, 28))
中,您复制batch_clean的引用而不是其值。您应该使用numpy.copy
返回数组的副本。 batch_damaged = np.copy(np.reshape(batch_clean, (batch_size, 28, 28)))