为什么 moco 键编码器反向传播很棘手

问题描述 投票:0回答:1

MoCo的原论文中,是这么说的:

使用队列可以使字典变大,但也使得通过反向传播更新密钥编码器变得困难(梯度应该传播到队列中的所有样本)。

首先我认为bp不能暗示密钥编码器的主要原因是队列操作是不可微的。但这似乎不是真的。您可以计算队列中所有样本的梯度,然后 bp 应该正确执行。请参阅底部的代码。

那么BP难以处理按键编码器的真正原因是什么?在我看来,我认为可能是因为队列(字典)太大,导致内存爆炸。

q = nn.Linear(768,128)
k = nn.Linear(768,128)
bs = 64
ks = 4095
model = nn.ModuleList([q,k])
x = torch.randn(bs, 768)
optim = torch.optim.SGD(model.parameters(),lr=0.01)
loss = nn.CrossEntropyLoss()
def forward(x):
    xq = q(x)
    xk = k(x + 0.1)
    que = torch.rand(ks,128)
    pos = torch.einsum("nc,nc->n",xq,xk)
    neg = torch.einsum("nc,kc->nk",xq,que)
    out = torch.cat([pos.unsqueeze(-1),neg],dim=1)
    t = torch.zeros(out.shape[0],dtype=torch.long)
    l = loss(out,t)
    return l
loss = forward(x)
loss.backward()
optim.step()
deep-learning pytorch self-supervised-learning
1个回答
0
投票

我认为你是对的,因为关键原因与队列的大小有关,这可能会使内存需求爆炸。

在 MoCo 框架中,您维护来自大量历史数据样本的编码密钥表示的队列。当在密钥编码器上执行反向传播时,您需要计算队列中所有样本的梯度。存储这些梯度的内存要求可能会变得非常高,尤其是当队列很大时。

因此导致采用了动量更新策略来修改密钥编码器的参数。

# momentum update: key network
f_k.params = m*f_k.params+(1-m)*f_q.params

MoCo论文:https://arxiv.org/pdf/1911.05722.pdf

© www.soinside.com 2019 - 2024. All rights reserved.