我是TF新手。由于某些原因,我必须使用TF1.10,在此我发现.ppm
不支持tf.image.decode_image
。
我的网络的最终目标是读取RGBD输入,并使用它生成更多特征(通过手工制作方法(例如采样法线),并最终利用这些附加特征和地面实况来计算损耗。
由于我的数据集很大,因此我使用tf.data.TextLineDataset
获取了input_fn
中文件路径列表的数据集,并尝试使用Dataset.map
生成要素。当我尝试解码.ppm
文件时,我发现了此问题。 (代码如下所示。)
或,还有其他方法可以避免在读取strings
中的图像之前将路径Tensor
转换为input_fn
,然后可以使用cv2.imread
吗?但是,如果我这样做,我想我必须使用所有Tensors
来构建我的数据集和迭代器,这可能会占用大量内存。 (也许我错了。)
或者,如果您认为我完全误解了dataset
和Estimator
的用法,请告诉我正确的方法。谢谢。
def input_fn(self, dataset, mode="train"):
self.dict_dataset_lists = {}
ds_rgb = os.path.expandvars(dataset["rgb"])
ds_d = os.path.expandvars(dataset["d"])
ds_gt = os.path.expandvars(dataset["gt"])
self.dict_dataset_lists["rgb"] = tf.data.TextLineDataset(ds_rgb)
self.dict_dataset_lists["d"] = tf.data.TextLineDataset(ds_d)
self.dict_dataset_lists["gt"] = tf.data.TextLineDataset(ds_gt)
with tf.name_scope("Dataset_API"):
tf_dataset = tf.data.Dataset.zip(self.dict_dataset_lists)
# load path to imgs(tensor)
if mode == "train":
tf_dataset = tf_dataset.repeat(self.parameters.max_epochs)
if self.parameters.shuffle:
tf_dataset = tf_dataset.shuffle(
buffer_size=self.parameters.steps_per_epoch * self.parameters.batch_size)
tf_dataset = tf_dataset.map(load_img, num_parallel_calls=1)
tf_dataset = tf_dataset.batch(self.parameters.batch_size)
tf_dataset = tf_dataset.prefetch(buffer_size=self.parameters.prefetch_buffer_size)
# make iterator
iterator = tf_dataset.make_one_shot_iterator()
dict_tf_input = iterator.get_next()
这里是如何在tf.data.Dataset.map
函数中从张量获取字符串部分的示例。
下面是我在代码中实现的步骤。
tf.py_function(get_path, [x], [tf.string])
装饰地图功能。您可以找到有关tf.py_function here的更多信息。bytes.decode(file_path.numpy())
来获得琴弦部分。代码-
%tensorflow_version 2.x
import tensorflow as tf
import numpy as np
def get_path(file_path):
print("file_path: ",bytes.decode(file_path.numpy()),type(bytes.decode(file_path.numpy())))
return file_path
train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(get_path, [x], [tf.string]))
for one_element in train_dataset:
print(one_element)
输出-
file_path: /content/bird.jpg <class 'str'>
(<tf.Tensor: shape=(), dtype=string, numpy=b'/content/bird.jpg'>,)
希望这能回答您的问题。