我有一个这样的串联数据集:
<ConcatenateDataset shapes: ((None, 1024), (1,)), types: (tf.float16, tf.float32)>
然后我需要对其进行洗牌和批处理,所以我做了:
dataset.shuffle(buffer_size=1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
它返回给我一个像这样的新数据集:
<PrefetchDataset shapes: ((None, None, 1024), (None, 1)), types: (tf.float16, tf.float32)>
但我期待
((None, None, 1024), (1,))
中的形状,因为我使用第二个维度作为预测值,预测值不需要留占位符,这意味着我需要保持输出单数。那么有没有办法在batching之后调整形状呢?