我训练了一个resnet-34模型,并将其安排在flask服务器上进行在线图像分类。但是,当我打开服务器并加载模型时,大约需要1.5G内存,当我通过HTTP预测图像时,会跳转到大约3.0G内存请求,无论我预测更多图像,都保持这个水平。 更奇怪的是,我使用flask app来安排在线YOLOv5,也需要大约3.2G。我不明白为什么它们会占用大约内存,因为resnet-34的参数数量比YOLOv5少得多。 我的resnet-34 Flask服务器代码是否错误?如何减少内存占用?
这是我的代码:
# initialize the flask app
app = flask.Flask(__name__)
app.model = None
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
def load_model():
"""Load the pre-trained model, you can use your model just as easily.
"""
model = resnet34(num_classes=7).to(device)
weights_path = './resnet34.pth'
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
return model
def load_image(path):
assert os.path.exists(path), "file: '{}' dose not exist.".format(path)
img = Image.open(path).convert("RGB")
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
return img
@app.route("/resnet", methods=['get', 'post'])
def predict():
if app.model is None:
app.model = load_model()
path = unquote(request.args.get("path", ""))
if path is None:
return jsonify({'error': 'Path not provided'}), 400
image = load_image(path)
with torch.no_grad():
# predict class
output = torch.squeeze(app.model(image.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
torch.cuda.empty_cache()
del image
gc.collect()
return jsonify({
'label': str(predict_cla),
'predicted_class': class_indict[str(predict_cla)],
'probability': float(predict[predict_cla].numpy())
})
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
# start the classification service and wait for request
app.run(port='5012')
如果您能提供任何建议,我将不胜感激。
当您从 Torch GPU 转移到 numpy 时,您应该分离,否则会导致数据泄漏
predict_cla = torch.argmax(predict).detach().cpu().numpy()