与标准排序相比,基数排序比预期慢

问题描述 投票:0回答:2

我在 Python 中实现了两个版本的基数排序(该版本允许对值达到 n² 的整数进行排序,其中 n 是要排序的列表的大小),以针对标准排序 (Timsort) 进行基准测试。我使用 PyPy 进行更公平的比较。

令人惊讶的是,即使不使用哈希图(使用直接访问数组),我的基数排序实现也比标准排序慢,即使对于较大的输入大小也是如此。我缺少微观优化,因为 O(n) 最终应该击败 O(nlogn)。我正在寻求建议以实现更好的表现。我这样做是为了学习目的,因此我并不是在寻找内置函数、库或 C 编译的代码来从 Python 调用。

我可以进行微观优化吗?我的代码真的是 O(n) 吗?

enter image description here

我的代码在 AMD Ryzen 9 7950X CPU 上运行最多需要 10 秒:

import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict

def radix_sort(arr, size):
    least_sig_digit = defaultdict(list)
    for num in arr:
        q, r = divmod(num, size)
        least_sig_digit[r].append(q)
    highest_sig_digit = defaultdict(list)
    for k in range(size):  # k goes in order of lowest significant digit
        for q in least_sig_digit[k]:
            highest_sig_digit[q].append(q*size+k)
    i: int = 0
    for k in range(size):
        for num in highest_sig_digit[k]:
            arr[i] = num
            i += 1
    return arr

def radix_sort_no_hashmap(arr, size):
    least_sig_digit = [[] for _ in range(size)]
    for num in arr:
        q, r = divmod(num, size)
        least_sig_digit[r].append(q)
    highest_sig_digit = [[] for _ in range(size)]
    for k in range(size):  # k goes in order of lowest significant digit
        for q in least_sig_digit[k]:
            highest_sig_digit[q].append(q*size+k)
    i: int = 0
    for k in range(size):
        for num in highest_sig_digit[k]:
            arr[i] = num
            i += 1
    return arr


def benchmark_sorting_algorithms():
    sizes = [1000, 10000, 100000, 200000, 1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 10000000]
    radix_times = []
    radix_sort_no_hashmap_times = []
    std_sort_times = []

    for size in sizes:
        array = random.sample(range(1, size**2), size)

        new_arr = array.copy()
        start_time = time.time()
        a = radix_sort(new_arr, size)
        radix_times.append(time.time() - start_time)

        new_arr = array.copy()
        start_time = time.time()
        b = radix_sort_no_hashmap(new_arr, size)
        radix_sort_no_hashmap_times.append(time.time() - start_time)

        new_arr = array.copy()
        start_time = time.time()
        c = sorted(new_arr)
        std_sort_times.append(time.time() - start_time)

        for k in range(len(array)):
            assert a[k] == b[k] == c[k]

    return sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times


sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times = benchmark_sorting_algorithms()

plt.figure(figsize=(12, 6))
plt.plot(sizes, radix_times, label='Radix Sort (O(n))')
plt.plot(sizes, std_sort_times, label='Standard Sort (O(nlogn))')
plt.plot(sizes, radix_sort_no_hashmap_times, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()

按照@user24714692的建议,在CPP中发布了同样的问题。

python algorithm sorting complexity-theory pypy
2个回答
1
投票
这里的版本似乎是当前 CPython 和一些旧 PyPy 中的版本的 1.5-2.5 倍(我希望您可以将其包含在基准/图中)。我想它在 C++ 中可能会更有用。

它避免创建这么多

list

对象。相反,它使用 
arr
 形式的链表 - 索引存储在三个长列表中。每个余数/商一个链表: 
first[r]
last[r]
 告诉余数 
r
 链表的结尾,而 
next[i]
 告诉列表中的下一个元素。 
First
/
Last
/
Next
 相同,但对于商 q。

def radix_sorted(arr): n = len(arr) # Sort by x%n first = [-1] * n last = [-1] * n next = [-1] * n for i, x in enumerate(arr): r = x % n if first[r] == -1: first[r] = i else: next[last[r]] = i last[r] = i # Sort by x//n First = [-1] * n Last = [-1] * n Next = [-1] * n for r in range(n): i = first[r] while i != -1: x = arr[i] q = x // n if First[q] == -1: First[q] = i else: Next[Last[q]] = i Last[q] = i i = next[i] # Output out = [] for q in range(n): i = First[q] while i != -1: out.append(arr[i]) i = Next[i] return out # Demo / testing import random n = 10**6 arr = random.sample(range(n**2), n) print(sorted(arr) == radix_sorted(arr))

在线尝试!


-3
投票
时间复杂度为

O(N)

。有时也称为 
O(K * N))
,K 是位数。然而, 
K
 是恒定的并且 
O(N)
 是正确的。

我会添加一个

seed()

以获得更准确的结果:

import matplotlib.pyplot as plt import random import time from collections import defaultdict def radix_sort(arr, base): least_sig_digit = defaultdict(list) for num in arr: q, r = divmod(num, base) least_sig_digit[r].append(q) highest_sig_digit = defaultdict(list) for r in sorted(least_sig_digit.keys()): for q in least_sig_digit[r]: highest_sig_digit[q].append(q * base + r) arr.clear() for k in sorted(highest_sig_digit.keys()): arr.extend(highest_sig_digit[k]) return arr def radix_sort_no_hashmap(arr, size): least_sig_digit = [[] for _ in range(size)] highest_sig_digit = [[] for _ in range(size)] for num in arr: q, r = divmod(num, size) least_sig_digit[r].append(q) for k in range(size): for q in least_sig_digit[k]: highest_sig_digit[q % size].append(q * size + k) return [num for sublist in highest_sig_digit for num in sublist] def benchmark_sorting_algorithms(): sizes = [1_000, 10_000, 200_000, 400_000, 800_000, 1_200_000] std_sorted, std_sort, radix, radix_no_hashmap = [], [], [], [] random.seed(40) for size in sizes: array = random.sample(range(1, size ** 2), size) A, B, C, D = array.copy(), array.copy(), array.copy(), array.copy() start_time = time.time() _ = radix_sort(A, size) radix.append(time.time() - start_time) start_time = time.time() _ = radix_sort_no_hashmap(B, size) radix_no_hashmap.append(time.time() - start_time) start_time = time.time() _ = sorted(C) std_sorted.append(time.time() - start_time) start_time = time.time() D.sort() std_sort.append(time.time() - start_time) return sizes, radix, std_sorted, std_sort, radix_no_hashmap sizes, radix, std_sorted, std_sort, radix_no_hashmap = benchmark_sorting_algorithms() plt.figure(figsize=(12, 6)) plt.plot(sizes, radix, label='Radix Sort (O(n))') plt.plot(sizes, std_sorted, label='Standard Sort using sorted() (O(n log n))') plt.plot(sizes, std_sort, label='Standard Sort using .sort() (O(n log n))') plt.plot(sizes, radix_no_hashmap, label='Radix Sort (O(n)) - No Hashmap') plt.xlabel('Input size (n)') plt.xscale('log') plt.ylabel('Time (seconds)') plt.yscale('log') plt.title('Radix Sort vs Standard Sort') plt.legend() plt.grid(True) plt.show()
打印

Python3 绘图

enter image description here

PyPy 7.3.16 绘图

enter image description here


PyPy 7.3.16 绘图

    如果我们使用 PyPy 7.3.16 对
  • array = list(range(size ** 20, size ** 20 - size, -1))
     进行排序和基准测试:
def benchmark_sorting_algorithms(): sizes = [1_000, 10_000, 20_000, 80_000, 200_000, 800_000] std_sorted, std_sort, radix, radix_no_hashmap = [], [], [], [] for size in sizes: array = list(range(size ** 20, size ** 20 - size, -1)) A, B, C, D = array.copy(), array.copy(), array.copy(), array.copy() start_time = time.time() _ = radix_sort(A, size) radix.append(time.time() - start_time) start_time = time.time() _ = radix_sort_no_hashmap(B, size) radix_no_hashmap.append(time.time() - start_time) start_time = time.time() _ = sorted(C) std_sorted.append(time.time() - start_time) start_time = time.time() D.sort() std_sort.append(time.time() - start_time) return sizes, radix, std_sorted, std_sort, radix_no_hashmap

enter image description here

评论

    请注意,Python Sort 和 sort() 速度更快,并且是用 C 实现的。
  • 我会在 C 中实现它并对其进行基准测试以获得更准确的结果。
© www.soinside.com 2019 - 2024. All rights reserved.