如何在火炬中用2-d张量索引3-d张量?

问题描述 投票:0回答:1
import torch
a = torch.rand(5,256,120)
min_values, indices = torch.min(a,dim=0)
aa = torch.zeros(256,120)
for i in range(256):
    for j in range(120):
        aa[i,j] = a[indices[i,j],i,j]

print((aa==min_values).sum()==256*120)

我想知道如何避免使用for-for循环获取aa值? (我想使用索引来选择另一个3维张量中的元素,因此我不能直接使用由min返回的值)

import torch a = torch.rand(5,256,120)min_values,index = torch.min(a,dim = 0)aa = torch.zeros(256,120)for i in range(256):for j in range(120): aa [i,j] = a [indices [i,j],i,j] print(((aa = ...

pytorch torch
1个回答
1
投票

您可以使用torch.gather

© www.soinside.com 2019 - 2024. All rights reserved.