我使用 PyTorch 编写了以下代码并遇到了运行时错误:
tns = torch.tensor([1,0,1])
tns.mean()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-666-194e5ab56931> in <module>
----> 1 tns.mean()
RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead.
但是,如果我将张量更改为
float
,错误就会消失:
tns = torch.tensor([1.,0,1])
tns.mean()
---------------------------------------------------------------------------
tensor(0.6667)
我的问题是为什么会发生错误。第一个张量的数据类型是
int64
而不是Long
,为什么PyTorch将其视为Long
?
您应该将 'torch.tensor([1,0,1])' 更改为 'torch.Tensor([1,0,1])。
我在下面遇到了同样的错误:
运行时错误:mean():无法推断输出数据类型。输入数据类型必须是浮点或复数数据类型。得到:长
因为我使用了 mean() 和整数张量,如下所示。 *
mean()
可以表示(平均)零个或多个浮点或复数的0D或多个D张量:
import torch
my_tensor = torch.tensor([2, 7, 4])
torch.mean(input=my_tensor) # Error
所以,我将张量转换为浮点或复数,然后我可以得到如下所示的结果:
import torch
my_tensor = torch.tensor([2, 7, 4])
torch.mean(input=my_tensor.float()) # tensor(4.3333)
torch.mean(input=my_tensor.cfloat()) # tensor(4.3333+0.j)
import torch
my_tensor = torch.tensor([2., 7., 4.])
torch.mean(input=my_tensor) # tensor(4.3333)
my_tensor = torch.tensor([2.+0.j, 7.+0.j, 4.+0.j])
torch.mean(input=my_tensor) # tensor(4.3333+0.j)