用tf.data.Dataset()矢量化映射时出错。

问题描述 投票:0回答:1

我有一个图像数据集,我通过使用 tf.data.Dataset.list_files().

在我 .map() 函数,我读取和解码图像,就像下面这样。

def map_function(filepath):
    image = tf.io.read_file(filename=filepath)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
    return image

如果我使用(下面这个函数)

 dataset = tf.data.Dataset.list_files(file_pattern=...)
 dataset = dataset.map(map_function)
 for image in dataset.as_numpy_iterator():
    #Correctly outputs the numpy array, no error is displayed/encountered
    print(image)

但是,如果我使用(下面这个抛出错误)。

  dataset = tf.data.Dataset.list_files(file_pattern=...)
  dataset = dataset.batch(32).map(map_function)
  for image in dataset.as_numpy_iterator():
    #Error is displayed 
      print(image)

ValueError: Shape must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') Shape must be rank 0 but is rank 1 for 'ReadFile' (op: 'ReadFile') with input shapes: [?].

现在,根据这个。https:/www.tensorflow.orgguidedata_performance#vectorizing_mapping我的代码不应该失败,预处理步骤应该优化(批处理与一次性处理)。

我的代码哪里错了?

*** 如果我使用 map().batch() 行得通

python tensorflow tensorflow2.0 tensorflow-datasets tensorflow2.x
1个回答
0
投票

出现错误的原因是 map_function 期望使用非批次的元素,但在第二个例子中,你给了它批次的元素。

中的例子是 https:/www.tensorflow.orgguidedata_performance 是通过定义一个 increment 函数,它可以同时适用于分批和非分批元素,因为在一个分批元素上加1,比如[1, 2, 3],会得到[2, 3, 4]。

def increment(x):
    return x+1

要使用矢量化,你需要写一个 vectorized_map_function,它接收了一个非字节元素的向量,将map函数应用于向量中的每个元素,然后返回一个结果的向量。

不过在你的情况下,我不认为矢量化会有明显的影响,因为读取和解码文件的成本远远高于调用函数的开销。当map函数非常便宜,以至于函数调用的时间与在map函数中做实际工作的时间相当时,矢量化的影响最大。

© www.soinside.com 2019 - 2024. All rights reserved.