在MoCo的原论文中,是这么说的:
使用队列可以使字典变大,但也使得通过反向传播更新密钥编码器变得困难(梯度应该传播到队列中的所有样本)。
首先我认为bp不能暗示密钥编码器的主要原因是队列操作是不可微的。但这似乎不是真的。您可以计算队列中所有样本的梯度,然后 bp 应该正确执行。请参阅底部的代码。
那么BP难以处理按键编码器的真正原因是什么?在我看来,我认为可能是因为队列(字典)太大,导致内存爆炸。
q = nn.Linear(768,128)
k = nn.Linear(768,128)
bs = 64
ks = 4095
model = nn.ModuleList([q,k])
x = torch.randn(bs, 768)
optim = torch.optim.SGD(model.parameters(),lr=0.01)
loss = nn.CrossEntropyLoss()
def forward(x):
xq = q(x)
xk = k(x + 0.1)
que = torch.rand(ks,128)
pos = torch.einsum("nc,nc->n",xq,xk)
neg = torch.einsum("nc,kc->nk",xq,que)
out = torch.cat([pos.unsqueeze(-1),neg],dim=1)
t = torch.zeros(out.shape[0],dtype=torch.long)
l = loss(out,t)
return l
loss = forward(x)
loss.backward()
optim.step()
我认为你是对的,因为关键原因与队列的大小有关,这可能会使内存需求爆炸。
在 MoCo 框架中,您维护来自大量历史数据样本的编码密钥表示的队列。当在密钥编码器上执行反向传播时,您需要计算队列中所有样本的梯度。存储这些梯度的内存要求可能会变得非常高,尤其是当队列很大时。
因此导致采用了动量更新策略来修改密钥编码器的参数。
# momentum update: key network
f_k.params = m*f_k.params+(1-m)*f_q.params