在不丢失基数信息的情况下对 TensorFlow 数据集进行窗口处理?

问题描述 投票:0回答:1

tf.data.Dataset.window
返回一个新的数据集,其元素是数据集,这些嵌套数据集的元素是所需大小的窗口。如果您有一个数据集(例如,
Dataset.range(10)
,并且想要像
[0 1 2] [1 2 3] ... [7 8 9]
这样的窗口数据集),可以使用
window
flat_map
来实现这一点:

>>> d = tf.data.Dataset.range(10).window(3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(3))
>>> print(list(d))
[<tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 1, 2])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 3, 4])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 4, 5])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([4, 5, 6])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([5, 6, 7])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([6, 7, 8])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([7, 8, 9])>]

但是,

flat_map
会导致数据集丢失基数信息:

>>> d.cardinality.numpy()
<tf.Tensor: shape=(), dtype=int64, numpy=-2>

(-2 是 UNKNOWN_CARDINALITY;参见 Tensorflow 2.0:用于展平数据集的数据集的 flat_map() 返回基数 -2

我想创建此类窗口的数据集,同时保留基数信息。使用未知基数的数据集的一个小烦恼是 Keras 训练进度条需要先运行一个 epoch,然后才能生成 ETA。我尝试了

.take(n_windows)
我自己计算
n_windows
,但仍然返回了带有
UNKNOWN_CARDINALITY
的数据集。

是否有某种方法可以在不丢失基数信息的情况下对数据集进行窗口化?

python tensorflow keras tensorflow-datasets
1个回答
2
投票

主要问题是基数是静态计算的。因此无法计算

flat_map
运算的基数。您可以参考这个问题

解决方案,如您所知

flat_map
输入和输出的关系,是使用
tf.data.experimental.assert_cardinality
自行设置基数。

这是如何设置窗口基数的示例:

import tensorflow as tf

ds = tf.data.Dataset.range(10)
print("Original cardinality -> ", ds.cardinality().numpy())
# Output:
# Original cardinality -> 10

ds = ds.window(3, shift=1, drop_remainder=True)
# cardinality at this point is still known.
# as drop_remainder is true, window cardinality will be <= original cardinality
window_cardinality = ds.cardinality()
print("window cardinality ->",window_cardinality.numpy())
# Output:
# window cardinality -> 8

ds = ds.flat_map(lambda x: x.batch(3))
# after flat_map the inferred cardinality is lost.
print("flat cardinality ->",ds.cardinality().numpy())
# Output:
# flat cardinality -> -2

# as we know the flat_map relation is 1:1 we can set the cardinality back to the original value.
ds = ds.apply(tf.data.experimental.assert_cardinality(window_cardinality))
print("dataset cardinality ->",ds.cardinality().numpy())
print("length of dataset ->", len(list(ds)))
# Output: 
# dataset cardinality -> 8
# length of dataset -> 8

for idx, x in ds.enumerate():
    print(f"{idx} -> {x}")
# Output:
# 0 -> [0 1 2]
# 1 -> [1 2 3]
# 2 -> [2 3 4]
# 3 -> [3 4 5]
# 4 -> [4 5 6]
# 5 -> [5 6 7]
# 6 -> [6 7 8]
# 7 -> [7 8 9]
© www.soinside.com 2019 - 2024. All rights reserved.