火炬 - 像numpy重复一样重复张量

问题描述 投票:3回答:3

我试图以两种方式重复火炬的张量。例如,重复张量{1,2,3,4}两次两种方式产生;

{1,2,3,4,1,2,3,4,1,2,3,4}
{1,1,1,2,2,2,3,3,3,4,4,4}

有一个内置的火炬:repeatTensor函数,它将生成两个中的第一个(如numpy.tile()),但我找不到后者(如numpy.repeat())。我确信我可以在第一次调用sort来给第二个,但我认为对于更大的数组,这可能在计算上是昂贵的?

谢谢。

torch
3个回答
5
投票
a = torch.Tensor{1,2,3,4}

为了获得{1,2,3,4,1,2,3,4,1,2,3,4},我们在第一维中重复三次张量:

a:repeatTensor(3)

为了获得{1,1,1,2,2,2,3,3,3,4,4,4},我们在张量中添加一个维度,并在第二维中重复三次以获得4 x 3张量,我们可以将其展平。

b = a:reshape(4,1):repeatTensor(1,3)
b:view(b:nElement())

1
投票

引用https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853 -

z = torch.FloatTensor([[1,2,3],[4,5,6],[7,8,9]])
1 2 3
4 5 6
7 8 9
z.transpose(0,1).repeat(1,3).view(-1, 3).transpose(0,1)
1 1 1 2 2 2 3 3 3
4 4 4 5 5 5 6 6 6
7 7 7 8 8 8 9 9 9

这将让您直观地了解它的工作原理。


0
投票

这是一个在张量中重复元素的通用函数。

def repeat(tensor, dims):
    if len(dims) != len(tensor.shape):
        raise ValueError("The length of the second argument must equal the number of dimensions of the first.")
    for index, dim in enumerate(dims):
        repetition_vector = [1]*(len(dims)+1)
        repetition_vector[index+1] = dim
        new_tensor_shape = list(tensor.shape)
        new_tensor_shape[index] *= dim
        tensor = tensor.unsqueeze(index+1).repeat(repetition_vector).reshape(new_tensor_shape)
    return tensor

如果你有

foo = tensor([[1, 2],
              [3, 4]])

通过调用repeat(foo, [2,1])你得到

tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4]])

因此,您沿着维度0复制了每个元素,并在维度1上复制了左侧元素。

© www.soinside.com 2019 - 2024. All rights reserved.