我有以下代码,将张量
X
乘以矩阵C
。根据 X
的大小以及 C
是否附加到计算图,当我比较批量乘法与循环 X
的每个切片时,我会得到不同的结果。
import torch
from torch import nn
for X,C in [(torch.rand(8, 50, 32), nn.Parameter(torch.randn(32,32))),
(torch.rand(16, 50, 32), nn.Parameter(torch.randn(32,32))),
(torch.rand(8, 50, 32), nn.Parameter(torch.randn(32,32)).detach())
]:
#multiply each entry
A = torch.empty_like(X)
for t in range(X.shape[1]):
A[:,t,:] = (C @ X[:,t,:].unsqueeze(-1)).squeeze(-1)
#multiply in batch
A1 = (C @ X.unsqueeze(-1)).squeeze(-1)
print('equal:', (A1 == A).all().item(), ', close:', torch.allclose(A1, A))
退货
equal: False , close: False
equal: True , close: True
equal: True , close: True
发生什么事了?我希望它们在所有三种情况下都是相等的。
仅供参考,
import sys, platform
print('OS:', platform.platform())
print('Python:', sys.version)
print('Pytorch:', torch.__version__)
给出:
OS: macOS-14.4.1-arm64-arm-64bit
Python: 3.12.1 | packaged by conda-forge | (main, Dec 23 2023, 08:01:35) [Clang 16.0.6 ]
Pytorch: 2.2.0
这是计算领域最大的挑战之一。
但是,对于您的情况,您可以简单地使用
allclose()
函数并使用 double()
在技术上使用双精度(浮点数 64)。
import torch
from torch import nn
def _compare(X, C):
A = torch.empty_like(X)
for t in range(X.shape[1]):
A[:, t, :] = (C @ X[:, t, :].unsqueeze(-1)).squeeze(-1)
A1 = (C @ X.unsqueeze(-1)).squeeze(-1)
equal = (A1 == A).all().item()
close = torch.allclose(A1, A)
max_diff = (A1 - A).abs().max().item()
print(f'equal: {equal}, close: {close}, diff: {max_diff:.16f}')
torch.manual_seed(4)
for X, C in [
(torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32))),
(torch.rand(16, 50, 32), nn.Parameter(torch.randn(32, 32))),
(torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32)).detach())
]:
_compare(X, C)
equal: False, close: False, diff: 0.0000023841857910
equal: False, close: False, diff: 0.0000019073486328
equal: True, close: True, diff: 0.0000000000000000
double()
torch.float64
。它大大减慢了程序的速度。但这是计算时间和准确性之间的权衡。这是你的决定。import torch
from torch import nn
def _compare(X, C):
X, C = X.double(), C.double()
A = torch.empty_like(X)
for t in range(X.shape[1]):
A[:, t, :] = (C @ X[:, t, :].unsqueeze(-1)).squeeze(-1)
A1 = (C @ X.unsqueeze(-1)).squeeze(-1)
equal = (A1 == A).all().item()
close = torch.allclose(A1, A)
max_diff = (A1 - A).abs().max().item()
print(f'equal: {equal}, close: {close}, diff: {max_diff:.32f}')
torch.manual_seed(4)
for X, C in [
(torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32))),
(torch.rand(16, 50, 32), nn.Parameter(torch.randn(32, 32))),
(torch.rand(8, 50, 32), nn.Parameter(torch.randn(32, 32)).detach())
]:
_compare(X, C)
equal: False, close: True, diff: 0.00000000000000266453525910037570
equal: True, close: True, diff: 0.00000000000000000000000000000000
equal: True, close: True, diff: 0.00000000000000000000000000000000