我需要在一个需要比我的GPU更大内存的数据集上训练一个模型。
以下是我的步骤。
BATCH_SIZE=32
builder = tfds.builder('mnist')
builder.download_and_prepare()
datasets = builder.as_dataset(batch_size=BATCH_SIZE)
for record in raw_train_ds.take(1):
train_images, train_labels = record['image'], record['label']
print(train_images.shape)
train_images = train_images.numpy().astype(np.float32) / 255.0
train_labels = tf.keras.utils.to_categorical(train_labels)
history = model.fit(train_images,train_labels, epochs=NUM_EPOCHS, validation_split=0.2)
但在第2步,我准备了第一批的数据,错过了其他批次的数据,因为model.fit不在循环范围内(据我所知,只对第一批的数据有效)。另一方面,我不能去掉take(1),把model.fit方法移到循环下。因为是的,在这种情况下,我将处理所有的批次,但同时model.fill将在每次迭代结束时被调用,在这种情况下,它也将无法正常工作。
所以,我应该如何改变我的代码,以能够适当地与一个大的数据集使用model.fit工作吗? 你可以指出文章,任何文件,或只是建议如何处理它?
更新在我下面的帖子中(方法1),我描述了如何解决这个问题的一种方法--有没有其他更好的方法,或者说这只是一种解决方法?
你可以把整个数据集传给 fit
用于培训。如你所见,在 文件,第一个参数的可能值之一是。
- A
tf.data
数据集。应该返回一个元组,其中包括(inputs, targets)
或(inputs, targets, sample_weights)
.
所以你只需要将你的数据集转换为这种格式(一个带有输入和目标的元组),然后将它传递给 fit
:
BATCH_SIZE=32
builder = tfds.builder('mnist')
builder.download_and_prepare()
datasets = builder.as_dataset(batch_size=BATCH_SIZE)
raw_train_ds = datasets['train']
train_dataset_fit = raw_train_ds.map(
lambda x: (tf.cast.dtypes(x['image'], tf.float32) / 255.0, x['label']))
history = model.fit(train_dataset_fit, epochs=NUM_EPOCHS)
这方面的一个问题是,它并不支持 validation_split
参数,但是,如 本指南, tfds
已经为你提供了数据分割的功能。所以你只需要得到测试拆分的数据集,按照上面的方法进行转换,然后将它作为 validation_data
到 fit
.
谢谢 @jdehesa 我修改了我的代码。
raw_train_ds, raw_validation_ds = builder.as_dataset(split=["train[:90%]", "train[10%:]"], batch_size=BATCH_SIZE)
def prepare_data(x):
train_images, train_labels = x['image'], x['label']
# TODO: resize image
train_images = tf.cast(train_images,tf.float32)/ 255.0
# train_labels = tf.keras.utils.to_categorical(train_labels,num_classes=NUM_CLASSES)
train_labels = tf.one_hot(train_labels,NUM_CLASSES)
return (train_images, train_labels)
train_dataset_fit = raw_train_ds.map(prepare_data)
train_dataset_fit = raw_train_ds.map(prepare_data)
history = model.fit(train_dataset_fit, epochs=NUM_EPOCHS)