使用TensorFlow计算简单线性回归时,为什么会得到[nan]?

问题描述 投票:0回答:2

当我使用TensorFlow计算简单的线性回归时,得到[nan],包括:w,b和loss。

这是我的代码:

import tensorflow as tf

w = tf.Variable(tf.zeros([1]), tf.float32)
b = tf.Variable(tf.zeros([1]), tf.float32)
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

liner = w*x+b

loss = tf.reduce_sum(tf.square(liner-y))

train = tf.train.GradientDescentOptimizer(1).minimize(loss)

sess = tf.Session()

x_data = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000]
y_data = [265000, 324000, 340000, 412000, 436000, 490000, 574000, 585000, 680000]                                                    

sess.run(tf.global_variables_initializer())

for i in range(1000):
    sess.run(train, {x: x_data, y: y_data})

nw, nb, nloss = sess.run([w, b, loss], {x: x_data, y: y_data})

print(nw, nb, nloss)

输出:

[ nan] [ nan] nan

Process finished with exit code 0

为什么会发生这种情况,我该如何解决?

python tensorflow machine-learning linear-regression
2个回答
1
投票

使用如此高的学习率(在你的情况下为1),你正在溢出。尝试学习率为0.001。此外,您的数据需要除以1000并且迭代次数增加,它应该工作。这是我测试过的代码并且工作得很好。

x_data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
y_data = [265, 324, 340, 412, 436, 490, 574, 585, 680]

plt.plot(x_data, y_data, 'ro', label='Original data')
plt.legend()
plt.show()

W = tf.Variable(tf.random_uniform([1], 0, 1))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b

loss = tf.reduce_mean(tf.square(y - y_data))

optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)

for step in range(0,50000):
   sess.run(train)
   print(step, sess.run(loss))
print (step, sess.run(W), sess.run(b))

plt.plot(x_data, y_data, 'ro')
plt.plot(x_data, sess.run(W) * x_data + sess.run(b))
plt.legend()
plt.show()

1
投票

这给出了我认为的解释:

for i in range(10):
     print(sess.run([train, w, b, loss], {x: x_data, y: y_data}))

给出以下结果:

[None, array([  4.70380012e+10], dtype=float32), array([ 8212000.], dtype=float32), 2.0248419e+12] 
[None, array([ -2.68116614e+19], dtype=float32), array([ -4.23342041e+15], dtype=float32),
6.3058345e+29] 
[None, array([  1.52826476e+28], dtype=float32), array([  2.41304958e+24], dtype=float32), inf] [None, array([
-8.71110858e+36], dtype=float32), array([ -1.37543819e+33], dtype=float32), inf] 
[None, array([ inf], dtype=float32), array([ inf], dtype=float32), inf]

你的学习率太大了,所以你在每次迭代时“过度修正”w的值(看它在负面和正面之间振荡,绝对值增加)。您获得越来越高的值,直到某些东西达到无穷大,从而产生Nan值。降低(很多)学习率。

© www.soinside.com 2019 - 2024. All rights reserved.