如何重塑张量流数据集结构?

问题描述 投票:1回答:1

我正在使用自己的load_input / real_image函数对Pix2Pix网络进行编码,目前正在使用tf.data.Dataset创建数据集。问题是我的数据集形状错误:

我尝试应用一些tf.data.experimemtal函数,但它们都无法按我的要求工作。

    raw_data = [load_image_train(category)
                for category in SELECTED_CATEGORIES
                for _ in range(min(MAX_SAMPLES_PER_CATEGORY, category[1]))]
    train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
    train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE)
    train_dataset = train_dataset.batch(1)

我有:

我想要:

python tensorflow tensorflow-datasets
1个回答
0
投票

您可以通过两种方式完成。

选项1(首选)

raw_data1, raw_data2 = tf.unstack(raw_data, axis=1)
train_dataset = tf.data.Dataset.from_tensor_slices((raw_data1, raw_data2))

选项2

def map_fn(data):
    return tf.unstack(data, axis=0)

train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
train_dataset = train_dataset.map(map_fn)
© www.soinside.com 2019 - 2024. All rights reserved.