Scikit-learn:我的线性回归不是直线,很乱

问题描述 投票:0回答:2

我试图简单地绘制一条回归线,但是我会得到混乱的线条。是因为我为模型配备了2个特征,所以唯一合适的可视化对象是3D平面?

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
from sklearn.linear_model import LinearRegression

# prepare data
boston = load_boston()
X = pd.DataFrame(boston.data, columns=boston.feature_names)[['AGE','RM']]
y = boston.target

# split dataset into training and test data
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=33)

# apply linear regression on dataset
lm = LinearRegression()
lm.fit(X_train, y_train)
pred_train = lm.predict(X_train)
pred_test = lm.predict(X_test)

#plot relationship between RM and price
plt.scatter(X_train['RM'],
            y_train,
            c='g',
            s=40,
            alpha=0.5)
plt.plot(X_train['RM'], pred_train, color='r')
plt.title('Relationship between RM and Price')
plt.ylabel('Price')
plt.xlabel('RM')

enter image description here

python machine-learning scikit-learn linear-regression
2个回答
1
投票

你是对的。您正在训练多个功能,即AGE和RM。但是,您正在绘制仅具有一项功能的二维图,即RM。尝试获取3D图。通常,具有两个特征的线性回归会产生一个平面。这仍然是线性回归。这就是为什么我们使用术语“超平面”。它解析为一个要素的直线,两个要素的平面,依此类推。

这里是3D输出:

plt3d = plt.figure().gca(projection='3d')
plt3d.view_init(azim=135)
plt3d.plot_trisurf(X_train['RM'].values, X_train['AGE'].values, pred_train, alpha=0.7, antialiased=True)

enter image description here


0
投票

问题是,在绘制时必须对参数进行排序。

'plt.plot(np.sort(X_train ['RM']),np.sort(pred_train),color ='r')'

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
from sklearn.linear_model import LinearRegression

# prepare data
boston = load_boston()
X = pd.DataFrame(boston.data, columns=boston.feature_names)[['AGE','RM']]
y = boston.target

# split dataset into training and test data
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=33)

# apply linear regression on dataset
lm = LinearRegression()
lm.fit(X_train, y_train)
pred_train = lm.predict(X_train)
pred_test = lm.predict(X_test)

#plot relationship between RM and price
plt.scatter(X_train['RM'],
            y_train,
            c='g',
            s=40,
            alpha=0.5)
plt.plot(np.sort(X_train['RM']), np.sort(pred_train), color='r')
plt.title('Relationship between RM and Price')
plt.ylabel('Price')
plt.xlabel('RM')
plt.show()

结果:output-plot

如果您进行3D绘图,您将很容易直观地看到协变量RM与年龄3d-plot之间的关系

© www.soinside.com 2019 - 2024. All rights reserved.