我想创建一个管道来对训练特征和目标进行预处理,然后训练模型。数据集类似于:
v1 v2 target
0 1 a yes
1 5 c no
2 3 f yes
我有一个类似的管道
num_cols = ['v1']
cat_cols = ['v2']
clf = DecisionTreeClassifier()
num_transformer = Pipeline(steps=[
('impute', SimpleImputer(strategy='mean')),
('scale', MinMaxScaler())
])
cat_transformer = Pipeline(steps=[('impute', SimpleImputer(strategy='most_frequent'))])
col_trans = ColumnTransformer(transformers=[
('num_pipeline', num_transformer, num_cols),
('cat_pipeline', cat_transformer, cat_cols))
clf_pipeline = Pipeline(steps=[
('col_trans', col_trans),
('model', clf)
])
这个想法是对目标进行标签编码,这样它就会是
target
1
0
1
如果可能的话,解码预测也会很有趣
解决此问题的一种方法如下:
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder, LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
data = {
'v1': [1, 5, 3],
'v2': ['a', 'c', 'f'],
'target': ['yes', 'no', 'yes']
}
df = pd.DataFrame(data)
X = df.drop('target', axis=1)
y = df['target']
num_cols = ['v1']
cat_cols = ['v2']
num_transformer = Pipeline(steps=[
('impute', SimpleImputer(strategy='mean')),
('scale', MinMaxScaler())
])
cat_transformer = Pipeline(steps=[
('impute', SimpleImputer(strategy='most_frequent')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
col_trans = ColumnTransformer(transformers=[
('num_pipeline', num_transformer, num_cols),
('cat_pipeline', cat_transformer, cat_cols)
])
clf = DecisionTreeClassifier()
clf_pipeline = Pipeline(steps=[
('col_trans', col_trans),
('model', clf)
])
le = LabelEncoder()
y_encoded = le.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
clf_pipeline.fit(X_train, y_train)
y_pred = clf_pipeline.predict(X_test)
y_pred_decoded = le.inverse_transform(y_pred)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(f'Predictions: {y_pred_decoded}')
Wich 给出编码
array([1, 0, 1])
和预测
Accuracy: 1.0
Predictions: ['yes']