我正在使用自己的load_input / real_image函数对Pix2Pix网络进行编码,目前正在使用tf.data.Dataset创建数据集。问题是我的数据集形状错误:
我尝试应用一些tf.data.experimemtal函数,但它们都无法按我的要求工作。
raw_data = [load_image_train(category)
for category in SELECTED_CATEGORIES
for _ in range(min(MAX_SAMPLES_PER_CATEGORY, category[1]))]
train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(1)
我有:
我想要:
您可以通过两种方式完成。
选项1(首选)
raw_data1, raw_data2 = tf.unstack(raw_data, axis=1)
train_dataset = tf.data.Dataset.from_tensor_slices((raw_data1, raw_data2))
选项2
def map_fn(data):
return tf.unstack(data, axis=0)
train_dataset = tf.data.Dataset.from_tensor_slices(raw_data)
train_dataset = train_dataset.map(map_fn)