为什么 timm 视觉变压器位置嵌入初始化为零?

问题描述 投票:0回答:2

我正在查看视觉转换器的

timm
实现,对于位置嵌入,他正在用零初始化他的位置嵌入,如下所示:

self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

看这里: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L309

我不确定当稍后将其添加到补丁中时,这实际上如何嵌入有关该位置的任何内容?

x = x + self.pos_embed

如有任何反馈,我们将不胜感激。

pytorch vision transformer-model
2个回答
1
投票

位置嵌入是包含在计算图中并在训练期间更新的参数。因此,是否用零进行初始化并不重要;它们是在培训过程中学到的。


0
投票

如果您现在检查存储库,它确实是使用从正态分布中采样的权重进行初始化的。使用零初始化权重然后在

init_weights
函数内重新初始化是很常见的。您查看的文件版本可能就是这种情况。

当使用可学习的位置嵌入时,它们确实应该随机初始化。这是为了帮助网络学习令牌的位置信息。我以前没有见过使用零初始化,这充其量可能会影响模型收敛。

© www.soinside.com 2019 - 2024. All rights reserved.