我正在加载许多Keras模型,如下所示:
from keras import backend as K # Tensorflow backend
from MiscFunctions import *
def main():
for i in range(...):
K.clear_session() # Needed to speed up model loading
model = load_model(...)
model._make_predict_function()
main()
但是,我稍后在脚本中有一个函数调用,它接受模型输入并从该模型输出预测。
length = get_length(model, ...)
这里是get_length
的缩短代码
def get_length(model, ...):
...
# input_vector is the correct size
return model.predict(np.asarray(input_vector).reshape(1,1,len(input_vector)))
除了prediction
方法调用给我错误:
tensorflow.python.framework.errors_impl.NotFoundError: FetchOutputs node dense_1/Softmax:0: not found
Exception tensorflow.python.framework.errors_impl.InvalidArgumentError: InvalidArgumentError() in <bound method _Callable.__del__ of <tensorflow.python.client.session._Callable object at 0x7f619b8c7e10>> ignored
我怀疑K.clear_session()
线可能导致问题,但我需要清除会话以加快模型加载。我该如何解决这个问题?
要有效地加载模型,请将其设置为全局并将其加载到另一个函数中,这样您就不必一次又一次地加载它。使其成为全局之后,它将在main函数中可访问:
def load_model():
global model
json_file = open('model.json', 'r')
model_json = json_file.read()
model = model_from_json(model_json)
model.load_weights("model.h5")
model._make_predict_function()