当 PyTorch Lightning 中 `limit_train_batches = n` 时,训练期间使用哪些批次

问题描述 投票:0回答:1

我对使用 PyTorch Lightning 时 Trainer 类中

limit_train_batches
的行为有疑问,非常感谢您的帮助。

我正在尝试训练一些类似 Unet 的模型来对我的数据进行去噪。然而,我的训练数据集非常大,我需要很多时间来训练所有这些模型。为了加快速度,我想限制一个时期内训练模型的样本数量。例如,如果我的整个训练数据集由 1000 个样本组成,并且我使用的批量大小为 50,则需要对我的训练数据进行 20 次传递才能“看到”所有数据。由于时间限制,我想将样本限制为 200 个。假设我仍然使用 50 个批量大小,我现在只需要对我的模型的训练数据进行 4 次传递,即可看到所有 200 个样本某个时代。我不想从原始数据集中创建 200 个样本的单个子集,而是希望确保模型在每个时期看到 200 个样本的不同子集,以确保训练数据中有足够的可变性。

如果我设置

limit_train_batch = 4
,Trainer 类最终会使用数据加载器创建的前 4 个批次吗?

这是一个小例子来说明我的意思。

假设我的训练数据由 20 个 0-10 之间的随机数组成

from torch.utils.data import DataLoader

toy_training_data = [6 3 7 4 6 9 2 6 7 4 3 7 7 2 5 4 1 7 5 1]
dataloader = DataLoader(toy_training_data, batch_size=4, shuffle=True, drop_last=True)

for epoch in range(2):
    for batch in dataloader:
        print(f"Epoch: {epoch}, batch: {batch}")
    print("End of epoch", epoch)

当我在每个纪元后对数据进行洗牌时,数据加载器的输出如下:

Epoch: 0, batch: tensor([7, 5, 7, 1])
Epoch: 0, batch: tensor([3, 4, 5, 4])
Epoch: 0, batch: tensor([4, 7, 6, 6])
Epoch: 0, batch: tensor([3, 6, 2, 1])
Epoch: 0, batch: tensor([7, 2, 7, 9])
End of epoch 0
Epoch: 1, batch: tensor([5, 7, 2, 7])
Epoch: 1, batch: tensor([4, 6, 6, 5])
Epoch: 1, batch: tensor([4, 4, 6, 7])
Epoch: 1, batch: tensor([7, 7, 1, 2])
Epoch: 1, batch: tensor([3, 1, 3, 9])
End of epoch 1

如果我要设置

limit_train_batches = 2
,我是否会期望它在每个时期仅使用前 2 个批次,因此
batch: tensor([7, 5, 7, 1], batch: tensor([3, 4, 5, 4]
代表纪元 0,然后
batch: tensor([5, 7, 2, 7], batch: tensor([4, 6, 6, 5]
代表纪元 1,依此类推。

假设这确实是

limit_train_batches
的行为,您会说这是一个很好的策略,可以确保训练数据的可变性,同时仍然控制训练时间。如果没有,我也很想听听你们对如何更有效地解决这个问题的意见。

非常感谢您花时间阅读我的帖子,我期待听到您的答案!

python deep-learning pytorch pytorch-lightning pytorch-dataloader
1个回答
0
投票

不幸的是,我没有足够的信用来发表评论,但我面临着同样的问题。 我使用的是闪电版本 1.7.7。 我深入研究了闪电网络的 trainer.py 文件,寻找

limit_train_batches
,但在调用顺序中找不到任何有用的东西。我也非常感谢任何帮助。

© www.soinside.com 2019 - 2024. All rights reserved.