为什么batch()只返回一批?

问题描述 投票:0回答:1

我是图像处理的初学者。我有两个二进制类作为子目录,总共有 496 个图像,我对最后一批剩余的 13 个图像有问题。因此,最后一批中的张量不是 tf.dataset 张量 (32, 300, 300, 3),而是 (16, 300, 300, 3)。事实上,我注意到了:

  1. 洗牌后包含13批
  2. 分批后仅产生 1 批次(我认为是剩余批次)
  3. 当drop_remainder时数据为空

为什么洗牌后只剩下一批了?

image_size = (300, 300)
batch_size = 32

train_dataset = image_dataset_from_directory(
    dataset_dir,
    image_size=(image_size[0], image_size[1]),
    batch_size=batch_size,
    label_mode="binary",
    validation_split=0.2,
    subset="training",
    seed=123,
)

train_dataset = train_dataset.shuffle(1000)
train_dataset = train_dataset.batch(
    batch_size=batch_size, drop_remainder=True
).prefetch(buffer_size=AUTOTUNE)

print(train_dataset.cardinality().numpy())
python image tensorflow tensor
1个回答
0
投票

当您创建

train_dataset
时,它已经批处理到
size 32
中,因为您在
batch_size=32
 中指定了 
image_dataset_from_directory()

当您在已经批处理的数据集上再次调用

.batch()
时,它会尝试 对批次进行批处理,这有效地创建了批次的批次

代码修正

train_dataset = image_dataset_from_directory(
    dataset_dir,
    image_size=image_size,
    batch_size=batch_size,
    label_mode="binary",
    validation_split=0.2,
    subset="training",
    seed=123,
)


train_dataset = train_dataset.shuffle(1000).prefetch(buffer_size=AUTOTUNE)
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.