我使用 joblib 训练并保存了以下模型:
def to_dense(x):
return np.asarray(x.todense())
to_dense_array = FunctionTransformer(to_dense, accept_sparse=True)
model = make_pipeline(
TfidfVectorizer(),
to_dense_array,
HistGradientBoostingClassifier()
)
est = model.fit(texts, y)
save_path = os.path.join(os.getcwd(), "VAT_estimator.pkl")
joblib.dump(est, save_path)
模型工作正常,准确性好,保存到joblib期间没有任何消息发出。
现在,我尝试使用以下代码从 joblib 重新加载模型:
import joblib
# Load the saved model
estimator_file = "VAT_estimator.pkl"
model = joblib.load(estimator_file)
然后我收到以下错误消息:
AttributeError: Can't get attribute 'to_dense' on <module '__main__'>
我无法避免在管道中转换为密集数组的步骤。
我尝试在导入后将转换步骤插入模型中,但是在预测时,我收到消息称 FunctionTransformer 不可调用。
我看不到任何出路。
出现问题的原因是管道中的
FunctionTransformer
使用了 __main__
范围中定义的自定义函数 to_dense,当您使用 joblib
重新加载模型时,它不知道如何找到 to_dense
,因为它不是相同范围内。
要解决此问题,您需要确保在加载模型时在同一模块(文件)中定义该函数,或者为
joblib
提供查找自定义函数的方法。
有多种选择可以解决此问题: 其中之一是:定义main之外的函数并保存/加载模型
示例: 首先,我们将重新创建您主要所做的事情:
import os
import numpy as np
import joblib
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import HistGradientBoostingClassifier
# Step 1: Define the `to_dense` function at the top level
def to_dense(x):
return np.asarray(x.todense())
# Create a transformer using the `to_dense` function
to_dense_array = FunctionTransformer(to_dense, accept_sparse=True)
# Step 2: Define the model pipeline
model = make_pipeline(
TfidfVectorizer(),
to_dense_array,
HistGradientBoostingClassifier()
)
# Assume `texts` and `y` are your training data and labels
# est = model.fit(texts, y)
# Step 3: Save the model using joblib
save_path = os.path.join(os.getcwd(), "VAT_estimator.pkl")
joblib.dump(model, save_path)
现在我们要加载模型: PS:确保在加载模型时定义或导入
to_dense
:
import joblib
import numpy as np
from sklearn.preprocessing import FunctionTransformer
# Step 1: Define the `to_dense` function again or import it from your module
def to_dense(x):
return np.asarray(x.todense())
# Step 2: Load the saved model
estimator_file = "VAT_estimator.pkl"
model = joblib.load(estimator_file)
# Now, you can use `model` to make predictions or further train the model good luck.
# Example: model.predict(new_texts)
通过在脚本的顶层定义
to_dense
并确保在加载模型时它存在,joblib
将正确定位该函数,并且管道应该可以正常工作。
这个完整的工作流程应该保存并加载您的模型,而不会遇到
AttributeError
。