我正在使用 Polars DataFrame,需要使用其他行的值对每一行执行计算。目前,我正在使用map_elements方法,但效率不高。
在以下示例中,我向 DataFrame 添加两个新列:
这是我当前的实现:
import polars as pl
COL_VALUE = "value"
def fun_sum_lower(current_row, df):
tmp_df = df.filter(pl.col(COL_VALUE) < current_row[COL_VALUE])
sum_lower = tmp_df.select(pl.sum(COL_VALUE)).item()
return sum_lower
def fun_max_other(current_row, df):
tmp_df = df.filter(pl.col(COL_VALUE) != current_row[COL_VALUE])
max_other = tmp_df.select(pl.col(COL_VALUE)).max().item()
return max_other
if __name__ == '__main__':
df = pl.DataFrame({COL_VALUE: [3, 7, 1, 9, 4]})
df = df.with_columns(
pl.struct([COL_VALUE])
.map_elements(lambda row: fun_sum_lower(row, df), return_dtype=pl.Int64)
.alias("sum_lower")
)
df = df.with_columns(
pl.struct([COL_VALUE])
.map_elements(lambda row: fun_max_other(row, df), return_dtype=pl.Int64)
.alias("max_other")
)
print(df)
上述代码的输出为:
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════════╪═══════════╡
│ 3 ┆ 1 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 1 ┆ 0 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
│ 4 ┆ 4 ┆ 9 │
└───────┴───────────┴───────────┘
虽然此代码有效,但由于使用了 lambda 和逐行操作,效率不高。
是否有更有效的方法可以在 Polars 中实现此目的,而不使用 lambda、迭代行或运行 Python 代码?
我还尝试使用 Polars 方法:
cum_sum
、group_by_dynamic
和rolling
,但我认为这些方法不能用于此任务。
.join_where()
df_sum = (
df.unique("value") # summing requires unique LHS
.join_where(df, pl.col.value > pl.col.value_right)
.group_by("value")
.sum()
)
df_max = (
df.join_where(df, pl.col.value != pl.col.value_right)
.group_by("value")
.max()
)
(df.select("value")
.join(df_sum, on="value", how="left")
.join(df_max, on="value", how="left")
)
shape: (5, 3)
┌───────┬─────────────┬───────────────────┐
│ value ┆ value_right ┆ value_right_right │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═════════════╪═══════════════════╡
│ 3 ┆ 1 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 1 ┆ null ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
│ 4 ┆ 4 ┆ 9 │
└───────┴─────────────┴───────────────────┘
尽管如此,对于这些特定任务 - 它并不能真正扩展到更大的数据。
或者,如果您可以对数据进行排序,问题本质上就变成了
.cum_sum()
+ .max()
操作(带有几个“移位”) - 这是非常快的操作。
(df
.sort("value")
.with_columns(
pl.when(pl.col.value.is_last_distinct())
.then(pl.col.value.cum_sum())
.shift()
.forward_fill()
.fill_null(0)
.alias("sum_lower")
)
.with_columns(max_other = pl.col.value.max())
.with_columns(
pl.when(pl.col.value == pl.col.max_other)
.then(pl.col.value.rle().get(-2).struct.field("value"))
.otherwise(pl.col.max_other)
.alias("max_other")
)
)
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════════╪═══════════╡
│ 1 ┆ 0 ┆ 9 │
│ 3 ┆ 1 ┆ 9 │
│ 4 ┆ 4 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
└───────┴───────────┴───────────┘
如果我们在您的示例中添加一些重复的值。
df = pl.DataFrame({"value": [3, 7, 1, 9, 4, 4, 3, 3]})
(df.sort("value")
.with_columns(is_last = pl.col.value.is_last_distinct())
.with_columns(pl.when(pl.col.value.is_last_distinct()).then(pl.col.value.cum_sum()).alias("sum0"))
.with_columns(pl.col("sum0").shift().alias("sum1"))
.with_columns(pl.col("sum1").forward_fill().alias("sum2"))
)
shape: (8, 5)
┌───────┬─────────┬──────┬──────┬──────┐
│ value ┆ is_last ┆ sum0 ┆ sum1 ┆ sum2 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ i64 ┆ i64 ┆ i64 │
╞═══════╪═════════╪══════╪══════╪══════╡
│ 1 ┆ true ┆ 1 ┆ null ┆ null │
│ 3 ┆ false ┆ null ┆ 1 ┆ 1 │
│ 3 ┆ false ┆ null ┆ null ┆ 1 │
│ 3 ┆ true ┆ 10 ┆ null ┆ 1 │
│ 4 ┆ false ┆ null ┆ 10 ┆ 10 │
│ 4 ┆ true ┆ 18 ┆ null ┆ 10 │
│ 7 ┆ true ┆ 25 ┆ 18 ┆ 18 │
│ 9 ┆ true ┆ 34 ┆ 25 ┆ 25 │
└───────┴─────────┴──────┴──────┴──────┘
对数据进行排序后,
.is_last_distinct()
会为您提供每次运行/连续运行中的最后一个值,并向前移动 1 步。
然后使用 .forward_fill()
处理可能的重复值,因此它们各自获得“sum_lower”值。
“max_other”逻辑似乎是您在除最后一种情况之外的所有情况下都需要
.max()
值,即当 value == max 时。
.rle()
是一种快速方法,可为您提供一组唯一的值并保留其排序顺序。
df.sort("value").select(pl.col.value.rle())
shape: (5, 1)
┌───────────┐
│ value │
│ --- │
│ struct[2] │
╞═══════════╡
│ {1,1} │
│ {3,3} │
│ {2,4} │
│ {1,7} │ # <-
│ {1,9} │
└───────────┘
对于您的特定用例,您实际上并不需要连接,您可以使用窗口函数计算值。
pl.Expr.shift()
排除当前行。pl.Expr.cum_sum()
计算当前行之前所有元素的总和。pl.Expr.max()
计算最大值。pl.Expr.bottom_k()
计算 2 个最大元素,因此我们可以将 pl.Expr.min()
作为第二大元素。(
df
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min())
)
)
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════════╪═══════════╡
│ 1 ┆ 0 ┆ 9 │
│ 3 ┆ 1 ┆ 9 │
│ 4 ┆ 4 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
└───────┴───────────┴───────────┘
pl.DataFrame.with_row_index()
来保留当前顺序,以便您可以在最后使用 pl.DataFrame.sort()
恢复到当前顺序。
(
df.with_row_index()
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min())
)
.sort("index")
.drop("index")
)
另一种可能的解决方案是使用 DuckDB 与 Polars 集成。
使用窗口函数,利用优秀的 DuckDB 窗口框架选项。
max(arg, n)
计算前 2 个最大元素。import duckdb
duckdb.sql("""
select
d.value,
coalesce(sum(d.value) over(
order by d.value
rows unbounded preceding
exclude current row
), 0) as sum_lower,
max(d.value) over(
rows between unbounded preceding and unbounded following
exclude current row
) as max_other
from df as d
""").pl()
shape: (5, 3)
┌───────┬───────────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ decimal[38,0] ┆ i64 │
╞═══════╪═══════════════╪═══════════╡
│ 1 ┆ 0 ┆ 9 │
│ 3 ┆ 1 ┆ 9 │
│ 4 ┆ 4 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
└───────┴───────────────┴───────────┘
lateral join
:
import duckdb
duckdb.sql("""
select
d.value,
coalesce(s.value, 0) as sum_lower,
m.value as max_other
from df as d,
lateral (select sum(t.value) as value from df as t where t.value < d.value) as s,
lateral (select max(t.value) as value from df as t where t.value != d.value) as m
""").pl()
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═══════╪═══════════╪═══════════╡
│ 3 ┆ 1 ┆ 9 │
│ 7 ┆ 8 ┆ 9 │
│ 1 ┆ 0 ┆ 9 │
│ 9 ┆ 15 ┆ 7 │
│ 4 ┆ 4 ┆ 9 │
└───────┴───────────┴───────────┘
重复值
如果没有重复值,上面的纯极坐标解决方案效果很好,但如果有,您也可以解决它。 以下是 2 个示例,具体取决于您是否要保留原始订单:
# not keeping original order
(
df
.select(pl.col.value.value_counts()).unnest("value")
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min()),
value = pl.col.value.repeat_by("count")
).drop("count").explode("value")
)
# keeping original order
(
df.with_row_index()
.group_by("value").agg("index")
.sort("value")
.with_columns(
sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
max_other =
pl.when(pl.col.value.max() != pl.col.value)
.then(pl.col.value.max())
.otherwise(pl.col.value.bottom_k(2).min())
)
.explode("index")
.sort("index")
.drop("index")
)