给出要使用 tf.data.Dataset.map 加载的图像路径对列表,但似乎只读取 2

问题描述 投票:0回答:1

我正在学习如何使用 Tensorflow 以及几年前在研讨会上提供的一段旧代码(即它应该已经过测试并且可以工作)。然而它不起作用,我的调查让我回到了加载数据并确保它已正确加载的步骤 1。

数据读取管道如下:

  1. 负载功能
@tf.function
def load(path_pair):
    image_path = path_pair[0]
    masks_path = path_pair[1]

    image_raw = tf.io.read_file(image_path)
    image = tf.io.decode_image(
        image_raw, channels=1, dtype=tf.uint8
    )

    masks_raw = tf.io.read_file(masks_path)
    masks = tf.io.decode_image(
        masks_raw, channels=NUM_CONTOURS, dtype=tf.uint8
    )

    return image / 255, masks / 255```
2. The function used to create the dataset
```def create_datasets(dataset_type):
    path_pairs = get_path_pairs(dataset_type) # this just gives a list of 2 x 2 tuples containing the image/mask path to load
    dataset = tf.data.Dataset.from_tensor_slices(path_pairs)
    dataset = dataset.shuffle(
        len(path_pairs),
        reshuffle_each_iteration=True,
    )
    dataset = dataset.map(load)

    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset```



When I use the create_datasets function on a dataset that contains 818 data pairs and check the size of the loaded dataset using len(dataset)it tells me there is only 2 items loaded. 
python tensorflow tensorflow-datasets
1个回答
0
投票

问题是,您正在对数据集进行批处理,因此当您使用

len(dataset)
时,您会得到批次数,而不是数据集中的元素数。 例如,要获取它们,您可以迭代您的批次:

num_samples = 0
for batch in dataset:
    num_samples += len(batch[0])
print(num_samples)
© www.soinside.com 2019 - 2024. All rights reserved.