我正在 Codewars 上解决“双线性”型
事情是这样的:
考虑一个序列 u,其中 u 定义如下:
数字 u(0) = 1 是 u 中的第一个数字。 对于 u 中的每个 x,则 y = 2 * x + 1 和 z = 3 * x + 1 也必须在 u 中。 u 中没有其他数字。
例如:u = [1, 3, 4, 7, 9, 10, 13, 15, 19, 21, 22, 27, ...]
1 给出 3 和 4,然后 3 给出 7 和 10,4 给出 9 和 13,然后 7 给出 15 和 22 等等...
任务:
给定参数 n,函数 dbl_线性(或 dblLinear...)返回有序的元素 u(n)(带有 <) sequence u (so, there are no duplicates).
示例:
dbl_linear(10) 应返回 22
我提出了解决方案
def dbl_linear(n):
u_list= [1]
for i in range(0, n+1):
u_list.append(u_list[i] * 2 + 1)
u_list.append(u_list[i] * 3 + 1)
u_list= sorted(set(u_list))
return u_list[n]
它可以工作,但是当它达到 6 位数时,时间就耗尽了
我该如何优化这个解决方案?
如果排序不在循环中,它会在 200 之后产生错误的数字,所以它可能需要进入 而且时间还不够
我意识到这个答案有点晚了,但无论如何:
使用二等分
标准库有一个名为
bisect
的不错的模块,它允许使用复杂度为 O(logN) 的二分算法将元素添加到排序列表,同时保持排序。
不幸的是,
bisect.insort
没有避免插入重复值的选项,因此我们需要在简单的辅助函数中使用bisect.bisect()
:找到可能的插入点,确保我们没有插入重复值,并且最后插入我们的元素。
处理列表值的结尾
如果我们考虑到我们会经常在列表末尾附加数字,我们可以做另一个小的优化,所以在这种情况下我们不需要任何检查,只需
append
该元素即可。
这就是结果
from bisect import bisect
def my_insort(a, x):
if x > a[-1]:
a.append(x)
return
ip = bisect(a, x)
if x != a[ip-1]:
a.insert(ip, x)
def dbl_bisect(n):
u_list= [1]
for i in range(0, n+1):
my_insort(u_list, u_list[i] * 2 + 1)
my_insort(u_list, u_list[i] * 3 + 1)
return u_list[n]
随着
n
值的增加,速度会变得更加明显:让我们在 n = 10000
上检查一下
timeit('dbl_linear(10000)', globals=globals(), number=1)
8.237784899945837
timeit('dbl_bisect(10000)', globals=globals(), number=1000) #note: 1000 repeats
5.394557400024496
正如您所见,对于
n = 10000
,我们进行了 1500 倍的优化。还不错。
仅生成所需的元素
还有另一种优化:
dbl_linear()
和dbl_bisect()
都会生成太多无用元素:如果我们在乘以2时调用my_insort()
,并且我们得到一个超出第n
位置的插入点,我们知道任何新数字都将大于该数字,因此我们可以停止迭代;如果我们乘以 3,则情况并非如此,因为下一个数字可能小于当前数字;但我们至少可以停止考虑乘以 3。
这是新版本
def my_insort2(a, x, n):
if x > a[-1]:
if len(a) < n:
a.append(x)
return True
else:
return False
ip = bisect(a, x)
if x != a[ip-1]:
if ip <= n:
a.insert(ip, x)
return True
else:
return False
return True
def dbl_bisect2(n):
u_list= [1]
continue_2 = continue_3 = True
for i in range(0, n+1):
continue_2 = my_insort2(u_list, u_list[i] * 2 + 1, n)
if not continue_2:
break
if continue_3:
continue_3 = my_insort2(u_list, u_list[i] * 3 + 1, n)
return u_list[n]
timeit('dbl_bisect2(10000)', globals=globals(), number=1000)
2.7067045000148937
好的,现在我们的速度提高了 3000 倍(当
n
增长时,进一步的 2 倍加速似乎不会改变)