我在 Python 中实现了两个版本的基数排序(该版本允许对值达到 n² 的整数进行排序,其中 n 是要排序的列表的大小),以针对标准排序 (Timsort) 进行基准测试。令人惊讶的是,即使不使用哈希图(使用直接访问数组),我的基数排序实现也比标准排序慢,即使对于较大的输入大小也是如此。我正在寻求有关优化实施以实现更好性能的建议。我这样做不是为了实用目的,只是为了学习目的,因为我对编程相当陌生,因此我不是在寻找内置函数、库或 C 编译的代码来从 Python 调用。
我可以进行微观优化吗?我的代码真的是 O(n) 吗?
我使用 PyPy 进行更公平的比较。
这是我的代码(该代码在AMD Ryzen 9 7950X CPU上运行最多需要10秒,仅供参考):
import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict
def radix_sort(arr, size):
least_sig_digit = defaultdict(list)
for k in range(len(arr)):
q, r = divmod(arr[k], size)
least_sig_digit[r].append(q)
highest_sig_digit = defaultdict(list)
for k in range(size): # k goes in order of lowest significant digit
for q in least_sig_digit[k]:
highest_sig_digit[q].append(q*size+k)
i: int = 0
for k in range(size):
for num in highest_sig_digit[k]:
arr[i] = num
i += 1
return arr
def radix_sort_no_hashmap(arr, size):
least_sig_digit = [[] for _ in range(size)]
for k in range(len(arr)):
q, r = divmod(arr[k], size)
least_sig_digit[r].append(q)
highest_sig_digit = [[] for _ in range(size)]
for k in range(size): # k goes in order of lowest significant digit
for q in least_sig_digit[k]:
highest_sig_digit[q].append(q*size+k)
i: int = 0
for k in range(size):
for num in highest_sig_digit[k]:
arr[i] = num
i += 1
return arr
def benchmark_sorting_algorithms():
sizes = [1000, 10000, 100000, 200000, 1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 10000000]
radix_times = []
radix_sort_no_hashmap_times = []
std_sort_times = []
for size in sizes:
array = random.sample(range(1, size**2), size)
new_arr = array.copy()
start_time = time.time()
a = radix_sort(new_arr, size)
radix_times.append(time.time() - start_time)
new_arr = array.copy()
start_time = time.time()
b = radix_sort_no_hashmap(new_arr, size)
radix_sort_no_hashmap_times.append(time.time() - start_time)
new_arr = array.copy()
start_time = time.time()
c = sorted(new_arr)
std_sort_times.append(time.time() - start_time)
for k in range(len(array)):
assert a[k] == b[k] == c[k]
return sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times
sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times = benchmark_sorting_algorithms()
plt.figure(figsize=(12, 6))
plt.plot(sizes, radix_times, label='Radix Sort (O(n))')
plt.plot(sizes, std_sort_times, label='Standard Sort (O(nlogn))')
plt.plot(sizes, radix_sort_no_hashmap_times, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()
时间复杂度为
O(N)
。有时也称为 O(K * N))
,K 是位数。然而, K
是常数,并且 O(N)
是正确的。
我会添加一个
seed()
以获得更准确的结果:
import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict
def radix_sort(arr, base):
least_sig_digit = defaultdict(list)
for num in arr:
q, r = divmod(num, base)
least_sig_digit[r].append(q)
highest_sig_digit = defaultdict(list)
for r in sorted(least_sig_digit.keys()):
for q in least_sig_digit[r]:
highest_sig_digit[q].append(q * base + r)
arr.clear()
for k in sorted(highest_sig_digit.keys()):
arr.extend(highest_sig_digit[k])
return arr
def radix_sort_no_hashmap(arr, size):
least_sig_digit = [[] for _ in range(size)]
highest_sig_digit = [[] for _ in range(size)]
for num in arr:
q, r = divmod(num, size)
least_sig_digit[r].append(q)
for k in range(size):
for q in least_sig_digit[k]:
highest_sig_digit[q % size].append(q * size + k)
return [num for sublist in highest_sig_digit for num in sublist]
def benchmark_sorting_algorithms():
sizes = [1_000, 10_000, 200_000, 400_000, 800_000, 1_200_000]
std_sorted, std_sort, radix, radix_no_hashmap = [], [], [], []
random.seed(40)
for size in sizes:
array = random.sample(range(1, size ** 2), size)
A, B, C, D = array.copy(), array.copy(), array.copy(), array.copy()
start_time = time.time()
_ = radix_sort(A, size)
radix.append(time.time() - start_time)
start_time = time.time()
_ = radix_sort_no_hashmap(B, size)
radix_no_hashmap.append(time.time() - start_time)
start_time = time.time()
_ = sorted(C)
std_sorted.append(time.time() - start_time)
start_time = time.time()
D.sort()
std_sort.append(time.time() - start_time)
return sizes, radix, std_sorted, std_sort, radix_no_hashmap
sizes, radix, std_sorted, std_sort, radix_no_hashmap = benchmark_sorting_algorithms()
plt.figure(figsize=(12, 6))
plt.plot(sizes, radix, label='Radix Sort (O(n))')
plt.plot(sizes, std_sorted, label='Standard Sort using sorted() (O(n log n))')
plt.plot(sizes, std_sort, label='Standard Sort using .sort() (O(n log n))')
plt.plot(sizes, radix_no_hashmap, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()