是否可以为TensorFlows的py_function
指定嵌套的输出类型?
作为特定情况,我希望py_function
的返回类型为py_function
,其中各个元素的尺寸不一定相同。是否可以为((tf.float32, tf.float32), (tf.float32, tf.float32))
指定此方法?
就像了解为什么在我的情况下有用时一样,我有一个py_function
,其中包含文件路径列表。 tf.data.Dataset
采用这些文件路径之一,并从文件中生成负例和正例以及相应的标签,结果为py_function
(请注意,标签不一定是单个值,但它们也不是相同的形状作为输入数据)。可以将此((positive_data, positive_label), (negative_data, negative_label))
映射到数据集,并且(具有上述结构)将一级展平以生成具有py_function
结构化元素的训练数据集。虽然可以通过一种解决方法将数据和标签堆叠在(data, label)
中,然后再将其堆叠(或从py_function完全非结构化开始,然后再进行配对),但这会导致设置混乱且令人困惑。如果py_function
可以直接输出py_function
类型,则可以进行更整洁的设置。
((tf.float32, tf.float32), (tf.float32, tf.float32))
的输出类型不能为嵌套序列。但是,当将tf.py_function
与tf.py_function
API一起使用时,您需要创建一个包装函数(在下面的示例中为tf.data
),然后可以将输出嵌套在该函数中。
tf_foo
这也在import tensorflow as tf
# The python function.
def foo(x):
return x, x, x, x
# Wrap the python function to make it compatible with `tf.data.Dataset.map`.
def tf_foo(x):
a, b, c, d = tf.py_function(foo, [x], Tout=[tf.float32, tf.float32, tf.float32, tf.float32])
return (a, b), (c, d)
dset = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
dset.map(tf_foo)
# <MapDataset shapes: ((<unknown>, <unknown>), (<unknown>, <unknown>)),
# types: ((tf.float32, tf.float32), (tf.float32, tf.float32))>
中得到了证明。