pytorch min by columns with nan

问题描述 投票:0回答:1

我有一个具有 nan 值的 2D 火炬张量,我想获取列最小值并忽略具有 nan 值的单元格。

import torch

data = torch.tensor([[ 0.,  1., float('nan'),  3.],[ 4.,  5.,  6.,  7.], [ 8.,  9., 10., 11.]])
torch.min(data,0)

# What I would like to get is
# tensor([0., 1., 6., 3.])

有什么建议吗? 谢谢

nan tensor torch minimum
1个回答
0
投票

您可以通过将张量中的所有 nan 值转换为令人难以置信的高值,然后运行 torch.min 来实现此目的:

 #I am replacing nan with 10^15
 data = torch.nan_to_num(data,nan = 10e14)
 data = torch.min(data,0)

torch.nan_to_num 用于将张量中的所有 nan 值转换为某个值。根据您的要求进行操作。

© www.soinside.com 2019 - 2024. All rights reserved.