import torch
x = torch.rand(2, 3)
print(x)
splitted = x.split(split_size=2, dim=0) # should get 2 tensors of 1 x 3 and 1 x 3 size I thought
print(splitted) #instead, get a tuple of len 1, with [0] = tensor same as input
print(type(splitted), len(splitted))
print(splitted[0].shape)
print(torch.__version__)
提供以下输出:
tensor([[0.0702, 0.1275, 0.3735],
[0.0260, 0.9393, 0.9448]])
(tensor([[0.0702, 0.1275, 0.3735],
[0.0260, 0.9393, 0.9448]]),)
<class 'tuple'> 1
torch.Size([2, 3])
1.3.1
为什么我在一个元组中没有两个张量?我本来希望将输入分成两部分。我在Windows 10下]
import torch x = torch.rand(2,3)print(x)splitted = x.split(split_size = 2,dim = 0)#应该得到2个张量为1 x 3和1 x 3的张量,我认为print( #取而代之的是获得len 1的元组,其中[0] = ...
您可能误解了split(...)
文档。它说:
我认为行为符合预期。请注意,该参数为chunk(...)
,而不是分割数。因此,您要指定拆分的大小。您将大小指定为2,而第一个维度是大小2,因此您将获得一个包含一个元素(即import torch
x = torch.rand(2, 3)
chunks = x.chunk(chunks=2, dim=0)
的全部)的元组。