训练我的'参数'(w1,w2,Conv网中过滤器的权重)后,将它们保存为参数= sess.run(参数)
我拍摄图像img = [1,64,64,3],并将其传递给mypredict(x,参数)函数进行预测,但它给出了错误。功能如下。什么是错误的建议。
def forward_propagation(X,参数):
W1 = parameters['W1']
W2 = parameters['W2']
Z1 = tf.nn.conv2d(X,W1,strides=[1,1,1,1],padding='SAME')
A1 = tf.nn.relu(Z1)
P1 = tf.nn.max_pool(A1,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME')
Z2 = tf.nn.conv2d(P1,W2,strides=[1,1,1,1],padding='SAME')
A2 = tf.nn.relu(Z2)
P2 = tf.nn.max_pool(A2,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')
P2 = tf.contrib.layers.flatten(P2)
Z3 = tf.contrib.layers.fully_connected(P2,num_outputs=6,activation_fn=None)
return Z3
def mypredict(X,par):
W1 = tf.convert_to_tensor(par["W1"])
W2 = tf.convert_to_tensor(par["W2"])
params = {"W1": W1,
"W2": W2}
x = tf.placeholder("float", [1,64,64,3])
z3 = forward_propagation_for_predict(x, params)
p = tf.argmax(z3)
sess = tf.Session()
prediction = sess.run(p, feed_dict = {x:X})
return prediction
我使用相同的函数“forward_propagation”来训练权重,但是当我传递单个图像时,它不起作用。
错误:
FailedPreconditionError Traceback(最近一次调用last)/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py在_do_call(self,fn,* args)1138尝试: - > 1139 return fn (* args)1140除了errors.OpError为e:
_run_fn中的/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py(session,feed_dict,fetch_list,target_list,options,run_metadata)1120 feed_dict,fetch_list,target_list, - > 1121 status ,run_metadata)1122
/opt/conda/lib/python3.6/contextlib.py在exit(self,type,value,traceback)88尝试:---> 89 next(self.gen)90除了StopIteration:
raise_exception_on_not_ok_status()中的/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py 465 compat.as_text(pywrap_tensorflow.TF_Message(status)), - > 466 pywrap_tensorflow.TF_GetCode(status ))467最后:
FailedPreconditionError:尝试使用未初始化的值fully_connected_1 / biases [[Node:fully_connected_1 / biases / read = IdentityT = DT_FLOAT,_class = [“loc:@ fully_connected_1 / biases”],_ device =“/ job:localhost / replica:0 / task :0 / CPU:0" ]]
在处理上述异常期间,发生了另一个异常:
FailedPreconditionError Traceback(最近一次调用last)in()----> 1 pred = mypredict(t,pp)2
在mypredict(X,par)49 50 sess = tf.Session()---> 51 prediction = sess.run(p,feed_dict = {x:X})52 53返回预测
运行中的/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py(self,fetches,feed_dict,options,run_metadata)787尝试:788 result = self._run(无,取出,feed_dict,options_ptr, - > 789 run_metadata_ptr)790 if run_metadata:791 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
如果final_fetches或final_targets:996结果= self,则_run中的/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py(self,handle,fetches,feed_dict,options,run_metadata)995。 _do_run(handle,final_targets,final_fetches, - > 997 feed_dict_string,options,run_metadata)998 else:999 results = []
/do/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self,handle,target_list,fetch_list,feed_dict,options,run_metadata)1130 if handle is None:1131 return self ._do_call(_run_fn,self__session,feed_dict,fetch_list, - > 1132 target_list,options,run_metadata)1133 else:1134 return self._do_call(_prun_fn,self._session,handle,feed_dict,
_do_call(self,fn,* args)1150中的/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py除了KeyError:1151 pass - > 1152 raise type(e)(node_def) ,op,message)1153 1154 def _extend_graph(self):
FailedPreconditionError:尝试使用未初始化的值fully_connected_1 / biases
您还必须从完全连接的图层加载参数。
但是,无论如何,我建议使用TensorFlow's Saver and Restore functions。
为了保存,这是一个玩具示例:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000) # saving model after 1000 steps
存储以下文件:
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
因此,对于恢复,您可以先重新创建网络,然后加载参数:
with tf.Session() as sess:
recreated_net = tf.train.import_meta_graph('my_test_model-1000.meta')
recreated_net.restore(sess, tf.train.latest_checkpoint('./'))