仅保留张量的行最大值并将所有其他条目设置为零

问题描述 投票:0回答:2

给定一个这样的张量

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0534, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0944, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.2780, 1.5430, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.1799, 1.2002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]

我想通过仅保留每行的最大元素并将所有其他元素设置为 0 来对其进行转换。我试图使用

torch.argmax(tensor, dim=1)
但不确定这是否会有帮助。所以在这种情况下所需的输出是

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0534, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0944, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.5430, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.2002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]
pytorch
2个回答
2
投票

似乎有效:

a = torch.randn((8, 7))
max_a, ids = torch.max(a, 1, keepdim=True)
b = torch.zeros_like(a)     # result tensor
b.scatter_(1, ids, max_a)   # set max values on idx indices

0
投票

解决方案:

a * (a == a.amax(dim=1, keepdim=True))

说明:

  1. a.amax(dim=1, keepdim=True)
    将为您提供每行的最大值,其形状为
    (N, 1)
  2. 接下来,我们检查
    a
    中的每个值是否与最大值匹配,如
    (N, N) == (N, 1)
    。这将为您提供一个二进制掩码,其中
    True
    位于
    a
    的最大值位置。如果您的蒙版还包含
    True
    表示零
    (0==0)
    ,请不要担心,因为它们无论如何都会被忽略。
  3. *
    将乘以
    a
    和来自
    2.
    的掩码,并仅保留
    dim=*
    上的最大元素。

评论:

amax
max
类似,但支持广播。如果您想在一批图像强度中找到最大值,这很有用
(B, H, W)

a * (a == a.amax(dim=(1, 2), keepdim=True))

© www.soinside.com 2019 - 2024. All rights reserved.