我对 PyTorch 完全陌生,我想知道在
moveaxis()
和 movedim()
方法方面是否缺少任何内容。对于相同的参数,输出完全相同。还有这两种方法都不能用permute()
代替吗?
参考示例:
import torch
mytensor = torch.randn(3, 6, 3, 1, 7, 21, 4)
t_md = torch.movedim(mytensor, 2, 5)
t_ma = torch.moveaxis(mytensor, 2, 5)
print(t_md.shape, t_ma.shape)
print(torch.allclose(t_md, t_ma))
t_p = torch.permute(mytensor, (0, 1, 3, 4, 5, 2, 6))
print(t_p.shape)
print(torch.allclose(t_md, t_p))
moveaxis() 是 movedim() 的别名,如 the doc 所说,所以两者完全相同,重塑 0D 或更多 D 张量,它们类似于 permute() 可以重塑 1D或更多 D 张量,通过设置维度如下所示:
*备注:
permute()
、movedim()
和 moveaxis()
都可以在张量和 torch 中使用。movedim()
和 moveaxis()
会创建张量的副本 permute()
更轻。import torch
my_tensor = torch.tensor([[[0, 1, 2], [3, 4, 5]], # The size is [4, 2, 3].
[[6, 7, 8], [9, 10, 11]],
[[12, 13, 14], [15, 16, 17]],
[[18, 19, 20], [21, 22, 23]]])
torch.permute(input=my_tensor, dims=(0, 1, 2))
my_tensor.permute(dims=(0, 1, 2))
torch.movedim(input=my_tensor, source=0, destination=0)
my_tensor.movedim(source=0, destination=0)
torch.moveaxis(input=my_tensor, source=0, destination=0)
my_tensor.moveaxis(source=0, destination=0)
# tensor([[[0, 1, 2], [3, 4, 5]],
# [[6, 7, 8], [9, 10, 11]],
# [[12, 13, 14], [15, 16, 17]],
# [[18, 19, 20], [21, 22, 23]]])
# The size is [4, 2, 3].
torch.permute(input=my_tensor, dims=(2, 0, 1))
torch.movedim(input=my_tensor, source=2, destination=0)
torch.moveaxis(input=my_tensor, source=2, destination=0)
# tensor([[[0, 3], [6, 9], [12, 15], [18, 21]],
# [[1, 4], [7, 10], [13, 16], [19, 22]],
# [[2, 5], [8, 11], [14, 17], [20, 23]]])
# The size is [3, 4, 2].
torch.permute(input=my_tensor, dims=(1, 2, 0))
torch.movedim(input=my_tensor, source=0, destination=2)
torch.moveaxis(input=my_tensor, source=0, destination=2)
# tensor([[[0, 6, 12, 18], [1, 7, 13, 19], [2, 8, 14, 20]],
# [[3, 9, 15, 21], [4, 10, 16, 22], [5, 11, 17, 23]]])
# The size is [2, 3, 4].
torch.permute(input=my_tensor, dims=(0, 2, 1))
torch.movedim(input=my_tensor, source=1, destination=2)
torch.moveaxis(input=my_tensor, source=1, destination=2)
# tensor([[[0, 3], [1, 4], [2, 5]],
# [[6, 9], [7, 10], [8, 11]],
# [[12, 15], [13, 16], [14, 17]],
# [[18, 21], [19, 22], [20, 23]]])
# The size is [4, 3, 2].