我想使用tf数据窗口来创建数据集。如何在下面的代码中flat_map批量2个输入?我在网上找到的所有示例都是只有1个输入,我想用tf数据窗口来创建数据集。
import tensorflow as tf
def make_window_dataset(ds, window_size=3, shift=1, stride=1):
windows = ds.window(window_size, shift=shift, stride=stride)
def sub_to_batch(sub, sub2):
return sub.batch(window_size, drop_remainder=True) # Pls fix here. How to batch 2 param?
windows = windows.flat_map(sub_to_batch)
return windows
# 2 input to dataset
ds = tf.data.Dataset.from_tensor_slices(([[1, 2],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]], [[1, 2],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]]))
# 1 validation data set
v = tf.data.Dataset.from_tensor_slices([1,3,5,7,1,3,5,7])
ds = make_window_dataset(tf.data.Dataset.zip((ds,v))).batch(2).repeat(2)
for example in ds.take(10):
print('---', example.numpy())
model.fit(ds, ...
答案是我应该在sub_to_batch里面用适当的元组压缩。
import tensorflow as tf
tf.compat.v1.enable_v2_behavior()
def make_window_dataset(ds, window_size=3, shift=1, stride=1):
windows = ds.window(window_size, shift=shift, stride=stride)
def sub_to_batch(sub, sub2):
sub2batch = sub2.batch(window_size, drop_remainder=True)
return tf.data.Dataset.zip(((sub[0].batch(window_size, drop_remainder=True), sub[1])
, (sub2batch, sub2batch)))
# windows.flat_map(sub_to_batch)
windows = windows.flat_map(sub_to_batch)
return windows
ds = tf.data.Dataset.from_tensor_slices((([[1, 2],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]]
, [[2, 3],[3,4],[5,6],[7,8],[1, 2],[3,4],[5,6],[7,8]]),(
[[1],[3],[5],[7],[1],[3],[5],[7]])))
ds = make_window_dataset(ds)#.batch(2).repeat(2)
print('---sssss')
for example in ds.take(10):
print('---', example)
model.fit(ds, ...