考虑以下示例:
import polars as pl
df = pl.DataFrame(
[
pl.Series(
"name", ["A", "B", "C", "D"], dtype=pl.Enum(["A", "B", "C", "D"])
),
pl.Series("month", [1, 2, 12, 1], dtype=pl.Int8()),
pl.Series(
"category", ["x", "x", "y", "z"], dtype=pl.Enum(["x", "y", "z"])
),
]
)
print(df)
shape: (4, 3)
┌──────┬───────┬──────────┐
│ name ┆ month ┆ category │
│ --- ┆ --- ┆ --- │
│ enum ┆ i8 ┆ enum │
╞══════╪═══════╪══════════╡
│ A ┆ 1 ┆ x │
│ B ┆ 2 ┆ x │
│ C ┆ 12 ┆ y │
│ D ┆ 1 ┆ z │
└──────┴───────┴──────────┘
我们可以计算数据框中与一年中每个月匹配的月份数:
from math import inf
binned_df = (
df.select(
pl.col.month.hist(
bins=[x + 1 for x in range(11)],
include_breakpoint=True,
).alias("binned"),
)
.unnest("binned")
.with_columns(
pl.col.breakpoint.map_elements(
lambda x: 12 if x == inf else x, return_dtype=pl.Float64()
)
.cast(pl.Int8())
.alias("month")
)
.drop("breakpoint")
.select("month", "count")
)
print(binned_df)
shape: (12, 2)
┌───────┬───────┐
│ month ┆ count │
│ --- ┆ --- │
│ i8 ┆ u32 │
╞═══════╪═══════╡
│ 1 ┆ 2 │
│ 2 ┆ 1 │
│ 3 ┆ 0 │
│ 4 ┆ 0 │
│ 5 ┆ 0 │
│ … ┆ … │
│ 8 ┆ 0 │
│ 9 ┆ 0 │
│ 10 ┆ 0 │
│ 11 ┆ 0 │
│ 12 ┆ 1 │
└───────┴───────┘
(注意:有 3 个类别
"x"
、"y"
和 "z"
,因此我们期望形状为 12 x 3 = 36 的数据框。)
假设我想对每列的数据进行分类
"category"
。我可以做以下事情:
# initialize an empty dataframe
category_binned_df = pl.DataFrame()
for cat in df["category"].unique():
# repeat the binning logic from earlier, except on a dataframe filtered for
# the particular category we are iterating over
binned_df = (
df.filter(pl.col.category.eq(cat)) # <--- the filter
.select(
pl.col.month.hist(
bins=[x + 1 for x in range(11)],
include_breakpoint=True,
).alias("binned"),
)
.unnest("binned")
.with_columns(
pl.col.breakpoint.map_elements(
lambda x: 12 if x == inf else x, return_dtype=pl.Float64()
)
.cast(pl.Int8())
.alias("month")
)
.drop("breakpoint")
.select("month", "count")
.with_columns(category=pl.lit(cat).cast(df["category"].dtype))
)
# finally, vstack ("append") the resulting dataframe
category_binned_df = category_binned_df.vstack(binned_df)
print(category_binned_df)
shape: (36, 3)
┌───────┬───────┬──────────┐
│ month ┆ count ┆ category │
│ --- ┆ --- ┆ --- │
│ i8 ┆ u32 ┆ enum │
╞═══════╪═══════╪══════════╡
│ 1 ┆ 1 ┆ x │
│ 2 ┆ 1 ┆ x │
│ 3 ┆ 0 ┆ x │
│ 4 ┆ 0 ┆ x │
│ 5 ┆ 0 ┆ x │
│ … ┆ … ┆ … │
│ 8 ┆ 0 ┆ z │
│ 9 ┆ 0 ┆ z │
│ 10 ┆ 0 ┆ z │
│ 11 ┆ 0 ┆ z │
│ 12 ┆ 1 ┆ z │
└───────┴───────┴──────────┘
over
来做到这一点,比如 pl.col.month.hist(bins=...).over("category")
,但尝试这样做的第一步会引发错误:
df.select(
pl.col.month.hist(
bins=[x + 1 for x in range(11)],
include_breakpoint=True,
)
.over("category")
.alias("binned"),
)
ComputeError: the length of the window expression did not match that of the group
Error originated in expression: 'col("month").hist([Series]).over([col("category")])'
那么,当我想到
over
时,我犯了某种概念性错误吗?有没有办法在这里使用over
?
嗯,用groupbyexplode也可以,但是不知道效率如何
df.group_by("category").agg(
# one row per category with the whole list as the element
pl.col.month.hist(
bins=[x + 1 for x in range(11)],
include_breakpoint=True,
)
).explode("month").unnest("month").with_columns(
pl.col("breakpoint").replace(float("inf"), 12.0).cast(int)
)
over
类似于 transform
中的 pandas
,因为它保留了行的原始标识,因此在这里使用它会很奇怪 - 您尝试执行的操作是 groupby-agg,没有连接回原始表行。 (也就是说,我不确定为什么 over
在这里不做 某事;它失败了,这似乎很奇怪。)