pytorch 矩阵乘法精度取决于张量大小

问题描述 投票:0回答:1

我有以下代码,将张量

X
乘以矩阵
C
。根据
X
的大小以及
C
是否附加到计算图,当我比较批量乘法与循环
X
的每个切片时,我会得到不同的结果。

import torch
from torch import nn

for X,C in [(torch.rand(8,  50, 32), nn.Parameter(torch.randn(32,32))),
            (torch.rand(16, 50, 32), nn.Parameter(torch.randn(32,32))),
            (torch.rand(8,  50, 32), nn.Parameter(torch.randn(32,32)).detach())
            ]:

    #multiply each entry
    A = torch.empty_like(X)
    for t in range(X.shape[1]):
        A[:,t,:] = (C @ X[:,t,:].unsqueeze(-1)).squeeze(-1)

    #multiply in batch
    A1 = (C @ X.unsqueeze(-1)).squeeze(-1)
    
    print('equal:', (A1 == A).all().item(), ', close:', torch.allclose(A1, A))

退货

equal: False , close: False
equal: True , close: True
equal: True , close: True

发生什么事了?我希望它们在所有三种情况下都是相等的。

仅供参考,

import sys, platform
print('OS:', platform.platform())
print('Python:', sys.version)
print('Pytorch:', torch.__version__)

给出:

OS: macOS-14.4.1-arm64-arm-64bit
Python: 3.12.1 | packaged by conda-forge | (main, Dec 23 2023, 08:01:35) [Clang 16.0.6 ]
Pytorch: 2.2.0
python pytorch floating-point precision
1个回答
0
投票

这是计算领域最大的挑战之一。

但是,对于您的情况,您可以简单地使用

allclose()
函数并使用
double()
在技术上使用双精度(浮点数 64)。

代码

import torch
from torch import nn


def _compare(X, C):

    A = torch.empty_like(X)
    for t in range(X.shape[1]):
        A[:, t, :] = (C @ X[:, t, :].unsqueeze(-1)).squeeze(-1)

    A1 = (C @ X.unsqueeze(-1)).squeeze(-1)

    equal = (A1 == A).all().item()
    close = torch.allclose(A1, A)
    max_diff = (A1 - A).abs().max().item()

    print(f'equal: {equal}, close: {close}, diff: {max_diff:.16f}')


torch.manual_seed(4)

for X, C in [
    (torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32))),
    (torch.rand(16, 50, 32), nn.Parameter(torch.randn(32, 32))),
    (torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32)).detach())
]:
    _compare(X, C)

打印

equal: False, close: False, diff: 0.0000023841857910
equal: False, close: False, diff: 0.0000019073486328
equal: True, close: True, diff: 0.0000000000000000

使用
double()

  • 您可以提高精度。
  • 您可以使用双精度,即使用
    torch.float64
    。它大大减慢了程序的速度。但这是计算时间和准确性之间的权衡。这是你的决定。

代码

import torch
from torch import nn


def _compare(X, C):
    X, C = X.double(), C.double()

    A = torch.empty_like(X)
    for t in range(X.shape[1]):
        A[:, t, :] = (C @ X[:, t, :].unsqueeze(-1)).squeeze(-1)

    A1 = (C @ X.unsqueeze(-1)).squeeze(-1)

    equal = (A1 == A).all().item()
    close = torch.allclose(A1, A)
    max_diff = (A1 - A).abs().max().item()

    print(f'equal: {equal}, close: {close}, diff: {max_diff:.32f}')


torch.manual_seed(4)

for X, C in [
    (torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32))),
    (torch.rand(16, 50, 32), nn.Parameter(torch.randn(32, 32))),
    (torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32)).detach())
]:
    _compare(X, C)

打印

equal: False, close: True, diff: 0.00000000000000266453525910037570
equal: True, close: True, diff: 0.00000000000000000000000000000000
equal: True, close: True, diff: 0.00000000000000000000000000000000
© www.soinside.com 2019 - 2024. All rights reserved.