我在 Polars 中有一个
train_test_split
函数,可以处理急切的 DataFrame。我希望编写一个等效函数,它可以将 LazyFrame 作为输入并返回两个 LazyFrame,而不对其进行评估。
我的功能如下。它会打乱所有行,然后根据全帧的高度使用行索引将其分割。
def train_test_split(
df: pl.DataFrame, train_fraction: float = 0.75
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""Split polars dataframe into two sets.
Args:
df (pl.DataFrame): Dataframe to split
train_fraction (float, optional): Fraction that goes to train. Defaults to 0.75.
Returns:
Tuple[pl.DataFrame, pl.DataFrame]: Tuple of train and test dataframes
"""
df = df.with_columns(pl.all().shuffle(seed=1))
split_index = int(train_fraction * df.height)
df_train = df[:split_index]
df_test = df[split_index:]
return df_train, df_test
df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [4, 3, 2, 1]})
train, test = train_test_split(df)
# this is what the above looks like:
train = pl.DataFrame({'a': [2, 3, 4], 'b': [3, 2, 1]})
test = pl.DataFrame({'a': [1], 'b': [4]})
然而,Lazyframes 的高度未知,因此我们必须采用另一种方式。我有两个想法,但都遇到了问题:
df.sample(frac=train_fraction, with_replacement=False, shuffle=False)
。这样我可以获得火车部分,但无法获得测试部分。.apply
相当于 np.random.uniform
,这会非常耗时。.with_row_count()
并过滤大于总行数的一部分的行,但这里我还需要高度,并且创建行数可能会很昂贵。最后,我可能会以错误的方式处理这个问题:我可以事先计算总行数,但我不知道这被认为有多昂贵。
这里有一个大数据框可供测试(需要约 1 秒)以急切地运行我的函数:
N = 50_000_000
df_big = pl.DataFrame(
[
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
],
schema=["a", "b", "c", "d", "e"],
)
这是在惰性模式下使用 Polars 进行此操作的一种方法 with_row_index:
def train_test_split_lazy(
df: pl.DataFrame, train_fraction: float = 0.75
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""Split polars dataframe into two sets.
Args:
df (pl.DataFrame): Dataframe to split
train_fraction (float, optional): Fraction that goes to train. Defaults to 0.75.
Returns:
Tuple[pl.DataFrame, pl.DataFrame]: Tuple of train and test dataframes
"""
df = df.with_columns(pl.all().shuffle(seed=1)).with_row_index()
df_train = df.filter(pl.col("index") < pl.col("index").max() * train_fraction)
df_test = df.filter(pl.col("index") >= pl.col("index").max() * train_fraction)
return df_train, df_test
然后:
df_big = pl.DataFrame(
[
pl.int_range(N, eager=True),
pl.int_range(N, eager=True),
pl.int_range(N, eager=True),
pl.int_range(N, eager=True),
pl.int_range(N, eager=True),
],
schema=["a", "b", "c", "d", "e"],
).lazy()
train, test = train_test_split_lazy(df_big)
print(train.collect())
print(test.collect())
shape: (37_500_000, 6)
┌──────────┬──────────┬──────────┬──────────┬──────────┬──────────┐
│ row_nr ┆ a ┆ b ┆ c ┆ d ┆ e │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞══════════╪══════════╪══════════╪══════════╪══════════╪══════════╡
│ 0 ┆ 27454110 ┆ 27454110 ┆ 27454110 ┆ 27454110 ┆ 27454110 │
│ 1 ┆ 2309916 ┆ 2309916 ┆ 2309916 ┆ 2309916 ┆ 2309916 │
│ 2 ┆ 15065100 ┆ 15065100 ┆ 15065100 ┆ 15065100 ┆ 15065100 │
│ 3 ┆ 12766444 ┆ 12766444 ┆ 12766444 ┆ 12766444 ┆ 12766444 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 37499996 ┆ 40732880 ┆ 40732880 ┆ 40732880 ┆ 40732880 ┆ 40732880 │
│ 37499997 ┆ 32447037 ┆ 32447037 ┆ 32447037 ┆ 32447037 ┆ 32447037 │
│ 37499998 ┆ 41754221 ┆ 41754221 ┆ 41754221 ┆ 41754221 ┆ 41754221 │
│ 37499999 ┆ 7019133 ┆ 7019133 ┆ 7019133 ┆ 7019133 ┆ 7019133 │
└──────────┴──────────┴──────────┴──────────┴──────────┴──────────┘
shape: (12_500_000, 6)
┌──────────┬──────────┬──────────┬──────────┬──────────┬──────────┐
│ row_nr ┆ a ┆ b ┆ c ┆ d ┆ e │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞══════════╪══════════╪══════════╪══════════╪══════════╪══════════╡
│ 37500000 ┆ 29107559 ┆ 29107559 ┆ 29107559 ┆ 29107559 ┆ 29107559 │
│ 37500001 ┆ 26750366 ┆ 26750366 ┆ 26750366 ┆ 26750366 ┆ 26750366 │
│ 37500002 ┆ 17450938 ┆ 17450938 ┆ 17450938 ┆ 17450938 ┆ 17450938 │
│ 37500003 ┆ 30333846 ┆ 30333846 ┆ 30333846 ┆ 30333846 ┆ 30333846 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 49999996 ┆ 17167194 ┆ 17167194 ┆ 17167194 ┆ 17167194 ┆ 17167194 │
│ 49999997 ┆ 9092583 ┆ 9092583 ┆ 9092583 ┆ 9092583 ┆ 9092583 │
│ 49999998 ┆ 1929693 ┆ 1929693 ┆ 1929693 ┆ 1929693 ┆ 1929693 │
│ 49999999 ┆ 35668469 ┆ 35668469 ┆ 35668469 ┆ 35668469 ┆ 35668469 │
在我的机器上,我在 100 次运行中平均在 0.455 秒内获得此输出。
如果我在您的
df.height
版本中作弊并用 50_000_000
替换 train_test_split
,然后运行它的惰性模式,我在 100 次运行中平均在 0.446 秒 内得到相同的输出,这相当于性能。