Pytorch:如何(有效)将没有“dim”参数的函数应用于 2D 张量的每一行?

问题描述 投票:0回答:1

长话短说,我有一个由 1 和 0 组成的二维矩阵,我需要为每一行检索设置为 1 的元素的索引。这样做的“标准”方法是 torch.nonzero,但众所周知,该函数是 1)真正的瓶颈,因为它事先不知道最终向量的大小,2)它不能应用于一次拍摄 2D 张量的每一行,因为不同的行可能有不同数量的张量。

最近引入了 at::nonzero_static,它通过为函数提供预期的最大非零元素数量来解决第一点(这对我的应用程序来说很好)。但是,它不具有“暗淡”参数,这意味着它不能单独应用于每行/列,在我看来,这是没有意义的,因为设置输出的大小可以保证每行具有相同数量的项目,从而使输出成为张量。

使用 for 循环显然可以解决我的问题,但这意味着多次调用该函数,这不是 GPU 高效的。有谁知道如何将 nonzero_static 有效地应用于每一行,并返回一个张量,其中每一行都是其应用于张量的每个切片的结果?据我了解,vmap可能是一个解决方案,但我不确定它是否针对GPU进行了优化。

python pytorch tensor
1个回答
0
投票

我实施了一些解决方案。一些预备知识:

    不幸的是,
  • nonzero_static()
    与cuda后端不兼容,这可能会限制您的用例
  • vmap
    不太可能工作,因为它“不提供一般的自动批处理或处理开箱即用的可变长度序列”。并创建一个batched_tensor输出。在
    nonzero_static
    上运行 vmap 会产生警告
    UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::nonzero_static. 
  • 一般来说,将结果保留为类似列表的形式(即分别具有行索引和列索引的两个一维张量)比将这些索引放入数据原始形状的张量中并排序以便有用的索引更快首先在此基础上添加一些额外的时间。
  • 从我非常粗略的实验中得出的结论是,对于最合理的张量大小,vanilla
    nonzero()
    是最快或几乎与索引广播解决方案一样快。与相对笨重的解决方案相比,要分配的内存大小不明确通常并不是一个大瓶颈。重新评估
    nonzero_static
    是否针对使用
    vmap
    进行批量计算进行优化或为
    nonzero_static
    实现 CUDA 后端会很有趣,希望最终会发生,因为它是 pytorch 中相对较新的功能。
import torch
import time
m = 2000
n = 1000
trials = 100

results = {}
for t in range(trials):
    
    device = torch.device("cpu")
    data = torch.rand([m,n],device = device).round().long()
    
    # use nonzero 
    name = "nonzero"
    t1 = time.time()
    idx = data.nonzero()
    midx = idx[:,0]
    nidx = idx[:,1]
    output = torch.zeros([m,n],device = device,dtype = torch.long)
    output[midx,nidx] = nidx 
    output = output.sort(dim = 1,descending = True)
    torch.cuda.synchronize()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
    
    
    # use nonzero_static and leave in "listy" form
    name = "nonzero_static"
    t1 = time.time()
    count_nonzero = int(data.sum().item())
    d = data.view(-1)
    idx = d.nonzero_static(size = count_nonzero)
    midx,nidx = idx//n, idx%n
    torch.cuda.synchronize()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
    
    # use nonzero_static and put in matrix form, leave unsorted
    name = "nonzero_static -> matrix"
    t1 = time.time()
    count_nonzero = int(data.sum().item())
    d = data.view(-1)
    idx = d.nonzero_static(size = count_nonzero)
    midx,nidx = idx//n, idx%n
    output = torch.zeros([m,n],device = device,dtype = torch.long)
    output[midx,nidx] = nidx 
    torch.cuda.synchronize()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
    
    
    
    # use nonzero_static and put in matrix form, then sort
    name = "nonzero_static -> sorted matrix"
    t1 = time.time()
    count_nonzero = int(data.sum().item())
    d = data.view(-1)
    idx = d.nonzero_static(size = count_nonzero)
    midx,nidx = idx//n, idx%n
    output = torch.zeros([m,n],device = device,dtype = torch.long)
    output[midx,nidx] = nidx 
    output = output.sort(dim = 1,descending = True)
    torch.cuda.synchronize()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
    
    
    # vmap nonzero_static
    name = "vmap nonzero_static"
    t1 = time.time()
    test = torch.func.vmap(torch.nonzero_static)
    output = test(data,size = n).squeeze(-1)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
    
    # use index broadcasting then sort
    name = "index broadcasting"
    t1 = time.time()
    index_array = torch.arange(n).unsqueeze(0).expand(m,n)
    output = data*index_array
    output = output.sort(dim = 1,descending = True)
    torch.cuda.synchronize()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
    
    
    
    device = torch.device("cuda:0")
    data = data.to(device)
    torch.cuda.synchronize()
    
    # use index broadcasting then sort on GPU
    name = "GPU index broadcasting"
    t1 = time.time()
    index_array = torch.arange(n,device = device).unsqueeze(0).expand(m,n)
    output = data*index_array
    output = output.sort(dim = 1,descending = True)
    torch.cuda.synchronize()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
    
    del output
    torch.cuda.empty_cache()
    
    #use nonzero and leave in listy form
    name = "GPU nonzero"
    t1 = time.time()
    idx = data.nonzero()
    midx = idx[:,0]
    nidx = idx[:,1]
    output = torch.zeros([m,n],device = device,dtype = torch.long)
    output[midx,nidx] = nidx 
    output = output.sort(dim = 1,descending = True)
    
    torch.cuda.synchronize()
    try:
        results[name] += time.time()- t1
    except:
        results[name] = time.time() - t1
        
print("Results for [{},{}] over {} trials".format(m,n,trials))
for key in results:
    print("{:.5f}s for {}".format(results[key]/trials,key))
   

Results for [200,100] over 100 trials
0.00051s for nonzero
0.00035s for nonzero_static
0.00037s for nonzero_static -> matrix
0.00062s for nonzero_static -> sorted matrix
0.00191s for vmap nonzero_static
0.00033s for index broadcasting
0.00015s for GPU index broadcasting
0.00019s for GPU nonzero
Results for [2000,1000] over 100 trials
0.00575s for nonzero
0.01028s for nonzero_static
0.01036s for nonzero_static -> matrix
0.01302s for nonzero_static -> sorted matrix
0.03645s for vmap nonzero_static
0.00466s for index broadcasting
0.00129s for GPU index broadcasting
0.00198s for GPU nonzero

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.