我正在测试tf.data(),这是现在建议批量添加数据的推荐方式,但是,我正在加载自定义数据集,因此我需要'str'格式的文件名。但是,当创建tf.Dataset.from_tensor_slices时,它们是Tensor对象。
def load_image(file, label):
nifti = np.asarray(nibabel.load(file).get_fdata()) # <- here is the problem
xs, ys, zs = np.where(nifti != 0)
nifti = nifti[min(xs):max(xs) + 1, min(ys):max(ys) + 1, min(zs):max(zs) + 1]
nifti = nifti[0:100, 0:100, 0:100]
nifti = np.reshape(nifti, (100, 100, 100, 1))
nifti = tf.convert_to_tensor(nifti, np.float32)
return nifti, label
def load_image_wrapper(file, labels):
file = tf.py_function(load_image, [file, labels], (tf.string, tf.int32))
return file
dataset = tf.data.Dataset.from_tensor_slices((train, labels))
dataset = dataset.map(load_image_wrapper, num_parallel_calls=6)
dataset = dataset.batch(6)
dataset = dataset.prefetch(buffer_size=6)
iterator = iter(dataset)
batch_of_images = iterator.get_next()
这里是错误:typeerror expected str bytes or os.pathlike object not Tensor
我尝试使用'py_function'包装器,但无济于事。有什么想法吗?
解决了TensorFlow 2.1的问题:
def load_image(file, label):
nifti = np.asarray(nibabel.load(file.numpy().decode('utf-8')).get_fdata())
xs, ys, zs = np.where(nifti != 0)
nifti = nifti[min(xs):max(xs) + 1, min(ys):max(ys) + 1, min(zs):max(zs) + 1]
nifti = nifti[0:100, 0:100, 0:100]
nifti = np.reshape(nifti, (100, 100, 100, 1))
nifti = tf.convert_to_tensor(nifti, np.float64)
return nifti, label
def load_image_wrapper(file, labels):
return tf.py_function(load_image, [file, labels], [tf.float64, tf.float64])
dataset = tf.data.Dataset.from_tensor_slices((train, labels))
dataset = dataset.map(load_image_wrapper, num_parallel_calls=6)
dataset = dataset.batch(2)
dataset = dataset.prefetch(buffer_size=2)
iterator = iter(dataset)
batch_of_images = iterator.get_next()