在BIG数据集上训练模型的最佳实践是什么?

问题描述 投票:1回答:1

我需要在一个需要比我的GPU更大内存的数据集上训练一个模型。

以下是我的步骤。

  1. 首先,我用batch_size加载数据集。
BATCH_SIZE=32

builder = tfds.builder('mnist')
builder.download_and_prepare()
datasets  = builder.as_dataset(batch_size=BATCH_SIZE)
  1. 第二步我准备数据
for record in raw_train_ds.take(1):
    train_images, train_labels = record['image'],  record['label']
    print(train_images.shape)
    train_images  = train_images.numpy().astype(np.float32) / 255.0
    train_labels = tf.keras.utils.to_categorical(train_labels)
  1. 然后我把数据输入到模型中
history = model.fit(train_images,train_labels, epochs=NUM_EPOCHS, validation_split=0.2)

但在第2步,我准备了第一批的数据,错过了其他批次的数据,因为model.fit不在循环范围内(据我所知,只对第一批的数据有效)。另一方面,我不能去掉take(1),把model.fit方法移到循环下。因为是的,在这种情况下,我将处理所有的批次,但同时model.fill将在每次迭代结束时被调用,在这种情况下,它也将无法正常工作。

所以,我应该如何改变我的代码,以能够适当地与一个大的数据集使用model.fit工作吗? 你可以指出文章,任何文件,或只是建议如何处理它?

更新在我下面的帖子中(方法1),我描述了如何解决这个问题的一种方法--有没有其他更好的方法,或者说这只是一种解决方法?

python tensorflow deep-learning tensorflow2.0 tensorflow2.x
1个回答
2
投票

你可以把整个数据集传给 fit 用于培训。如你所见,在 文件,第一个参数的可能值之一是。

  • A tf.data 数据集。应该返回一个元组,其中包括 (inputs, targets)(inputs, targets, sample_weights).

所以你只需要将你的数据集转换为这种格式(一个带有输入和目标的元组),然后将它传递给 fit:

BATCH_SIZE=32

builder = tfds.builder('mnist')
builder.download_and_prepare()
datasets = builder.as_dataset(batch_size=BATCH_SIZE)
raw_train_ds = datasets['train']
train_dataset_fit = raw_train_ds.map(
    lambda x: (tf.cast.dtypes(x['image'], tf.float32) / 255.0, x['label']))
history = model.fit(train_dataset_fit, epochs=NUM_EPOCHS)

这方面的一个问题是,它并不支持 validation_split 参数,但是,如 本指南, tfds 已经为你提供了数据分割的功能。所以你只需要得到测试拆分的数据集,按照上面的方法进行转换,然后将它作为 validation_datafit.


1
投票

方法1

谢谢 @jdehesa 我修改了我的代码。

  1. load dataset - 实际上,在数据集迭代器第一次调用'next'之前,它不会将数据加载到内存中,即使如此,我认为迭代器也会加载一部分数据(批处理),其大小等于BATCH_SIZE。
raw_train_ds, raw_validation_ds = builder.as_dataset(split=["train[:90%]", "train[10%:]"], batch_size=BATCH_SIZE)
  1. 将所有所需的转化为一种方法
def prepare_data(x):
    train_images, train_labels = x['image'],  x['label']
# TODO: resize image
    train_images = tf.cast(train_images,tf.float32)/ 255.0 
    # train_labels = tf.keras.utils.to_categorical(train_labels,num_classes=NUM_CLASSES) 
    train_labels = tf.one_hot(train_labels,NUM_CLASSES) 
    return (train_images, train_labels)
  1. 将这些变换应用于批处理(数据集)中的每个元素,使用的方法是 td.dataset.map
train_dataset_fit = raw_train_ds.map(prepare_data)
  1. 然后将这个数据集送入model.fit中--据我所知,model.fit会遍历数据集中的所有批次。
train_dataset_fit = raw_train_ds.map(prepare_data)
history = model.fit(train_dataset_fit, epochs=NUM_EPOCHS)
© www.soinside.com 2019 - 2024. All rights reserved.