Pytorch中位数 - 它是错误还是我错误使用它

问题描述 投票:1回答:2

我试图获得每行2D torch.tensor的中位数。但与使用标准阵列或numpy相比,结果并不是我所期望的

import torch
import numpy as np
from statistics import median

print(torch.__version__)
>>> 0.4.1

y = [[1, 2, 3, 5, 9, 1],[1, 2, 3, 5, 9, 1]]
median(y[0])
>>> 2.5

np.median(y,axis=1)
>>> array([2.5, 2.5])

yt = torch.tensor(y,dtype=torch.float32)
yt.median(1)[0]
>>> tensor([2., 2.])
python pytorch median torch
2个回答
2
投票

看起来这是本期提到的Torch的预期行为

https://github.com/pytorch/pytorch/issues/1837 https://github.com/torch/torch7/pull/182

上面链接中提到的推理

在奇数元素的情况下,中位数返回'中间'元素,否则一个中前元素(也可以做另一个约定,取两个中间元素的平均值,但这会贵两倍,所以我决定这个)。


0
投票

您可以使用pytorch模拟numpy中位数:

import torch
import numpy as np
y =[1, 2, 3, 5, 9, 1]
print("numpy=",np.median(y))
print(sorted([1, 2, 3, 5, 9, 1]))
yt = torch.tensor(y,dtype=torch.float32)
ymax = torch.tensor([yt.max()])
print("torch=",yt.median())
print("torch_fixed=",(torch.cat((yt,ymax)).median()+yt.median())/2.)
© www.soinside.com 2019 - 2024. All rights reserved.