如果我有一个如下所示的表格:
身份证 | 阈值 | 价值 |
---|---|---|
1 | 2 | 1 |
1 | 2 | 2 |
1 | 2 | 3 |
1 | 2 | 4 |
2 | 4 | 1 |
2 | 4 | 3 |
2 | 4 | 5 |
如何使用spark获得以下内容?
身份证 | 阈值 | 总阈值以上 |
---|---|---|
1 | 2 | 7 |
2 | 4 | 5 |
我想到了一种解决方法,我创建一个附加列来标记低于或等于阈值的值,并聚合未标记的值。但是spark是否提供了一种不需要创建额外列的好方法(例如窗口函数?)?
groupBy
,如下所示:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum
spark = SparkSession.builder.getOrCreate()
data = [
(1, 2, 1),
(1, 2, 2),
(1, 2, 3),
(1, 2, 4),
(2, 4, 1),
(2, 4, 3),
(2, 4, 5)
]
df = spark.createDataFrame(data, ["ID", "Threshold", "Value"])
result_df = (
df
.groupBy("ID", "Threshold")
.agg(sum(col("Value")).alias("total"))
.withColumn("total_above_threshold", col("total") - col("Threshold"))
.drop("total")
)
result_df.show()
# +---+---------+---------------------+
# | ID|Threshold|total_above_threshold|
# +---+---------+---------------------+
# | 1| 2| 8|
# | 2| 4| 5|
# +---+---------+---------------------+