如何使用此代码打印具有最低 RMSE 的模型名称?

问题描述 投票:0回答:1
X = dataset[['bmi', 's5', 'bp']]
model4 = LinearRegression()
model4.fit(X,Y)

res4 = model4.fit(X,Y)
rmse4 = np.sqrt(mean_squared_error(res4.predict(X),Y))
print("RMSE for Linear Regression: ", rmse4)

model5 = Lasso()
model5.fit(X,Y)

res5 = model5.fit(X,Y)
rmse5 = np.sqrt(mean_squared_error(res5.predict(X),Y))
print("RMSE for Lasso Regression: ", rmse5)

model6 = Ridge()
model6.fit(X,Y)

res6 = model6.fit(X,Y)
rmse6 = np.sqrt(mean_squared_error(res6.predict(X),Y))
print("RMSE for Ridge Regression: ", rmse6)

我想打印 RMSE 最低的型号名称。 我知道使用 MIN 函数会带来最小值,但我需要获得有关哪个模型具有最低 RMSE 的结果,而不是最低值本身。谢谢你的帮助!!

minimum
1个回答
0
投票

将模型的所有名称存储在列表/字典中,并在 min 函数中使用 key 参数来提取密钥。

你可以尝试这样的事情:

rmse_list = {}
rmse_dict['Linear Regression'] = rmse4
rmse_dict['Lasso Regression'] = rmse5
rmse_dict['Ridge Regression'] = rmse6

min_rmse_model = min(rmse_dict, key=rmse_dict.get)

print("Model with lowest RMSE:", min_rmse_model)
print("RMSE:", rmse_dict[min_rmse_model])
© www.soinside.com 2019 - 2024. All rights reserved.