我正在从Udemy课程中学习如何在tensorflow 2.0和Keras中从头开始创建MNIST模型。
所以,我得到了mnist数据集,如下所示
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
而且一切都很好,即使我测试模型的准确性达到了97%,我也很高兴。
当我尝试做不同于课程的事情时,问题就开始了。我尝试使用matplotlib plt.imshow()
从mnist_dataset打印一些示例,但我完全失败了。然后,我开始了一些研究,找到了解决方案,我需要获取像这样的数据集:
mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']
其中mnistt
是我可以使用matplotlib操作和打印的数据集。
所以我的问题如下:在哪里可以获得有关您可以获取的tfds.load()类型的信息,以及如何根据需要正确地对其进行操作? (并且可以像我这样在张量流中从初学者开始扩展)。
tfds.load
方法的主调用包含您需要的一切:
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
name="mnist"
->您正在指定要使用的构建器(错误)with_info=True
->您要让tfds.load
返回包含关于返回的数据集的[[您需要知道的所有信息]的info
对象as_supervised=True
->您要让tfds.load
仅获得监督学习任务所需的数据集元素(图像和标签对)。mnist_dataset
获取数据(与matplotlib
一起使用失败,因为您可以从中看到]print(mnist_info) #run me!
数据集包含2个不同的分割:train
和test
。
tfds.core.DatasetInfo( name='mnist', version=1.0.0, description='The MNIST database of handwritten digits.', urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'], features=FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), }), total_num_examples=70000, splits={ 'test': 10000, 'train': 60000, }, supervised_keys=('image', 'label'), citation="""@article{lecun2010mnist, title={MNIST handwritten digit database}, author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist}, volume={2}, year={2010} }""", redistribution_info=, )
因此,tfds.load
返回的对象是字典:
{ "train": <train dataset>, "test": <test dataset> }
事实上,在示例的下一行中,您以这种方式提取“ train”和“ test”数据集:
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
从mnist_info
对象中,您可以获得处理数据集所需的每条信息:分割数,数据类型(例如,“图像”是具有dtype tf.uint8的28x28x1图像,等等...] >