我有一个 PySpark DataFrame,如下所示:
id | id2 | id3 | h_生成 | 衰减因子 | h_总计 |
---|---|---|---|---|---|
1 | 164 | 1 | 149.8092121 | ||
1 | 164 | 2 | 1417.298433 | 0.944908987 | 1558.854504 |
1 | 164 | 3 | 3833.995761 | 0.886920437 | 5216.575679 |
1 | 164 | 4 | 285.9751331 | 0.816006194 | 4542.733199 |
1 | 164 | 5 | 309.3110378 | 0.926198535 | 4516.783871 |
2 | 315 | 1 | 97.6314541 | ||
2 | 315 | 2 | 335.8205993 | 0.881027299 | 335.8205993 |
2 | 315 | 3 | 3549.666563 | 0.735895859 | 3796.795552 |
2 | 315 | 4 | 2802.059006 | 0.840857282 | 5994.622196 |
2 | 315 | 5 | 2748.337439 | 0.592542112 | 6300.403536 |
每个id的首次填充h_total计算:
h_total[1] = (h_generated[0] * decay_factor[1]) + h_generated[1]
每个 id 的后续行计算:
h_total[n] = (h_total[n-1] * decay_factor[n]) + h_generated[n]
本质上,这个公式在行上累积 h_total,同时考虑到前几行的衰减。
我尝试过的 我尝试在 PySpark 中使用 for 循环来迭代填充 h_total 值,但由于 PySpark 的递归限制,这种方法对于我的大型数据集(数百万行)来说速度缓慢且效率低下。
from pyspark.sql import Window
import pyspark.sql.functions as F
# Define the cooling constant
cooling_constant = 300
# Define a window to partition by id
window_spec = Window.partitionBy("id")
# Step 1: Initialize h_total for the first row in each partition
test_df = sorted_df.withColumn(
"h_total",
F.when(F.row_number().over(window_spec) == 1, F.col("h_generated"))
)
# Step 2: Iteratively update the h_total for each row
# We will loop a number of times to propagate the h_total value to all rows
for i in range(1, sorted_df.count()): # Increase the number of iterations as needed for your dataset
test_df = test_df.withColumn(
"h_total",
F.when(F.col("h_total").isNotNull(), F.col("h_total"))
.otherwise(
(F.lag("h_total").over(window_spec) * F.col("decay_factor")) + F.col("h_generated")
)
)
我面临的最大问题是我想避免使用 for 循环,因为我的数据集包含数百万行。以这种方式使用循环使我无法充分利用 Spark 集群的强大功能进行分布式处理。
有没有办法在 PySpark 中解决这个问题而不诉诸循环?理想情况下,我想要一个能够利用 Spark 分布式功能来提高性能和可扩展性的解决方案。
让我们考虑 id =
1
和 id2 = 164
的情况。为了简化,我使用 id3
列作为行号。在此分区中:
对于 id3 =
1
,h_total[1] = h_generate[1]
对于 id3 =
2
,h_total[2] = h_ generated[2] + 衰减[2] * h_total[1] = h_ generated[2] + 衰减[2] * h_ generated[1]
对于 id3 =
3
,h_total[3] = h_ generated[3] + 衰减[3] * h_total[2] = h_ generated[3] + 衰减[3] * h_ generated[2] + 衰减[3] * 衰减[ 2] * h_生成[1]
如您所见,存在这样的模式:
h_total[n] = h_generated[n] + decay[n] * h_generated[n-1] + decay[n] * decay[n-1] * h_generated[n-2] + ... + (decay[n] * ... * decay[2]) * h_generated[1]
因此,我们可以使用Window函数来实现:
df.withColumn(
"factor", func.aggregate(
func.collect_list(
"decay_factor"
).over(
Window.partitionBy("id", "id2").orderBy("id3").rowsBetween(1, Window.unboundedFollowing)
),
func.lit(1),
lambda x, y: x * y
) * func.col("h_generated")
).withColumn(
"h_total", func.col("h_generated") + func.sum("factor").over(
Window.partitionBy("id", "id2").orderBy("id3").rowsBetween(Window.unboundedPreceding, -1)
)
)
对于
factor
列,我们使用带有分区的collect_list
函数来收集从当前行的下一行到分区末尾的每个分区的所有decay_factor
,然后将所有decay_factor
与当前相乘行h_generated
。