这是我的代码
def train_dataloader(self):
if self._is_weighted_sampler:
weights = list(self.label_weight_by_name.values())
sampler = torch.utils.data.sampler.WeightedRandomSampler(
torch.tensor(weights), len(weights)
)
else:
sampler = torch.utils.data.RandomSampler(self._train_dataset)
return DataLoader(self._train_dataset, batch_size=self._batch_size, shuffle=True, sampler=sampler)
请注意,在加权采样器的情况下,它不需要数据集,但 RandomSampler 需要。
在RandomSampler的情况下,这意味着数据集被传递了两次。
我一定遗漏了一些关于如何使用它的信息,请纠正我。
实际上,看起来你对这种差异的看法是正确的;似乎没有一个明显的原因说明为什么一个调用需要数据集对象而另一个调用不需要。根据docs,函数原型是:
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)
深入研究源代码,我们会发现
data_source
从未被 Random_Sampler
索引,仅用作 len(data_source)
。该对象产生索引,而数据集对象仅用于确定数据的长度。
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)
从
weights
中随机采样权重(即应设置为与数据集具有相同数量的元素)并返回一组索引,然后必须单独使用这些索引来索引数据集对象。
开发人员的基本原理可能是“在
WeightedRandomSampler
的情况下,用户必须为数据源中的每个项目定义一个权重。在 RandomSampler
的情况下,所有权重都是 1,所以而不是定义一个向量,用户可以简单地传递数据集对象本身。”为什么他们不只是将数据集的长度作为整数传递,这超出了我的理解。