正如下面的代码,我试图获得模型的47次重复输出的平均值。但它总是内存不足。如果我删除z_proto_class_list.append(z_proto_class)
,那就没关系了。我想这是因为如果我不附加张量,则释放内存。我总是试图一次生成47输出,但它显然比我目前的选择更多的内存消耗。有没有办法解决我目前的问题?谢谢。
z_proto_class_list = []
for support_input_ids, support_input_mask, support_segment_ids in dataloader:
s_z, s_pooled_output = model(support_input_ids, support_input_mask, support_segment_ids, output_all_encoded_layers=False)
sz_dim = s_z.size(-1)
index = torch.LongTensor(support_idx_list).unsqueeze(1).unsqueeze(2).expand(len(support_idx_list),1,sz_dim).cuda()
z_proto_raw = torch.gather(s_z,1,index)
z_proto_class = z_proto_raw.view(1,n_support, sz_dim).mean(1)
z_proto_class_list.append(z_proto_class)
torch.cuda.empty_cache()
z_proto = torch.cat(z_proto_class_list, 0)
似乎z_proto_class_list.append(z_proto_class)
存储了整个计算图,因此内存不会自动释放。我使用z_proto_class_list.append(z_proto_class.detach())
并解决了这个问题。但问题是这不适合我原来的实现,因为我希望在给定类实例的质心的情况下更新模型的参数。