如何将4维PyTorch张量乘以1维张量?

问题描述 投票:0回答:1

我想写一个混合训练的函数。在这个 场地 我找到了一些代码,并对我以前的代码进行了修改。但在原来的代码中,只有一个随机变量生成的批次(64)。但我想为批次中的每张图片生成随机值。在代码中,每批图片只有一个变量。

def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    mixed_y = lam * y + (1 - lam) * y[index,:]

    return mixed_x, mixed_y

输入的x和y来自pytorch DataLoader.x输入大小。torch.Size([64, 3, 256, 256])y输入大小: torch.Size([64, 3474])

这段代码很好用 然后我把它改成这样。

def mixup_data(x, y):
    batch_size = x.size()[0]
    lam = torch.rand(batch_size)
    index = torch.randperm(batch_size)

    mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
    mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]

    return mixed_x, mixed_y

但它给出了一个错误。RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

我对代码工作原理的理解是,它把批处理中的第一个图像乘以 lam 张量(64值长)。我怎样才能做到这一点?

python pytorch torch
1个回答
3
投票

你需要替换下面一行。

lam = torch.rand(batch_size)

lam = torch.rand(batch_size, 1, 1, 1)

用你现在的代码。lam[index] * x 乘法是不可能的,因为 lam[index] 是大小 torch.Size([64])x 是大小 torch.Size([64, 3, 256, 256]). 所以,你需要将 lam[index] 作为 torch.Size([64, 1, 1, 1]) 以便成为可广播的。

为了应对下面的语句。

mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

我们可以重塑 lam 语句前的张量。

lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]

0
投票

问题是,两个相乘的张量的大小不一致。让我们来看看 lam[index] * x 为例。其尺寸为:

  • x: torch.Size([64, 3, 256, 256])
  • lam[index]: torch.Size([64])

为了将它们相乘,它们应该具有相同的大小,其中 lam[index] 使用相同的值 [3, 256, 256] 每批次,因为你想用相同的值乘以该批次中的每个元素,但每批次都不同。

lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])

.expand_as(x) 重复奇异的尺寸,使它的大小与x相同,参见 .expand() 文件 了解详情。

您不需要展开张量,因为如果存在奇异维度,PyTorch 会自动为您展开。这就是所谓的广播。PyTorch - 广播语义. 所以,只要有一个尺寸为----------------------------------------就足够了。torch.Size([64, 1, 1, 1]) 相乘 x.

lam[index].view(batch_size, 1, 1, 1) * x

这也适用于 y 但大小 torch.Size([64, 1]),因为 y 有尺寸 torch.Size([64, 3474]).

mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]

只是一个小的附带说明。lam[index] 只是重新排列了 lam但由于你是随机创建的,所以无论你是否重新排列它都没有任何区别。唯一重要的是 xy 重新排列,就像在原始代码中一样。

© www.soinside.com 2019 - 2024. All rights reserved.