我正在尝试使用 PySpark 的 MLlib 的交叉验证来搜索 xbgoost 的参数。我正在使用与 pyspark 的数据帧 SparkXGBRegressor 兼容的 xgboost 模块。
我的数据框有 40 列和 600 万行。
当进行数据拟合并且已经定义参数搜索时,会出现以下错误消息。是因为交叉验证与xgboost的模块不兼容还是数据太大?
(对于这种情况,硬件容量不是问题)
from xgboost.spark import SparkXGBRegressor
xgb = SparkXGBRegressor(
features_col="features",
label_col="label",
num_workers=4
)
paramGrid = (ParamGridBuilder()\
.addGrid(xgb.learning_rate, \[float(v) for v in np.arange(0.01, 0.25, 0.01)\])\
.addGrid(xgb.colsample_bytree, \[float(v) for v in np.arange(0.8, 1.01, 0.1)\])\
.addGrid(xgb.subsample, \[float(v) for v in np.arange(0.5, 1.01, 0.1)\])\
.addGrid(xgb.n_estimators, \[int(v) for v in np.arange(100, 3000, 100)\])\
.addGrid(xgb.reg_alpha, \[float(v) for v in np.arange(0.01, 0.5, 0.05)\])\
.addGrid(xgb.max_depth, \[5, 10, 30, 50, 80, 100, 150\])\
.addGrid(xgb.gamma, \[int(v) for v in np.arange(0, 10, 2)\])\
.build())
mae_evaluator = RegressionEvaluator(metricName="mae", labelCol="label", predictionCol="prediction")
cv = CrossValidator(estimator=xgb, estimatorParamMaps=paramGrid, evaluator=mae_evaluator, numFolds=5)
cvModel = cv.fit(df)
Py4JJavaError Traceback (most recent call last)
\<command-1695020085907756\> in \<module\>
1 # fit del cross-validator al dataset (acá es donde rompe --\> Error: "The spark driver has stopped unexpectedly and is restarting. Your notebook will be automatically reattached.")
\----\> 2 cvModel = cv.fit(df)
/databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/\_pyspark.py in patched_method(self, \*args, \*\*kwargs)
28 call_succeeded = False
29 try:
\---\> 30 result = original_method(self, \*args, \*\*kwargs)
31 call_succeeded = True
32 return result
/databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
159 return self.copy(params).\_fit(dataset)
160 else:
\--\> 161 return self.\_fit(dataset)
162 else:
163 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
/databricks/spark/python/pyspark/ml/tuning.py in \_fit(self, dataset)
711 subModels\[i\]\[j\] = subModel
712
\--\> 713 \_cancel_on_failure(dataset.\_sc, self.uid, sub_task_failed, calculate_metrics)
714 # END-EDGE
715
/databricks/spark/python/pyspark/ml/util.py in \_cancel_on_failure(sc, uid, sub_task_failed, f)
94 "issue, you should enable pyspark pinned thread mode."
95 .format(uid))
\---\> 96 raise e
97
98 old_job_group = sc.getLocalProperty("spark.jobGroup.id")
/databricks/spark/python/pyspark/ml/util.py in \_cancel_on_failure(sc, uid, sub_task_failed, f)
88 if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() != "true":
89 try:
\---\> 90 return f()
91 except Exception as e:
92 warnings.warn("{} fit call failed but some spark jobs "
/databricks/spark/python/pyspark/ml/util.py in wrapper(\*args, \*\*kwargs)
148
149 ipython.events.register("post_run_cell", on_cancel)
\--\> 150 return f(\*args, \*\*kwargs)
151
152 return wrapper
/databricks/spark/python/pyspark/ml/tuning.py in calculate_metrics()
705 return task()
706
\--\> 707 for j, metric, subModel in pool.imap_unordered(run_task, tasks):
708 metrics\[j\] += (metric / nFolds)
709 metrics_all\[i\]\[j\] = metric
/usr/lib/python3.8/multiprocessing/pool.py in next(self, timeout)
866 if success:
867 return value
\--\> 868 raise value
869
870 __next__ = next # XXX
/usr/lib/python3.8/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
123 job, i, func, args, kwds = task
124 try:
\--\> 125 result = (True, func(\*args, \*\*kwds))
126 except Exception as e:
127 if wrap_exception and func is not \_helper_reraises_exception:
/databricks/spark/python/pyspark/ml/tuning.py in run_task(task)
703 if sub_task_failed\[0\]:
704 raise RuntimeError("Terminate this task because one of other task failed.")
\--\> 705 return task()
706
707 for j, metric, subModel in pool.imap_unordered(run_task, tasks):
/databricks/spark/python/pyspark/ml/tuning.py in singleTask()
68
69 def singleTask():
\---\> 70 index, model = next(modelIter)
71 # TODO: duplicate evaluator to take extra params from input
72 # Note: Supporting tuning params in evaluator need update method
/databricks/spark/python/pyspark/ml/base.py in __next__(self)
67 raise StopIteration("No models remaining.")
68 self.counter += 1
\---\> 69 return index, self.fitSingleModel(index)
70
71 def next(self):
/databricks/spark/python/pyspark/ml/base.py in fitSingleModel(index)
124
125 def fitSingleModel(index):
\--\> 126 return estimator.fit(dataset, paramMaps\[index\])
127
128 return \_FitMultipleIterator(fitSingleModel, len(paramMaps))
/databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/\_pyspark.py in patched_method(self, \*args, \*\*kwargs)
28 call_succeeded = False
29 try:
\---\> 30 result = original_method(self, \*args, \*\*kwargs)
31 call_succeeded = True
32 return result
/databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
157 elif isinstance(params, dict):
158 if params:
\--\> 159 return self.copy(params).\_fit(dataset)
160 else:
161 return self.\_fit(dataset)
/databricks/python/lib/python3.8/site-packages/xgboost/spark/core.py in \_fit(self, dataset)
862 return ret\[0\], ret\[1\]
863
\--\> 864 (config, booster) = \_run_job()
865
866 result_xgb_model = self.\_convert_to_sklearn_model(
/databricks/python/lib/python3.8/site-packages/xgboost/spark/core.py in \_run_job()
853 def \_run_job():
854 ret = (
\--\> 855 dataset.mapInPandas(
856 \_train_booster, schema="config string, booster string"
857 )
/databricks/spark/python/pyspark/rdd.py in collect(self)
965 # Default path used in OSS Spark / for non-credential passthrough clusters:
966 with SCCallSiteSync(self.context) as css:
\--\> 967 sock_info = self.ctx.\_jvm.PythonRDD.collectAndServe(self.\_jrdd.rdd())
968 return list(\_load_from_socket(sock_info, self.\_jrdd_deserializer))
969
/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, \*args)
1302
1303 answer = self.gateway_client.send_command(command)
\-\> 1304 return_value = get_return_value(
1305 answer, self.gateway_client, self.target_id, self.name)
1306
/databricks/spark/python/pyspark/sql/utils.py in deco(\*a, \*\*kw)
115 def deco(\*a, \*\*kw):
116 try:
\--\> 117 return f(\*a, \*\*kw)
118 except py4j.protocol.Py4JJavaError as e:
119 converted = convert_exception(e.java_exception)
/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER\[type\](answer\[2:\], gateway_client)
325 if answer\[1\] == REFERENCE_TYPE:
\--\> 326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\\n".
328 format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(640, 0) finished unsuccessfully.
org.apache.spark.api.python.PythonException: 'pyspark.serializers.SerializationError: Caused by Traceback (most recent call last):
File "/databricks/spark/python/pyspark/serializers.py", line 165, in \_read_with_length
return self.loads(obj)
File "/databricks/spark/python/pyspark/serializers.py", line 469, in loads
return pickle.loads(obj, encoding=encoding)
ModuleNotFoundError: No module named 'xgboost''. Full traceback below:
Traceback (most recent call last):
File "/databricks/spark/python/pyspark/serializers.py", line 165, in \_read_with_length
return self.loads(obj)
File "/databricks/spark/python/pyspark/serializers.py", line 469, in loads
return pickle.loads(obj, encoding=encoding)
ModuleNotFoundError: No module named 'xgboost'
During handling of the above exception, another exception occurred:
pyspark.serializers.SerializationError: Caused by Traceback (most recent call last):
File "/databricks/spark/python/pyspark/serializers.py", line 165, in \_read_with_length
return self.loads(obj)
File "/databricks/spark/python/pyspark/serializers.py", line 469, in loads
return pickle.loads(obj, encoding=encoding)
ModuleNotFoundError: No module named 'xgboost'
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:642)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:101)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:595)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:489)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.hasNext(SerDeUtil.scala:88)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.foreach(SerDeUtil.scala:82)
at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:442)
at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:797)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:521)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2241)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:313)
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2913)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2860)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2854)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2854)
at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:2611)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3118)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3062)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3050)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1115)
at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2500)
at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1036)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:419)
at org.apache.spark.rdd.RDD.collect(RDD.scala:1034)
at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:260)
at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
at py4j.Gateway.invoke(Gateway.java:295)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:251)
at java.lang.Thread.run(Thread.java:750)