我正在学习如何使用 Tensorflow 以及几年前在研讨会上提供的一段旧代码(即它应该已经过测试并且可以工作)。然而它不起作用,我的调查让我回到了加载数据并确保它已正确加载的步骤 1。
数据读取管道如下:
@tf.function
def load(path_pair):
image_path = path_pair[0]
masks_path = path_pair[1]
image_raw = tf.io.read_file(image_path)
image = tf.io.decode_image(
image_raw, channels=1, dtype=tf.uint8
)
masks_raw = tf.io.read_file(masks_path)
masks = tf.io.decode_image(
masks_raw, channels=NUM_CONTOURS, dtype=tf.uint8
)
return image / 255, masks / 255```
2. The function used to create the dataset
```def create_datasets(dataset_type):
path_pairs = get_path_pairs(dataset_type) # this just gives a list of 2 x 2 tuples containing the image/mask path to load
dataset = tf.data.Dataset.from_tensor_slices(path_pairs)
dataset = dataset.shuffle(
len(path_pairs),
reshuffle_each_iteration=True,
)
dataset = dataset.map(load)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset```
When I use the create_datasets function on a dataset that contains 818 data pairs and check the size of the loaded dataset using len(dataset)it tells me there is only 2 items loaded.
问题是,您正在对数据集进行批处理,因此当您使用
len(dataset)
时,您会得到批次数,而不是数据集中的元素数。
例如,要获取它们,您可以迭代您的批次:
num_samples = 0
for batch in dataset:
num_samples += len(batch[0])
print(num_samples)