如何找到张量前$n$个最大值的索引?

问题描述 投票:0回答:1

我知道

torch.argmax(x, dim = 0)
返回
x
中沿维度
0
的第一个最大值的索引。但是有没有一种有效的方法来返回第一个
n
最大值的索引?如果存在重复值,我还想要
n
索引中的索引。

举个具体的例子,说

x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
。我想要一个功能

generalized_argmax(xI torch.tensor, n: int)

这样

generalized_argmax(x, 4)
在此示例中返回
[0, 2, 4, 5]

python machine-learning pytorch argmax
1个回答
3
投票

无论如何,要获取遍历整个张量所需的所有内容,最有效的方法应该是使用

argsort
手动限制为
n
条目。

>>> x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
>>> x.argsort(dim=0, descending=True)[:n]
[2, 4, 0, 5]

如果需要索引升序,请再次排序以获得

[0, 2, 4, 5]

© www.soinside.com 2019 - 2024. All rights reserved.