如何在pytorch中对矩阵的每一行中的值进行重复数据删除?

问题描述 投票:2回答:1

我遇到了与this相同的问题,但是在该问题下面,我没有找到想要的答案。我想在pytorch中对矩阵的每一行中的值进行重复数据删除。给定一个矩阵,例如

torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
          [1, 6, 3, 5, 3, 5, 4]])

to

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
          [1, 6, 3, 5, 0, 0, 4]])

torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
          [1, 6, 3, 5, 4, 0, 0]])

我知道torch.unique()无法实现这一点,所以我想知道如何在没有循环的情况下实现此功能。

python numpy pytorch torch
1个回答
0
投票
x = torch.tensor([
    [1, 2, 3, 4, 3, 3, 4],
    [1, 6, 3, 5, 3, 5, 4]
], dtype=torch.long)

y, indices = x.sort(dim=-1)
indices = indices.sort(dim=-1)[1]
y[:, 1:] *= ((y[:, 1:] - y[:, :-1]) !=0).long()
result = torch.gather(y, 1, indices)
print(result)

输出

tensor([[1, 2, 3, 4, 0, 0, 0],
        [1, 6, 3, 5, 0, 0, 4]])
© www.soinside.com 2019 - 2024. All rights reserved.