我发现当头数很大(>=16)时,Scaled_dot_product_attention 会消耗更多的内存。这是我重现问题的代码。
import torch
length = 10000
dim = 64
head_num1 = 8
head_num2 = 16
batch = 1
shapes = [[batch, head_num1, length, dim//head_num1], [batch, head_num2, length, dim//head_num2]]
for shape in shapes:
torch.cuda.reset_peak_memory_stats()
shape2 = [1,1, length, length]
q = torch.rand(shape, dtype = torch.float16).cuda()
k= torch.rand(shape, dtype = torch.float16).cuda()
v = torch.rand(shape, dtype = torch.float16).cuda()
attn_mask = torch.ones(shape2, dtype = torch.bool).cuda()
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0)
peak_memory = torch.cuda.max_memory_allocated()
print(f"head number {shape[1]} case peak memory: {peak_memory / 1e6:.2f} MB")
环境:
Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: Tesla V100-SXM2-16GB
CUDDN: 8902
输出:
head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 6716.14 MB
只需将头数增加一倍就需要 16 倍的内存...这正常吗?
我在另一台机器上尝试过
Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: NVIDIA GeForce GTX 1070
CUDDN: 8902
输出:
head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 406.49 MB
我认为这应该是正确的,因为更高的头数不需要更多的内存,但我在许多其他机器上未能实现这一点。希望有人知道如何实现这一点,非常感谢。
这是因为 dim = 64, head_num2 = 16, 64 // 16 = 4 并且 4 不能被 8 整除。
为了让head_num2 = 16,还需要将dim = 128设置为128 // 16 = 8。