我遇到了一个问题,我必须对每一行数据迭代相同的计算,直到它们收敛。我的思路是在每次迭代后删除收敛的行,以便在下一次迭代时减少计算时间。 然后我注意到计算量随着每次迭代而增加。
我对 pyspark 有点陌生,目前正在尝试使用 F.when() 重写我的代码,但我觉得无论如何这里都有一个教训。
我在下面做了一个小demo来说明问题:
import time
def create_dummy_df(size=10000):
data = np.random.uniform(0, 1, size=(size, 6))
rows = [(float(a), float(b), float(c), float(d), float(e), float(f)) for a,b,c,d,e,f in data]
return spark.createDataFrame(rows, ["a", "b", "c", "d", "e", "f"])
def doComputation(df):
"""Simulate heavy polynomial calculations with increased complexity"""
df = df.withColumn("res_a", F.pow(F.col("a"), 2))
df = df.withColumn("res_b", F.pow(F.col("b"), 3))
df = df.withColumn("res_c", F.pow(F.col("c"), 4))
df = df.withColumn("res_d", F.pow(F.col("a") + F.col("b"), 2))
df = df.withColumn("res_e", F.pow(F.col("b") + F.col("c"), 3))
df = df.withColumn("res_f", F.pow(F.col("c") + F.col("a"), 4))
df = df.withColumn("output", F.col("res_a") + F.col("res_b") + F.col("res_c") + F.col("res_d") + F.col("res_e") + F.col("res_f"))
df = df.withColumn("converged", F.col("output") > 5)
# Increase a, b, c by 0.01 so more will converge in next iteration
df = df.withColumn("a", F.col("a") + 0.01).withColumn("b", F.col("b") + 0.01).withColumn("c", F.col("c") + 0.01)
return df
def testFunc(df, maxIter = 40):
total_rows = df.count()
converged_rows = []
for i in range(maxIter):
t = time.time()
df = doComputation(df)
# Identify converged rows
converged = df.filter(F.col("converged") | F.isnan('output'))
# Remove converged rows from working set
df = df.filter(~F.col("converged") & ~F.isnan('output'))
#re-persist dataframe to avoid exploding computation tree on later iterations
df.unpersist()
df = df.cache()
remaining = df.count()
#store converged rows with information on when they converged
if not converged.isEmpty():
converged_rows.append(converged.withColumn("converged_iteration", F.lit(i+1)))
print(f"Iteration {i+1}: {total_rows - remaining}/{total_rows} rows converged, iteration time: {time.time()-t:.2f} seconds")
if remaining == 0:
break
t = time.time()
# merge all converged results
result_df = converged_rows[0]
for df_chunk in converged_rows[1:]:
result_df = result_df.union(df_chunk)
# merge unconverged results with iteration 999
if df.count() > 0:
result_df = result_df.union(df.withColumn("converged_iteration", F.lit(999)))
print(f'merge time: {time.time()-t:.2f} seconds')
return result_df
testdf = create_dummy_df()
testFunc(testdf)
Iteration 1: 4350/10000 rows converged, iteration time: 1.87 seconds
Iteration 2: 4574/10000 rows converged, iteration time: 1.26 seconds
Iteration 3: 4821/10000 rows converged, iteration time: 0.88 seconds
Iteration 4: 5028/10000 rows converged, iteration time: 0.82 seconds
Iteration 5: 5253/10000 rows converged, iteration time: 1.03 seconds
Iteration 6: 5509/10000 rows converged, iteration time: 1.17 seconds
Iteration 7: 5745/10000 rows converged, iteration time: 1.02 seconds
Iteration 8: 5977/10000 rows converged, iteration time: 1.24 seconds
Iteration 9: 6196/10000 rows converged, iteration time: 1.19 seconds
Iteration 10: 6410/10000 rows converged, iteration time: 1.27 seconds
Iteration 11: 6623/10000 rows converged, iteration time: 1.41 seconds
Iteration 12: 6829/10000 rows converged, iteration time: 1.44 seconds
Iteration 13: 6998/10000 rows converged, iteration time: 1.45 seconds
Iteration 14: 7186/10000 rows converged, iteration time: 2.08 seconds
Iteration 15: 7354/10000 rows converged, iteration time: 1.65 seconds
Iteration 16: 7523/10000 rows converged, iteration time: 1.89 seconds
Iteration 17: 7673/10000 rows converged, iteration time: 2.03 seconds
Iteration 18: 7852/10000 rows converged, iteration time: 2.00 seconds
Iteration 19: 7994/10000 rows converged, iteration time: 1.91 seconds
Iteration 20: 8132/10000 rows converged, iteration time: 2.05 seconds
Iteration 21: 8261/10000 rows converged, iteration time: 2.57 seconds
Iteration 22: 8371/10000 rows converged, iteration time: 2.45 seconds
Iteration 23: 8497/10000 rows converged, iteration time: 3.14 seconds
Iteration 24: 8612/10000 rows converged, iteration time: 2.72 seconds
Iteration 25: 8713/10000 rows converged, iteration time: 2.91 seconds
Iteration 26: 8822/10000 rows converged, iteration time: 3.02 seconds
Iteration 27: 8945/10000 rows converged, iteration time: 3.18 seconds
Iteration 28: 9026/10000 rows converged, iteration time: 3.23 seconds
Iteration 29: 9115/10000 rows converged, iteration time: 3.63 seconds
Iteration 30: 9202/10000 rows converged, iteration time: 3.66 seconds
Iteration 31: 9273/10000 rows converged, iteration time: 3.76 seconds
Iteration 32: 9338/10000 rows converged, iteration time: 3.49 seconds
Iteration 33: 9412/10000 rows converged, iteration time: 4.09 seconds
Iteration 34: 9472/10000 rows converged, iteration time: 4.32 seconds
Iteration 35: 9546/10000 rows converged, iteration time: 4.67 seconds
Iteration 36: 9590/10000 rows converged, iteration time: 4.30 seconds
Iteration 37: 9643/10000 rows converged, iteration time: 4.54 seconds
Iteration 38: 9688/10000 rows converged, iteration time: 4.93 seconds
Iteration 39: 9732/10000 rows converged, iteration time: 5.67 seconds
Iteration 40: 9769/10000 rows converged, iteration time: 5.72 seconds
Dataframes
) 时,Spark 的行为会很奇怪。取消注释下面代码中的 # df.explain()
并查看每个 test_type
force_plan_execution()
。这些随着时间的推移而演变,例如过去df.count()
用于强制执行计划,今天(较新的火花)则不再。test_type
的合并时间。如果你确实说 15 次迭代,那么“orig”的合并将比“rdd”的合并更快。"remove the converged rows after each iteration so the computation time decreases on the next iteration"
根据数据的分区方式等,这可能没有帮助,甚至可能使事情变得更糟(成为不必要的开销)。我不太明白你在代码中做什么,所以不能说。但要知道,您将无法将在这里学到的任何内容作为通用方法并将其应用到另一个实验(具有完全不同的 doComputation()
实现和/或在更大的数据集上运行)并期望相同的行为。这就是为什么#3。import time
import numpy as np
from pyspark.sql import functions as F
def create_dummy_df(size=10000):
data = np.random.uniform(0, 1, size=(size, 6))
rows = [(float(a), float(b), float(c), float(d), float(e), float(f)) for a,b,c,d,e,f in data]
return spark.createDataFrame(rows, ["a", "b", "c", "d", "e", "f"])
def force_plan_execution(df, test_type):
# 'as_table', 'rdd'
if test_type == 'as_table':
df.write.saveAsTable("df_temp_view1", mode="overwrite")
return spark.read.table("df_temp_view1")
elif test_type == 'rdd':
r = df.rdd
s = df.schema
return r.toDF(s)
else:
return df
def doComputation(df):
"""Simulate heavy polynomial calculations with increased complexity"""
df = df.withColumn("res_a", F.pow(F.col("a"), 2))
df = df.withColumn("res_b", F.pow(F.col("b"), 3))
df = df.withColumn("res_c", F.pow(F.col("c"), 4))
df = df.withColumn("res_d", F.pow(F.col("a") + F.col("b"), 2))
df = df.withColumn("res_e", F.pow(F.col("b") + F.col("c"), 3))
df = df.withColumn("res_f", F.pow(F.col("c") + F.col("a"), 4))
df = df.withColumn("output", F.col("res_a") + F.col("res_b") + F.col("res_c") + F.col("res_d") + F.col("res_e") + F.col("res_f"))
df = df.withColumn("converged", F.col("output") > 5)
# Increase a, b, c by 0.01 so more will converge in next iteration
df = df.withColumn("a", F.col("a") + 0.01).withColumn("b", F.col("b") + 0.01).withColumn("c", F.col("c") + 0.01)
return df
def testFunc(df, test_type, num_iter = 16):
total_rows = df.count()
converged_rows = []
for i in range(num_iter):
t = time.time()
df = doComputation(df)
df = force_plan_execution(df, test_type)
# df.explain()
# Identify converged rows
converged = df.filter(F.col("converged") | F.isnan('output'))
# Remove converged rows from working set
df = df.filter(~F.col("converged") & ~F.isnan('output'))
#re-persist dataframe to avoid exploding computation tree on later iterations
df.unpersist()
df = df.cache()
remaining = df.count()
#store converged rows with information on when they converged
if not converged.isEmpty():
converged_rows.append(converged.withColumn("converged_iteration", F.lit(i+1)))
if i % 2 == 1:
print(f"Iteration {i+1}: {total_rows - remaining}/{total_rows} rows converged, iteration time: {time.time()-t:.2f} seconds")
if remaining == 0:
break
t = time.time()
# merge all converged results
result_df = converged_rows[0]
for df_chunk in converged_rows[1:]:
result_df = result_df.union(df_chunk)
# merge unconverged results with iteration 999
if df.count() > 0:
result_df = result_df.union(df.withColumn("converged_iteration", F.lit(999)))
print(f'merge time: {time.time()-t:.2f} seconds')
return result_df
for test_type in ['rdd', 'orig', 'as_table', ]:
t = time.time()
print(f"Start {test_type} {'-'*60}")
testdf = create_dummy_df()
testFunc(testdf, test_type)
print(f"END Total time: {test_type} {time.time()-t:.2f} seconds {'-'*40}")
打印:
Start rdd ------------------------------------------------------------
Iteration 5: 5299/10000 rows converged, iteration time: 0.93 seconds
Iteration 10: 6439/10000 rows converged, iteration time: 1.31 seconds
Iteration 15: 7358/10000 rows converged, iteration time: 1.90 seconds
Iteration 20: 8151/10000 rows converged, iteration time: 2.76 seconds
Iteration 25: 8774/10000 rows converged, iteration time: 3.83 seconds
Iteration 30: 9232/10000 rows converged, iteration time: 5.07 seconds
merge time: 1.42 seconds
END Total time: rdd 72.91 seconds ----------------------------------------
Start orig ------------------------------------------------------------
Iteration 5: 5283/10000 rows converged, iteration time: 0.45 seconds
Iteration 10: 6364/10000 rows converged, iteration time: 0.84 seconds
Iteration 15: 7344/10000 rows converged, iteration time: 1.40 seconds
Iteration 20: 8099/10000 rows converged, iteration time: 2.01 seconds
Iteration 25: 8726/10000 rows converged, iteration time: 3.03 seconds
Iteration 30: 9209/10000 rows converged, iteration time: 3.97 seconds
merge time: 5.76 seconds
END Total time: orig 58.26 seconds ----------------------------------------
Start as_table ------------------------------------------------------------
Iteration 5: 5229/10000 rows converged, iteration time: 4.34 seconds
Iteration 10: 6347/10000 rows converged, iteration time: 4.46 seconds
Iteration 15: 7296/10000 rows converged, iteration time: 4.39 seconds
Iteration 20: 8114/10000 rows converged, iteration time: 4.58 seconds
Iteration 25: 8743/10000 rows converged, iteration time: 4.54 seconds
Iteration 30: 9208/10000 rows converged, iteration time: 5.07 seconds
merge time: 0.18 seconds
END Total time: as_table 139.36 seconds ----------------------------------------