长话短说,我有一个由 1 和 0 组成的二维矩阵,我需要为每一行检索设置为 1 的元素的索引。这样做的“标准”方法是 torch.nonzero,但众所周知,该函数是 1)真正的瓶颈,因为它事先不知道最终向量的大小,2)它不能应用于一次拍摄 2D 张量的每一行,因为不同的行可能有不同数量的张量。
最近引入了 at::nonzero_static,它通过为函数提供预期的最大非零元素数量来解决第一点(这对我的应用程序来说很好)。但是,它不具有“暗淡”参数,这意味着它不能单独应用于每行/列,在我看来,这是没有意义的,因为设置输出的大小可以保证每行具有相同数量的项目,从而使输出成为张量。
使用 for 循环显然可以解决我的问题,但这意味着多次调用该函数,这不是 GPU 高效的。有谁知道如何将 nonzero_static 有效地应用于每一行,并返回一个张量,其中每一行都是其应用于张量的每个切片的结果?据我了解,vmap可能是一个解决方案,但我不确定它是否针对GPU进行了优化。
我实施了一些解决方案。一些预备知识:
nonzero_static()
与cuda后端不兼容,这可能会限制您的用例vmap
不太可能工作,因为它“不提供一般的自动批处理或处理开箱即用的可变长度序列”。并创建一个batched_tensor输出。在 nonzero_static
上运行 vmap 会产生警告 UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::nonzero_static.
nonzero()
是最快或几乎与索引广播解决方案一样快。与相对笨重的解决方案相比,要分配的内存大小不明确通常并不是一个大瓶颈。重新评估 nonzero_static
是否针对使用 vmap
进行批量计算进行优化或为 nonzero_static
实现 CUDA 后端会很有趣,希望最终会发生,因为它是 pytorch 中相对较新的功能。import torch
import time
m = 2000
n = 1000
trials = 100
results = {}
for t in range(trials):
device = torch.device("cpu")
data = torch.rand([m,n],device = device).round().long()
# use nonzero
name = "nonzero"
t1 = time.time()
idx = data.nonzero()
midx = idx[:,0]
nidx = idx[:,1]
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use nonzero_static and leave in "listy" form
name = "nonzero_static"
t1 = time.time()
count_nonzero = int(data.sum().item())
d = data.view(-1)
idx = d.nonzero_static(size = count_nonzero)
midx,nidx = idx//n, idx%n
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use nonzero_static and put in matrix form, leave unsorted
name = "nonzero_static -> matrix"
t1 = time.time()
count_nonzero = int(data.sum().item())
d = data.view(-1)
idx = d.nonzero_static(size = count_nonzero)
midx,nidx = idx//n, idx%n
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use nonzero_static and put in matrix form, then sort
name = "nonzero_static -> sorted matrix"
t1 = time.time()
count_nonzero = int(data.sum().item())
d = data.view(-1)
idx = d.nonzero_static(size = count_nonzero)
midx,nidx = idx//n, idx%n
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# vmap nonzero_static
name = "vmap nonzero_static"
t1 = time.time()
test = torch.func.vmap(torch.nonzero_static)
output = test(data,size = n).squeeze(-1)
torch.cuda.synchronize()
torch.cuda.empty_cache()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
# use index broadcasting then sort
name = "index broadcasting"
t1 = time.time()
index_array = torch.arange(n).unsqueeze(0).expand(m,n)
output = data*index_array
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
device = torch.device("cuda:0")
data = data.to(device)
torch.cuda.synchronize()
# use index broadcasting then sort on GPU
name = "GPU index broadcasting"
t1 = time.time()
index_array = torch.arange(n,device = device).unsqueeze(0).expand(m,n)
output = data*index_array
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
del output
torch.cuda.empty_cache()
#use nonzero and leave in listy form
name = "GPU nonzero"
t1 = time.time()
idx = data.nonzero()
midx = idx[:,0]
nidx = idx[:,1]
output = torch.zeros([m,n],device = device,dtype = torch.long)
output[midx,nidx] = nidx
output = output.sort(dim = 1,descending = True)
torch.cuda.synchronize()
try:
results[name] += time.time()- t1
except:
results[name] = time.time() - t1
print("Results for [{},{}] over {} trials".format(m,n,trials))
for key in results:
print("{:.5f}s for {}".format(results[key]/trials,key))
Results for [200,100] over 100 trials
0.00051s for nonzero
0.00035s for nonzero_static
0.00037s for nonzero_static -> matrix
0.00062s for nonzero_static -> sorted matrix
0.00191s for vmap nonzero_static
0.00033s for index broadcasting
0.00015s for GPU index broadcasting
0.00019s for GPU nonzero
Results for [2000,1000] over 100 trials
0.00575s for nonzero
0.01028s for nonzero_static
0.01036s for nonzero_static -> matrix
0.01302s for nonzero_static -> sorted matrix
0.03645s for vmap nonzero_static
0.00466s for index broadcasting
0.00129s for GPU index broadcasting
0.00198s for GPU nonzero