使用 Spark pandas_udf 创建列,并具有动态数量的输入列

问题描述 投票:0回答:5

我有这个 df:

df = spark.createDataFrame(
    [('row_a', 5.0, 0.0, 11.0),
     ('row_b', 3394.0, 0.0, 4543.0),
     ('row_c', 136111.0, 0.0, 219255.0),
     ('row_d', 0.0, 0.0, 0.0),
     ('row_e', 0.0, 0.0, 0.0),
     ('row_f', 42.0, 0.0, 54.0)],
    ['value', 'col_a', 'col_b', 'col_c']
)

我想使用 Pandas 中的

.quantile(0.25, axis=1)
来添加一列:

import pandas as pd
pdf = df.toPandas()
pdf['25%'] = pdf.quantile(0.25, axis=1)
print(pdf)
#    value     col_a  col_b     col_c      25%
# 0  row_a       5.0    0.0      11.0      2.5
# 1  row_b    3394.0    0.0    4543.0   1697.0
# 2  row_c  136111.0    0.0  219255.0  68055.5
# 3  row_d       0.0    0.0       0.0      0.0
# 4  row_e       0.0    0.0       0.0      0.0
# 5  row_f      42.0    0.0      54.0     21.0

性能对我来说很重要,所以我认为

pandas_udf
中的
pyspark.sql.functions
可以以更优化的方式做到这一点。但我很难做出一个高性能且有用的功能。这是我最好的尝试:

from pyspark.sql import functions as F
import pandas as pd
@F.pandas_udf('double')
def quartile1_on_axis1(a: pd.Series, b: pd.Series, c: pd.Series) -> pd.Series:
    pdf = pd.DataFrame({'a':a, 'b':b, 'c':c})
    return pdf.quantile(0.25, axis=1)

df = df.withColumn('25%', quartile1_on_axis1('col_a', 'col_b', 'col_c'))
  1. 我不喜欢我需要为每一列提供一个参数,然后在函数中分别处理这些参数以创建 df。所有这些列都有相同的目的,所以恕我直言,应该有一种方法可以将它们全部解决,就像这个伪代码一样:

    def quartile1_on_axis1(*cols) -> pd.Series:
        pdf = pd.DataFrame(cols)
    

    这样我就可以将此函数用于任意数量的列。

  2. 是否需要在UDF内部创建

    pd.Dataframe
    ?对我来说,这似乎与没有 UDF 相同(Spark df -> Pandas df -> Spark df),如上所示。如果没有 UDF,它会更短。我真的应该尝试让它在性能方面发挥作用吗?我认为
    pandas_udf
    是专门为这种目的而设计的......
    
    

apache-spark pyspark apache-spark-sql user-defined-functions pyspark-pandas
5个回答
4
投票

pandas_udf

pyspark.sql.functions.pandas_udf

请注意,类型提示在所有情况下都应使用
@F.pandas_udf('double') def quartile1_on_axis1(s: pd.DataFrame) -> pd.Series: return s.quantile(0.25, axis=1) cols = ['col_a', 'col_b', 'col_c'] df = df.withColumn('25%', quartile1_on_axis1(F.struct(*cols))) df.show() # +-----+--------+-----+--------+-------+ # |value| col_a|col_b| col_c| 25%| # +-----+--------+-----+--------+-------+ # |row_a| 5.0| 0.0| 11.0| 2.5| # |row_b| 3394.0| 0.0| 4543.0| 1697.0| # |row_c|136111.0| 0.0|219255.0|68055.5| # |row_d| 0.0| 0.0| 0.0| 0.0| # |row_e| 0.0| 0.0| 0.0| 0.0| # |row_f| 42.0| 0.0| 54.0| 21.0| # +-----+--------+-----+--------+-------+

,但 有一种变体应该使用

pandas.Series
当输入或输出列是时输入或输出类型提示
pandas.DataFrame

    



2
投票
GroupedData

。因为这需要您传递 df 的架构,所以添加具有所需数据类型的列并获取架构。需要时传递该架构。代码如下; df.to_pandas_on_spark().quantile(0.25, axis=1) NotImplementedError: axis should be either 0 or "index" currently.

您还可以在 udf 中使用 numpy 来完成此任务。如果您不想列出所有列,请按索引对它们(列)进行切片。

#Generate new schema by adding new column sch =df.withColumn('quantile25',lit(110.5)).schema #udf def quartile1_on_axis1(pdf): pdf =pdf.assign(quantile25=pdf.quantile(0.25, axis=1)) return pdf #apply udf df.groupby('value').applyInPandas(quartile1_on_axis1, schema=sch).show() #outcome +-----+--------+-----+--------+----------+ |value| col_a|col_b| col_c|quantile25| +-----+--------+-----+--------+----------+ |row_a| 5.0| 0.0| 11.0| 2.5| |row_b| 3394.0| 0.0| 4543.0| 1697.0| |row_c|136111.0| 0.0|219255.0| 68055.5| |row_d| 0.0| 0.0| 0.0| 0.0| |row_e| 0.0| 0.0| 0.0| 0.0| |row_f| 42.0| 0.0| 54.0| 21.0| +-----+--------+-----+--------+----------+



0
投票
quartile1_on_axis1=udf(lambda x: float(np.quantile(x, 0.25)),FloatType()) df.withColumn("0.25%", quartile1_on_axis1(array(df.columns[1:]))).show(truncate=False) +-----+--------+-----+--------+-------+ |value|col_a |col_b|col_c |0.25% | +-----+--------+-----+--------+-------+ |row_a|5.0 |0.0 |11.0 |2.5 | |row_b|3394.0 |0.0 |4543.0 |1697.0 | |row_c|136111.0|0.0 |219255.0|68055.5| |row_d|0.0 |0.0 |0.0 |0.0 | |row_e|0.0 |0.0 |0.0 |0.0 | |row_f|42.0 |0.0 |54.0 |21.0 | +-----+--------+-----+--------+-------+

代替

pandas_udf
。如果我能以类似的方式使用
udf
那就太好了。
pandas_udf

from pyspark.sql import functions as F
import numpy as np

@F.udf('double')
def lower_quart(*cols):
    return float(np.quantile(cols, 0.25))


0
投票

第一种方法将所有列直接作为参数传递给 UDF。这更简单,但不保留列名称:

df = df.withColumn('25%', lower_quart('col_a', 'col_b', 'col_c')) df.show() #+-----+--------+-----+--------+-------+ #|value| col_a|col_b| col_c| 25%| #+-----+--------+-----+--------+-------+ #|row_a| 5.0| 0.0| 11.0| 2.5| #|row_b| 3394.0| 0.0| 4543.0| 1697.0| #|row_c|136111.0| 0.0|219255.0|68055.5| #|row_d| 0.0| 0.0| 0.0| 0.0| #|row_e| 0.0| 0.0| 0.0| 0.0| #|row_f| 42.0| 0.0| 54.0| 21.0| #+-----+--------+-----+--------+-------+

当您需要保留列名或执行特定于列的操作时,可以使用包装函数:

@F.pandas_udf("double") def fn(*cols) -> pd.Series: # cols will be a tuple of pandas Series, one for each column return pd.concat(cols, axis=1).quantile(0.25, axis=1) df.withColumn("%25", fn(*df.columns)).show() # +-----+--------+-----+--------+-------+ # |value| col_a|col_b| col_c| %25| # +-----+--------+-----+--------+-------+ # |row_a| 5.0| 0.0| 11.0| 2.5| # |row_b| 3394.0| 0.0| 4543.0| 1697.0| # |row_c|136111.0| 0.0|219255.0|68055.5| # |row_d| 0.0| 0.0| 0.0| 0.0| # |row_e| 0.0| 0.0| 0.0| 0.0| # |row_f| 42.0| 0.0| 54.0| 21.0| # +-----+--------+-----+--------+-------+

	
© www.soinside.com 2019 - 2024. All rights reserved.