如何折叠用 PyTorch 展开且有重叠的张量?

问题描述 投票:0回答:2

我有一个大小为:

torch.Size([1, 63840])
的张量,然后我展开:

inp_unfolded = inp_seq.unfold(1, 160, 80)

这给了我一个形状:

torch.Size([1, 797, 160])

我怎样才能重新

fold
以获得
torch.Size([1, 63840])
的张量?

python pytorch tensor
2个回答
3
投票

嗯,实际上,鉴于

t.unfold(i, n, s)
,条件是:

  • n >= s
    (否则步骤会跳过一些原始数据,我们无法恢复它)
  • n + s <= t.shape[i]

然后我们可以通过:

def roll(x, n, s, axis=1):
    return torch.cat((x[0], x[1:][:, n-s:].flatten()), axis)

说明:

x[0]
是起始块,在开始时始终是唯一的

x[1:][:, n-s:]
- 然后,我们获取其余的卷并
n-s
描述卷之间有多少元素会重叠,因此我们想忽略它们并仅获取
n-s

中的元素

图示:

x.unfold(0, 5, 2)
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 3.,  4.,  5.,  6.,  7.], # 3, 4, 5 are repeated
        [ 5.,  6.,  7.,  8.,  9.], # 5, 6, 7 are repeated...
        [ 7.,  8.,  9., 10., 11.],
        [ 9., 10., 11., 12., 13.],
        [11., 12., 13., 14., 15.],
        [13., 14., 15., 16., 17.]])

示例:

>> x = torch.arange(1., 18)
>> p = x.unfold(0, 5, 2)
>> roll(p, 5, 2, 0)

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16., 17.])

你也可以尝试一下

x = torch.arange(1., 18).reshape(1, 17)

和轴 1


3
投票

对于该特定配置,由于

63840
可以被
160
整除,并且步长是切片大小的倍数,因此您可以简单地选择沿该维度的每隔一个元素,然后
flatten
得到的张量:

inp_unfolded[:, ::2, :].flatten(1, 2)

更一般地,对于

t.unfold(i, n, s)
,如果
t.shape[i] % n == 0 and n % s == 0
成立,那么你可以通过以下方式恢复原始张量:

index = [slice(None) for __ in t.shape]
index[i] = slice(None, None, n // s)
original = t.unfold(i, n, s)[tuple(index)].flatten(i, i+1)

当然,如果事先知道尺寸

i
,您也可以使用切片表示法。例如
i == 1
如您的示例所示:

original = t.unfold(1, n, s)[:, ::n//s, ...].flatten(1, 2)
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.