keras中train_on_batch()有什么用?

问题描述 投票:0回答:6

train_on_batch()
fit()
有什么不同?什么情况下我们应该使用
train_on_batch()

machine-learning deep-learning keras
6个回答
57
投票

对于这个问题,这是主要作者的简单回答

使用

fit_generator
,您可以使用生成器来生成验证数据,如下所示 出色地。一般来说,我建议使用
fit_generator
,但是使用
train_on_batch
也很好用。这些方法的存在只是为了 不同用例的便利性,没有“正确”的方法。

train_on_batch
允许您根据您提供的样本集合明确更新权重,而不考虑任何固定的批量大小。您可以在您想要的情况下使用它:在显式样本集合上进行训练。您可以使用该方法在传统训练集的多个批次上维护自己的迭代,但允许
fit
fit_generator
为您迭代批次可能更简单。

最好使用

train_on_batch
的一种情况是在一批新样本上更新预训练模型。假设您已经训练并部署了一个模型,稍后您收到了一组以前从未使用过的新训练样本。您可以使用
train_on_batch
仅在这些样本上直接更新现有模型。其他方法也可以做到这一点,但在这种情况下使用
train_on_batch
是相当明确的。

除了像这样的特殊情况(无论是出于教学原因在不同的训练批次中维护自己的光标,还是在特殊批次上进行某种类型的半在线训练更新),最好始终使用

fit
(适用于适合内存的数据)或
fit_generator
(适用于将批量数据作为生成器进行流式传输)。


22
投票

train_on_batch()
使您可以更好地控制 LSTM 的状态,例如,当使用有状态 LSTM 并需要控制对
model.reset_states()
的调用时。您可能有多个系列数据,并且需要在每个系列之后重置状态,您可以使用
train_on_batch()
来完成此操作,但如果您使用
.fit()
那么网络将在所有系列数据上进行训练,而无需重置状态。没有对错之分,这取决于您使用的数据以及您希望网络如何运行。


4
投票

如果您使用大型数据集并且没有易于序列化的数据(如高阶 numpy 数组)来写入 tfrecords,那么与 fit 和 fit 生成器相比,Train_on_batch 还将获得性能提升。

在这种情况下,当整个集合无法容纳在内存中时,您可以将数组保存为 numpy 文件并在内存中加载它们的较小子集(traina.npy、trainb.npy 等)。然后,您可以使用 tf.data.Dataset.from_tensor_slices ,然后将 train_on_batch 与您的子数据集一起使用,然后加载另一个数据集并再次调用批量训练,等等,现在您已经对整个集合进行了训练,并且可以准确控制多少和什么您的数据集训练您的模型。然后,您可以使用简单的循环和函数来定义自己的纪元、批量大小等,以从数据集中获取。


2
投票

来自Keras - 模型训练API

  • fit:训练模型固定数量的 epoch(数据集上的迭代)。
  • train_on_batch:在单批次数据上运行单个梯度更新

当我们一次使用一批训练数据集更新鉴别器和生成器时,我们可以在 GAN 中使用它。我看到 Jason Brownlee 在他的教程中使用了 train_on_batch (How to Develop a 1D Generative Adversarial Network From Scratch in Keras)

快速搜索提示:键入 Control+F 并在搜索框中键入您要搜索的术语(例如 train_on_batch)。


1
投票

确实@nbro 的回答有帮助,只是添加了一些场景,假设您正在训练一些 seq to seq 模型或具有一个或多个编码器的大型网络。我们可以使用 train_on_batch 创建自定义训练循环,并使用部分数据直接在编码器上进行验证,而无需使用回调。为复杂的验证过程编写回调可能很困难。有几种情况我们希望批量训练。

问候, 卡西克


0
投票

我建议阅读文章。了解 Keras 中 3 种训练方法之间的差异非常有用。

  • fit
  • fit_generator
  • train_on_batch
© www.soinside.com 2019 - 2024. All rights reserved.