使用 pandas 和 sklearn 转换器时如何保留数据类型?

问题描述 投票:0回答:1

在使用大型 sklearn

Pipeline
(适合使用
DataFrame
)时,我遇到了一个错误,导致我的输入数据类型错误。该问题发生在来自 API 的单个观察上,该 API 应该与生产中的模型进行交互。单行中缺少信息使得 pandas(显然)无法推断正确的 dtype,但我认为我的拟合变压器将处理转换。显然我错了。

import pandas as pd
from sklearn.impute import SimpleImputer

X_tr = pd.DataFrame({"A": [5, 8, None, 4], "B": [3, None, 9.9, 12]})
print(X_tr.dtypes)  

#>> A    float64
#>> B    float64

x = pd.DataFrame({"A": [10.1], "B": [None]})
print(x.dtypes)

#>> A    float64
#>> B    object

上面清楚地表明,pandas 推断出训练数据集中 A 列和 B 列的

float64
类型,但是(同样明显)对于单个观察,它不知道 B 列的 dtype,因此它分配
object
。到目前为止还没有问题。但是让我们想象一下
SimpleImputer
中某处的
Pipeline
来替换缺失的值:

imputer = SimpleImputer(
    fill_value=0, strategy="constant", missing_values=pd.NA
).set_output(transform="pandas")

X_tr_im = imputer.fit_transform(X_tr)  # training
print(X_tr_im.dtypes)

#>> A    float64
#>> B    float64

x_im = imputer.transform(x)
print(x_im.dtypes)

#>> A    object
#>> B    object

在所有情况下,输入器都会将

None
值替换为零,但是,发生了两件我没想到的事情:

  • B 列未转换为适合的数据类型
  • A 列已转换为不需要的数据类型
    object

这会创建两种不需要的非数字数据类型,从而导致管道中进一步出现错误。即使保存数据类型不是转换器的任务,就我而言,它仍然非常有帮助。

所以我的问题是,我是否做错了什么?有什么解决办法吗?

python pandas scikit-learn
1个回答
0
投票

您遇到的问题是 pandas 如何处理列中的

None
。 如果列具有其他浮点或整数值,则
None
会被强制转换为
numpy.nan
,它是
float
的实例。 将列的类型保持为数字列。

但是,如果列中没有其他值,只有

None
值,pandas 不会尝试强制列浮动,而是将其保留为 Python
object
类型的列。

为了确保数据帧的列在通过管道的其余部分之前转换为数字类型,您可以在输入器之前使用

sklearn.preprocessing.FunctionTransformer

import pandas as pd
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer

def df_to_float(x):
    return x.astype(np.float64)

x_tr = pd.DataFrame({"A": [5, 8, None, 4], "B": [3, None, 9.9, 12]})
x = pd.DataFrame({"A": [10.1], "B": [None]})

float_xform = FunctionTransformer(df_to_float)
imputer = SimpleImputer(
    fill_value=0, 
    strategy="constant",
    missing_values=pd.NA
).set_output(transform="pandas")

pipe = Pipeline([('float-transform', float_xform), ('impute-NA', imputer)])

print(pipe.fit_transform(x_tr).dtypes)
# A    float64
# B    float64

print(pipe.transform(x).dtypes)
# A    float64
# B    float64
© www.soinside.com 2019 - 2024. All rights reserved.