给定两个张量
t1
和 t2
:
t1=torch.tensor([[1,2],[3,4],[5,6]])
t2=torch.tensor([[1,2],[5,6]])
如果
t1
的行元素存在于t2
中,则返回True
,否则返回False
。理想的结果是
[Ture, False, True]
。
我尝试了torch.isin(t1, t2)
,但它按元素而不是按行返回结果。顺便说一句,如果它们是numpy数组,可以通过完成
np.in1d(t1.view('i,i').reshape(-1), t2.view('i,i').reshape(-1))
我想知道如何在张量中得到类似的结果?
def rowwise_in(a,b):
"""
a - tensor of size a0,c
b - tensor of size b0,c
returns - tensor of size a1 with 1 for each row of a in b, 0 otherwise
"""
# dimensions
a0 = a.shape[0]
b0 = b.shape[0]
c = a.shape[1]
assert c == b.shape[1] , "Tensors must have same number of columns"
a_expand = a.unsqueeze(1).expand(a0,b0,c)
b_expand = b.unsqueeze(0).expand(a0,b0,c)
# element-wise equality
equal = a_expand == b_expand
# sum along dim 2 (all elements along this dimension must be true for the summed dimension to be True)
row_equal = torch.prod(equal,dim = 2)
row_in_b = torch.max(row_equal, dim = 1)[0]
return row_in_b
除了DerekG的伟大解决方案之外,这个小改变似乎更加快速和稳健
a,b = torch.tensor([[1,2,3],[3,4,5],[5,6,7]],device=torch.device(0)), torch.tensor([[1,2,3],[5,6,7]],device=torch.device(0))
# dimensions
shape1 = a.shape[0]
shape2 = b.shape[0]
c = a.shape[1]
assert c == b.shape[1] , "Tensors must have same number of columns"
a_expand = a.unsqueeze(1).expand(-1,shape2,c)
b_expand = b.unsqueeze(0).expand(shape1,-1,c)
# element-wise equality
mask = (a_expand == b_expand).all(-1).any(-1)
我尝试过对于 10 000 行的张量,它的工作速度非常快并且不会浪费内存
这是使用广播的较短版本:
t1=torch.tensor([[1,2],[3,4],[5,6]])
t2=torch.tensor([[1,2],[5,6]])
torch.any(torch.all(t1[:, None] == t2, axis=2), axis=1)
Out[1]: tensor([ True, False, True])
说明:
t1[:, None]
(形状:3,1,2
):添加新维度。t1[:, None] == t2
(形状:len(t1), len(t2), 2
):沿着扩展的 t2
广播 t1
以创建与所有元素比较相对应的 3-D 布尔数组。torch.all(t1[:, None] == t2, axis=2)
(形状:len(t1), len(t2)
):检查行中的所有元素是否匹配。torch.any(torch.all(t1[:, None] == t2, axis=2), axis=1)
(形状:len(t1)
):检查t1
中的任何行是否与t2