我的问题与这个类似,但我需要使用
udf
而不是pandas_udf
。
我有一个包含许多列的 Spark 数据框(列数各不相同),我需要对它们应用自定义函数(例如求和)。我知道我可以硬编码列名称,但当列数变化时它不起作用。
请参阅示例:
解决方案是在函数调用中使用
*expression
并在 pd.concat
函数体内使用
pandas_udf
方法
>>> import pandas as pd
>>> import pyspark.sql.functions as F
>>> @F.pandas_udf("double")
... def col_sum(*args: pd.Series) -> pd.Series:
... pdf = pd.concat(args, axis=1)
... col_sum = pdf.sum(axis=1)
... return col_sum
...
>>> df = spark.createDataFrame([(1,1,1),(2,2,2),(3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+
| A| B| C|SUM|
+---+---+---+---+
| 1| 1| 1|3.0|
| 2| 2| 2|6.0|
| 3| 3| 3|9.0|
+---+---+---+---+
>>> df = spark.createDataFrame([(1,1,1,1),(2,2,2,2),(3,3,3,3)],["A","B","C"])
>>> df.withColumn('SUM', col_sum(*df.columns)).show()
+---+---+---+---+----+
| A| B| C| _4| SUM|
+---+---+---+---+----+
| 1| 1| 1| 1| 4.0|
| 2| 2| 2| 2| 8.0|
| 3| 3| 3| 3|12.0|
+---+---+---+---+----+