我有一个包含 n 列的数据框。我想创建一个新的数据框,其中包含所有现有列以及任何 2 (ab,ac,...), 3(abc, abd, ...) , 4(abcd, abce, ...)、5(abcde、abcdf、...)、...和 n(abcdef...) 列。唯一的标准是任何产品中都不存在重复的列。这是 n=3 的示例:
原df:
col1: a
col2: b
col3: c
修改后的df:
col1: a
col2: b
col3: c
col4: a*b
col5: a*c
col6: b*c
col7: a*b*c
创建此类数据框最有效的方法是什么?
powerset
,然后eval
它:
# pip install more_itertools
from more_itertools import powerset
N = df.shape[1] # Nb of columns
out = (
df.assign(
**{
col: df.eval(col)
for col in map("*".join, list(powerset(df.columns))[N+1:])
}
)
)
注意:如果您无法安装 more_itertools,您可以使用文档中的this配方:
from itertools import chain, combinations
def powerset(iterable):
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
输出:
# powerset
[(),
('a',),
('b',),
('c',),
('a', 'b'), # we moved the cursor up to this one with N+1 == 4
('a', 'c'), # to avoid evaluating empty string and original cols
('b', 'c'),
('a', 'b', 'c')]
# out
a b c a*b a*c b*c a*b*c
0 0.55 0.72 0.60 0.39 0.33 0.43 0.24
1 0.54 0.42 0.65 0.23 0.35 0.27 0.15
2 0.44 0.89 0.96 0.39 0.42 0.86 0.38
3 0.38 0.79 0.53 0.30 0.20 0.42 0.16
4 0.57 0.93 0.07 0.53 0.04 0.07 0.04
5 0.09 0.02 0.83 0.00 0.07 0.02 0.00
6 0.78 0.87 0.98 0.68 0.76 0.85 0.66
7 0.80 0.46 0.78 0.37 0.62 0.36 0.29
8 0.12 0.64 0.14 0.08 0.02 0.09 0.01
9 0.94 0.52 0.41 0.49 0.39 0.22 0.20
[10 rows x 7 columns]
使用的输入:
import pandas as pd
import numpy as np
np.random.seed(0)
df = pd.DataFrame(np.random.rand(10, 3), columns=list("abc"))
我建议使用
itertools.combination()
获取指定长度的所有列组合,并循环遍历要包含的列数。
示例:
import numpy as np
import pandas as pd
import itertools
df = pd.DataFrame({col: np.random.rand(10) for col in 'abcdef'})
def all_combined_product_cols(df):
cols = list(df.columns)
product_cols = []
for length in range(1, len(cols) + 1):
for combination in itertools.combinations(cols, r=length):
combined_col = None
for col in combination:
if combined_col is None:
combined_col = df[col].copy()
else:
combined_col *= df[col]
combined_col.name = '_'.join(combination)
product_cols.append(combined_col)
return pd.concat(product_cols, axis=1)
print(all_combined_product_cols(df))