pytorch:反向索引

问题描述 投票:0回答:1

我有 2 个火炬张量

a : Tensor, dim (K,)
reverse_indices: Tensor, dim (L,)

其中

reverse_indices
中的所有值都是唯一且已排序的(但
a
的所有值不一定都在
reverse_indices
中)

我正在寻找一种有效的方法来获取张量

indices: Longtensor, dim (M,)

这样

reverse_indices[indices] = a[torch.isin(a, reverse_indices)]

(注意

reverse_indices
不是
torch.unique
的输出,否则只要将正确的结果传递给
torch.unique
就很容易得到结果)。

arrays indexing torch
1个回答
0
投票

好吧,现在我就这么做了,但如果有人有办法在不创建

K×L
矩阵的情况下做到这一点,我会接受。

def reverse_index(a, reverse_indices):
        """
        assuming all elements in a are also in reverse_indices
        Args:
            a : Tensor, dim (K,)
            reverse_indices: Tensor, dim (L,)
        Returns:
            indices: so that ``reverse_indices[indices] = a``
        """
        a = a[torch.isin(a, reverse_indices)]
        correspondance = (a[:, None] == reverse_indices[None, :]).astype(int)
        indices = correspondance.argmax(dim=1) #works cause all of them are 0 and one is 1
        return indices
© www.soinside.com 2019 - 2024. All rights reserved.