我已将 MNIST 转换为 TFRecord。我使用以下代码验证了记录的格式有效:
example.ParseFromString(raw_record.numpy())
print(example)
输出:
features {
feature {
key: "depth"
value {
int64_list {
value: 1
}
}
}
...
然后我使用 TFRecordDataset 使用映射函数读取和解析文件:
def parse(record):
feature = {
'height': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'width': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'depth': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'label': tf.io.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'image_raw': tf.io.FixedLenFeature([], tf.string)
}
features = tf.io.parse_single_example(record, feature)
image = tf.io.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([28 * 28])
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
image = tf.reshape(image, [28, 28, 1])
label = tf.cast(features['label'], tf.int32)
return image, label
BATCH_SIZE = 64 * 10
NUM_PARALLEL_BATCHES = 2
LOCAL_PATTERN = './data/*'
ds = tf.data.Dataset.list_files(LOCAL_PATTERN)
ds = tf.data.TFRecordDataset(filenames=ds) \
.map(parse) \
.batch(batch_size=NUM_PARALLEL_BATCHES)
这会导致错误:
文件 “/Users/Projects/mobileye/lib/python3.7/site-packages/tensorflow/python/eager/execute.py”, 第 60 行,在快速执行中 输入、属性、num_outputs)tensorflow.python.framework.errors_impl.DataLossError:已损坏 记录在 0 [[节点 IteratorGetNext (定义于 cnn_mnist_pipe_gcp.py:94)]] [操作:__inference_train_function_885]
编辑:编写 tfrecords 的代码:
with tf.python_io.TFRecordWriter(filename) as writer:
image_raw = images[index].tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())
当我尝试从 Roboflow 导出的数据集作为
efficientDet-object-detection
文件训练 tfrecord
模型时,我遇到了同样的问题。这是我解决这个问题的方法:
此错误意味着您生成的
tfrecord
文件已损坏。使用以下脚本检查 tfrecord
文件的状态:
import tensorflow as tf
def is_tfrecord_corrupted(tfrecord_file):
try:
for record in tf.data.TFRecordDataset(tfrecord_file):
# Attempt to parse the record
_ = tf.train.Example.FromString(record.numpy())
except tf.errors.DataLossError as e:
print(f"DataLossError encountered: {e}")
return True
except Exception as e:
print(f"An error occurred: {e}")
return True
return False
# Replace with your TFRecord file paths
tfrecord_files = ['your_test_record_fname', 'your_train_record_fname']
for tfrecord_file in tfrecord_files:
if is_tfrecord_corrupted(tfrecord_file):
print(f"The TFRecord file {tfrecord_file} is corrupted.")
else:
print(f"The TFRecord file {tfrecord_file} is fine.")
为了修复损坏的
tfrecords
,我将数据集导出为 pascal-voc
格式,然后编写了托管在 GitHub 上的以下脚本,以从 tfrecords
格式化数据集生成新的 pascal-voc
。
tfrecords
的脚本在这里:https://github.com/arrafi-musabbir/license-plate-detection-recognition/blob/main/generate_tfrecord.pylabel-map-pbtxt
:label_path = "your label_map.pbtxt path"
# modify according to your dataset class names
labels = [{'name':'license', 'id':1}]
with open(label_path, 'w') as f:
for label in labels:
f.write('item { \n')
f.write('\tname:\'{}\'\n'.format(label['name']))
f.write('\tid:{}\n'.format(label['id']))
f.write('}\n')
python generate_tfrecord.py -x {train_dir_path} -l {labelmap_path} -o {new_train_record_path}
python generate_tfrecord.py -x {valid_dir_path} -l {labelmap_path} -o {new_valid_record_path}
python generate_tfrecord.py -x {test_dir_path} -l {labelmap_path} -o {new_test_record_path}
之后,再次运行
is_tfrecord_corrupted(tfrecord_file)
,你会看到tfrecords
没问题了。