与本地执行相比,PyTorch 模型在 Flask 服务器中的表现较差

问题描述 投票:0回答:1

我有一个 PyTorch 模型,当我在本地运行它时,它可以生成高质量和高速度的图像。然而,当我在 Flask 服务器中部署相同的代码时,生成的图像质量低得多,而且过程非常慢。

详情:

当地环境:

操作系统:Windows 10

Python版本:3.12

PyTorch版本:(2.3.1+cu121)

GPU:

GPU information

服务器环境:

部署:Flask

服务员:服务员

托管:本地

资源:GPU

代码片段(单独):

from diffusers import StableDiffusionPipeline
from torch import float16
pipeline = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=float16)
pipeline.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
pipeline.to('cuda')
pipeline.enable_attention_slicing()
pipeline('best quality, high quality, photorealistic, an astronaut riding a white horse in space ', num_inference_steps=20, negative_prompt='bad quality, low quality', num_images_per_prompt=1, height=800, width=1000).images[0].save('1.png')

结果(不到1分钟):

perfect result

代码片段(烧瓶):

from flask import Flask
from diffusers import StableDiffusionPipeline
from torch import float16
app = Flask(__name__)
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=float16) # stabilityai/sdxl-turbo
pipeline.to("cuda")
pipeline.enable_attention_slicing()
@app.route('/texttoimage/generate', methods=['POST'])
def ttig():
    global count
    if eval(request.cookies['season']) in (user := info[request.cookies['name'].encode()])[1]:
        user = user[4]
        pipeline.safety_checker = lambda images, **kwargs: (images, [False if request.args['NSFW'] == 'true' else True] * len(images))
        images = pipeline(request.args['prompt'], negative_prompt=request.args.get('negative'), num_inference_steps=int(request.args['quality']), num_images_per_prompt=int(request.args['count']), height=int(request.args['height']) // 8 * 8, width=int(request.args['width']) // 8 * 8).images
        for k,j in enumerate(images):
            user[f"{count + k}.{request.args['type']}"] = 'a'
            j.save(f"s/{count + k}.{request.args['type']}")
        count += len(images)
        return str(count)
if __name__ == '__main__':
    from waitress import serve
    serve(app, port=80)

结果(约10分钟):

[terrible result(https://i.sstatic.net/M6ghfzpB.jpg)

尝试过的步骤:

  • 确保模型仅加载一次。
  • 在 Flask 的调试和非调试模式下进行了测试。
  • 已验证服务器有足够的资源。

问题:

  1. Flask 服务器性能缓慢和图像质量较低的原因可能是什么?
  2. 使用 Flask 部署 PyTorch 模型是否有任何最佳实践以确保最佳性能?
  3. WSGI 服务器或服务器配置的选择会影响性能吗?

任何指导或建议将不胜感激!

python flask pytorch
1个回答
0
投票

这些是我的建议:

确保您的 Flask 服务器确实正在使用 GPU。有时,模型可能会无意中回退到 CPU,从而导致速度显着下降。 使用

torch.cuda.is_available()
验证 GPU 是否可用且正在使用。 使用
nvidia-smi
等工具检查是否有任何其他进程消耗 GPU 资源。

考虑使用性能更高的 WSGI 服务器,例如具有多个工作进程和线程的

gunicorn

根据您机器的资源调整worker

(-w)
和线程(
-k
gthread)的数量。

© www.soinside.com 2019 - 2024. All rights reserved.