如何在Tensorflow 2.x中正确操作tfds.load()数据集?

问题描述 投票:0回答:1

我正在从Udemy课程中学习如何在tensorflow 2.0和Keras中从头开始创建MNIST模型。

所以,我得到了mnist数据集,如下所示

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

而且一切都很好,即使我测试模型的准确性达到了97%,我也很高兴。

当我尝试做不同于课程的事情时,问题就开始了。我尝试使用matplotlib plt.imshow()从mnist_dataset打印一些示例,但我完全失败了。然后,我开始了一些研究,找到了解决方案,我需要获取像这样的数据集:

mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']

其中mnistt是我可以使用matplotlib操作和打印的数据集。

所以我的问题如下:在哪里可以获得有关您可以获取的tfds.load()类型的信息,以及如何根据需要正确地对其进行操作? (并且可以像我这样在张量流中从初学者开始扩展)。

python tensorflow tensorflow-datasets tensorflow2.0
1个回答
0
投票

tfds.load方法的主调用包含您需要的一切:

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
  • name="mnist"->您正在指定要使用的构建器(错误)
  • [with_info=True->您要让tfds.load返回包含关于返回的数据集的[[您需要知道的所有信息]的info对象
  • [as_supervised=True->您要让tfds.load仅获得监督学习任务所需的数据集元素(图像和标签对)。
  • 您第一次尝试使用mnist_dataset获取数据(与matplotlib一起使用失败,因为您可以从中看到]

    print(mnist_info) #run me!

    数据集包含2个不同的分割:traintest

    tfds.core.DatasetInfo( name='mnist', version=1.0.0, description='The MNIST database of handwritten digits.', urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'], features=FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), }), total_num_examples=70000, splits={ 'test': 10000, 'train': 60000, }, supervised_keys=('image', 'label'), citation="""@article{lecun2010mnist, title={MNIST handwritten digit database}, author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist}, volume={2}, year={2010} }""", redistribution_info=, )

    因此,tfds.load返回的对象是

    字典:

    { "train": <train dataset>, "test": <test dataset> }
    事实上,在示例的下一行中,您以这种方式提取“ train”和“ test”数据集:

    mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

    mnist_info对象中,您可以获得处理数据集所需的每条信息:分割数,数据类型(例如,“图像”是具有dtype tf.uint8的28x28x1图像,等等...] >
  • © www.soinside.com 2019 - 2024. All rights reserved.