我遇到了与this相同的问题,但是在该问题下面,我没有找到想要的答案。我想在pytorch中对矩阵的每一行中的值进行重复数据删除。给定一个矩阵,例如
torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
[1, 6, 3, 5, 3, 5, 4]])
to
torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 0, 0, 4]])
或
torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 4, 0, 0]])
我知道torch.unique()无法实现这一点,所以我想知道如何在没有循环的情况下实现此功能。
x = torch.tensor([
[1, 2, 3, 4, 3, 3, 4],
[1, 6, 3, 5, 3, 5, 4]
], dtype=torch.long)
y, indices = x.sort(dim=-1)
indices = indices.sort(dim=-1)[1]
y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()
result = torch.gather(y, 1, indices)
print(result)
输出
tensor([[1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 0, 0, 4]])