让我们考虑将数据集分成多个TFRecord文件:
1.tfrecord
,2.tfrecord
,我想从同一TFRecord文件生成大小为t
(例如3
)的序列,该序列由连续元素组成,我不希望序列具有属于不同TFRecord文件的元素。
例如,如果我们有两个包含如下数据的TFRecord文件:
1.tfrecord
:{0, 1, 2, ..., 7}
2.tfrecord
:{1000, 1001, 1002, ..., 1007}
没有任何改组,我想得到以下批次:
0, 1, 2
,1, 2, 3
,5, 6, 7
,1000, 1001, 1002
,1001, 1002, 1003
,1005, 1006, 1007
,0, 1, 2
,[我知道如何使用tf.data.Dataset.window
或tf.data.Dataset.batch
生成序列数据,但我不知道如何防止序列包含来自不同文件的元素。
我正在寻找可扩展的解决方案,即该解决方案应该可以处理数百个TFRecord文件。
下面是我的失败尝试(完全可重复的示例:):>
import tensorflow as tf # **************************** # Generate toy TF Record files def _create_example(i): example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))}) return tf.train.Example(features=example) def parse_fn(serialized_example): return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data'] num_tf_records = 2 records_per_file = 8 options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) for i in range(num_tf_records): with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer: for j in range(records_per_file): example = _create_example(j + 1000 * i) writer.write(example.SerializeToString()) # **************************** # **************************** data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\ .map(lambda x: parse_fn(x)) data = data.window(3, 1, 1, True)\ .repeat(-1)\ .flat_map(lambda x: x.batch(3))\ .batch(16) data_it = data.make_initializable_iterator() next_element = data_it.get_next() with tf.Session() as sess: sess.run(data_it.initializer) print(sess.run(next_element))
输出:
[[ 0 1 2] # good
[ 1 2 3] # good
[ 2 3 4] # good
[ 3 4 5] # good
[ 4 5 6] # good
[ 5 6 7] # good
[ 6 7 1000] # bad – mix of elements from 0.tfrecord and 1.tfrecord
[ 7 1000 1001] # bad
[1000 1001 1002] # good
[1001 1002 1003] # good
[1002 1003 1004] # good
[1003 1004 1005] # good
[1004 1005 1006] # good
[1005 1006 1007] # good
[ 0 1 2] # good
[ 1 2 3]] # good
让我们考虑将数据集分成多个TFRecord文件:1.tfrecord,2.tfrecord等。我想生成大小为t(例如3)的序列,该序列由相同的连续元素组成。
flat_map
该功能即可制作windo
数据集:def make_dataset_from_filename(filename):
data = tf.data.TFRecordDataset(filename, compression_type='GZIP')\
.map(lambda x: parse_fn(x))
data = data.window(3, 1, 1, True)\
.repeat(-1)\
.flat_map(lambda x: x.batch(3))\
.batch(16)
tf.data.Dataset.list_files('*.tfrecord').flat_map(make_dataset_from_filename)