扩展基于NUMPY的代码,该代码检测连续数字的频率用于多维数组而不是1D数组

问题描述 投票:0回答:1
-stackoverflow答案

提供了一种简单的方法(下图),以找到连续重复数字的频率和索引。该解决方案比基于循环的代码快得多(请参见上面的原始帖子)。 boundaries = np.where(np.diff(aa) != 0)[0] + 1 #group boundaries get_idx_freqs = lambda i, d: (np.concatenate(([0], i))[d >= 2], d[d >= 2]) idx, freqs = get_idx_freqs(boundaries, np.diff(np.r_[0, boundaries, len(aa)]))

输出
# aa=np.array([1,2,2,3,3,3,4,4,4,4,5,5,5,5,5])
(array([ 1,  3,  6, 10]), array([2, 3, 4, 5]))

# aa=np.array([1,1,1,np.nan,np.nan,1,1,np.nan])
(array([0, 5]), array([3, 2]))

wondering如果可以扩展该解决方案以在多维数组而不是慢速传统循环上工作,则如下:

#%%
def get_frequency_of_events_fast(aa):
    boundaries = np.where(np.diff(aa) != 0)[0] + 1 #group boundaries

    get_idx_freqs = lambda i, d: (np.concatenate(([0], i))[d >= 2], d[d >= 2])

    idx, freqs = get_idx_freqs(boundaries, np.diff(np.r_[0, boundaries, len(aa)]))
    return idx,freqs

tmp2_file=np.load('tmp2.npz')
tmp2 = tmp2_file['arr_0']

idx_all=[]
frq_all=[]
for i in np.arange(tmp2.shape[1]):
    for j in np.arange(tmp2.shape[2]):
        print("==>> i, j "+str(i)+' '+str(j))
        idx,freq=get_frequency_of_events_fast(tmp2[:,i,j])
        idx_all.append(idx)
        frq_all.append(freq)
        #if j == 69:
        #    break
        print(idx)
        print(freq)
    #if i == 0:
    #    break

I将索引和频率附加到了一维列表中,我也想知道是否有一种方法可以附加到二维数组。

可以从
Box.com
下载该文件。这是样本输出

==>> i, j 0 61 [ 27 73 226 250 627 754 760 798 825 891 906] [ 12 8 5 17 109 5 12 26 30 12 3] ==>> i, j 0 62 [ 29 75 226 250 258 627 754 761 800 889] [ 11 7 5 6 6 114 5 14 57 21] ==>> i, j 0 63 [ 33 226 622 680 754 762 801 888] [ 9 5 56 63 5 21 58 26] ==>> i, j 0 64 [ 33 226 615 622 693 753 762 801 889 972 993] [12 5 4 68 54 6 21 60 26 3 2] ==>> i, j 0 65 [ 39 615 621 693 801 891 972 987 992] [ 7 3 70 90 61 24 3 2 7] ==>> i, j 0 66 [ 39 617 657 801 891 970 987] [ 7 34 132 63 30 5 13] ==>> i, j 0 67 [ 39 88 621 633 657 680 801 804 891 969 986] [ 11 4 6 2 6 110 2 63 30 6 14] ==>> i, j 0 68 [ 39 87 681 715 740 766 807 873 891 969 984] [12 6 33 3 22 24 60 3 31 6 16]

可能的解决方案(在我的计算机上,它瞬间运行):
# data = np.load('tmp2.npz')
# tmp2 = data['arr_0']

def get_freqs(aa):
    boundaries = np.where(np.diff(aa) != 0)[0] + 1
    edges = np.r_[0, boundaries, len(aa)]
    group_lengths = np.diff(edges)
    valid = group_lengths >= 2
    idx = np.concatenate(([0], boundaries))[valid]
    return idx, group_lengths[valid]

out = {
    (i, j): get_freqs(tmp2[:, i, j])
    for i, j in np.ndindex(tmp2.shape[1], tmp2.shape[2])
}
python numpy
1个回答
0
投票
该函数在一维数组中计算连续组的起始索引和长度,其中该值保持不变,忽略了少于两个元素的组。它通过首先使用

np.diff
识别更改点来做到这一点,然后用

np.r_

构建组边缘,并根据最小长度标准计算np.diff
的组长度,最后构建组长度。字典理解将此函数应用于3-D阵列
(i, j)的每个
tmp2
切片(即沿着第一维),将结果存储在由
(i, j)
索引键入的字典中。
    

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.