优化代码以查找最接近值列表的点子集

问题描述 投票:0回答:1

我有两个一维数组

A
B
(分别为长度
m
n
),以及一些数字
n_0
0 < n_0 < n
。我正在寻找最有效的方法来识别
n_0
B
A
的每个单独元素的最近点。

也就是说,我需要生成一个

m x n
布尔数组
M
这样:

  • ith
    M
    列恰好包含
    n_0
    个1和
    n-n_0
    个0
  • M[:,i]
    中带有1的位置是距离
    n_0
    最近的
    B
    A[i]
    点的索引。

这是我尝试过的。

import numpy as np

m = 1000
n = 1200
n_0 = 300

A = np.random.rand(m)
B = np.random.rand(n)

A_dist, B_dist = np.meshgrid(A, B, sparse=True, copy=False)

dist = (A_dist - B_dist)**2

qq = np.quantile(dist, n_0/n, axis = 0)

M = (dist <= qq)

我认为这可以改进的原因是,倒数第二行中分位数的计算没有利用矩阵

dist
的源 - 即
ith
jth
dist
是从两点
A[i,0]
A[j,0]
到 B 中的 完全相同的值的距离。

欢迎任何想法!

python arrays optimization quantile
1个回答
0
投票

您可以尝试利用 并行性来计算

M
:

import numba as nb
import numpy as np

def compute_M_orig(A, B, m, n, n_0):
    A_dist, B_dist = np.meshgrid(A, B, sparse=True, copy=False)
    dist = (A_dist - B_dist) ** 2
    qq = np.quantile(dist, n_0 / n, axis=0)
    M = dist <= qq
    return M

@nb.njit(parallel=True)
def compute_M_numba(A, B, m, n, n_0):
    B_sorted_idx = np.argsort(B)
    B_sorted = B[B_sorted_idx]

    M = np.zeros(shape=(n, m), dtype="uint8")

    indices = np.searchsorted(B_sorted, A)
    for i in nb.prange(len(indices)):
        idx = indices[i]

        l, h = idx - n_0, idx + n_0

        if l < 0:
            l = 0

        if h >= len(B):
            h = len(B)

        tmp = np.argsort(np.abs(B_sorted[l:h] - A[i]))
        closest = B_sorted_idx[l + tmp[:n_0]]
        M[closest, i] = 1

    return M


m = 10
n = 15
n_0 = 6

np.random.seed(0)

A = np.random.rand(m)
B = np.random.rand(n)

m1 = compute_M_orig(A, B, m, n, n_0)
m2 = compute_M_numba(A, B, m, n, n_0)

assert np.allclose(m1, m2)

基准与

perfplot

import perfplot


def _setup(m):
    A = np.random.rand(m)
    B = np.random.rand(m + int(m * 0.5))
    return A, B, len(A), len(B), int(len(B) * 0.3)


perfplot.show(
    setup=_setup,
    kernels=[
        compute_M_orig,
        compute_M_numba,
    ],
    labels=["orig", "numba"],
    n_range=[500, 1000, 5000, 10_000, 25_000],
    xlabel="m",
    logx=True,
    logy=True,
    equality_check=np.allclose,
)

在我的 AMD 5700x 上创建此图:

enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.