我正在尝试编写一些自定义类来使用 Databricks 上现有的 MLlib 代码库和 MLflow。 例如,编写转换器、估算器或扩展现有的 MLlib 类,并能够添加到管道中,适合它(如有必要),将其记录到 mlflow 并提供服务。
有没有人有编写可与 MLflow 一起使用的自定义 MLlib 类的经验,并且可以帮助我?
我可以创建转换器、估算器等等,甚至可以扩展 MLlib 类,但在将其登录到 MLflow 运行时我总是收到以下警告:
2023/02/27 02:37:31 WARNING mlflow.utils.environment: Encountered an unexpected error while inferring pip requirements (model URI: /tmp/tmpjpcb462k, flavor: spark), fall back to return ['pyspark==3.3.0']. Set logging level to DEBUG to see the full traceback. Out[7]: <mlflow.models.model.ModelInfo at 0x7f5ea18f9610>
如果我在执行“%run”之后从笔记本加载模型,它仍然可以正常工作,但是如果我提供模型并使用 REST 端点,它将无法工作,例如。
有没有“官方”的方式来做到这一点?有人可以帮我解决这个问题吗?
干杯,
玩具示例:
from pyspark.ml.evaluation import Evaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
from pyspark.ml import Transformer, Model
from pyspark.ml.param import Param, Params
from typing import List, Sequence, Callable, Tuple, Optional, cast
from multiprocessing.pool import ThreadPool
from pyspark import inheritable_thread_target, keyword_only
from pyspark.sql import DataFrame
import numpy as np
class ValueRounder(Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable):
@keyword_only
def __init__(self, inputCol=None, outputCol=None):
super(ValueRounder, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol=None, outputCol=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def setInputCol(self, value):
return self._set(inputCol=value)
def setOutputCol(self, value):
return self._set(outputCol=value)
def _transform(self, dataset):
return dataset.withColumn(self.getOutputCol(), spark_round(self.getInputCol()))
df = spark.createDataFrame(
[
(1.0, 2, None),
(1.2, None, 3),
(None, 2, None),
(1.5, None, 3),
(1.7, 2, 3)
],
['A', 'B', 'C']
)
myAss = VectorAssembler(inputCols=['A', 'B', 'C'], outputCol='features', handleInvalid='keep')
myRounder = ValueRounder(inputCol='A', outputCol='rounded(A)')
model = Pipeline(stages=[myAss, myRounder]).fit(df)
# the following will spit out the aforementioned warning
mlflow.spark.log_model(model, artifact_path='model')