运行时错误:mean():输入数据类型应该是浮点或复杂数据类型。反而长了

问题描述 投票:0回答:3

我使用 PyTorch 编写了以下代码并遇到了运行时错误:

tns = torch.tensor([1,0,1])
tns.mean()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-666-194e5ab56931> in <module>
----> 1 tns.mean()

RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead.

但是,如果我将张量更改为

float
,错误就会消失:

tns = torch.tensor([1.,0,1])
tns.mean()
---------------------------------------------------------------------------
tensor(0.6667)

我的问题是为什么会发生错误。第一个张量的数据类型是

int64
而不是
Long
,为什么PyTorch将其视为
Long

python pytorch runtime-error mean tensor
3个回答
6
投票

您应该将 'torch.tensor([1,0,1])' 更改为 'torch.Tensor([1,0,1])。


5
投票

这是因为

torch.int64
torch.long
都指的是相同的数据类型,即 64 位有符号整数。请参阅此处了解所有数据类型的概述。


0
投票

我在下面遇到了同样的错误:

运行时错误:mean():无法推断输出数据类型。输入数据类型必须是浮点或复数数据类型。得到:长

因为我使用了 mean() 和整数张量,如下所示。 *

mean()
可以表示(平均)零个或多个浮点或复数的0D或多个D张量:

import torch

my_tensor = torch.tensor([2, 7, 4])

torch.mean(input=my_tensor) # Error

所以,我将张量转换为浮点或复数,然后我可以得到如下所示的结果:

import torch

my_tensor = torch.tensor([2, 7, 4])

torch.mean(input=my_tensor.float()) # tensor(4.3333)
torch.mean(input=my_tensor.cfloat()) # tensor(4.3333+0.j)
import torch

my_tensor = torch.tensor([2., 7., 4.])

torch.mean(input=my_tensor) # tensor(4.3333)

my_tensor = torch.tensor([2.+0.j, 7.+0.j, 4.+0.j])

torch.mean(input=my_tensor) # tensor(4.3333+0.j)
© www.soinside.com 2019 - 2024. All rights reserved.