在 Google Cloud Functions 中部署 Keras 模型进行预测

问题描述 投票:0回答:1

我一直在尝试将一个非常简单的玩具 Keras 模型部署到 Cloud Functions,该模型可以预测图像的类别,但由于未知的原因,当执行到

predict
方法时,它会卡住,不会抛出异常任何错误,最终都会超时。

import functions_framework
import io
import numpy as np
import tensorflow as tf

from tensorflow.keras.models import load_model
from PIL import Image

model = load_model("gs://<my-bucket>/cifar10_model.keras")

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

def preprocess_image(image_file):
    img = Image.open(io.BytesIO(image_file.read()))
    img = img.resize((32, 32))
    img = np.array(img)
    img = img / 255.0
    img = img.reshape(1, 32, 32, 3)
    return img

@functions_framework.http
def predict(request):
    image = preprocess_image(request.files['image_file'])
    print(image.shape) # this prints OK
    prediction = model.predict(image)
    print(prediction) # this never prints
    predicted_class = class_names[np.argmax(prediction)]
    return f"Predicted class: {predicted_class}"

本地调试工作正常,预测速度如预期一样快(模型权重文件为2MB)。我还一路添加了几个打印(从上面的代码片段中删除)并且执行工作正常,直到使用

predict
方法。

尽管最小的计算配置应该可以工作,但我尝试保留更多的内存和 CPU,但没有任何效果。该模型托管在存储中,我尝试先下载它,但这也不起作用。我还尝试在

tf.device('/cpu:0')
上下文中进行预测,传递
step=1
参数并首先将图像数组转换为 Keras 数据集,正如 ChatGPT 所建议的那样,得到了相同的结果。实际上,调用
predict
根本不会打印任何内容。打电话给
call
而不是
predict
我无处可去。

我错过了什么?

python machine-learning keras google-cloud-functions predict
1个回答
0
投票

我建议下载模型到本地存储(

/tmp
)并在冷启动时加载一次。以下是代码概述:

import os
import tempfile # add these imports after [from PIL import Image]

model = None
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

def get_model():
    global model
    if model is None: 
        local_model_path = os.path.join(tmpfile.gettmpdir(), "cifar10_model.keras") # download the model only once to /tmp directory
        if not os.path.exists(local_model_path):
            print("Downloading model...")
            tf.io.gfile.copy("gs://<my-bucket>/cifar10_model.keras", local_model_path, overwrite=True)
        print("Loading model...")
        model = load_model(local_model_path)
    return model

def preprocess_image(image_file):
    img = Image.open(io.BytesIO(image_file.read()))
    img = img.resize((32, 32))
    img = np.array(img) / 255.0
    img = img.reshape(1, 32, 32, 3)
    return img

@functions_framework.http
def predict(request):
    global model
    model = get_model() # get or load the model
    image = preprocess_image(request.files['image_file'])
    print("Image shape:", image.shape)
    prediction = model.predict(image)
    print("Prediction:", prediction) 
    predicted_class = class_names[np.argmax(prediction)]
    return f"Predicted class: {predicted_class}"

希望这对您有用!

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.