我正在使用Tensorflow 2.0将numpy数组保存/加载到TFRecordDataset中。似乎完全缺少2.0的文档,而且api也不简单。
我在这里创建了一个可复制的最小示例作为笔记本:https://gist.github.com/vicpara/3b4ea00553a1990620a2df77d8b6aa1f。
感谢您的任何建议。
一种方法是遍历数据集并保存特征,示例中提供的标签对:
import numpy as np
import tensorflow as tf
import random
import os
# For Tensorflow <2.0:
#tf.enable_eager_execution()
SAVE_PATH = "."
def _make_tf_example(feature, labels):
bytes_feature = tf.io.serialize_tensor(feature)
bytes_labels = tf.io.serialize_tensor(labels)
feature_mapping = {
'feature': _bytes_feature(bytes_feature),
'labels': _bytes_feature(bytes_labels),
}
return tf.train.Example(features=tf.train.Features(feature=feature_mapping))
def _bytes_feature(value):
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _save_single_entry(feature, label, path_to_save, idx):
tfrecord_full_filename = os.path.join(path_to_save, f"my_tfrecord_{idx:05}.tfrecord")
with tf.io.TFRecordWriter(tfrecord_full_filename) as writer:
tf_example = _make_tf_example(feature, label)
writer.write(tf_example.SerializeToString())
然后:(假设orig
是您的tf.data.Dataset
)
for idx, (feature, label) in enumerate(orig):
_save_single_entry(feature, label, SAVE_PATH, idx)