我有一个具有 nan 值的 2D 火炬张量,我想获取列最小值并忽略具有 nan 值的单元格。
import torch
data = torch.tensor([[ 0., 1., float('nan'), 3.],[ 4., 5., 6., 7.], [ 8., 9., 10., 11.]])
torch.min(data,0)
# What I would like to get is
# tensor([0., 1., 6., 3.])
有什么建议吗? 谢谢
您可以通过将张量中的所有 nan 值转换为令人难以置信的高值,然后运行 torch.min 来实现此目的:
#I am replacing nan with 10^15
data = torch.nan_to_num(data,nan = 10e14)
data = torch.min(data,0)
torch.nan_to_num 用于将张量中的所有 nan 值转换为某个值。根据您的要求进行操作。