为什么训练集越大,张量流损失会趋于无穷大?

问题描述 投票:0回答:1

我创建了一个非常简单的张量流模型,如果我有一组训练数据,它就可以工作。 但是,如果我再添加一个训练示例,那么损失就会变为无穷大,并且模型将不起作用。 即使这两个示例的模型是相同的也是如此。 唯一的区别是增加了一个训练示例。

我想制作一个大的训练集,但如果训练集太大,损失会发散,这似乎是不可能的。 当损失减少到很低的时候,预测也是完全错误的。 在下面的代码中,模型有 20 个训练样本,损失趋于无穷大。 Model2 有 19 个训练样本,损失函数接近于零。

<pre>    
    import tensorflow as tf
    import numpy as np
    from tensorflow import keras

    print(tf.__version__)``

    def hw_function(x):
        y = (2. * x) - 1.
        return y

    # Build a simple Sequential model
    model = tf.keras.Sequential([
        tf.keras.Input(shape=(1,)),
        tf.keras.layers.Dense(units=1)])

    # Compile the model
    model.compile(optimizer='sgd', loss='mean_squared_error')

    # Declare model inputs and outputs for training
    xs=[x for x in range(-1, 19, 1)]
    ys=[x for x in range(-3, 36, 2)]

    xs=np.array(xs, dtype=float)
    ys=np.array(ys, dtype=float)

    # Train the model
    model.fit(xs, ys, verbose=1, epochs=500)

    # Make a prediction
    p = np.array([100.0, 900.0], dtype=float)
    print(model.predict(p))


    # Build exactly the same model but have one more training example
    model2 = tf.keras.Sequential([
        tf.keras.Input(shape=(1,)),
        tf.keras.layers.Dense(units=1)])
    model2.compile(optimizer='sgd', loss='mean_squared_error')
    xs2=[x for x in range(-1, 18, 1)]
    ys2=[x for x in range(-3, 34, 2)]

    xs2=np.array(xs2, dtype=float)
    ys2=np.array(ys2, dtype=float)

    # Train the model
    model2.fit(xs2, ys2, verbose=1, epochs=500)
    p = np.array([100.0, 900.0], dtype=float)
    print(model2.predict(p))
<code>
tensorflow loss
1个回答
0
投票

这与示例的数量无关,而是与数据的数值大小有关。例如,您可以扩展“下面”的数据集,它仍然有效:

xs=[x for x in range(-2, 18, 1)]
ys=[x for x in range(-5, 34, 2)]

您可能找到了优化开始变得不稳定的确切数值阈值。您可以通过降低学习率来修复它。对于“更大”的数据集,这会失败:

opt = tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(optimizer=opt, loss='mean_squared_error')

这很好用:

opt = tf.keras.optimizers.SGD(learning_rate=0.001)
model.compile(optimizer=opt, loss='mean_squared_error')

至于预测“完全错误”,我无法重现。它们略有偏差,但那是因为参数不完全匹配。例如。在一个示例中,我得到

1.996 * x - 0.9469
而不是
2 * x - 1
,这相当接近,但对于 100 或 900 这样的大输入,差异会更大。

© www.soinside.com 2019 - 2024. All rights reserved.