假设我有一个形状为 [A,B,C,...] 的张量,并且我希望创建沿第 0 维的张量列表。这意味着输出应该是长度为 A 的一维列表,其每个元素都是形状为 [B,C,...] 的张量。
我已经看到了 torch.split 函数,但问题是它返回形状为 [1,B,C,...] 的一维张量元组,这不是我想要的。
例如,
>>> minibatch = torch.rand ((4, 2))
tensor([[0.8218, 0.8997],
[0.4612, 0.9416],
[0.1481, 0.2389],
[0.7764, 0.7884]])
>>> torch.split (minibatch, split_size_or_sections=1)
(tensor([[0.8218, 0.8997]]),
tensor([[0.4612, 0.9416]]),
tensor([[0.1481, 0.2389]]),
tensor([[0.7764, 0.7884]]))
我想要的输出:
[tensor([0.8218, 0.8997]),
tensor([0.4612, 0.9416]),
tensor([0.1481, 0.2389]),
tensor([0.7764, 0.7884])]
当然,我现在可以将这些张量中的每一个压缩到第 0 维并以所需的格式获取它,但是如果有关于执行此操作的更简单方法的快速指南,我将不胜感激。
我参加聚会迟到了,但是:
您正在寻找的是torch.unbind。
来自文档:
删除张量维度。返回沿给定维度的所有切片的元组,已经没有它了。
示例:
>>> import torch
>>> minibatch = torch.rand((4, 2))
>>> minibatch
tensor([[0.2150, 0.8344],
[0.2773, 0.4082],
[0.7117, 0.3576],
[0.6589, 0.0147]])
>>> torch.unbind(minibatch, dim=0)
(tensor([0.2150, 0.8344]),
tensor([0.2773, 0.4082]),
tensor([0.7117, 0.3576]),
tensor([0.6589, 0.0147]))