我有一个图像数据集,我通过使用 tf.data.Dataset.list_files()
.
在我 .map()
函数,我读取和解码图像,就像下面这样。
def map_function(filepath):
image = tf.io.read_file(filename=filepath)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
return image
如果我使用(下面这个函数)
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.map(map_function)
for image in dataset.as_numpy_iterator():
#Correctly outputs the numpy array, no error is displayed/encountered
print(image)
但是,如果我使用(下面这个抛出错误)。
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.batch(32).map(map_function)
for image in dataset.as_numpy_iterator():
#Error is displayed
print(image)
ValueError: Shape must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') Shape must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') with input shapes: [?].
现在,根据这个。https:/www.tensorflow.orgguidedata_performance#vectorizing_mapping我的代码不应该失败,预处理步骤应该优化(批处理与一次性处理)。
我的代码哪里错了?
*** 如果我使用 map().batch()
行得通
出现错误的原因是 map_function
期望使用非批次的元素,但在第二个例子中,你给了它批次的元素。
中的例子是 https:/www.tensorflow.orgguidedata_performance 是通过定义一个 increment
函数,它可以同时适用于分批和非分批元素,因为在一个分批元素上加1,比如[1, 2, 3],会得到[2, 3, 4]。
def increment(x):
return x+1
要使用矢量化,你需要写一个 vectorized_map_function
,它接收了一个非字节元素的向量,将map函数应用于向量中的每个元素,然后返回一个结果的向量。
不过在你的情况下,我不认为矢量化会有明显的影响,因为读取和解码文件的成本远远高于调用函数的开销。当map函数非常便宜,以至于函数调用的时间与在map函数中做实际工作的时间相当时,矢量化的影响最大。