dataset.repeat() 如何与 model.fit() 一起使用?

问题描述 投票:0回答:1

我有一个包含 1061 个样本的数据集,我正在使用以下内容:

epochs = 50 

batch_size=20

dataset = dataset.repeat(epochs).batch(batch_size)

现在如果我使用:

model.fit(dataset , epochs = epochs , batch_size = batch_size)

在一个训练周期中输入模型的样本数量是多少?是1061吗?或者 1061 * 50 = 53050 个样本?

我搜索了 dataset.repeat() 是如何工作的,但我仍然对它如何与 model.fit() 一起工作具体感到困惑。

python tensorflow keras tf.dataset
1个回答
0
投票

你理解正确!当您调用

dataset.repeat(epochs)
时,您将创建一个重复原始样本 50 次(纪元 = 50)的数据集。含义该数据集将包含 1061 * 50 = 53050 个样本。

每个 epoch 将处理 1061 个样本。

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models

# Create a sample dataset
num_samples = 1061
num_features = 10
X = np.random.rand(num_samples, num_features).astype(np.float32)
y = np.random.randint(0, 2, size=(num_samples,)).astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((X, y))
epochs = 50
batch_size = 20

# Repeat
dataset = dataset.repeat(epochs).batch(batch_size)

model = models.Sequential([
    layers.Dense(32, activation='relu', input_shape=(num_features,)),
    layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Counting samples
class SampleCountCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        samples_processed = (epoch + 1) * num_samples
        print(f"Epoch {epoch + 1}: {samples_processed} samples processed")

# Model fitting
model.fit(dataset, epochs=epochs, steps_per_epoch=num_samples // batch_size, callbacks=[SampleCountCallback()])

enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.