Pytorch:测试第一个二维张量的每一行是否也存在于第二个张量中?

问题描述 投票:0回答:3

给定两个张量

t1
t2

t1=torch.tensor([[1,2],[3,4],[5,6]])
t2=torch.tensor([[1,2],[5,6]])

如果

t1
的行元素存在于
t2
中,则返回
True
,否则返回
False
。理想的结果是
[Ture, False, True]
。 我尝试了
torch.isin(t1, t2)
,但它按元素而不是按行返回结果。顺便说一句,如果它们是numpy数组,可以通过

完成
np.in1d(t1.view('i,i').reshape(-1), t2.view('i,i').reshape(-1))

我想知道如何在张量中得到类似的结果?

arrays numpy pytorch tensor isin
3个回答
1
投票
def rowwise_in(a,b):
  """ 
  a - tensor of size a0,c
  b - tensor of size b0,c
  returns - tensor of size a1 with 1 for each row of a in b, 0 otherwise
  """
  
  # dimensions
  a0 = a.shape[0]
  b0 = b.shape[0]
  c  = a.shape[1]
  assert c == b.shape[1] , "Tensors must have same number of columns"

  a_expand = a.unsqueeze(1).expand(a0,b0,c)
  b_expand = b.unsqueeze(0).expand(a0,b0,c)

  # element-wise equality
  equal = a_expand == b_expand

  # sum along dim 2 (all elements along this dimension must be true for the summed dimension to be True)
  row_equal = torch.prod(equal,dim = 2)

  row_in_b = torch.max(row_equal, dim = 1)[0]
  return row_in_b

0
投票

除了DerekG的伟大解决方案之外,这个小改变似乎更加快速和稳健

a,b = torch.tensor([[1,2,3],[3,4,5],[5,6,7]],device=torch.device(0)), torch.tensor([[1,2,3],[5,6,7]],device=torch.device(0))
# dimensions
shape1 = a.shape[0]
shape2 = b.shape[0]
c  = a.shape[1]
assert c == b.shape[1] , "Tensors must have same number of columns"

a_expand = a.unsqueeze(1).expand(-1,shape2,c)
b_expand = b.unsqueeze(0).expand(shape1,-1,c)
# element-wise equality
mask = (a_expand == b_expand).all(-1).any(-1)

我尝试过对于 10 000 行的张量,它的工作速度非常快并且不会浪费内存


0
投票

这是使用广播的较短版本:

t1=torch.tensor([[1,2],[3,4],[5,6]])
t2=torch.tensor([[1,2],[5,6]])

torch.any(torch.all(t1[:, None] == t2, axis=2), axis=1)

Out[1]: tensor([ True, False,  True])

说明:

  • t1[:, None]
    (形状:
    3,1,2
    ):添加新维度。
  • t1[:, None] == t2
    (形状:
    len(t1), len(t2), 2
    ):沿着扩展的
    t2
    广播
    t1
    以创建与所有元素比较相对应的 3-D 布尔数组。
  • torch.all(t1[:, None] == t2, axis=2)
    (形状:
    len(t1), len(t2)
    ):检查行中的所有元素是否匹配。
  • torch.any(torch.all(t1[:, None] == t2, axis=2), axis=1)
    (形状:
    len(t1)
    ):检查
    t1
    中的任何行是否与
    t2
  • 中的匹配
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.