Pyspark:在 pandas_udf 中传递多列

问题描述 投票:0回答:1

我的问题与这个类似,但我需要使用

udf
而不是
pandas_udf

我有一个包含许多列的 Spark 数据框(列数各不相同),我需要对它们应用自定义函数(例如求和)。我知道我可以硬编码列名称,但当列数变化时它不起作用。

请参阅示例:

pyspark apache-spark-sql user-defined-functions
1个回答
0
投票

解决方案是在函数调用中使用

*expression
并在 pd.concat
 函数体内使用 
pandas_udf
 方法

>>> import pandas as pd
>>> import pyspark.sql.functions as F

>>> @F.pandas_udf("double")
... def col_sum(*args: pd.Series) -> pd.Series:
...     pdf = pd.concat(args, axis=1)
...     col_sum = pdf.sum(axis=1)
...     return col_sum
... 

>>> df = spark.createDataFrame([(1,1,1),(2,2,2),(3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+                                                               
|  A|  B|  C|SUM|
+---+---+---+---+
|  1|  1|  1|3.0|
|  2|  2|  2|6.0|
|  3|  3|  3|9.0|
+---+---+---+---+

>>> df = spark.createDataFrame([(1,1,1,1),(2,2,2,2),(3,3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+----+
|  A|  B|  C| _4| SUM|
+---+---+---+---+----+
|  1|  1|  1|  1| 4.0|
|  2|  2|  2|  2| 8.0|
|  3|  3|  3|  3|12.0|
+---+---+---+---+----+
© www.soinside.com 2019 - 2024. All rights reserved.