Polars 相当于 groupby 聚合中的 Pandas idxmin 是多少?

问题描述 投票:0回答:2

我正在寻找在 group_by agg 操作中相当于 Pandas idxmin 的 Polars

使用 Polars 和此示例数据框:

import polars as pl

dct = {
    "a": [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
    "b": [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5],
    "c": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    "d": [0.18418279, 0.67394382, 0.80951643, 0.10115085, 0.03998497, 0.17771175, 0.28428486,     0.24553192, 0.31388881, 0.07525366, 0.28622033, 0.61240989]
}

df = pl.DataFrame(dct)
a   b   c   d
i64 i64 i64 f64
0   0   0   0.184183
0   0   1   0.673944
0   1   2   0.809516
0   1   3   0.101151
1   2   4   0.039985
1   2   5   0.177712
1   3   6   0.284285
1   3   7   0.245532
2   4   8   0.313889
2   4   9   0.075254
2   5   10  0.28622
2   5   11  0.61241

和group_by操作:

df.group_by(["a", "b"], maintain_order=True).agg(pl.col('d').min())
shape: (6, 3)
a   b   d
i64 i64 f64
0   0   0.184183
0   1   0.101151
1   2   0.039985
1   3   0.245532
2   4   0.075254
2   5   0.28622

我需要原始数据帧 c 列中的值与 d 列中的聚合值相对应。

在 Pandas 中我可以通过以下方式得到结果:

import pandas as pd
df2 = pd.DataFrame(dct)
df2.iloc[df2.groupby(["a", "b"]).agg({"d": "idxmin"}).d]
    a   b   c   d
0   0   0   0   0.184183
3   0   1   3   0.101151
4   1   2   4   0.039985
7   1   3   7   0.245532
9   2   4   9   0.075254
10  2   5   10  0.286220

我在 Polars 中没有找到 idxmin 的等效函数。有 arg_min,但它没有给出所需的结果,如以下输出所示:

df.group_by(["a", "b"], maintain_order=True).agg(pl.col('d').arg_min())
shape: (6, 3)
a   b   d
i64 i64 u32
0   0   0
0   1   1
1   2   0
1   3   1
2   4   1
2   5   0

实际的数据帧有几百万行,并且在 Pandas 中需要很长时间(分钟)来计算。 如何在 Polars 中高效地达到相同的结果?

group-by python-polars
2个回答
3
投票

您可以使用

Expr.get
按索引从列中获取。

df.group_by(["a", "b"], maintain_order=True).agg(
    c=pl.col("c").get(pl.col("d").arg_min()), d=pl.col("d").min(),
)
┌─────┬─────┬─────┬──────────┐
│ a   ┆ b   ┆ c   ┆ d        │
│ --- ┆ --- ┆ --- ┆ ---      │
│ i64 ┆ i64 ┆ i64 ┆ f64      │
╞═════╪═════╪═════╪══════════╡
│ 0   ┆ 0   ┆ 0   ┆ 0.184183 │
│ 0   ┆ 1   ┆ 3   ┆ 0.101151 │
│ 1   ┆ 2   ┆ 4   ┆ 0.039985 │
│ 1   ┆ 3   ┆ 7   ┆ 0.245532 │
│ 2   ┆ 4   ┆ 9   ┆ 0.075254 │
│ 2   ┆ 5   ┆ 10  ┆ 0.28622  │
└─────┴─────┴─────┴──────────┘

2
投票

你可以做

df.with_row_count().group_by("a", "b", maintain_order=True).agg(
    pl.col("d").min(),
    pl.col("row_nr").filter(pl.col("d") == pl.col("d").min()).first()
)

这给出了

shape: (6, 4)
┌─────┬─────┬──────────┬────────┐
│ a   ┆ b   ┆ d        ┆ row_nr │
│ --- ┆ --- ┆ ---      ┆ ---    │
│ i64 ┆ i64 ┆ f64      ┆ u32    │
╞═════╪═════╪══════════╪════════╡
│ 0   ┆ 0   ┆ 0.184183 ┆ 0      │
│ 0   ┆ 1   ┆ 0.101151 ┆ 3      │
│ 1   ┆ 2   ┆ 0.039985 ┆ 4      │
│ 1   ┆ 3   ┆ 0.245532 ┆ 7      │
│ 2   ┆ 4   ┆ 0.075254 ┆ 9      │
│ 2   ┆ 5   ┆ 0.28622  ┆ 10     │
└─────┴─────┴──────────┴────────┘
© www.soinside.com 2019 - 2024. All rights reserved.