我是图像处理的初学者。我有两个二进制类作为子目录,总共有 496 个图像,我对最后一批剩余的 13 个图像有问题。因此,最后一批中的张量不是 tf.dataset 张量 (32, 300, 300, 3),而是 (16, 300, 300, 3)。事实上,我注意到了:
为什么洗牌后只剩下一批了?
image_size = (300, 300)
batch_size = 32
train_dataset = image_dataset_from_directory(
dataset_dir,
image_size=(image_size[0], image_size[1]),
batch_size=batch_size,
label_mode="binary",
validation_split=0.2,
subset="training",
seed=123,
)
train_dataset = train_dataset.shuffle(1000)
train_dataset = train_dataset.batch(
batch_size=batch_size, drop_remainder=True
).prefetch(buffer_size=AUTOTUNE)
print(train_dataset.cardinality().numpy())
当您创建
train_dataset
时,它已经批处理到 size 32
中,因为您在 batch_size=32
中指定了
image_dataset_from_directory()
当您在已经批处理的数据集上再次调用
时,它会尝试 对批次进行批处理,这有效地创建了批次的批次
.batch()
代码修正
train_dataset = image_dataset_from_directory(
dataset_dir,
image_size=image_size,
batch_size=batch_size,
label_mode="binary",
validation_split=0.2,
subset="training",
seed=123,
)
train_dataset = train_dataset.shuffle(1000).prefetch(buffer_size=AUTOTUNE)