torch 查找 2 个 2D 张量中匹配行的索引

问题描述 投票:0回答:3

我有两个不同长度的二维张量,两者都是同一原始二维张量的不同子集,我想找到所有匹配的“行”
例如

A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)

我只看到 numpy 解决方案,使用 dtype 作为字典,并且不适用于 pytorch。


这是我在 numpy 中的做法

arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)
pytorch tensor
3个回答
5
投票

这个答案是在OP用其他限制更新问题之前发布的,这些限制极大地改变了问题。

TL;DR 你可以这样做:

torch.where((A == B).all(dim=1))[0]

首先,假设您有:

import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])

我们可以检查

A == B
返回:

>>> A == B
tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True,  True]])

所以,我们想要的是:它们所在的行都是

True
。为此,我们可以使用
.all()
操作并指定感兴趣的维度,在我们的例子中
1
:

>>> (A == B).all(dim=1)
tensor([ True, False,  True])

您真正想知道的是

True
在哪里。为此,我们可以获得
torch.where()
函数的第一个输出:

>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])

2
投票

如果 A 和 B 是 2D 张量,以下代码将查找满足

A[indices] == B
的索引。如果多个索引满足此条件,则返回找到的第一个索引。如果 A 中并非存在 B 的所有元素,则忽略相应的索引。

values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])

0
投票

如果两个张量的行数不同,那么我们无法直接比较张量。我们必须首先在其中一个张量上添加一个虚拟维度。一步一步:

  1. 创建张量
A = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[3,3,3]])
B = torch.tensor([[1,2,3],[7,8,9],[4,4,4]])
  1. 添加虚拟维度并获得两两比较:
B == A.unsqueeze(1)

输出将是一个 4x3x3 张量,其中 4 个

i
子张量中的每一个都是
A[i] == B

  1. 获取表明哪些索引具有完美“行”匹配的张量:
(B == A.unsqueeze(1)).all(-1)

输出是一个 4x3 张量。具有

True
元素的行包含完美的行匹配。

  1. 获取完美匹配的行:
(B == A.unsqueeze(1)).all(-1).any(-1)
  1. 最后,获取
    A
    中与
    B
    中匹配的行的索引:
torch.where((B == A.unsqueeze(1)).all(-1).any(-1))[0]
>> tensor([0, 2])

要获取 B 中与 A 中匹配的行的索引,只需交换张量即可:

torch.where((A == B.unsqueeze(1)).all(-1).any(-1))[0]
>> tensor([0, 1])

这个问题有一个类似的 numpy 版本here,我的答案深受Ehsan答案的启发。

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.