同时从 pytorch 张量中删除行和列的最快方法

问题描述 投票:0回答:1

我正在寻找一种快速方法来同时从 pytorch 张量中删除行和列。令

t
为形状为
[l,l
] 的方形 2D 张量。这个问题与这个问题类似,但是这个问题涉及沿单个维度的删除。我想要的操作可以通过以下方式实现:

 # index to delete
 idx = 499

 keep = [_ for _ in range(l)]
 keep.remove(idx)
 t = t[keep,:][:,keep]

 # t now has dimension [l-1,l-1]

但这非常慢(列表索引不是张量视图操作,索引分两部分完成)。 谁能推荐一个更快的方法吗?

以下是我尝试过的一些其他方法。

# Double index with a list
    import torch
    import time

    l = 8000
    iterations = 500
    idx = 599

    t = torch.rand([l,l])

    # test 1 - double list index
    keep = [_ for _ in range(l)]
    keep.remove(idx)

    total = 0
    for i in range(iterations):
    
        start = time.time()
        t2 = t[keep,:][:,keep]
        torch.cuda.synchronize()
        elapsed = time.time() - start
        total += elapsed
    print("Took {:.1f}s for {} iterations, {}s/it".format(total,iterations,total/iterations))

500次迭代花费了35.0秒,0.07009070539474488s/it

# Double index with a tensor
(marginally faster in some cases than the above but both are non-contiguous memory operations)

    # test 2 - double tensor index
    keep = torch.tensor(keep)

    total = 0
    for i in range(iterations):
    
        start = time.time()
        t2 = t[keep,:][:,keep]
        torch.cuda.synchronize()
        elapsed = time.time() - start
        total += elapsed
    
    print("Took {:.1f}s for {} iterations, {}s/it".format(total,iterations,total/iterations))

500次迭代耗时34.6秒,0.06911029624938965s/it

# Concatenation
    # test 3 - concat_based approach

    total = 0
    for i in range(iterations):
        start = time.time()
    
        t2 = torch.cat([torch.cat([t[:idx,:idx],t[idx+1:,:idx]],dim = 0),torch.cat([t[:idx,idx+1:],t[idx+1:,idx+1:]],dim = 0)],dim = 1)
    
        torch.cuda.synchronize()
        elapsed = time.time() - start
        total += elapsed
    
    print("Took {:.1f}s for {} iterations, {}s/it".format(total,iterations,total/iterations))

500次迭代耗时31.1秒,0.06218040370941162s/it

# Shift and delete the last row/column 
(which is a traditional indexing op that's fast but requires cloning `t`). This and the above approach are the fastest but don't scale nicely with having more than one row/column to delete (requiring either to do the deletions one at a time or else nest a conditional block for each possible variable number of deletions).

    # test 4 - shift and remove end approach
    total = 0

    for i in range(iterations):
        t2 = torch.clone(t)

        start = time.time()    
        t2[idx:-1,:] = t[idx+1:,:]
        t2[:,idx:-1] = t[:,idx+1:]
        t2 = t2[:-1,:-1]
    
        torch.cuda.synchronize()
        elapsed = time.time() - start
        total += elapsed
    
    print("Took {:.1f}s for {} iterations, {}s/it".format(total,iterations,total/iterations))

500次迭代耗时26.5秒,0.052913659572601315s/it

或者,如果我从总时间中排除克隆张量的时间:

500次迭代耗时11.7秒,0.023439947128295897s/it

# Flatten into a 1D array for single indexing
    total = 0 
    for i in range(iterations):
    
        
        start = time.time()
        o = torch.ones(t.shape,dtype = int)
        o[:,idx] = 0
        o[idx,:] = 0
        o = o.view(-1).nonzero().squeeze(1)
        t2 = t.view(-1)[o].view(l-1,l-1)
        elapsed = time.time() - start
        total += elapsed

    print("Took {:.1f}s for {} iterations, {}s/it".format(total,iterations,total/iterations))

500次迭代花费了87.1秒,0.11186568689346313s/it

# 2D boolean mask indexing
(under the hood this seems to be roughly the same as the above and requires a reshaping `view` operation at the end as well (and the time taken is about the same)

    total = 0 
    for i in range(iterations):
        o = torch.ones(t.shape,dtype = bool)

        start = time.time()
        o[:,idx] = 0
        o[idx,:] = 0
        t2 = t[o].view(l-1,l-1)
        elapsed = time.time() - start
        total += elapsed

    print("Took {:.1f}s for {} iterations, {}s/it".format(total,iterations,total/iterations))

500次迭代花费了86.1秒,0.17228954696655274s/it

pytorch
1个回答
0
投票

你可以尝试以下方法吗?它与选项 4 类似,但速度应该稍快一些,因为复制次数较少,并且在开始时分配了确切的内存量。 顺序(左->右上->下)也应该是最适合缓存的。

我认为使用 Pytorch 的高级 API 不会比现有的速度更快。进一步的优化可能需要修改复制内核本身以改善读/写操作的顺序

H,W = t.shape
# Empty initialization, just a memory alloc
t2 = torch.empty(H-1, W-1)
# top-left block
t2[:idx, :idx] = t[:idx; :idx]
# top-right block
t2[:idx, idx:] = t[:idx, idx+1:]
# bottom-left block
t2[idx:, :idx] = t[idx+1:, :idx]
# bottom-right block
t2[idx:, idx:] = t[idx+1:, idx+1:]

我同意,如果您有很多索引需要删除,它的扩展性就不好。但是在这种情况下,我认为您可以保留此块分解结构,并在此代码周围有两个简单的 for 循环来循环索引并按正确的顺序复制块

© www.soinside.com 2019 - 2024. All rights reserved.