在 Tensorflow 中批处理时如何调整数据集形状变化

问题描述 投票:0回答:1

我有一个这样的串联数据集:

<ConcatenateDataset shapes: ((None, 1024), (1,)), types: (tf.float16, tf.float32)>

然后我需要洗牌和批处理它,所以我做了:

dataset.shuffle(buffer_size=1024).batch(hyperparams["batch_size"]).prefetch(tf.data.experimental.AUTOTUNE)

它返回给我一个像这样的新数据集:

<PrefetchDataset shapes: ((None, None, 1024), (None, 1)), types: (tf.float16, tf.float32)>

但我认为我希望形状保持不变 ((None, None, 1024), (1,)),因为我使用第二维作为预测值,这意味着我需要保持输出奇异。那么有没有办法在batching之后调整形状呢?

python tensorflow deep-learning artificial-intelligence tensorflow2.0
1个回答
0
投票

是的,您可以在批处理后调整数据集的形状,方法是使用 tf.data.Dataset.map() 方法对每个批次应用一个函数。

要保留第二个维度的形状并保持其单一,您可以使用 tf.squeeze() 函数删除任何大小为 1 的维度。这是一个示例函数,您可以使用它来调整批次的形状:

def adjust_batch_shape(x, y):
    x = tf.squeeze(x, axis=1) # remove the second dimension if it has size 1
    y = tf.squeeze(y, axis=1) # remove the second dimension if it has size 1
    return x, y

然后您可以将此函数与 tf.data.Dataset.map() 方法结合使用来调整批次的形状:

dataset = dataset.shuffle(buffer_size=1024).batch(hyperparams["batch_size"]).map(adjust_batch_shape).prefetch(tf.data.experimental.AUTOTUNE)

在这个例子中,adjust_batch_shape() 函数有两个参数 x 和 y,分别代表批次的特征和标签。它使用 tf.squeeze() 从 x 和 y 中删除任何大小为 1 的维度,然后返回调整后的 x 和 y。

请注意,此函数假定 y 的第二个维度的大小为 1。如果不是这种情况,您可能需要相应地修改函数以保留 y 的形状。

© www.soinside.com 2019 - 2024. All rights reserved.