使用数据增强训练转置 CNN 时防止欠拟合和过拟合的方法

问题描述 投票:0回答:0

我正在训练一个使用来自 JSON 的输入数据构成图像的 CNN(使用 ConvTranspose2D 中的一系列 pytorch)。与自然语言不同,输入数据可以按任何顺序排列,因为它包含有关场景中各种精灵的信息。

在我第一次尝试训练模型时,我没有改变输入数据的顺序(意思是,在每个时期,每个精灵都在输入数据中的相同位置表示)。该模型学习了大约 10 个 epoch,但随后训练损失(继续下降)和测试损失之间开始出现分歧。如此经典的过拟合。

我试图通过一种数据增强的形式来解决这个问题,其中输出数据(在本例中为图像)保持不变,但我打乱了输入数据的顺序。由于我有大约 400 个精灵,最大洗牌是 400!,所以理论上这可以极大地扩展训练数据量。例如,通过打乱输入数据中精灵的顺序,而不是对应于 100K 图像的 100k JSON 文档,你有 400!*100000 训练数据点。当然,在实践中,这种数据量是不切实际的,因此我使用了大约 200 万个数据点进行初始测试。我在这里遇到的问题是该模型根本没有学习——在很快(在前几个小批量之后)出现一定的损失后,它在大约 4 个时期内根本没有学习。如此经典的欠拟合。

像Goldilocks一样,我想在最初的过拟合和随后的欠拟合之间找到“恰到好处”。我想知道我可以尝试的其他策略。我的一个想法是让模型按照预定顺序的精灵进行训练(过度拟合的情况),然后,一旦开始过度拟合(即两个连续的时期,测试和训练损失之间存在分歧),就会对数据进行洗牌。我也可以尝试更改模型,尽管由于硬件限制以及推理需要在 20 毫秒内发生这一事实,它只能变大。

在这种情况下,是否有任何论文或技术被推荐,其中数据增强可以导致更多的数据点,但会导致模型停止学习?提前感谢您的任何提示!

deep-learning pytorch overfitting-underfitting
© www.soinside.com 2019 - 2024. All rights reserved.