考虑以下数据框:
df = pl.DataFrame({
"letters": ["A", "B", "C", "D", "E", "F", "G", "H"],
"values": ["aa", "bb", "cc", "dd", "ee", "ff", "gg", "hh"]
})
print(df)
shape: (8, 2)
┌─────────┬────────┐
│ letters ┆ values │
│ --- ┆ --- │
│ str ┆ str │
╞═════════╪════════╡
│ A ┆ aa │
│ B ┆ bb │
│ C ┆ cc │
│ D ┆ dd │
│ E ┆ ee │
│ F ┆ ff │
│ G ┆ gg │
│ H ┆ hh │
└─────────┴────────┘
如何在满足给定条件的任何行周围选取大小为 +/- N 的窗口?例如,条件是
pl.col("letters").contains("D|F")
和N = 2
。那么,输出应该是:
┌─────────┬────────────────────────────────┐
│ letters ┆ output │
│ --- ┆ --- │
│ str ┆ list[str] │
╞═════════╪════════════════════════════════╡
│ D ┆ ["bb", "cc", "dd", "ee", "ff"] │
│ F ┆ ["dd", "ee", "ff", "gg", "hh"] │
└─────────┴────────────────────────────────┘
请注意,在这种情况下,窗口是重叠的(
F
窗口还包含dd
,D
窗口还包含ff
)。另外,请注意,为了简单起见,此处 N = 2,但实际上,它会更大(~10 - 20)。而且数据集相对较大,因此我希望尽可能高效地完成此操作,而不会导致内存使用量激增。
编辑:为了使问题更明确,这里是 DuckDB 的 SQL 语法中的查询,它给出了正确的答案(我想知道如何将其转换为 Polars):
df_table = df.to_arrow()
con = duckdb.connect()
query = """
SELECT
letters,
list(values) OVER (
ROWS BETWEEN 2 PRECEDING
AND 2 FOLLOWING
) as combined
FROM df_table
QUALIFY letters in ('D', 'F')
"""
print(pl.from_arrow(con.execute(query).arrow()))
shape: (2, 2)
┌─────────┬────────────────────────┐
│ letters ┆ combined │
│ --- ┆ --- │
│ str ┆ list[str] │
╞═════════╪════════════════════════╡
│ D ┆ ["bb", "cc", ... "ff"] │
│ F ┆ ["dd", "ee", ... "hh"] │
└─────────┴────────────────────────┘
我在 Amazon 的一台
ml.c5.xlarge
机器上的 Jupyter 笔记本中运行了建议的解决方案。当笔记本电脑运行时,我还在终端中保持 htop
打开,以观察 CPU 和内存的使用情况。数据集有 12M+ 行。
我通过 eager 和惰性 API 运行了这两种解决方案。为了更好地衡量,我还尝试使用简单的 Python for 循环在识别感兴趣的行以及 DuckDB 后提取切片。
Polars 具有非常强大的性能和明智的内存使用(使用 @jqurious' 方法),因为
.shift()
的巧妙、无复制实现。令人惊讶的是,一个经过深思熟虑的 Python for 循环也能起到同样的效果。 DuckDB 在速度和内存使用方面都表现得相当差。
Polars 和 DuckDB 都不使用多个核心进行操作。不确定这是否是由于缺乏优化造成的,或者这个问题是否适合并行化。我想我们只过滤一列,然后获取同一列的切片,因此没有太多多个线程可以做。
方法 | CPU使用 | 内存使用 | 时间 |
---|---|---|---|
ΩΠΟΚΚΚΡΥΜΜΕΝΟΣ | 单核 | 爆炸 | |
我很好奇 | 单核 | 2.53G 至 2.53G | 4.63秒 |
(智能)for循环 | 单核 | 2.53G至2.58G | 4.91秒 |
鸭数据库 | 单核 | 1.62G至6.13G | 38.6秒 |
preceding = 2
following = 2
look_around = [pl.col("body").shift(-i)
for i in range(-preceding, following + 1)]
(
df
.with_columns(
pl.when(pl.col('body').str.contains(regex))
.then(pl.concat_list(look_around))
.alias('combined')
)
.filter(pl.col('combined').is_not_null())
)
不幸的是,在我相当大的数据集上,这个解决方案导致内存使用爆炸,并且内核因急切 API 和惰性 API 崩溃。
preceding = 2
following = 2
look_around = [
pl.col("body").shift(-i).alias(f"lag_{i}") for i in range(-preceding, following + 1)
]
(
df
.with_columns(
look_around
)
.filter(pl.col("body").str.contains(regex))
.select(
pl.col("body"),
pl.concat_list([f"lag_{i}" for i in range(-2, 3)]).alias("output")
)
)
渴望:
懒:
preceding = 2
following = 2
output = []
indices = df.with_row_index().select(
pl.col("index").filter(pl.col("body").str.contains(regex))
)["index"]
for idx, x in enumerate(indices):
offset = max(0, x - preceding)
length = preceding + following + 1
output.append(df["body"].slice(offset, length))
请注意,在运行查询之前,我首先将
df
转换为 Arrow.Table
,以便 DuckDB 可以直接对其进行操作。另外,我不确定将结果转换回 Arrow 是否会占用大量计算量并且对它不公平。
preceding = 2
following = 2
query = f"""
SELECT
body,
list(body) OVER (
ROWS BETWEEN {preceding} PRECEDING
AND {following} FOLLOWING
) as combined
FROM df_table
QUALIFY regexp_matches(body, '{regex}')
"""
result = con.execute(query).arrow()
使用 DuckDB,我第一次尝试运行计算失败了。我必须通过直接读取 Arrow Table 来重试,而不使用 Polars(这节省了大约 1GB 内存),以便为 DuckDB 提供更多内存可供使用。
第一次尝试:
第二次尝试:
>>> (
... df
... .with_columns(
... [pl.col("values").shift(i).alias(f"lag_{i}") for i in range(-2, 3)])
... .filter(pl.col("letters").str.contains("D|F"))
... .select([
... pl.col("letters"),
... pl.concat_list(reversed([f"lag_{i}" for i in range(-2, 3)])).alias("output")
... ])
... )
shape: (2, 2)
┌─────────┬────────────────────────────────┐
│ letters | output │
│ --- | --- │
│ str | list[str] │
╞═════════╪════════════════════════════════╡
│ D | ["bb", "cc", "dd", "ee", "ff"] │
├─────────┼────────────────────────────────┤
│ F | ["dd", "ee", "ff", "gg", "hh"] │
└─//──────┴─//─────────────────────────────┘
你可以试试这个:
preceding = 2
following = 2
look_around = [pl.col("values").shift(-i)
for i in range(-preceding, following + 1)]
(
df
.with_column(
pl.when(pl.col('letters').str.contains('D|F'))
.then(pl.concat_list(look_around))
.alias('combined')
)
.filter(pl.col('combined').is_not_null())
)
shape: (2, 3)
┌─────────┬────────┬────────────────────────┐
│ letters ┆ values ┆ combined │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ list[str] │
╞═════════╪════════╪════════════════════════╡
│ D ┆ dd ┆ ["bb", "cc", ... "ff"] │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ F ┆ ff ┆ ["dd", "ee", ... "hh"] │
└─────────┴────────┴────────────────────────┘