Tensorflow MNIST数据加载器提供互连的numpy数组

问题描述 投票:0回答:1

我正在使用内置的Tensorflow数据集模块读取一批MNIST数据。这给了一个numpy数组作为批处理。但是,如果我将数组复制到另一个变量并对第二个变量进行更改,则原始批处理数组也会更改。我怀疑为什么原始数组和复制的数组之间存在任何连接。

您可以在此CoLab链接上进行测试:

https://colab.research.google.com/drive/1DN4n5_YCO33LozxtidM7STqEAUWypNOv

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def test_reconstruction(mnist, h=28, w=28, batch_size=100):
    # Test the trained model: reconstruction
    batch = mnist.test.next_batch(batch_size)
    batch_clean = batch[0]

    print('before damage:', np.mean(batch_clean))
    batch_damaged = np.reshape(batch_clean, (batch_size, 28, 28))
    tmp = batch_damaged
    tmp[:, 10:20, 10:20] = 0
    print('after damage:', np.mean(batch_clean))

test_reconstruction(mnist)

预期:两个打印语句都应返回相同的平均值

实际:我得到两个打印语句的不同平均值

tensorflow mnist
1个回答
0
投票

在您的行batch_damaged = np.reshape(batch_clean, (batch_size, 28, 28))中,您复制batch_clean的引用而不是其值。您应该使用numpy.copy返回数组的副本。 batch_damaged = np.copy(np.reshape(batch_clean, (batch_size, 28, 28)))

© www.soinside.com 2019 - 2024. All rights reserved.