我想写一个混合训练的函数。在这个 场地 我找到了一些代码,并对我以前的代码进行了修改。但在原来的代码中,只有一个随机变量生成的批次(64)。但我想为批次中的每张图片生成随机值。在代码中,每批图片只有一个变量。
def mixup_data(x, y, alpha=1.0):
lam = np.random.beta(alpha, alpha)
batch_size = x.size()[0]
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index,:]
mixed_y = lam * y + (1 - lam) * y[index,:]
return mixed_x, mixed_y
输入的x和y来自pytorch DataLoader.x输入大小。torch.Size([64, 3, 256, 256])
y输入大小: torch.Size([64, 3474])
这段代码很好用 然后我把它改成这样。
def mixup_data(x, y):
batch_size = x.size()[0]
lam = torch.rand(batch_size)
index = torch.randperm(batch_size)
mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]
return mixed_x, mixed_y
但它给出了一个错误。RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3
我对代码工作原理的理解是,它把批处理中的第一个图像乘以 lam
张量(64值长)。我怎样才能做到这一点?
你需要替换下面一行。
lam = torch.rand(batch_size)
由
lam = torch.rand(batch_size, 1, 1, 1)
用你现在的代码。lam[index] * x
乘法是不可能的,因为 lam[index]
是大小 torch.Size([64])
而 x
是大小 torch.Size([64, 3, 256, 256])
. 所以,你需要将 lam[index]
作为 torch.Size([64, 1, 1, 1])
以便成为可广播的。
为了应对下面的语句。
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]
我们可以重塑 lam
语句前的张量。
lam = lam.reshape(batch_size, 1)
mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]
问题是,两个相乘的张量的大小不一致。让我们来看看 lam[index] * x
为例。其尺寸为:
x
: torch.Size([64, 3, 256, 256])
lam[index]
: torch.Size([64])
为了将它们相乘,它们应该具有相同的大小,其中 lam[index]
使用相同的值 [3, 256, 256]
每批次,因为你想用相同的值乘以该批次中的每个元素,但每批次都不同。
lam[index].view(batch_size, 1, 1, 1).expand_as(x)
# => Size: torch.Size([64, 3, 256, 256])
.expand_as(x)
重复奇异的尺寸,使它的大小与x相同,参见 .expand()
文件 了解详情。
您不需要展开张量,因为如果存在奇异维度,PyTorch 会自动为您展开。这就是所谓的广播。PyTorch - 广播语义. 所以,只要有一个尺寸为----------------------------------------就足够了。torch.Size([64, 1, 1, 1])
相乘 x
.
lam[index].view(batch_size, 1, 1, 1) * x
这也适用于 y
但大小 torch.Size([64, 1])
,因为 y
有尺寸 torch.Size([64, 3474])
.
mixed_x = lam[index].view(batch_size, 1, 1, 1) * x + (1 - lam[index]).view(batch_size, 1, 1, 1) * x[index, :]
mixed_y = lam[index].view(batch_size, 1) * y + (1 - lam[index]).view(batch_size, 1) * y[index, :]
只是一个小的附带说明。lam[index]
只是重新排列了 lam
但由于你是随机创建的,所以无论你是否重新排列它都没有任何区别。唯一重要的是 x
和 y
重新排列,就像在原始代码中一样。