Python-polars:rolling_sum,其中window_size位于另一列中

问题描述 投票:0回答:2

考虑以下数据框:

df = pl.DataFrame(
    {
        "date": pl.date_range(
            pl.date(2023, 2, 1),
            pl.date(2023, 2, 5),
            interval="1d",
            eager=True),
        "periods": [2, 2, 2, 1, 1],
        "quantity": [10, 12, 14, 16, 18],
        "calculate": [22, 26, 30, 16, 18]
    }
)

列计算就是我想要的。这是通过 Rolling_sum 完成的,其中

window_size
参数取自
periods
列,而不是固定值。

我可以执行以下操作(window_size = 2):

df.select(pl.col("quantity").rolling_sum(window_size=2))

但是,当我尝试执行此操作时出现错误:

df.select(pl.col("quantity").rolling_sum(window_size=pl.col("periods")))

这是错误

TypeError: argument 'window_size': 'Expr' object cannot be converted to 'PyString'

如何根据另一列传递

window_size
的值?我也看过使用
rolling
但也无法弄清楚。

python python-polars
2个回答
1
投票

与@jqurious 非常相似,但(我认为)有点简化

df.lazy() \
    .with_row_count('i') \
    .with_columns(
        window = 
            pl.arange(
                pl.col("i"), 
                pl.col("i") + pl.col("periods")),
            qty=pl.col('quantity').list()
) \
.with_columns(
    rollsum=pl.col('qty').arr.take(pl.col('window')).arr.sum()
) \
.select(pl.exclude(['window','qty','i'])) \
.collect()

它的工作原理相同,但它本质上只是将整个

quantity
列重新创建为列表,然后使用
window
列将该列表过滤为相应的值并对它们进行求和。

另一种方法是仅使用循环,这将提高内存效率。

首先,您想要获取周期的所有唯一值,然后在 df 中初始化滚动总和的列,反转顺序,然后用每个周期的计算替换该列。 最后,将行放回原来的顺序。

periods=df.get_column('periods').unique()
df=df.with_columns(pl.lit(None).cast(pl.Float64()).alias("rollsum")).sort('date',reverse=True)
for period in periods:
    df=df.with_columns((pl.when(pl.col('periods')==period).then(pl.col('quantity').rolling_sum(window_size=period)).otherwise(pl.col('rollsum'))).alias('rollsum'))
df=df.sort('date')
df

1
投票

看起来这应该更容易做到,这表明我可能错过了一些明显的东西。

您可以为每个窗口生成索引范围,例如使用

.int_ranges()

(df.with_row_index()
   .with_columns(
      pl.int_ranges(pl.col.index, pl.col.index + pl.col.periods)
        .alias("window")
   )
)   
shape: (5, 6)
┌───────┬────────────┬─────────┬──────────┬───────────┬───────────┐
│ index ┆ date       ┆ periods ┆ quantity ┆ calculate ┆ window    │
│ ---   ┆ ---        ┆ ---     ┆ ---      ┆ ---       ┆ ---       │
│ u32   ┆ date       ┆ i64     ┆ i64      ┆ i64       ┆ list[i64] │
╞═══════╪════════════╪═════════╪══════════╪═══════════╪═══════════╡
│ 0     ┆ 2023-02-01 ┆ 2       ┆ 10       ┆ 22        ┆ [0, 1]    │
│ 1     ┆ 2023-02-02 ┆ 2       ┆ 12       ┆ 26        ┆ [1, 2]    │
│ 2     ┆ 2023-02-03 ┆ 2       ┆ 14       ┆ 30        ┆ [2, 3]    │
│ 3     ┆ 2023-02-04 ┆ 1       ┆ 16       ┆ 16        ┆ [3]       │
│ 4     ┆ 2023-02-05 ┆ 1       ┆ 18       ┆ 18        ┆ [4]       │
└───────┴────────────┴─────────┴──────────┴───────────┴───────────┘

爆炸后,

.search_sorted()
可用于查找具有该索引的行的位置。

然后使用

.get()
从该行中提取
quantity
值。

.group_by
用于恢复原来的框架形状。

(df.with_row_index()
   .with_columns(
      pl.int_ranges(pl.col.index, pl.col.index + pl.col.periods)
        .alias("window")
   )
   .explode("window")
   .with_columns(
      pl.col.quantity.get(
         pl.col.index.search_sorted(pl.col.window)
      )
      .alias("total")
   )
   .group_by("index", maintain_order=True)
   .agg(
      pl.exclude("total").first(), 
      pl.col.total.sum()
   )
)
shape: (5, 7)
┌───────┬────────────┬─────────┬──────────┬───────────┬────────┬───────┐
│ index ┆ date       ┆ periods ┆ quantity ┆ calculate ┆ window ┆ total │
│ ---   ┆ ---        ┆ ---     ┆ ---      ┆ ---       ┆ ---    ┆ ---   │
│ u32   ┆ date       ┆ i64     ┆ i64      ┆ i64       ┆ i64    ┆ i64   │
╞═══════╪════════════╪═════════╪══════════╪═══════════╪════════╪═══════╡
│ 0     ┆ 2023-02-01 ┆ 2       ┆ 10       ┆ 22        ┆ 0      ┆ 22    │
│ 1     ┆ 2023-02-02 ┆ 2       ┆ 12       ┆ 26        ┆ 1      ┆ 26    │
│ 2     ┆ 2023-02-03 ┆ 2       ┆ 14       ┆ 30        ┆ 2      ┆ 30    │
│ 3     ┆ 2023-02-04 ┆ 1       ┆ 16       ┆ 16        ┆ 3      ┆ 16    │
│ 4     ┆ 2023-02-05 ┆ 1       ┆ 18       ┆ 18        ┆ 4      ┆ 18    │
└───────┴────────────┴─────────┴──────────┴───────────┴────────┴───────┘
© www.soinside.com 2019 - 2024. All rights reserved.