PyTorch,将用户指定的参数传递给DataLoader

问题描述 投票:0回答:1

新年快乐,

[我正在使用U-Net并实施2015年论文中所述的加权技术(U-Net:生物医学卷积网络图像分割)和2019年(U-Net –用于细胞计数,检测和形态计量学的深度学习)。在该技术中,存在方差σ和权重w_0。我希望,尤其是σ,成为可学习的参数,而不是猜测每个数据集之间哪个值最好。

  1. 根据搜索到的内容,可以使用nn.Parameter。
  2. 要在各个时期之间使用学习到的σ,我需要通过某种方式将这个新值通过DataLoader传递给DataSet的get_item函数。

我目前的观点是扩展torch.utils.data.DataLoader,其中新的init具有一个接受用户指定/可学习的参数的附加参数。给定torch.utils.data.DataLoader的源代码,我不了解DataLoader在何处以及如何调用DataSet实例,因此无法传递这些参数。

在代码方面,在DataSet定义中有函数

def __getitem__(self, index):

我可以更改为

def __getitem__(self, index, sigma):

并利用更新的,新近学习的σ。

我的问题是,在训练期间,我将训练数据集迭代为:

for epoch in range( checkpoint[ 'epoch'], num_epochs):
....
    for ii, ( X, y, y_weight, fname) in enumerate( dataLoader[ phase]):

在该DataLoader的枚举中,如何将新的σ传递给DataLoader,以便DataLoader将其传递给上述的DataSet getitem函数?

有什么想法或其他方法吗?

谢谢

parameter-passing pytorch unity3d-unet dataloader
1个回答
1
投票

您需要做的是将sigma设置为数据集的属性,并在各个纪元之间进行更改。

用于数据集定义

© www.soinside.com 2019 - 2024. All rights reserved.