Scaled_dot_product_attention 较高的 head num 会消耗更多内存

问题描述 投票:0回答:1

我发现当头数很大(>=16)时,Scaled_dot_product_attention 会消耗更多的内存。这是我重现问题的代码。


import torch
length = 10000
dim = 64
head_num1 = 8
head_num2 = 16
batch = 1
shapes = [[batch, head_num1, length, dim//head_num1], [batch, head_num2, length, dim//head_num2]]
for shape in shapes:
    torch.cuda.reset_peak_memory_stats()

    shape2 = [1,1, length, length]
    q = torch.rand(shape, dtype = torch.float16).cuda()
    k= torch.rand(shape, dtype = torch.float16).cuda()
    v = torch.rand(shape, dtype = torch.float16).cuda()

    attn_mask = torch.ones(shape2, dtype = torch.bool).cuda()
    x = torch.nn.functional.scaled_dot_product_attention(
                                q, k, v, attn_mask=attn_mask, dropout_p=0)

    peak_memory = torch.cuda.max_memory_allocated()      
    print(f"head number {shape[1]} case peak memory: {peak_memory / 1e6:.2f} MB")

环境:

Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: Tesla V100-SXM2-16GB
CUDDN: 8902

输出:

head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 6716.14 MB

只需将头数增加一倍就需要 16 倍的内存...这正常吗?

我在另一台机器上尝试过

Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: NVIDIA GeForce GTX 1070
CUDDN: 8902

输出:

head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 406.49 MB

我认为这应该是正确的,因为更高的头数不需要更多的内存,但我在许多其他机器上未能实现这一点。希望有人知道如何实现这一点,非常感谢。

pytorch torch transformer-model
1个回答
0
投票

这是因为 dim = 64, head_num2 = 16, 64 // 16 = 4 并且 4 不能被 8 整除。

为了让head_num2 = 16,还需要将dim = 128设置为128 // 16 = 8。

© www.soinside.com 2019 - 2024. All rights reserved.