我使用sklearn创建了一个VotingClassifier()对象。稍后,我使用joblib将其保存到voting_predictor.pkl文件中。当我成功加载它时,当我尝试将某些数据预测为voting_predictor.predict(X_test)
时,我收到以下错误:
TypeError:根据规则'safe',无法将数组数据从dtype('O')转换为dtype('int64')
我试图用pickle转储/加载对象,我得到了同样的错误。代码如下所示:
eclf1 = VotingClassifier(estimators=estimators, voting='hard')
eclf1 = eclf1.fit(X_train, y_train)
y_pred = eclf1.predict(X_test)
report = classification_report(y_test, y_pred)
poll_accuracy = accuracy_score(y_test, y_pred)
print(report)
print(poll_accuracy)
# successful object dump
filename = 'voting_predictor.pkl'
joblib.dump(eclf1, filename)
#successful object load
voting_predictor = joblib.load(filename)
# this prints the object correctly, showing all its parameters
print(voting_predictor)
#error shows here
y_pred = voting_predictor.predict(X_test)
report = classification_report(y_test, y_pred)
poll_accuracy = accuracy_score(y_test, y_pred)
print(voting_predictor)
成功打印了对象及其所有参数。有关为什么会发生这种情况的任何想法?
我和其他预测因子一起使用catbooster时遇到了同样的错误。我找到了this解决方案,但我正在寻找一个更优雅的解决方案。
问题是目标列是类的名称,如string。似乎将字符串值保留而不将其编码为某个整数,导致此错误。但是,在任何其他情况下,sklearn正确处理每个类的字符串名称,提供诸如classification_report和accuracy_score之类的所有度量而没有错误。仅当我从文件加载对象时才会发生错误。