更新: 添加了输入 DF 和预期输出。
from pyspark.sql import Window, SparkSession
from pyspark.sql.functions import col, row_number, count
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
import functools
from pyspark.sql import DataFrame
# Initialize Spark session
spark = SparkSession.builder.appName("RatingAnalysis").getOrCreate()
# Define schema
schema = StructType([
StructField("UniqueID", IntegerType(), True),
StructField("Agency", StringType(), True),
StructField("NormRating", StringType(), True),
StructField("NormRatingShort", StringType(), True),
StructField("NormRatingLong", StringType(), True),
StructField("RatingRank", IntegerType(), True),
StructField("RatingRankShort", IntegerType(), True),
StructField("RatingRankLong", IntegerType(), True),
StructField("BLRank", IntegerType(), True),
StructField("BLRankShort", IntegerType(), True),
StructField("BLRankLong", IntegerType(), True),
StructField("HighestRank", IntegerType(), True),
StructField("HighestRankShort", IntegerType(), True),
StructField("HighestRankLong", IntegerType(), True),
StructField("LowestRank", IntegerType(), True),
StructField("LowestRankShort", IntegerType(), True),
StructField("LowestRankLong", IntegerType(), True),
StructField("MidRank", IntegerType(), True),
StructField("MidRankShort", IntegerType(), True),
StructField("MidRankLong", IntegerType(), True),
StructField("DistinctCount", IntegerType(), True),
StructField("DistinctCountShort", IntegerType(), True),
StructField("DistinctCountLong", IntegerType(), True),
StructField("Count", IntegerType(), True),
StructField("CountShort", IntegerType(), True),
StructField("CountLong", IntegerType(), True),
StructField("NRCount", IntegerType(), True),
StructField("NRCountShort", IntegerType(), True),
StructField("NRCountLong", IntegerType(), True),
StructField("NonNRCount", IntegerType(), True),
StructField("NonNRCountShort", IntegerType(), True),
StructField("NonNRCountLong", IntegerType(), True),
StructField("NonNRDistinctCount", IntegerType(), True),
StructField("NonNRDistinctCountShort", IntegerType(), True),
StructField("NonNRDistinctCountLong", IntegerType(), True)
])
# Create DataFrame from the updated data based on images
data = [
(212015627, "MDY", "A", None, "A", 60, None, 60, 60, None, 60, 60, None, 60, 30, None, 30, 60, 250, 60, 3, 1, 3, 3, 1, 3, 1, 1, 1, 2, 0, 2, 2, 0, 2),
(212015627, "S&P", "AA", None, "AA", 31, None, 31, 30, None, 30, 60, None, 60, 30, None, 30, 60, 250, 60, 3, 1, 3, 3, 1, 3, 1, 1, 1, 2, 0, 2, 2, 0, 2),
(212015627, "FITCH", "NR", "NR", "NR", 252, 252, 252, 250, 250, 252, 60, None, 60, 30, None, 30, 60, 250, 60, 3, 1, 3, 3, 1, 3, 1, 1, 1, 2, 0, 2, 2, 0, 2)
]
ratings_df = spark.createDataFrame(data, schema)
def generate_ratings(window_df, count_rate, distinct_rate, rating_col, order_col, nr_flag, term):
if nr_flag:
filtered_df = window_df.where(
(window_df[f"NRCount{term}"] > 0) & (window_df[f"NRCount{term}"] == window_df[f"Count{term}"])
)
window_spec = Window.partitionBy("UniqueID").orderBy(order_col)
return (
filtered_df.where(col(f"BLRank{term}") >= 250)
.withColumn("row_num", row_number().over(window_spec))
.where(col("row_num") == 1)
.select(
col("UniqueID"),
col("Agency").alias(f"SourceAgency{term}"),
col(f"BLRank{term}"),
col(f"NormRating{term}").alias(f"FinalRating{term}")
)
)
elif count_rate == 3 and distinct_rate == 2:
window_spec = Window.partitionBy("UniqueID").orderBy(order_col)
filtered_df = window_df.where(
(col(f"NormRating{term}") != "NR") &
(window_df[f"NonNRCount{term}"] == count_rate) &
(window_df[f"NonNRDistinctCount{term}"] == distinct_rate)
)
filtered_counts = (
filtered_df.groupBy("UniqueID", f"BLRank{term}")
.agg(count("*").alias("total_count"))
.where(col("total_count") == 2)
)
return (
filtered_df.join(filtered_counts, on=["UniqueID", f"BLRank{term}"],
how="inner")
.withColumn("row_num", row_number().over(window_spec))
.where(col("row_num") == 1)
.select(
col("UniqueID"),
col("Agency").alias(f"SourceAgency{term}"),
col(f"BLRank{term}"),
col(f"NormRating{term}").alias(f"FinalRating{term}")
)
)
else:
window_df = window_df.where(col(f"NormRating{term}") != "NR")
filtered_df = window_df.where(
(window_df[f"NonNRCount{term}"] == count_rate) &
(window_df[f"NonNRDistinctCount{term}"] == distinct_rate)
)
window_spec = Window.partitionBy("UniqueID").orderBy(order_col)
return (
filtered_df.where((col(f"BLRank{term}") == col(rating_col)))
.withColumn("row_num", row_number().over(window_spec))
.where(col("row_num") == 1)
.select(
col("UniqueID"),
col("Agency").alias(f"SourceAgency{term}"),
col(f"BLRank{term}"),
col(f"NormRating{term}").alias(f"FinalRating{term}")
)
)
def ratings_loop(df, term=""):
window_params = [
[0, 0, f"LowestRank{term}", col(f"RatingRank{term}").asc(), True],
[3, 3, f"MidRank{term}", col(f"RatingRank{term}").asc(), False],
[3, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False],
[2, 2, f"HighestRank{term}", col(f"RatingRank{term}").desc(), False],
[2, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False],
[1, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False]
]
df_reduced = df.select(
"UniqueID", "Agency", f"NormRating{term}",
f"RatingRank{term}", f"BLRank{term}", f"HighestRank{term}",
f"MidRank{term}", f"DistinctCount{term}", f"Count{term}", f"NRCount{term}", f"NonNRCount{term}",
f"NonNRDistinctCount{term}"
)
return functools.reduce(DataFrame.union,
[generate_ratings(df_reduced, *params, term) for params in window_params])
def join_final_df(df, df_short, df_long, on_column, select_list):
return df.join(df_short, on=on_column, how='left') \
.join(df_long, on=on_column, how='left') \
.select(select_list)
print("Input df: ")
ratings_df.show()
final_ratings_df = ratings_loop(ratings_df)
final_ratings_df_short = ratings_loop(ratings_df, "Short")
final_ratings_df_long = ratings_loop(ratings_df, "Long")
final_ratings_df_joined = join_final_df(
final_ratings_df,
final_ratings_df_short,
final_ratings_df_long,
["UniqueID"],
[
"UniqueID", "SourceAgency", "SourceAgencyShort",
"SourceAgencyLong", "BLRank", "BLRankShort",
"BLRankLong", "FinalRating", "FinalRatingShort", "FinalRatingLong"
]
)
print("final_ratings_df_joined: ")
final_ratings_df_joined.show()
我有一个 PySpark 脚本,它使用窗口函数和聚合来处理评级数据。该代码工作正常,但尚未优化,因为 它使用 for 循环和 functools.reduce 来组合 DataFrame,我相信这可以改进。我希望通过避免 for 循环和 reduce 来优化此脚本,同时保持相同的功能。
我尝试使用同一 DataFrame 中的窗口函数和条件将 DataFrame 转换合并为单个操作。但是,我找不到完全消除循环和 functools.reduce 的方法。
通过查看您的评级选择器/计算器代码,性能看起来不错,特别是因为您正在使用一些最佳实践,例如谓词下推、尽早过滤行以及高效应用联接。
如果应用得当,For 循环本质上并不是 Spark 中的罪魁祸首。显然,避免使用大量
joins
的 for 循环以及在函数中触发 count
之类的操作,这肯定会导致性能下降。
如果它符合良好的 SLA,我不会不必要地提高性能。然而,有一些额外优化的建议。
df_reduced
操作,缓存 ratings_df
或 union
:如果您的目标是尽可能避免当前逻辑中的更改,一个好的替代方案是简单地缓存
df_reduced
,因为应用了函数 generate_ratings
次。或者,您也可以缓存 ratings_df
,因为它将被触发 3 个术语。这是一个例子:
def ratings_loop(df, term=""):
window_params = [
[0, 0, f"LowestRank{term}", col(f"RatingRank{term}").asc(), True],
[3, 3, f"MidRank{term}", col(f"RatingRank{term}").asc(), False],
[3, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False],
[2, 2, f"HighestRank{term}", col(f"RatingRank{term}").desc(), False],
[2, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False],
[1, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False]
]
df_reduced = df.select(
"UniqueID", "Agency", f"NormRating{term}",
f"RatingRank{term}", f"BLRank{term}", f"HighestRank{term}",
f"MidRank{term}", f"DistinctCount{term}", f"Count{term}", f"NRCount{term}", f"NonNRCount{term}",
f"NonNRDistinctCount{term}"
).cache() # --> Add to the cache
df_reduced.take(1) # --> Simply triggers an action for force early caching
result_df = functools.reduce(DataFrame.union,
[generate_ratings(df_reduced, *params, term) for params in window_params])
df_reduced.unpersist() # --> Always unpersist to free up caching and avoid potential memory issues
return result_df
loops
和functools.reduce
:除了缓存(无论如何这都会有益)并尝试重构
generate_ratings
中的核心逻辑,或者你可以使用DataFrame.transform
并应用反复generate_ratings
。此策略还可以防止 for 循环并使用纯 pyspark 内置函数来保护您的长期性能最终下降。此策略也可以应用于代码开头的rating_loops
def ratings_loop(df, term=""):
df_reduced = ratings_df.select(
"UniqueID", "Agency", f"NormRating{term}",
f"RatingRank{term}", f"BLRank{term}", f"HighestRank{term}",
f"MidRank{term}", f"DistinctCount{term}", f"Count{term}", f"NRCount{term}", f"NonNRCount{term}",
f"NonNRDistinctCount{term}"
)
df_reduced.cache() # --> Adds to the cache
df_reduced.take(1) # --> Force triggering it before any transformation
# Python NOTE: Avoid mutable objects (e.g. lists) in function parameters, use immutable (e.g. tuples) preferably
item_0 = (0, 0, f"LowestRank{term}", col(f"RatingRank{term}").asc(), True)
item_1 = (3, 3, f"MidRank{term}", col(f"RatingRank{term}").asc(), False)
item_2 = (3, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False)
item_3 = (2, 2, f"HighestRank{term}", col(f"RatingRank{term}").desc(), False)
item_4 = (2, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False)
item_5 = (1, 1, f"LowestRank{term}", col(f"RatingRank{term}").asc(), False)
window_df0 = df_reduced.transform(generate_ratings, *item_0, term)
window_df1 = df_reduced.transform(generate_ratings, *item_1, term)
window_df2 = df_reduced.transform(generate_ratings, *item_2, term)
window_df3 = df_reduced.transform(generate_ratings, *item_3, term)
window_df4 = df_reduced.transform(generate_ratings, *item_4, term)
window_df5 = df_reduced.transform(generate_ratings, *item_5, term)
all_dfs = (window_df0.unionByName(window_df1)
.unionByName(window_df2)
.unionByName(window_df3)
.unionByName(window_df4)
.unionByName(window_df5)
)
result_df = all_dfs.dropDuplicates(["UniqueID"]) # --> Use pyspark.sql.Window to avoid shuffling
result_df.display()