假设我有两组要点:
>>> points1.shape
(10000, 3)
>>> points2.shape
(1529, 3)
我想在points1
的一个点的欧几里德距离cutoff
内列出points2
指数。我可以很容易地使用scipy.spatial.distance.cdist
这样做:
from scipy.spatial.distance import cdist
import numpy
indices = numpy.argwhere(cdist(points1, points2).min(axis=0) < cutoff)
然而,这似乎效率低下,因为我不需要知道彼此之间有多远,只要它们是否在一个截止距离内。 KDTree可以帮助解决这个问题吗?
这里有3个替代品,一个使用cdist,两个使用scipy.spatial.cKDTree:
import itertools as IT
import numpy as np
import scipy.spatial as spatial
import scipy.spatial.distance as dist
np.random.seed(2016)
points1 = np.random.randint(100, size=(10**5, 3))
points2 = np.random.randint(100, size=(1529, 3))
cutoff = 5
def using_cdist(points1, points2, cutoff):
indices = np.where(dist.cdist(points1, points2) <= cutoff)[0]
indices = np.unique(indices)
return indices
def using_kdtree(points1, points2, cutoff):
# build the KDTree using the *smaller* points array
tree = spatial.cKDTree(points2)
groups = tree.query_ball_point(points1, cutoff)
indices = np.unique([i for i, grp in enumerate(groups) if len(grp)])
return indices
def using_kdtree2(points1, points2, cutoff):
# build the KDTree using the *larger* points array
tree = spatial.cKDTree(points1)
groups = tree.query_ball_point(points2, cutoff)
indices = np.unique(IT.chain.from_iterable(groups))
return indices
cdist_result = using_cdist(points1, points2, cutoff)
kdtree_result = using_kdtree(points1, points2, cutoff)
kdtree_result2 = using_kdtree2(points1, points2, cutoff)
assert np.allclose(cdist_result, kdtree_result)
assert np.allclose(cdist_result, kdtree_result2)
在这3个替代品中,using_kdtree2
是最快的:
In [80]: %timeit using_kdtree3(points1, points2, cutoff)
10 loops, best of 3: 92.4 ms per loop
In [103]: %timeit using_kdtree(points1, points2, cutoff)
1 loops, best of 3: 938 ms per loop
In [104]: %timeit using_cdist(points1, points2, cutoff)
1 loops, best of 3: 1.51 s per loop
我对最快速度的直觉证明是完全错误的。我认为使用较小的点数组构建KDTree会是最快的。即使使用较大的点数组构建KDTree有点慢,在较小的点数组上调用tree.query_ball_point
要快得多:
In [68]: %timeit tree = spatial.cKDTree(points2)
1000 loops, best of 3: 312 µs per loop
In [69]: %timeit tree = spatial.cKDTree(points1)
10 loops, best of 3: 45.7 ms per loop
In [66]: %timeit tree = spatial.cKDTree(points2); groups = tree.query_ball_point(points1, cutoff)
1 loops, best of 3: 933 ms per loop
In [67]: %timeit tree = spatial.cKDTree(points1); groups = tree.query_ball_point(points2, cutoff)
10 loops, best of 3: 89.3 ms per loop
请注意,使用时存在一些问题
def orig(points1, points2, cutoff):
return np.argwhere(dist.cdist(points1, points2).min(axis=0) < cutoff)
首先,通过调用min(axis=0)
,如果points1
中的两个点都在cutoff
中的一个点的points2
内,则会丢失信息。您只能得到最近点的索引。另一个问题是,通过在0轴上调用min
,剩下的只是与points2
相关的1轴。所以orig
将指数归还给points2
,而不是points1
。
一些想法(?):
cdist
将计算平方根)来保存平方根的计算。