我一直在尝试为个人项目重新创建 Dino V1 训练设置。为此,我从这个存储库中获取了大部分代码:https://github.com/facebookresearch/dino[dinov1链接]1
我几乎已经完成了它,除了 main_dino.py 文件中的一部分有一个名为 train_one_epoch 的函数,在第 318 行中他们给出了:
teacher_output= teacher (images[:2]) # only the 2 global views pass through the teacher
现在我知道 pytorch 张量索引/切片是如何工作的。因此,如果图像是结构的一批图像:
(批量大小、作物数量、c、h、w)
在调用
train_one_epoch()
之前,对模型进行了另一项修改,student
和 teacher
模型都用 MultiCropWrapper
类包装。只需看一下类的文档字符串,如下所示:
class MultiCropWrapper(nn.Module):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
因此这个 MultiCropWrapper 类处理前向传递,并且还提到它针对不同的分辨率执行多次前向传递。