我用tf.estimator训练了一个resnet,在训练过程中保存了模型。保存的文件包括.data
,.index
和.meta
。我想加载这个模型并获得新图像的预测。使用tf.data.Dataset
在训练期间将数据输入模型。我已经密切关注了here给出的resnet实现。
我想使用feed_dict恢复模型并将输入提供给节点。
第一次尝试
#rebuild input pipeline
images, labels = input_fn(data_dir, batch_size=32, num_epochs=1)
#rebuild graph
prediction= imagenet_model_fn(images,labels,{'batch_size':32,'data_format':'channels_first','resnet_size':18},mode = tf.estimator.ModeKeys.EVAL).predictions
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(r'./model')
saver.restore(sess, ckpt.model_checkpoint_path)
while True:
try:
pred,im= sess.run([prediction,images])
print(pred)
except tf.errors.OutOfRangeError:
break
我提供了一个数据集,该数据集使用classifier.evaluate
在同一模型上进行了评估,但上述方法给出了错误的预测。该模型为所有图像提供相同的类和概率1.0。
第二次尝试
saver = tf.train.import_meta_graph(r'.\resnet\model\model-3220.meta')
sess = tf.Session()
saver.restore(sess,tf.train.latest_checkpoint(r'.\resnet\model'))
graph = tf.get_default_graph()
inputImage = graph.get_tensor_by_name('image:0')
logits= graph.get_tensor_by_name('logits:0')
#Get prediction
print(sess.run(logits,feed_dict={inputImage:newimage}))
与classifier.evaluate
相比,这也给出了错误的预测。我甚至可以在没有sess.run(logits)
的情况下运行feed_dict
!
第三次尝试
def serving_input_fn():
receiver_tensor = {'feature': tf.placeholder(shape=[None, 384, 256, 3], dtype=tf.float32)}
features = {'feature': receiver_tensor['images']}
return tf.estimator.export.ServingInputReceiver(features, receiver_tensor)
它失败了
Traceback (most recent call last):
File "imagenet_main.py", line 213, in <module>
tf.app.run(argv=[sys.argv[0]] + unparsed)
File "C:\Users\Photogauge\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 124, in run
_sys.exit(main(argv))
File "imagenet_main.py", line 204, in main
resnet.resnet_main(FLAGS, imagenet_model_fn, input_fn)
File "C:\Users\Photogauge\Desktop\iprings_images\models-master\models-master\official\resnet\resnet.py", line 527, in resnet_main
classifier.export_savedmodel(export_dir_base=r"C:\Users\Photogauge\Desktop\iprings_images\models-master\models-master\official\resnet\export", serving_input_receiver_fn=serving_input_fn)
File "C:\Users\Photogauge\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 528, in export_savedmodel
config=self.config)
File "C:\Users\Photogauge\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 725, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "imagenet_main.py", line 200, in imagenet_model_fn
loss_filter_fn=None)
File "C:\Users\Photogauge\Desktop\iprings_images\models-master\models-master\official\resnet\resnet.py", line 433, in resnet_model_fn
tf.argmax(labels, axis=1), predictions['classes'])
File "C:\Users\Photogauge\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\util\deprecation.py", line 316, in new_func
return func(*args, **kwargs)
File "C:\Users\Photogauge\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\math_ops.py", line 208, in argmax
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
File "C:\Users\Photogauge\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 508, in arg_max
name=name)
File "C:\Users\Photogauge\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 528, in _apply_op_helper
(input_name, err))
ValueError: Tried to convert 'input' to a tensor and failed. Error: None values not supported.
我用于培训和构建模型的代码如下:
解析数据集的规范:
def parse_record(raw_record, is_training):
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/class/label':
tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
}
parsed = tf.parse_single_example(raw_record, keys_to_features)
image = tf.image.decode_image(
tf.reshape(parsed['image/encoded'], shape=[]),3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
label = tf.cast(
tf.reshape(parsed['image/class/label'], shape=[]),
dtype=tf.int32)
return image, tf.one_hot(label,2)
以下函数解析数据并创建用于培训的批处理
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset = tf.data.Dataset.from_tensor_slices(
filenames(is_training, data_dir))
if is_training:
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(lambda value: parse_record(value, is_training),
num_parallel_calls=5)
dataset = dataset.prefetch(batch_size)
if is_training:
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
如下创建分类器,用于训练集和验证集的评估
classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={
'resnet_size': flags.resnet_size,
'data_format': flags.data_format,
'batch_size': flags.batch_size,
})
#Training cycle
classifier.train(
input_fn=lambda: input_function(
training_phase=True, flags.data_dir, flags.batch_size, flags.epochs_per_eval),
hooks=[logging_hook])
# Evaluate the model
eval_results = classifier.evaluate(input_fn=lambda: input_function(
training_phase=False, flags.data_dir, flags.batch_size))
这就是我尝试从模型中加载和获取预测的方法。
恢复已保存模型并对其进行推理的正确方法是什么。我想直接提供图像而不使用tf.data.Dataset
。
更新
ckpt
的价值是在运行ckpt = tf.train.get_checkpoint_state(r'./model')
之后
model_checkpoint_path:“。/ model \ model.ckpt-5980”all_model_checkpoint_paths:“./ modelmodel_checkpoint_606”all_model_checkpoint_paths:“./ modelmodel_checkpoint_paths:”。/ modelmodel.ckpt- 5520“all_model_checkpoint_paths:”./ model.model.ckpt-5521“all_model_checkpoint_paths:”./ model.model.ckpt-5980“saver.restore
的完整路径给出相同的输出在所有情况下相同的模型,model.ckpt-5980
恢复注意:只要有更多信息,这个答案就会发生变化。我不确定这是最合适的方法,但感觉比仅使用评论更好。如果这是不合适的,请随意发表评论。
我对import_meta_graph
方法没有太多经验,但如果sess.run(logits)
没有抱怨就运行,我认为元图也包含你的输入管道。
我刚刚做的一个快速测试证实,当你加载元图时,管道确实也恢复了。这意味着,您实际上并没有通过feed_dict
传递任何内容,因为输入来自检查点被使用时使用的基于Dataset
的管道。根据我的研究,我找不到为图形提供不同输入功能的方法。
你的代码看起来对我来说,所以我怀疑加载的检查点文件是不正确的。我在评论中提出了一些澄清,我会在信息可用后立即更新此部分
如果您有模型pb或pb.txt,那么推理很容易。使用预测模块,我们可以进行推理。查看here了解更多信息。对于图像数据,它将类似于下面的示例。希望这可以帮助 !!
示例代码:
import numpy as np
import matplotlib.pyplot as plt
def extract_data(index=0, filepath='data/cifar-10-batches-bin/data_batch_5.bin'):
bytestream = open(filepath, mode='rb')
label_bytes_length = 1
image_bytes_length = (32 ** 2) * 3
record_bytes_length = label_bytes_length + image_bytes_length
bytestream.seek(record_bytes_length * index, 0)
label_bytes = bytestream.read(label_bytes_length)
image_bytes = bytestream.read(image_bytes_length)
label = np.frombuffer(label_bytes, dtype=np.uint8)
image = np.frombuffer(image_bytes, dtype=np.uint8)
image = np.reshape(image, [3, 32, 32])
image = np.transpose(image, [1, 2, 0])
image = image.astype(np.float32)
result = {
'image': image,
'label': label,
}
bytestream.close()
return result
predictor_fn = tf.contrib.predictor.from_saved_model(
export_dir = saved_model_dir, signature_def_key='predictions')
N = 1000
labels = []
images = []
for i in range(N):
result = extract_data(i)
images.append(result['image'])
labels.append(result['label'][0])
output = predictor_fn(
{
'images': images,
}
)