我有一个用例,我需要在运行总和不超过特定阈值的分区上计算运行总和。
例如:
// Input dataset
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 0.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 0.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 0.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 0.0 | 10.0 |
// Output requirement
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 3.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 8.0 | 10.0 |
在这里,任何
id
的阈值对于具有该id
的所有行都是相同的。
请注意,第 3 行被跳过,因为 running_sum
会超过 threshold
值。但是添加了第 4 行,因为 running_sum
没有超过 threshold
值。
我能够在不考虑使用窗口函数的阈值的情况下计算运行总和,如下所示:
final WindowSpec window = Window.partitionBy(col("id"))
.orderBy(col("created_on").asc())
.rowsBetween(Window.unboundedPreceding(), Window.currentRow());
dataset.withColumn("running_sum", sum(col("value")).over(window)).show();
// Output
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 11.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 16.0 | 10.0 |
我尝试在窗口中使用
when()
,也尝试了lag()
,但它给了我意想不到的结果。
// With just sum over window
final WindowSpec window = Window.partitionBy(col("id"))
.orderBy(col("created_on").asc())
.rowsBetween(Window.unboundedPreceding(), Window.currentRow());
dataset.withColumn("running_sum",
when(sum(col("value")).over(window).leq(col("threshold")), sum(col("value")).over(window))
.otherwise(sum(col("value")).over(window).minus(col("value")))
).show();
// Output
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 3.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 11.0 | 10.0 |
// With combination of sum and lag
final WindowSpec lagWindow = Window.partitionBy(col("id")).orderBy(col("created_on").asc());
final WindowSpec window = Window.partitionBy(col("id"))
.orderBy(col("created_on").asc())
.rowsBetween(Window.unboundedPreceding(), Window.currentRow());
dataset.withColumn("running_sum",
when(sum(col("value")).over(window).leq(col("threshold")), sum(col("value")).over(window))
.otherwise(lag(col("running_sum"), 1, 0).over(lagWindow))
).show();
// Output
| id | created_on | value | running_sum | threshold |
| -- | ----------- | ----- | ------------ | --------- |
| A | 2021-01-01 | 1.0 | 1.0 | 10.0 |
| A | 2021-01-02 | 2.0 | 3.0 | 10.0 |
| A | 2021-01-03 | 8.0 | 0.0 | 10.0 |
| A | 2021-01-04 | 5.0 | 0.0 | 10.0 |
在网上浏览了一些资源后,我遇到了用户定义的聚合函数(UDAF),我相信它应该可以解决我的问题。
但我更喜欢在不使用 UDAF 的情况下实现它。请让我知道是否有任何其他方法可以做到这一点,或者如果我在我尝试过的代码中遗漏了一些东西。
谢谢!