Python DummyRegressor,组的最小值

问题描述 投票:0回答:1

[尝试使用sklearn.dummy DummyRegressor为我的模型创建一个基线,该基线是具有编码分类变量的回归模型,用于预测连续目标。基准策略是“最低”,我希望按组最低。以下是可重现的示例。我的实际数据集更大,它是一组赛跑者('a'ids)在赛道('c'ids)上比赛的记录,他们记录该表现的时间为目标'T'。我正在尝试查看模型是否比跑步者的最佳/最快记录时间(分钟)更好。

df = pd.DataFrame([['a1','c1',10],
                   ['a1','c2',15],
                   ['a1','c3',20],
                   ['a1','c1',15],
                   ['a2','c2',26],
                   ['a2','c3',15],
                   ['a2','c1',23],
                   ['a2','c2',15],
                   ['a3','c3',20],
                   ['a3','c3',13],
                   ['a1','c3',19],
                   ['a3','c3',19],
                   ['a3','c3',12],
                   ['a3','c3',20]], columns=['aid','cid','T'])

X = pd.get_dummies(df, columns=['aid','cid'],prefix_sep='',prefix='')
X.drop(['T'], axis=1, inplace=True)
y = df['T']


# train test split 80-20
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

regr = LinearRegression()
Lin_model = regr.fit(X_train, y_train)

y_pred = Lin_model.predict(X_test)
print('R-squared:', metrics.r2_score(y_test, y_pred))
print('MAE:', metrics.mean_absolute_error(y_test, y_pred))

为了进行比较,我想使用DummyRegressor。用均值作为策略是可行的,据我所知,它是使用整列的均值。

dummy_mean = DummyRegressor(strategy='mean')
dummy_mean.fit(X_train, y_train)
y_pred2 = dummy_mean.predict(X_test)
print('R-squared:', metrics.r2_score(y_test, y_pred2))
print('MAE:', metrics.mean_absolute_error(y_test, y_pred2))

为了比较最低的T或最快/最快的时间,我尝试了常数函数并将其定义为按组的最小值

min_value = df.groupby('aid').agg({'T': ['min']})
dummy_min = DummyRegressor(strategy='constant',constant = min_value)
dummy_min.fit(X_train, y_train)

y_pred4 = dummy_min.predict(X_test)

返回]

ValueError: could not broadcast input array from shape (1,3) into shape (3,1)

我想念什么?

[尝试使用sklearn.dummy DummyRegressor为我的模型创建一个基线,该基线是具有编码分类变量的回归模型,用于预测连续目标。基准策略是'min'...

python linear-regression sklearn-pandas
1个回答
0
投票

[使用min_value = df.groupby('aid').agg({'T': ['min']})时,数据框的形状更改为(3,1),请尝试将其更改为min_value = df.groupby('aid').agg({'T': ['min']}).values.reshape(1,-1),希望对您有所帮助。

© www.soinside.com 2019 - 2024. All rights reserved.