我有 2 个火炬张量
a : Tensor, dim (K,)
reverse_indices: Tensor, dim (L,)
其中
reverse_indices
中的所有值都是唯一且已排序的(但 a
的所有值不一定都在 reverse_indices
中)
我正在寻找一种有效的方法来获取张量
indices: Longtensor, dim (M,)
这样
reverse_indices[indices] = a[torch.isin(a, reverse_indices)]
(注意
reverse_indices
不是torch.unique
的输出,否则只要将正确的结果传递给torch.unique
就很容易得到结果)。
好吧,现在我就这么做了,但如果有人有办法在不创建
K×L
矩阵的情况下做到这一点,我会接受。
def reverse_index(a, reverse_indices):
"""
assuming all elements in a are also in reverse_indices
Args:
a : Tensor, dim (K,)
reverse_indices: Tensor, dim (L,)
Returns:
indices: so that ``reverse_indices[indices] = a``
"""
a = a[torch.isin(a, reverse_indices)]
correspondance = (a[:, None] == reverse_indices[None, :]).astype(int)
indices = correspondance.argmax(dim=1) #works cause all of them are 0 and one is 1
return indices