我有一个 PyTorch 模型,当我在本地运行它时,它可以生成高质量和高速度的图像。然而,当我在 Flask 服务器中部署相同的代码时,生成的图像质量低得多,而且过程非常慢。
详情:
操作系统:Windows 10
Python版本:3.12
PyTorch版本:(2.3.1+cu121)
GPU:
部署:Flask
服务员:服务员
托管:本地
资源:GPU
代码片段(单独):
from diffusers import StableDiffusionPipeline
from torch import float16
pipeline = StableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', torch_dtype=float16)
pipeline.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
pipeline.to('cuda')
pipeline.enable_attention_slicing()
pipeline('best quality, high quality, photorealistic, an astronaut riding a white horse in space ', num_inference_steps=20, negative_prompt='bad quality, low quality', num_images_per_prompt=1, height=800, width=1000).images[0].save('1.png')
结果(不到1分钟):
代码片段(烧瓶):
from flask import Flask
from diffusers import StableDiffusionPipeline
from torch import float16
app = Flask(__name__)
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=float16) # stabilityai/sdxl-turbo
pipeline.to("cuda")
pipeline.enable_attention_slicing()
@app.route('/texttoimage/generate', methods=['POST'])
def ttig():
global count
if eval(request.cookies['season']) in (user := info[request.cookies['name'].encode()])[1]:
user = user[4]
pipeline.safety_checker = lambda images, **kwargs: (images, [False if request.args['NSFW'] == 'true' else True] * len(images))
images = pipeline(request.args['prompt'], negative_prompt=request.args.get('negative'), num_inference_steps=int(request.args['quality']), num_images_per_prompt=int(request.args['count']), height=int(request.args['height']) // 8 * 8, width=int(request.args['width']) // 8 * 8).images
for k,j in enumerate(images):
user[f"{count + k}.{request.args['type']}"] = 'a'
j.save(f"s/{count + k}.{request.args['type']}")
count += len(images)
return str(count)
if __name__ == '__main__':
from waitress import serve
serve(app, port=80)
结果(约10分钟):
[
尝试过的步骤:
问题:
任何指导或建议将不胜感激!
这些是我的建议:
确保您的 Flask 服务器确实正在使用 GPU。有时,模型可能会无意中回退到 CPU,从而导致速度显着下降。 使用
torch.cuda.is_available()
验证 GPU 是否可用且正在使用。
使用nvidia-smi
等工具检查是否有任何其他进程消耗 GPU 资源。
考虑使用性能更高的 WSGI 服务器,例如具有多个工作进程和线程的
gunicorn
。
根据您机器的资源调整worker
(-w)
和线程(-k
gthread)的数量。