基数排序实现的优化:与标准排序相比,比预期慢

问题描述 投票:0回答:1

我在 Python 中实现了两个版本的基数排序(该版本允许对值达到 n² 的整数进行排序,其中 n 是要排序的列表的大小),以针对标准排序 (Timsort) 进行基准测试。令人惊讶的是,即使不使用哈希图(使用直接访问数组),我的基数排序实现也比标准排序慢,即使对于较大的输入大小也是如此。我正在寻求有关优化实施以实现更好性能的建议。我这样做不是为了实用目的,只是为了学习目的,因为我对编程相当陌生,因此我不是在寻找内置函数、库或 C 编译的代码来从 Python 调用。

我可以进行微观优化吗?我的代码真的是 O(n) 吗?

enter image description here

我使用 PyPy 进行更公平的比较。

这是我的代码(该代码在AMD Ryzen 9 7950X CPU上运行最多需要10秒,仅供参考):

import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict

def radix_sort(arr, size):
    least_sig_digit = defaultdict(list)
    for k in range(len(arr)):
        q, r = divmod(arr[k], size)
        least_sig_digit[r].append(q)
    highest_sig_digit = defaultdict(list)
    for k in range(size):  # k goes in order of lowest significant digit
        for q in least_sig_digit[k]:
            highest_sig_digit[q].append(q*size+k)
    i: int = 0
    for k in range(size):
        for num in highest_sig_digit[k]:
            arr[i] = num
            i += 1
    return arr

def radix_sort_no_hashmap(arr, size):
    least_sig_digit = [[] for _ in range(size)]
    for k in range(len(arr)):
        q, r = divmod(arr[k], size)
        least_sig_digit[r].append(q)
    highest_sig_digit = [[] for _ in range(size)]
    for k in range(size):  # k goes in order of lowest significant digit
        for q in least_sig_digit[k]:
            highest_sig_digit[q].append(q*size+k)
    i: int = 0
    for k in range(size):
        for num in highest_sig_digit[k]:
            arr[i] = num
            i += 1
    return arr


def benchmark_sorting_algorithms():
    sizes = [1000, 10000, 100000, 200000, 1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 10000000]
    radix_times = []
    radix_sort_no_hashmap_times = []
    std_sort_times = []

    for size in sizes:
        array = random.sample(range(1, size**2), size)
        
        new_arr = array.copy()
        start_time = time.time()
        a = radix_sort(new_arr, size)
        radix_times.append(time.time() - start_time)
        
        new_arr = array.copy()
        start_time = time.time()
        b = radix_sort_no_hashmap(new_arr, size)
        radix_sort_no_hashmap_times.append(time.time() - start_time)

        new_arr = array.copy()
        start_time = time.time()
        c = sorted(new_arr)
        std_sort_times.append(time.time() - start_time)

        for k in range(len(array)):
            assert a[k] == b[k] == c[k]

    return sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times


sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times = benchmark_sorting_algorithms()

plt.figure(figsize=(12, 6))
plt.plot(sizes, radix_times, label='Radix Sort (O(n))')
plt.plot(sizes, std_sort_times, label='Standard Sort (O(nlogn))')
plt.plot(sizes, radix_sort_no_hashmap_times, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()
python algorithm sorting complexity-theory
1个回答
0
投票

时间复杂度为

O(N)
。有时也称为
O(K * N))
,K 是位数。然而,
K
是常数,并且
O(N)
是正确的。

我会添加一个

seed()
以获得更准确的结果:

import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict


def radix_sort(arr, base):
    least_sig_digit = defaultdict(list)

    for num in arr:
        q, r = divmod(num, base)
        least_sig_digit[r].append(q)

    highest_sig_digit = defaultdict(list)

    for r in sorted(least_sig_digit.keys()):
        for q in least_sig_digit[r]:
            highest_sig_digit[q].append(q * base + r)

    arr.clear()

    for k in sorted(highest_sig_digit.keys()):
        arr.extend(highest_sig_digit[k])

    return arr


def radix_sort_no_hashmap(arr, size):
    least_sig_digit = [[] for _ in range(size)]
    highest_sig_digit = [[] for _ in range(size)]

    for num in arr:
        q, r = divmod(num, size)
        least_sig_digit[r].append(q)

    for k in range(size):
        for q in least_sig_digit[k]:
            highest_sig_digit[q % size].append(q * size + k)

    return [num for sublist in highest_sig_digit for num in sublist]


def benchmark_sorting_algorithms():
    sizes = [1_000, 10_000, 200_000, 400_000, 800_000, 1_200_000]
    std_sorted, std_sort, radix, radix_no_hashmap = [], [], [], []

    random.seed(40)
    for size in sizes:
        array = random.sample(range(1, size ** 2), size)
        A, B, C, D = array.copy(), array.copy(), array.copy(), array.copy()
        start_time = time.time()
        _ = radix_sort(A, size)
        radix.append(time.time() - start_time)

        start_time = time.time()
        _ = radix_sort_no_hashmap(B, size)
        radix_no_hashmap.append(time.time() - start_time)

        start_time = time.time()
        _ = sorted(C)
        std_sorted.append(time.time() - start_time)

        start_time = time.time()
        D.sort()
        std_sort.append(time.time() - start_time)

    return sizes, radix, std_sorted, std_sort, radix_no_hashmap


sizes, radix, std_sorted, std_sort, radix_no_hashmap = benchmark_sorting_algorithms()

plt.figure(figsize=(12, 6))
plt.plot(sizes, radix, label='Radix Sort (O(n))')
plt.plot(sizes, std_sorted, label='Standard Sort using sorted() (O(n log n))')
plt.plot(sizes, std_sort, label='Standard Sort using .sort() (O(n log n))')
plt.plot(sizes, radix_no_hashmap, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()

打印

enter image description here


评论

  • 请注意,Python Sort 和 sort() 速度更快,并且是用 C 实现的。
  • 我会在 C 中实现它并对其进行基准测试以获得更准确的结果。
© www.soinside.com 2019 - 2024. All rights reserved.