我正在使用图像制作一个关于狗品种预测的项目,我想将其部署在我使用 Flask 的网站上,我从 chatgpt 和多个 YouTube 教程中获得了帮助,因为这是我第一次使用 Flask,但是当我运行我的 app.py 文件时我 cmd 它给了我这个错误
# Load the pre-trained TensorFlow model
model = tf.keras.models.load_model("model path")
# Function to preprocess the image
def preprocess_image(image_path):
# Open the image using Pillow
img = Image.open(image_path)
# Resize image to match model's expected input shape
img = img.resize((224, 224))
# Convert image to numpy array
img_array = np.array(img)
# Normalize pixel values to range [0, 1]
img_array = img_array / 255.0
# Expand dimensions to match model's input shape (add batch dimension)
img_array = np.expand_dims(img_array, axis=0)
return img_array
# Route to handle image upload and prediction
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return 'No file part'
file = request.files['file']
# Check if the file is empty
if file.filename == '':
return 'No selected file'
# Check if the file is an image
if file and allowed_file(file.filename):
# Read the image file
img = Image.open(file)
# Preprocess the image
img_array = preprocess_image(img)
# Perform prediction using the loaded TensorFlow model
prediction = model.predict(img_array)
# Process prediction result (you may need to adjust this based on your model)
# For demonstration, let's assume the model outputs a class index
predicted_class_index = np.argmax(prediction)
# Convert class index to a human-readable prediction
class_labels = ['Class 0', 'Class 1', 'Class 2'] # Replace with your class labels
prediction_result = class_labels[predicted_class_index]
return f'The predicted class is: {prediction_result}'
return 'Invalid file format'
# Function to check if the file extension is allowed
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
请帮助我,我收到此错误,并且我无法理解该错误是在 Flask 代码中还是在我的 DL 模型代码中
给定错误的解决方案是这里。您可以通过传递custom_objects来解决问题:
import tensorflow_hub as hub
model = tf.keras.models.load_model(
"model path",
custom_objects={'KerasLayer':hub.KerasLayer}
)
您还需要检查
preprocess_image()
输入。看来你向它发送了一个文件(不是文件路径)。