我正在使用 Python 构建后端,以使用此 git 存储库中的 RealESRGAN 模型来升级图像:https://github.com/xinntao/Real-ESRGAN。
这是一个简化的
app.py
文件,我在其中运行测试:
import os
from flask import Flask, request, jsonify
from flask_cors import CORS
from realesrgan import RealESRGANer
from PIL import Image
import numpy as np
import io
app = Flask(__name__)
CORS(app)
# Initialize the RealESRGAN model
model_path = "./RealESRGAN_x4plus_anime_6B.pth" # Path to the model that can be dowloaded here: https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
model = RealESRGANer(scale=4, model_path=model_path) # Initialize the model with the specified scale
@app.route('/upscale', methods=['POST'])
def upscale_image():
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files['image']
try:
# Read the image file
img = Image.open(file.stream).convert('RGB')
# Convert image to numpy array and upscale
img_array = np.array(img)
upscaled_image = model.predict(img_array)
# Convert upscaled image back to PIL format
upscaled_image = Image.fromarray(upscaled_image)
# Save the upscaled image to a byte stream
img_byte_arr = io.BytesIO()
upscaled_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return jsonify({"message": "Image upscaled successfully", "upscaled_image": img_byte_arr.getvalue().decode('latin1')})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(debug=True)
但是,我在尝试运行该应用程序时遇到以下错误:
model.load_state_dict(loadnet[keyname], strict=True)
^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'load_state_dict'
关于如何解决这个问题有什么建议吗?
以下是我在调试时尝试做的一些事情:
加载模型:我尝试使用以下代码“手动”加载模型:
# Detect the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize RealESRGANer
self.model = RealESRGANer(
scale=4, # Upscale factor
model_path=model_path, # Model path for pretrained weights
tile=0, # No tiling
tile_pad=10, # Padding if tiling is used
pre_pad=0, # Pre-padding
half=False, # Avoid half precision issues
device=device # Use detected device (CUDA or CPU)
)
# Load the weights into the model architecture
state_dict = torch.load(model_path, map_location=device)
# Ensure 'params_ema' is present in the state_dict
if "params_ema" in state_dict:
self.model.model.load_state_dict(state_dict["params_ema"], strict=True)
else:
raise ValueError("The state_dict does not contain 'params_ema'. Please check the model file.")
# Set the model to evaluation mode (important for inference)
self.model.model.eval()
文件路径:我还添加了以下代码以确保我的模型路径正确:
print(os.path.exists("./RealESRGAN_x4plus_anime_6B.pth"))
替代模型:我还测试了可以在这里找到的替代模型:https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
问题是您尝试启动一个辅助方法,该方法需要“真实”模型作为
model
参数。该模型参数的默认值为 None
,它解释了错误消息。你可以在RealESRGANer
的init中看到它。您还可以在下面的行中查看实际模型是如何加载到推理脚本中的。 如果您要使用的权重是固定的,您可以复制模型加载线:
model_path = "./RealESRGAN_x4plus_anime_6B.pth"
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) # for `RealESRGAN_x4plus_anime_6B`
model = RealESRGANer(scale=4, model=model, model_path=model_path)
请注意,在推理脚本中,如果使用 --face_enhance
参数,则会加载另一个模型(code),并使用
RealESRGANer
模型作为参数(此处称为
upsampler
)。另请注意,我无法测试我的解决方案,它只是基于挖掘源代码。