如何在tensorflow数据集中加载numpy数组

问题描述 投票:2回答:1

我正在尝试从numpy数组开始在tensorflow 1.14中创建一个Dataset对象(我对此特定项目无法更改一些旧代码),但是每次尝试时,我都会将所有内容复制到图形上并为此我创建事件日志文件的原因很大(在这种情况下为719 MB)。

最初,我尝试使用此函数“ tf.data.Dataset.from_tensor_slices()”,但是它没有用,然后我读到它是一个常见问题,有人建议我尝试使用生成器,因此我尝试了以下方法代码,但是我又得到了一个巨大的事件文件(再次为719 MB)

def fetch_batch(x, y, batch):
    i = 0
    while i < batch:
        yield (x[i,:,:,:], y[i])
        i +=1

train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train  
images = images/255

training_dataset = tf.data.Dataset.from_generator(fetch_batch, 
    args=[images, np.int32(labels), batch_size], output_types=(tf.float32, tf.int32), 
    output_shapes=(tf.TensorShape(features_shape), tf.TensorShape(labels_shape)))

file_writer = tf.summary.FileWriter("/content", graph=tf.get_default_graph())

我知道在这种情况下,我可以使用tensorflow_datasets API,这会更容易,但这是一个更笼统的问题,它涉及到如何一般地创建数据集,而不仅仅是使用mnist。你能告诉我我做错了什么吗?谢谢

python tensorflow tensorflow-datasets
1个回答
0
投票

我想是因为您在args中使用了from_generator。这无疑会将提供的args放在图表中。

© www.soinside.com 2019 - 2024. All rights reserved.