我有一个大小为:
torch.Size([1, 63840])
的张量,然后我展开:
inp_unfolded = inp_seq.unfold(1, 160, 80)
这给了我一个形状:
torch.Size([1, 797, 160])
我怎样才能重新
fold
以获得torch.Size([1, 63840])
的张量?
嗯,实际上,鉴于
t.unfold(i, n, s)
,条件是:
n >= s
(否则步骤会跳过一些原始数据,我们无法恢复它)n + s <= t.shape[i]
然后我们可以通过:
def roll(x, n, s, axis=1):
return torch.cat((x[0], x[1:][:, n-s:].flatten()), axis)
说明:
x[0]
是起始块,在开始时始终是唯一的
x[1:][:, n-s:]
- 然后,我们获取其余的卷并 n-s
描述卷之间有多少元素会重叠,因此我们想忽略它们并仅获取 n-s
中的元素
图示:
x.unfold(0, 5, 2)
tensor([[ 1., 2., 3., 4., 5.],
[ 3., 4., 5., 6., 7.], # 3, 4, 5 are repeated
[ 5., 6., 7., 8., 9.], # 5, 6, 7 are repeated...
[ 7., 8., 9., 10., 11.],
[ 9., 10., 11., 12., 13.],
[11., 12., 13., 14., 15.],
[13., 14., 15., 16., 17.]])
示例:
>> x = torch.arange(1., 18)
>> p = x.unfold(0, 5, 2)
>> roll(p, 5, 2, 0)
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16., 17.])
你也可以尝试一下
x = torch.arange(1., 18).reshape(1, 17)
和轴 1
对于该特定配置,由于
63840
可以被 160
整除,并且步长是切片大小的倍数,因此您可以简单地选择沿该维度的每隔一个元素,然后 flatten
得到的张量:
inp_unfolded[:, ::2, :].flatten(1, 2)
更一般地,对于
t.unfold(i, n, s)
,如果t.shape[i] % n == 0 and n % s == 0
成立,那么你可以通过以下方式恢复原始张量:
index = [slice(None) for __ in t.shape]
index[i] = slice(None, None, n // s)
original = t.unfold(i, n, s)[tuple(index)].flatten(i, i+1)
当然,如果事先知道尺寸
i
,您也可以使用切片表示法。例如 i == 1
如您的示例所示:
original = t.unfold(1, n, s)[:, ::n//s, ...].flatten(1, 2)