火炬分离给出错误的答案

问题描述 投票:0回答:2
import torch
x = torch.rand(2, 3)
print(x)
splitted = x.split(split_size=2, dim=0)  # should get 2 tensors of 1 x 3  and 1 x 3 size I thought
print(splitted) #instead, get a tuple of len 1, with [0] =  tensor same as input
print(type(splitted), len(splitted))
print(splitted[0].shape)
print(torch.__version__)

提供以下输出:

tensor([[0.0702, 0.1275, 0.3735],
        [0.0260, 0.9393, 0.9448]])
(tensor([[0.0702, 0.1275, 0.3735],
        [0.0260, 0.9393, 0.9448]]),)
<class 'tuple'> 1
torch.Size([2, 3])
1.3.1

为什么我在一个元组中没有两个张量?我本来希望将输入分成两部分。我在Windows 10下]

import torch x = torch.rand(2,3)print(x)splitted = x.split(split_size = 2,dim = 0)#应该得到2个张量为1 x 3和1 x 3的张量,我认为print( #取而代之的是获得len 1的元组,其中[0] = ...

split pytorch torch
2个回答
0
投票

您可能误解了split(...)文档。它说:


0
投票

我认为行为符合预期。请注意,该参数为chunk(...),而不是分割数。因此,您要指定拆分的大小。您将大小指定为2,而第一个维度是大小2,因此您将获得一个包含一个元素(即import torch x = torch.rand(2, 3) chunks = x.chunk(chunks=2, dim=0) 的全部)的元组。

© www.soinside.com 2019 - 2024. All rights reserved.