我需要通过消除任何重叠间隔来清理数据帧。数据框通常如下所示:
start end speaker
0 0.03 0.33 SPEAKER_02
1 1.24 6.91 SPEAKER_02
2 1.38 2.03 SPEAKER_00
3 7.11 9.64 SPEAKER_02
4 9.80 21.02 SPEAKER_02
5 15.37 15.52 SPEAKER_01
6 16.55 16.80 SPEAKER_00
7 21.36 26.40 SPEAKER_02
8 26.76 27.01 SPEAKER_02
在行重叠的情况下,我只想创建没有任何重叠的新行。
原来的第 1 行和第 2 行将变成具有这些不重叠间隔的 3 行:
1 1.24 1.37 SPEAKER_02
2 1.38 2.03 SPEAKER_00
3 2.04 6.91 SPEAKER_02
原来的第 4、5 和 6 行将变成具有这些不重叠间隔的 4 行:
4 9.80 15.36 SPEAKER_02
5 15.37 15.52 SPEAKER_01
6 15.53 16.54 SPEAKER_02
7 16.55 16.80 SPEAKER_00
因此输出数据帧将变成 10 行这些不重叠的间隔:
start end speaker
0 0.03 0.33 SPEAKER_02
1 1.24 1.37 SPEAKER_02
2 1.38 2.03 SPEAKER_00
3 2.04 6.91 SPEAKER_02
4 7.11 9.64 SPEAKER_02
5 9.80 15.36 SPEAKER_02
6 15.37 15.52 SPEAKER_01
7 15.53 16.54 SPEAKER_02
8 16.55 16.80 SPEAKER_00
9 21.36 26.40 SPEAKER_02
10 26.76 27.01 SPEAKER_02
一如既往,数据庞大,效率至关重要。
我尝试了一些内置的区间分析,但继续回退到自定义 for 循环,因为我最终需要更改行。
conditional_join
(使用janitor
):
def to_intervals(x):
vals = np.sort(x.values.ravel())
return pd.DataFrame({'start': vals[:-1],
'end': vals[1:]-0.01*np.r_[np.ones(len(vals)-2), 0]
})
tmp = (df
# identify the overlapping intervals
.sort_values(by=['start', 'end'])
.assign(max_end=lambda d: d['end'].cummax(),
group=lambda d: d['start'].ge(d['max_end'].shift()).cumsum())
)
out = (tmp.groupby('group')[['start', 'end']].apply(to_intervals)
.conditional_join(tmp,
('start', 'start', '>='),
('end', 'end', '<='),
how='left', keep='last',
right_columns=['speaker', 'group'],
)
)
输出:
start end speaker group
0 0.03 0.33 SPEAKER_02 0
1 1.24 1.37 SPEAKER_02 1
2 1.38 2.02 SPEAKER_00 1
3 2.03 6.91 SPEAKER_02 1
4 7.11 9.64 SPEAKER_02 2
5 9.80 15.36 SPEAKER_02 3
6 15.37 15.51 SPEAKER_01 3
7 15.52 16.54 SPEAKER_02 3
8 16.55 16.79 SPEAKER_00 3
9 16.80 21.02 SPEAKER_02 3
10 21.36 26.40 SPEAKER_02 4
11 26.76 27.01 SPEAKER_02 5
中级
tmp
:
start end speaker max_end group
0 0.03 0.33 SPEAKER_02 0.33 0
1 1.24 6.91 SPEAKER_02 6.91 1
2 1.38 2.03 SPEAKER_00 6.91 1
3 7.11 9.64 SPEAKER_02 9.64 2
4 9.80 21.02 SPEAKER_02 21.02 3
5 15.37 15.52 SPEAKER_01 21.02 3
6 16.55 16.80 SPEAKER_00 21.02 3
7 21.36 26.40 SPEAKER_02 26.40 4
8 26.76 27.01 SPEAKER_02 27.01 5
代码最后部分修改为:
out = (tmp.groupby('group')[['start', 'end']].apply(to_intervals)
.conditional_join(tmp,
('start', 'start', '>='),
('end', 'end', '<='),
how='left',
right_columns=['speaker', 'group'],
)
.groupby(['start', 'end'], as_index=False)
['speaker'].agg(','.join)
)
输出:
start end speaker
0 0.03 0.33 SPEAKER_02
1 1.24 1.37 SPEAKER_02
2 1.38 2.02 SPEAKER_02,SPEAKER_00
3 2.03 6.91 SPEAKER_02
4 7.11 9.64 SPEAKER_02
5 9.80 15.36 SPEAKER_02
6 15.37 15.51 SPEAKER_02,SPEAKER_01
7 15.52 16.54 SPEAKER_02
8 16.55 16.79 SPEAKER_02,SPEAKER_00
9 16.80 21.02 SPEAKER_02
10 21.36 26.40 SPEAKER_02
11 26.76 27.01 SPEAKER_02