给定一个这样的张量
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.0534, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.0944, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[1.2780, 1.5430, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[1.1799, 1.2002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]
我想通过仅保留每行的最大元素并将所有其他元素设置为 0 来对其进行转换。我试图使用
torch.argmax(tensor, dim=1)
但不确定这是否会有帮助。所以在这种情况下所需的输出是
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.0534, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.0944, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.5430, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.2002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]
似乎有效:
a = torch.randn((8, 7))
max_a, ids = torch.max(a, 1, keepdim=True)
b = torch.zeros_like(a) # result tensor
b.scatter_(1, ids, max_a) # set max values on idx indices
解决方案:
a * (a == a.amax(dim=1, keepdim=True))
说明:
a.amax(dim=1, keepdim=True)
将为您提供每行的最大值,其形状为 (N, 1)
。a
中的每个值是否与最大值匹配,如 (N, N) == (N, 1)
。这将为您提供一个二进制掩码,其中 True
位于 a
的最大值位置。如果您的蒙版还包含 True
表示零 (0==0)
,请不要担心,因为它们无论如何都会被忽略。*
将乘以 a
和来自 2.
的掩码,并仅保留 dim=*
上的最大元素。评论:
amax
与max
类似,但支持广播。如果您想在一批图像强度中找到最大值,这很有用(B, H, W)
。
a * (a == a.amax(dim=(1, 2), keepdim=True))