如何将所有tf.data.Dataset对象提取到特征和标签中并传递到ImageDataGenerator的flow()方法中?

问题描述 投票:0回答:1

我正在根据cifar10数据集从事一个小型项目。我已经从tfds.load(...)中加载了数据并正在练习图像增强技术。

由于我正在使用tf.data.Dataset对象(这是我的数据集,所以实时数据增强是相当难以实现的,因此,我想将所有功能都传递到tf.keras.preprocessing.image.ImageDataGenerator.flow(...)中以获得实时增强的功能。

但是此flow(...)方法接受与tf.data.Dataset对象无关的NumPy数组。

有人可以在这方面(或任何其他选择)指导我,我该如何进一步进行?

tf.image转换是实时的吗?如果没有,除了ImageDataGenerator.flow(...)以外,什么是最好的方法?

我的代码:

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.preprocessing.image import ImageDataGenerator

splitting = tfds.Split.ALL.subsplit(weighted=(70, 20, 10))
dataset_cifar10, dataset_info = tfds.load(name='cifar10', 
                                          split=splitting, 
                                          as_supervised=True, 
                                          with_info=True)

train_dataset, valid_dataset, test_dataset = dataset_cifar10

BATCH_SIZE = 32

train_dataset = train_dataset.batch(batch_size=BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=1)

image_generator = ImageDataGenerator(rotation_range=45, 
                                     width_shift_range=0.15, 
                                     height_shift_range=0.15, 
                                     zoom_range=0.2, 
                                     horizontal_flip=True, 
                                     vertical_flip=True, 
                                     rescale=1./255)

train_dataset_generator = image_generator.flow(...)

...
python tensorflow-datasets tensorflow2.0 tf.keras data-augmentation
1个回答
0
投票

分割训练和测试数据集后,您可以立即遍历该数据集并追加一个列表,供ImageDataGenerator使用。完整的用例如下:

© www.soinside.com 2019 - 2024. All rights reserved.