我知道
torch.argmax(x, dim = 0)
返回 x
中沿维度 0
的第一个最大值的索引。但是有没有一种有效的方法来返回第一个 n
最大值的索引?如果存在重复值,我还想要 n
索引中的索引。
举个具体的例子,说
x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
。我想要一个功能
generalized_argmax(xI torch.tensor, n: int)
这样
generalized_argmax(x, 4)
在此示例中返回 [0, 2, 4, 5]
。
无论如何,要获取遍历整个张量所需的所有内容,最有效的方法应该是使用
argsort
手动限制为 n
条目。
>>> x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
>>> x.argsort(dim=0, descending=True)[:n]
[2, 4, 0, 5]
如果需要索引升序,请再次排序以获得
[0, 2, 4, 5]
。