将张量拆分为张量列表的最快方法

问题描述 投票:0回答:1

假设我有一个形状为 [A,B,C,...] 的张量,并且我希望创建沿第 0 维的张量列表。这意味着输出应该是长度为 A 的一维列表,其每个元素都是形状为 [B,C,...] 的张量。

我已经看到了 torch.split 函数,但问题是它返回形状为 [1,B,C,...] 的一维张量元组,这不是我想要的。

例如,

>>> minibatch = torch.rand ((4, 2))
    
    tensor([[0.8218, 0.8997],
            [0.4612, 0.9416],
            [0.1481, 0.2389],
            [0.7764, 0.7884]])

>>> torch.split (minibatch, split_size_or_sections=1)

    (tensor([[0.8218, 0.8997]]), 
     tensor([[0.4612, 0.9416]]), 
     tensor([[0.1481, 0.2389]]), 
     tensor([[0.7764, 0.7884]]))

我想要的输出:

    [tensor([0.8218, 0.8997]), 
     tensor([0.4612, 0.9416]), 
     tensor([0.1481, 0.2389]), 
     tensor([0.7764, 0.7884])]

当然,我现在可以将这些张量中的每一个压缩到第 0 维并以所需的格式获取它,但是如果有关于执行此操作的更简单方法的快速指南,我将不胜感激。

python pytorch
1个回答
0
投票

我参加聚会迟到了,但是:

您正在寻找的是torch.unbind
来自文档:

删除张量维度。返回沿给定维度的所有切片的元组,已经没有它了。

示例:

>>> import torch

>>> minibatch = torch.rand((4, 2))
>>> minibatch
    tensor([[0.2150, 0.8344],
            [0.2773, 0.4082],
            [0.7117, 0.3576],
            [0.6589, 0.0147]])

>>> torch.unbind(minibatch, dim=0)
    (tensor([0.2150, 0.8344]),
     tensor([0.2773, 0.4082]),
     tensor([0.7117, 0.3576]),
     tensor([0.6589, 0.0147]))
© www.soinside.com 2019 - 2024. All rights reserved.