在 Theano、Lenet5 中 Pickle 多层 CNN

问题描述 投票:0回答:1

我正在使用 Theano 进行图像识别,我想使用经过训练的模型创建一个预测系统。

我参考了LeNet5 卷积神经网络(LeNet)并训练了自己的数据,现在我想使用训练好的模型来预测新图像。

使用逻辑回归对 MNIST 数字进行分类中描述了pickle训练模型的方法,但它只是一个逻辑回归,而不是多层CNN。以同样的方式我保存了每一层,但我不能用它来预测。

这是我的代码:

def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""

# load the saved model
#x = Data
x = T.matrix('x')
Data = x.reshape((1, 1, 32, 32))
layer0
layer1
layer2_input = layer1.output.flatten(2)
layer2
layer3

# compile a predictor function
predict_model = theano.function([layer0.input],
    layer0.output)
    #inputs=[layer0.input],
    #outputs=layer3.y_pred)

# We can test it on some examples from test test
#dataset='facedata_cross_6_2_6.pkl.gz'
#datasets = load_data(dataset)
#test_set_x, test_set_y = datasets[2]
#test_set_x = test_set_x.get_value()
#reshape=np.reshape(test_set_x[26],(28,28))
#plt.imshow(reshape)

predicted_values = predict_model(Data)
print("Predicted values for the first 10 examples in test set:")
print(predicted_values)
deep-learning theano conv-neural-network image-recognition
1个回答
2
投票

保存模型的方法有很多。我经常使用的方法是通过pickle每一层的权重和偏差(顺序由你决定):

f = file('Models/bestmodel.pickle','wb')
cPickle.dump(layer0.W.get_value(borrow=True),f,protocol=cPickle.HIGHEST_PROTOCOL)
cPickle.dump(layer1.W.get_value(borrow=True),f,protocol=cPickle.HIGHEST_PROTOCOL)
cPickle.dump(layer2.W.get_value(borrow=True),f,protocol=cPickle.HIGHEST_PROTOCOL)
...
cPickle.dump(layer0.b.get_value(borrow=True),f,protocol=cPickle.HIGHEST_PROTOCOL)            
cPickle.dump(layer1.b.get_value(borrow=True),f,protocol=cPickle.HIGHEST_PROTOCOL)
cPickle.dump(layer2.b.get_value(borrow=True),f,protocol=cPickle.HIGHEST_PROTOCOL)
...
f.close()

然后对于预测系统,创建相同的模型架构并使用保存的模型作为初始值(与保存的顺序相同):

f=file('Models/bestmodel.pickle','rb')
layer0.W.set_value(cPickle.load(f), borrow=True)
layer1.W.set_value(cPickle.load(f), borrow=True)
layer2.W.set_value(cPickle.load(f), borrow=True)
...
layer0.b.set_value(cPickle.load(f), borrow=True)
layer1.b.set_value(cPickle.load(f), borrow=True)
layer2.b.set_value(cPickle.load(f), borrow=True)
...
f.close()
© www.soinside.com 2019 - 2024. All rights reserved.