如何高效地在其他行上运行计算

问题描述 投票:0回答:2

我正在使用 Polars DataFrame,需要使用其他行的值对每一行执行计算。目前,我正在使用map_elements方法,但效率不高。

在以下示例中,我向 DataFrame 添加两个新列:

  1. sum_lower:所有小于当前元素的元素之和。
  2. max_other:DataFrame 中的最大值,不包括当前元素。

这是我当前的实现:

import polars as pl

COL_VALUE = "value"

def fun_sum_lower(current_row, df):
    tmp_df = df.filter(pl.col(COL_VALUE) < current_row[COL_VALUE])
    sum_lower = tmp_df.select(pl.sum(COL_VALUE)).item()
    return sum_lower

def fun_max_other(current_row, df):
    tmp_df = df.filter(pl.col(COL_VALUE) != current_row[COL_VALUE])
    max_other = tmp_df.select(pl.col(COL_VALUE)).max().item()
    return max_other

if __name__ == '__main__':
    df = pl.DataFrame({COL_VALUE: [3, 7, 1, 9, 4]})

    df = df.with_columns(
        pl.struct([COL_VALUE])
        .map_elements(lambda row: fun_sum_lower(row, df), return_dtype=pl.Int64)
        .alias("sum_lower")
    )

    df = df.with_columns(
        pl.struct([COL_VALUE])
        .map_elements(lambda row: fun_max_other(row, df), return_dtype=pl.Int64)
        .alias("max_other")
    )

    print(df)

上述代码的输出为:

shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ ---   ┆ ---       ┆ ---       │
│ i64   ┆ i64       ┆ i64       │
╞═══════╪═══════════╪═══════════╡
│ 3     ┆ 1         ┆ 9         │
│ 7     ┆ 8         ┆ 9         │
│ 1     ┆ 0         ┆ 9         │
│ 9     ┆ 15        ┆ 7         │
│ 4     ┆ 4         ┆ 9         │
└───────┴───────────┴───────────┘

虽然此代码有效,但由于使用了 lambda 和逐行操作,效率不高。

是否有更有效的方法可以在 Polars 中实现此目的,而不使用 lambda、迭代行或运行 Python 代码?

我还尝试使用 Polars 方法:

cum_sum
group_by_dynamic
rolling
,但我认为这些方法不能用于此任务。

python dataframe python-polars
2个回答
6
投票

您可以尝试非等值连接:

.join_where()

df_sum = (
   df.unique("value") # summing requires unique LHS
     .join_where(df, pl.col.value > pl.col.value_right)
     .group_by("value")
     .sum()
)

df_max = (
   df.join_where(df, pl.col.value != pl.col.value_right)
     .group_by("value")
     .max()
)

(df.select("value")
   .join(df_sum, on="value", how="left")
   .join(df_max, on="value", how="left")
)
shape: (5, 3)
┌───────┬─────────────┬───────────────────┐
│ value ┆ value_right ┆ value_right_right │
│ ---   ┆ ---         ┆ ---               │
│ i64   ┆ i64         ┆ i64               │
╞═══════╪═════════════╪═══════════════════╡
│ 3     ┆ 1           ┆ 9                 │
│ 7     ┆ 8           ┆ 9                 │
│ 1     ┆ null        ┆ 9                 │
│ 9     ┆ 15          ┆ 7                 │
│ 4     ┆ 4           ┆ 9                 │
└───────┴─────────────┴───────────────────┘

尽管如此,对于这些特定任务 - 它并不能真正扩展到更大的数据。


或者,如果您可以对数据进行排序,问题本质上就变成了

.cum_sum()
+
.max()
操作(带有几个“移位”) - 这是非常快的操作。

(df
  .sort("value")
  .with_columns(
     pl.when(pl.col.value.is_last_distinct())
       .then(pl.col.value.cum_sum())
       .shift()
       .forward_fill()
       .fill_null(0)
       .alias("sum_lower")
  )
  .with_columns(max_other = pl.col.value.max())
  .with_columns(
     pl.when(pl.col.value == pl.col.max_other)
       .then(pl.col.value.rle().get(-2).struct.field("value"))
       .otherwise(pl.col.max_other)
       .alias("max_other")
  )
)
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ ---   ┆ ---       ┆ ---       │
│ i64   ┆ i64       ┆ i64       │
╞═══════╪═══════════╪═══════════╡
│ 1     ┆ 0         ┆ 9         │
│ 3     ┆ 1         ┆ 9         │
│ 4     ┆ 4         ┆ 9         │
│ 7     ┆ 8         ┆ 9         │
│ 9     ┆ 15        ┆ 7         │
└───────┴───────────┴───────────┘

发生了什么事?

如果我们在您的示例中添加一些重复的值。

df = pl.DataFrame({"value": [3, 7, 1, 9, 4, 4, 3, 3]})

(df.sort("value")
   .with_columns(is_last = pl.col.value.is_last_distinct())
   .with_columns(pl.when(pl.col.value.is_last_distinct()).then(pl.col.value.cum_sum()).alias("sum0"))
   .with_columns(pl.col("sum0").shift().alias("sum1"))
   .with_columns(pl.col("sum1").forward_fill().alias("sum2"))
)
shape: (8, 5)
┌───────┬─────────┬──────┬──────┬──────┐
│ value ┆ is_last ┆ sum0 ┆ sum1 ┆ sum2 │
│ ---   ┆ ---     ┆ ---  ┆ ---  ┆ ---  │
│ i64   ┆ bool    ┆ i64  ┆ i64  ┆ i64  │
╞═══════╪═════════╪══════╪══════╪══════╡
│ 1     ┆ true    ┆ 1    ┆ null ┆ null │
│ 3     ┆ false   ┆ null ┆ 1    ┆ 1    │
│ 3     ┆ false   ┆ null ┆ null ┆ 1    │
│ 3     ┆ true    ┆ 10   ┆ null ┆ 1    │
│ 4     ┆ false   ┆ null ┆ 10   ┆ 10   │
│ 4     ┆ true    ┆ 18   ┆ null ┆ 10   │
│ 7     ┆ true    ┆ 25   ┆ 18   ┆ 18   │
│ 9     ┆ true    ┆ 34   ┆ 25   ┆ 25   │
└───────┴─────────┴──────┴──────┴──────┘

对数据进行排序后,

.is_last_distinct()
会为您提供每次运行/连续运行中的最后一个值,并向前移动 1 步。

然后使用

.forward_fill()
处理可能的重复值,因此它们各自获得“sum_lower”值。


“max_other”逻辑似乎是您在除最后一种情况之外的所有情况下都需要

.max()
值,即当 value == max 时。

对于排序数据,

.rle()
是一种快速方法,可为您提供一组唯一的值并保留其排序顺序。

df.sort("value").select(pl.col.value.rle())
shape: (5, 1)
┌───────────┐
│ value     │
│ ---       │
│ struct[2] │
╞═══════════╡
│ {1,1}     │
│ {3,3}     │
│ {2,4}     │
│ {1,7}     │ # <-
│ {1,9}     │
└───────────┘

4
投票

对于您的特定用例,您实际上并不需要连接,您可以使用窗口函数计算值。

(
    df
    .sort("value")
    .with_columns(
        sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
        max_other =
        pl.when(pl.col.value.max() != pl.col.value)
        .then(pl.col.value.max())
        .otherwise(pl.col.value.bottom_k(2).min()) 
    )
)
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ ---   ┆ ---       ┆ ---       │
│ i64   ┆ i64       ┆ i64       │
╞═══════╪═══════════╪═══════════╡
│ 1     ┆ 0         ┆ 9         │
│ 3     ┆ 1         ┆ 9         │
│ 4     ┆ 4         ┆ 9         │
│ 7     ┆ 8         ┆ 9         │
│ 9     ┆ 15        ┆ 7         │
└───────┴───────────┴───────────┘

您还可以使用

pl.DataFrame.with_row_index()
来保留当前顺序,以便您可以在最后使用
pl.DataFrame.sort()
恢复到当前顺序。

(
    df.with_row_index()
    .sort("value")
    .with_columns(
        sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
        max_other =
        pl.when(pl.col.value.max() != pl.col.value)
        .then(pl.col.value.max())
        .otherwise(pl.col.value.bottom_k(2).min()) 
    )
    .sort("index")
    .drop("index")
)

另一种可能的解决方案是使用 DuckDB 与 Polars 集成

使用窗口函数,利用优秀的 DuckDB 窗口框架选项

import duckdb

duckdb.sql("""
    select
        d.value,
        coalesce(sum(d.value) over(
            order by d.value
            rows unbounded preceding
            exclude current row
        ), 0) as sum_lower,
        max(d.value) over(
            rows between unbounded preceding and unbounded following
            exclude current row
        ) as max_other
    from df as d
""").pl()
shape: (5, 3)
┌───────┬───────────────┬───────────┐
│ value ┆ sum_lower     ┆ max_other │
│ ---   ┆ ---           ┆ ---       │
│ i64   ┆ decimal[38,0] ┆ i64       │
╞═══════╪═══════════════╪═══════════╡
│ 1     ┆ 0             ┆ 9         │
│ 3     ┆ 1             ┆ 9         │
│ 4     ┆ 4             ┆ 9         │
│ 7     ┆ 8             ┆ 9         │
│ 9     ┆ 15            ┆ 7         │
└───────┴───────────────┴───────────┘

或使用

lateral join
:

import duckdb

duckdb.sql("""
    select
        d.value,
        coalesce(s.value, 0) as sum_lower,
        m.value as max_other
    from df as d,
        lateral (select sum(t.value) as value from df as t where t.value < d.value) as s,
        lateral (select max(t.value) as value from df as t where t.value != d.value) as m
""").pl()
shape: (5, 3)
┌───────┬───────────┬───────────┐
│ value ┆ sum_lower ┆ max_other │
│ ---   ┆ ---       ┆ ---       │
│ i64   ┆ i64       ┆ i64       │
╞═══════╪═══════════╪═══════════╡
│ 3     ┆ 1         ┆ 9         │
│ 7     ┆ 8         ┆ 9         │
│ 1     ┆ 0         ┆ 9         │
│ 9     ┆ 15        ┆ 7         │
│ 4     ┆ 4         ┆ 9         │
└───────┴───────────┴───────────┘

重复值

如果没有重复值,上面的纯极坐标解决方案效果很好,但如果有,您也可以解决它。 以下是 2 个示例,具体取决于您是否要保留原始订单:

# not keeping original order
(
    df
    .select(pl.col.value.value_counts()).unnest("value")
    .sort("value")
    .with_columns(
        sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
        max_other =
        pl.when(pl.col.value.max() != pl.col.value)
        .then(pl.col.value.max())
        .otherwise(pl.col.value.bottom_k(2).min()),
        value = pl.col.value.repeat_by("count")
    ).drop("count").explode("value")
)
# keeping original order
(
    df.with_row_index()
    .group_by("value").agg("index")
    .sort("value")
    .with_columns(
        sum_lower = pl.col.value.shift(1).cum_sum().fill_null(0),
        max_other =
        pl.when(pl.col.value.max() != pl.col.value)
        .then(pl.col.value.max())
        .otherwise(pl.col.value.bottom_k(2).min()) 
    )
    .explode("index")
    .sort("index")
    .drop("index")
)
© www.soinside.com 2019 - 2024. All rights reserved.