优化:从大于(或等于)`x`的数组返回最小值

问题描述 投票:0回答:1

编辑:我的问题与建议的副本不同,因为我已经有一个实现lowest的方法。我的问题不是如何实现lowest,而是如何优化lowest以更快地运行。

假设我有一个数组a。例如:

import numpy as np
a = np.array([2, 1, 3, 4, 5, 6, 7, 8, 9])

假设我有一个浮动x。例如:

x = 6.5

我想返回a中也是大于或等于x的最低值。所以在这种情况下......

print lowest(a, x)
>>> 7

我已经尝试了一些功能来代替lowest。例如:

def lowest(a, x):
""" `a` should be a sorted numpy array"""
    return lowest[lowest >= x][0]

def lowest(a, x):
""" `a` should be a sorted `list`, not a numpy array"""
    k = sorted(a + [x])
    return k[k.index(x) + 1]

然而,函数lowest仍然是我的代码的瓶颈在~90%。

有没有更快的方法来实现lowest功能?

关于我的代码的一些规则:

  • 可以假设a的长度为10
  • 函数lowest运行至少100k次。这可能是一个设计问题,但我感兴趣的是如果首先更快地实现我的问题。
  • a可以在运行这些循环之前进行预处理。 x会有所不同,但a不会。
  • 可以假设a[0] <= x <= a[-1]总是True
python python-2.7 numpy optimization
1个回答
2
投票

这是使用查找表的O(1)解决方案与OP(第一)解决方案和numpy.searchsorted相比较。这不是100%公平,因为OP的解决方案没有矢量化。无论如何,时间:

True                  # results equal
True                  # results equal
0.08163515606429428   # lookup
2.1996873939642683    # OP
0.016975965932942927  # numpy.searchsorted

对于这个小的列表大小seachsorted获胜即使它是O(log n)。

码:

import numpy as np

class find_next:
    def __init__(self, a, max_bins=100000):
        self.a = np.sort(a)
        self.low = self.a[0]
        self.high = self.a[-1]
        self.span = self.a[-1] - self.a[0]
        self.damin = np.diff(self.a).min()
        if self.span // self.damin > max_bins:
            raise ValueError('a too unevenly spaced for max_bins')
        self.lut = np.searchsorted(self.a, np.linspace(self.low, self.high,
                                                       max_bins + 1))
        self.no_bins = max_bins
    def f_pp(self, x):
        i = np.array((x-self.low)/self.span * self.no_bins, int)
        return self.a[self.lut[i + (x > self.a[self.lut[i]])]]
    def lowest(self, x):
        return self.a[self.a >= x][0]
    def f_ss(self, x):
        return self.a[self.a.searchsorted(x)]

a = np.array([2, 1, 3, 4, 5, 6, 7, 8, 9])

x = np.random.uniform(1, 9, (10000,))

fn = find_next(a)
sol_pp = fn.f_pp(x)
sol_OP = [fn.lowest(xi) for xi in x]
sol_ss = fn.f_ss(x)

print(np.all(sol_OP == sol_pp))
print(np.all(sol_OP == sol_ss))

from timeit import timeit
kwds = dict(globals=globals(), number=10000)

print(timeit('fn.f_pp(x)', **kwds))
print(timeit('[fn.lowest(xi) for xi in x]', **kwds))
print(timeit('fn.f_ss(x)', **kwds))
© www.soinside.com 2019 - 2024. All rights reserved.