为什么使用set_printoptions(精度=1)后张量的浮点数仍然波动

问题描述 投票:0回答:1

我遵循了一个教程,该教程展示了如何以正确的方式进行张量运算。他们说,通常张量之间的运算是通过迭代张量数组的循环手动完成的。然后他们展示了更好的方法,通过与三角形单位矩阵进行“点积”,对张量进行平均而不丢失其空间信息,并表明两种方法都会产生相同的结果,

print(torch.allclose(xbow, xbow2))
给出“True”作为返回值表明两种方法的工作方式相同。但是当我遵循他们的路径时,我的结果最终为“False”,显示张量运算结果可能存在差异。

根据我向周围专家询问,由于他们使用随机数生成器

torch.randn()
作为产生张量的方式,产生的数字可能会给出不同的浮点精度,从而降低准确性。尽管他们对此很确定,但他们不知道为什么教程没有遇到与我相同的问题。因此,根据我所拥有的,我使用
set_printoptions(precision=1)
来限制张量值精度点。但结果仍然是“错误”。我做错了什么或者我应该在这里寻找什么?

代码

教程中生成张量的方法

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
print(x.shape)
print(x[0]) 

它显示了

torch.Size([4, 8, 2])
tensor([[ 0.2, -0.1],
        [-0.4, -0.9],
        [ 0.6,  0.0],
        [ 1.0,  0.1],
        [ 0.4,  1.2],
        [-1.3, -0.5],
        [ 0.2, -0.2],
        [-0.9,  1.5]])

使用教程中的循环进行平均

# We want x[b, t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

我尝试打印张量的第一个成员

xbow[0]

有结果

tensor([[ 0.2, -0.1],
        [-0.1, -0.5],
        [ 0.1, -0.3],
        [ 0.4, -0.2],
        [ 0.4,  0.1],
        [ 0.1, -0.0],
        [ 0.1, -0.1],
        [-0.0,  0.1]])

然后教程展示了不同的方法,用

torch.tril(torch.ones(T,T))

wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # B(T, T, T) @ (B, T, C ) ----> (B, T, C)

然后我打印结果的第一个成员

print(xbow2[0])

结果是

tensor([[ 0.2, -0.1],
        [-0.1, -0.5],
        [ 0.1, -0.3],
        [ 0.4, -0.2],
        [ 0.4,  0.1],
        [ 0.1, -0.0],
        [ 0.1, -0.1],
        [-0.0,  0.1]])

从第一个成员看来它是相等的,但是当我这样做时

xbow == xbow2
它表明某些张量不相等

tensor([[[ True,  True],
         [ True,  True],
         [ True, False],
         [ True,  True],
         [ True, False],
         [False,  True],
         [False, False],
         [ True,  True]],

        [[ True,  True],
         [ True,  True],
         [False,  True],
         [ True,  True],
         [False, False],
         [False, False],
         [False,  True],
         [False,  True]],

        [[ True,  True],
         [ True,  True],
         [False, False],
         [ True,  True],
         [False, False],
         [False,  True],
         [False, False],
...

这里发生了什么?

python pytorch floating-point tensor
1个回答
0
投票

您可能遇到浮点精度问题,请查看这篇文章,了解为什么 this 有时会出现问题。 pytorch 中的“allclose”方法可以帮助解决这个问题。

© www.soinside.com 2019 - 2024. All rights reserved.