我一直在尝试将一个非常简单的玩具 Keras 模型部署到 Cloud Functions,该模型可以预测图像的类别,但由于未知的原因,当执行到
predict
方法时,它会卡住,不会抛出异常任何错误,最终都会超时。
import functions_framework
import io
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from PIL import Image
model = load_model("gs://<my-bucket>/cifar10_model.keras")
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def preprocess_image(image_file):
img = Image.open(io.BytesIO(image_file.read()))
img = img.resize((32, 32))
img = np.array(img)
img = img / 255.0
img = img.reshape(1, 32, 32, 3)
return img
@functions_framework.http
def predict(request):
image = preprocess_image(request.files['image_file'])
print(image.shape) # this prints OK
prediction = model.predict(image)
print(prediction) # this never prints
predicted_class = class_names[np.argmax(prediction)]
return f"Predicted class: {predicted_class}"
本地调试工作正常,预测速度如预期一样快(模型权重文件为2MB)。我还一路添加了几个打印(从上面的代码片段中删除)并且执行工作正常,直到使用
predict
方法。
尽管最小的计算配置应该可以工作,但我尝试保留更多的内存和 CPU,但没有任何效果。该模型托管在存储中,我尝试先下载它,但这也不起作用。我还尝试在
tf.device('/cpu:0')
上下文中进行预测,传递 step=1
参数并首先将图像数组转换为 Keras 数据集,正如 ChatGPT 所建议的那样,得到了相同的结果。实际上,调用 predict
根本不会打印任何内容。打电话给 call
而不是 predict
我无处可去。
我错过了什么?
我建议下载模型到本地存储(
/tmp
)并在冷启动时加载一次。以下是代码概述:
import os
import tempfile # add these imports after [from PIL import Image]
model = None
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def get_model():
global model
if model is None:
local_model_path = os.path.join(tmpfile.gettmpdir(), "cifar10_model.keras") # download the model only once to /tmp directory
if not os.path.exists(local_model_path):
print("Downloading model...")
tf.io.gfile.copy("gs://<my-bucket>/cifar10_model.keras", local_model_path, overwrite=True)
print("Loading model...")
model = load_model(local_model_path)
return model
def preprocess_image(image_file):
img = Image.open(io.BytesIO(image_file.read()))
img = img.resize((32, 32))
img = np.array(img) / 255.0
img = img.reshape(1, 32, 32, 3)
return img
@functions_framework.http
def predict(request):
global model
model = get_model() # get or load the model
image = preprocess_image(request.files['image_file'])
print("Image shape:", image.shape)
prediction = model.predict(image)
print("Prediction:", prediction)
predicted_class = class_names[np.argmax(prediction)]
return f"Predicted class: {predicted_class}"
希望这对您有用!