使用python中的平均值和最小值创建基线回归模型

问题描述 投票:0回答:1

我想将回归分析的结果与编码的分类变量与两个基线模型进行比较,其中基线预测指定为组的平均值或最小值。我选择了Rsquare和MAE进行比较。下面是我的代码的简化示例,用于说明。它之所以起作用,是因为它给了我实现我的目标的输出。这是正确的和/或最佳的方法吗?

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split 
from sklearn import metrics

df = pd.DataFrame([['a1','c1',10],
                   ['a1','c2',15],
                   ['a1','c3',20],
                   ['a1','c1',15],
                   ['a2','c2',20],
                   ['a2','c3',15],
                   ['a2','c1',20],
                   ['a2','c2',15],
                   ['a3','c3',20],
                   ['a3','c3',15],
                   ['a3','c3',15],
                   ['a3','c3',20]], columns=['aid','cid','T'])

df_dummies = pd.get_dummies(df, columns=['aid','cid'],prefix_sep='',prefix='')
df_dummies

X = df_dummies
y = df_dummies['T']

# train test split 80-20
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

regr = LinearRegression()
regr.fit(X_train, y_train)

y_pred = regr.predict(X_test)
print('R-squared:', metrics.r2_score(y_test, y_pred))
print('MAE:', metrics.mean_absolute_error(y_test, y_pred))

# Baseline model with group average as prediction
y_pred = df.groupby('aid').agg({'T': ['mean']})
print('R-squared:', metrics.r2_score(y_test, y_pred))
print('MAE:', metrics.mean_absolute_error(y_test, y_pred))

# Baseline model with group min as prediction
y_pred = df.groupby('aid').agg({'T': ['min']})
print('R-squared:', metrics.r2_score(y_test, y_pred))
print('MAE:', metrics.mean_absolute_error(y_test, y_pred))
python-3.x scikit-learn regression pandas-groupby linear-regression
1个回答
0
投票

首先,我会一直重命名y_pred,以免引起混淆。

一般:

y_pred = df.groupby('aid').agg({'T': ['mean']})

将为您提供“援助”一栏的平均值。

[y_pred = df.groupby('aid').agg({'T': ['min']})将为您提供最低要求。

有一个有趣的软件包供您使用:https://scikit-learn.org/stable/modules/generated/sklearn.dummy.DummyRegressor.html

这对虚拟回归很有帮助,并且内部还有其他方法。

© www.soinside.com 2019 - 2024. All rights reserved.