如何使用类似递归的操作在 PySpark 中计算累积衰减和?

问题描述 投票:0回答:1

我有一个 PySpark DataFrame,如下所示:

id id2 id3 h_生成 衰减因子 h_总计
1 164 1 149.8092121
1 164 2 1417.298433 0.944908987 1558.854504
1 164 3 3833.995761 0.886920437 5216.575679
1 164 4 285.9751331 0.816006194 4542.733199
1 164 5 309.3110378 0.926198535 4516.783871
2 315 1 97.6314541
2 315 2 335.8205993 0.881027299 335.8205993
2 315 3 3549.666563 0.735895859 3796.795552
2 315 4 2802.059006 0.840857282 5994.622196
2 315 5 2748.337439 0.592542112 6300.403536

每个id的首次填充h_total计算:

h_total[1] = (h_generated[0] * decay_factor[1]) + h_generated[1]

每个 id 的后续行计算:

h_total[n] = (h_total[n-1] * decay_factor[n]) + h_generated[n]

本质上,这个公式在行上累积 h_total,同时考虑到前几行的衰减。

我尝试过的 我尝试在 PySpark 中使用 for 循环来迭代填充 h_total 值,但由于 PySpark 的递归限制,这种方法对于我的大型数据集(数百万行)来说速度缓慢且效率低下。

from pyspark.sql import Window
import pyspark.sql.functions as F

# Define the cooling constant
cooling_constant = 300

# Define a window to partition by id
window_spec = Window.partitionBy("id")

# Step 1: Initialize h_total for the first row in each partition
test_df = sorted_df.withColumn(
    "h_total",
    F.when(F.row_number().over(window_spec) == 1, F.col("h_generated"))
)

# Step 2: Iteratively update the h_total for each row
# We will loop a number of times to propagate the h_total value to all rows
for i in range(1, sorted_df.count()):  # Increase the number of iterations as needed for your dataset
    test_df = test_df.withColumn(
        "h_total",
        F.when(F.col("h_total").isNotNull(), F.col("h_total"))
        .otherwise(
            (F.lag("h_total").over(window_spec) * F.col("decay_factor")) + F.col("h_generated")
        )
    )

我面临的最大问题是我想避免使用 for 循环,因为我的数据集包含数百万行。以这种方式使用循环使我无法充分利用 Spark 集群的强大功能进行分布式处理。

有没有办法在 PySpark 中解决这个问题而不诉诸循环?理想情况下,我想要一个能够利用 Spark 分布式功能来提高性能和可扩展性的解决方案。

python apache-spark pyspark python-multiprocessing aws-glue
1个回答
0
投票

让我们考虑 id =

1
和 id2 =
164
的情况。为了简化,我使用
id3
列作为行号。在此分区中:

对于 id3 =

1
,h_total[1] = h_generate[1]

对于 id3 =

2
,h_total[2] = h_ generated[2] + 衰减[2] * h_total[1] = h_ generated[2] + 衰减[2] * h_ generated[1]

对于 id3 =

3
,h_total[3] = h_ generated[3] + 衰减[3] * h_total[2] = h_ generated[3] + 衰减[3] * h_ generated[2] + 衰减[3] * 衰减[ 2] * h_生成[1]

如您所见,存在这样的模式:

h_total[n] = h_generated[n] + decay[n] * h_generated[n-1] + decay[n] * decay[n-1] * h_generated[n-2] + ... + (decay[n] * ... * decay[2]) * h_generated[1]

因此,我们可以使用Window函数来实现:

df.withColumn(
    "factor", func.aggregate(
        func.collect_list(
            "decay_factor"
        ).over(
            Window.partitionBy("id", "id2").orderBy("id3").rowsBetween(1, Window.unboundedFollowing)
        ),
        func.lit(1),
        lambda x, y: x * y
    ) * func.col("h_generated")
).withColumn(
    "h_total", func.col("h_generated") + func.sum("factor").over(
        Window.partitionBy("id", "id2").orderBy("id3").rowsBetween(Window.unboundedPreceding, -1)
    )
)

对于

factor
列,我们使用带有分区的
collect_list
函数来收集从当前行的下一行到分区末尾的每个分区的所有
decay_factor
,然后将所有
decay_factor
与当前相乘行
h_generated

© www.soinside.com 2019 - 2024. All rights reserved.