假设我有一个昂贵的 PySpark 查询,会产生一个大数据帧
sdf_input
。我现在想向此数据框添加一列,该列需要计算 sdf_input
中的总行数。为了简单起见,假设我想将另一列 A
除以 total_num_rows
。
我可以简单地做这样的事情:
total_num_rows = sdf_input.count()
sdf_output = sdf_input.withColumn('B', F.col('A')/total_num_rows)
这可行,但它将执行昂贵的
sdf_input
计算两次,因为 count()
操作将触发查询执行(物化 sdf_input)是否有(正确的)方法在不物化 sdf_input 的情况下获取 total_num_rows
?
感觉这应该不难,有什么想法可以正确地做到这一点吗?
我想出了一些“解决方案”,但所有这些都感觉过于复杂和/或计算成本高昂。
1。只需将
sdf_input
写入磁盘
不理想,因为它是一个很大的数据集并且感觉非常低效。
2。缓存
sdf_input
total_num_rows = sdf_input.cache().count()
sdf_output = sdf_input.withColumn('B', F.col('A')/total_num_rows)
整个
sdf_input
无法装入内存,所以这也感觉效率很低。
3.使用窗口操作
在整个数据帧上使用窗口来获取
total_num_rows
并附加到每一行:
sdf_output = sdf_input.withColumn('B', F.col('A')/F.count("*").over(W.partitionBy()))
这可能是一个非常糟糕的主意,因为它将所有数据移动到单个分区(Spark 甚至给出警告..)
4。聚合整个数据帧并连接回每一行
sdf_total_sum = sdf_input.agg(F.count("*").alias("total_num_rows"))
sdf_output = sdf_input.crossJoin(F.broadcast(sdf_total_sum)).withColumn("B", F.col("A") / F.col("total_num_rows"))
直观上这是有道理的,但仅仅获得一个计数就感觉太复杂了..
您的选项 4,聚合整个数据帧以获取计数 然后将该计数广播到数据帧的其余部分,也是 一个常见的模式。虽然看起来很复杂,但它是一个相当 Spark 中处理此类情况的标准方法。
这是使用广播对选项 4 进行的稍微简化 变量:
from pyspark.sql import functions as F
# Count the total number of rows
total_num_rows = sdf_input.count()
# Broadcast the count
broadcast_count = spark.sparkContext.broadcast(total_num_rows)
# Use the broadcasted count to add the new column
sdf_output = sdf_input.withColumn('B', F.col('A') / broadcast_count.value)
这样,您可以通过使用广播变量来避免连接操作。 计数计算一次,然后发送到所有节点,其中可以 用于计算新列而不触发任何额外的
的计算。sdf_input