PyTorch 中的 Layernorm

问题描述 投票:0回答:1

考虑以下示例:

    batch, sentence_length, embedding_dim = 2, 3, 4
    embedding = torch.randn(batch, sentence_length, embedding_dim)
    print(embedding)
    
# Output:
    tensor([[[-2.1918,  1.2574, -0.3838,  1.3870],
             [-0.4043,  1.2972, -1.7326,  0.4047],
             [ 0.4560,  0.6482,  1.0858,  2.2086]],
    
            [[-1.4964,  0.3722, -0.7766,  0.3062],
             [ 0.9812,  0.1709, -0.9177, -1.2558],
             [-1.1560, -0.0367,  0.5496, -1.1142]]])

应用跨嵌入维度标准化的 Layernorm,我得到:

layer_norm = torch.nn.LayerNorm(embedding_dim)
layer_norm(embedding)

# Output:
tensor([[[-1.5194,  0.8530, -0.2758,  0.9422],
         [-0.2653,  1.2620, -1.4576,  0.4609],
         [-0.9470, -0.6641, -0.0204,  1.6315]],

        [[-1.4058,  0.9872, -0.4840,  0.9026],
         [ 1.3933,  0.4803, -0.7463, -1.1273],
         [-0.9869,  0.5545,  1.3619, -0.9294]]],
       grad_fn=<NativeLayerNormBackward0>)

现在,当我用一个朴素的 python 实现规范化上述嵌入张量的第一个向量时,我得到:

    a = [-2.1918,  1.2574, -0.3838,  1.3870]
    mean_a = statistics.mean(a)
    var_a = statistics.stdev(a)
    eps = 1e-5
    d = [ ((i-mean_a)/math.sqrt(var_a + eps)) for i in a]
    print(d)
    
    #Output:
[-1.7048934056508998,0.9571791768620398,-0.3094894774404756,1.0572037062293356]

归一化值与我从 PyTorch 的 Layernorm 得到的不一样。我计算 Layernorm 的方法有问题吗?

machine-learning deep-learning pytorch nlp attention-model
1个回答
0
投票

你想要的是方差而不是标准差(标准差是方差的平方,你在计算

d
时得到的是平方)。 此外,这使用有偏方差 (statistics.pvariance)。要使用您将使用的统计模块重现预期结果:

a = [-2.1918,  1.2574, -0.3838,  1.3870]
mean_a = statistics.mean(a)
var_a = statistics.pvariance(a)
eps = 1e-5
d = [ ((i-mean_a)/math.sqrt(var_a + eps)) for i in a]
print(d)
[-1.519391435327454, 0.8530327107709863, -0.2758152854532861, 0.942174010009754]

另一种验证结果正确的方法是:

[[torch.mean(i).item(), torch.var(i, unbiased=False).item()] for i in layer_norm(embedding)]

[[1.9868215517249155e-08, 0.9999885559082031],
[-1.9868215517249155e-08, 0.9999839663505554]]

这表明归一化嵌入的均值和方差(非常接近)0 和 1,正如预期的那样。

相关doc: “标准偏差是通过有偏估计器计算的,相当于 torch.var(input, unbiased=False)。”

© www.soinside.com 2019 - 2024. All rights reserved.