我有一个有序的数据集(shuffle=False),被分类为“bins”。我将提供一个较小规模的示例来帮助澄清。假设数据集的大小为 60,箱大小为 10、20、30。我想按照垃圾箱的顺序训练我的模型。 (首先是 10,然后是 20 和 30)。我希望我的 DataLoader 获取批量大小为 8 的数据。在这种情况下,在获取前 8 个
datapoints
后,我不想从 bin-1 中获取剩余的 2 个数据,并从下一个数据中获取 6 个数据。我想要的是只得到 2,并在下一次迭代中,从 bin-2 中得到 8。简而言之,我想先完成一个容器中的训练,然后再转移到另一个容器中。另外,如果batch_size恰好大于bin大小,我想在移动到下一个之前仅在一个bin中获取数据。
我可以得到一些关于如何做到这一点的建议吗?我可以想到两种方法:实现一个自定义的 DataLoader(也需要这方面的建议)或者只是为每个 bin 创建单独的 DataLoader,并在最外层循环中迭代 bin 时,获取相应的 DataLoader 并进行训练。后一种方法会有一些严重的缺点吗?
我可以想到两种方法:实现一个自定义的 DataLoader(也需要这方面的建议)或者为每个 bin 创建单独的 DataLoader,并在最外层循环中迭代 bin 时,获取相应的 DataLoader 并进行训练。后一种方法会有一些严重的缺点吗?
第二个选项肯定更容易,我能看到的唯一问题是如果你使用多个工人ie如果你使用多处理:
将会发生的是,每个工作人员都会产生自己的工作进程。因此,您最终会得到一堆进程,其中大多数处于空闲状态(来自您未使用的数据加载器的进程)。如果您有很多这样的数据加载器,那么可能会出现性能问题。
如果您不打算使用多处理,我可能会选择
第一个选项会更难:这是数据加载器的代码,您可能想要创建它的子类,但它变得非常技术性
还有第三个选项,它更简单,尽管使用起来有点老套:*让你的
Dataset
类处理批处理,而不是使用Dataloader
的整理功能:
由于您的数据集是有序的,只需使其批量迭代,并在数据加载器中设置 collate=None
。
最后第四个选项,最干净的选项是使用火炬数据加载器的
batch_sampler
功能请参阅文档。只需构建一个采样器来生成批次的索引:这是一个遵循您的示例的版本:
from torch.utils.data import Sampler
from typing import List
from copy import copy
class BinSampler(torch.utils.data.Sampler[List[int]]):
bin_sizes: List[int]
batch_size: int
def __init__(self, bin_sizes: List[int], batch_size: int):
self.bin_sizes = bin_sizes
self.batch_size = batch_size
def __len__(self):
return sum(self.bin_sizes)
def __iter__(self):
bin_sizes = copy(self.bin_sizes)
remaining_in_bin = bin_sizes.pop()
current_index = 0
while bin_sizes or remaining_in_bin:
if remaining_in_bin == 0:
remaining_in_bin = bin_sizes.pop()
if remaining_in_bin < self.batch_size:
n_to_yield = self.batch_size
else: n_to_yield = remaining_in_bin
next_index = current_index + n_to_yield
yield range(current_index, next_index)
remaining_in_bin -= n_to_yield
current_index = next_index
dataloader = torch.utils.data.DataLoader(
dataset = torch.utils.data.TensorDataset(torch.randn(1000, 10)),
batch_sampler = BinSampler([10, 20, 30,], batch_size=8)
)