`movedim()` 与 `moveaxis()` 与 `permute()`

问题描述 投票:0回答:2

我对 PyTorch 完全陌生,我想知道在

moveaxis()
movedim()
方法方面是否缺少任何内容。对于相同的参数,输出完全相同。还有这两种方法都不能用
permute()
代替吗?

参考示例:

import torch

mytensor = torch.randn(3, 6, 3, 1, 7, 21, 4)

t_md = torch.movedim(mytensor, 2, 5)
t_ma = torch.moveaxis(mytensor, 2, 5)

print(t_md.shape, t_ma.shape)
print(torch.allclose(t_md, t_ma))

t_p = torch.permute(mytensor, (0, 1, 3, 4, 5, 2, 6))

print(t_p.shape)
print(torch.allclose(t_md, t_p))
python pytorch permutation tensor difference
2个回答
6
投票

是的,

moveaxis
movedim
的别名(类似于
swapaxes
swapdims
)。1

是的,这个功能可以通过

permute
来实现,但是移动一个轴同时保持所有其他轴的相对位置是一个足够常见的用例,以保证其自己的语法糖。


  1. 术语取自numpy

    torch.movedim()
    的别名。

    该函数相当于NumPy的moveaxis函数。


0
投票

moveaxis()movedim() 的别名,如 the doc 所说,所以两者完全相同,重塑 0D 或更多 D 张量,它们类似于 permute() 可以重塑 1D或更多 D 张量,通过设置维度如下所示:

*备注:

  • permute()
    movedim()
    moveaxis()
    都可以在张量和 torch 中使用。
  • permute() 不会创建张量的副本,而
    movedim()
    moveaxis()
    会创建张量的副本
    permute()
    更轻。
import torch

my_tensor = torch.tensor([[[0, 1, 2], [3, 4, 5]], # The size is [4, 2, 3].
                          [[6, 7, 8], [9, 10, 11]],
                          [[12, 13, 14], [15, 16, 17]],
                          [[18, 19, 20], [21, 22, 23]]])
torch.permute(input=my_tensor, dims=(0, 1, 2))
my_tensor.permute(dims=(0, 1, 2))
torch.movedim(input=my_tensor, source=0, destination=0)
my_tensor.movedim(source=0, destination=0)
torch.moveaxis(input=my_tensor, source=0, destination=0)
my_tensor.moveaxis(source=0, destination=0)
# tensor([[[0, 1, 2], [3, 4, 5]],        
#         [[6, 7, 8], [9, 10, 11]],
#         [[12, 13, 14], [15, 16, 17]],
#         [[18, 19, 20], [21, 22, 23]]])
# The size is [4, 2, 3].

torch.permute(input=my_tensor, dims=(2, 0, 1))
torch.movedim(input=my_tensor, source=2, destination=0)
torch.moveaxis(input=my_tensor, source=2, destination=0)
# tensor([[[0, 3], [6, 9], [12, 15], [18, 21]],
#         [[1, 4], [7, 10], [13, 16], [19, 22]],
#         [[2, 5], [8, 11], [14, 17], [20, 23]]])
# The size is [3, 4, 2].

torch.permute(input=my_tensor, dims=(1, 2, 0))
torch.movedim(input=my_tensor, source=0, destination=2)
torch.moveaxis(input=my_tensor, source=0, destination=2)
# tensor([[[0, 6, 12, 18], [1, 7, 13, 19], [2, 8, 14, 20]],
#         [[3, 9, 15, 21], [4, 10, 16, 22], [5, 11, 17, 23]]])
# The size is [2, 3, 4].

torch.permute(input=my_tensor, dims=(0, 2, 1))
torch.movedim(input=my_tensor, source=1, destination=2)
torch.moveaxis(input=my_tensor, source=1, destination=2)
# tensor([[[0, 3], [1, 4], [2, 5]],
#         [[6, 9], [7, 10], [8, 11]],
#         [[12, 15], [13, 16], [14, 17]],
#         [[18, 21], [19, 22], [20, 23]]])
# The size is [4, 3, 2].
© www.soinside.com 2019 - 2024. All rights reserved.