我在 Python 中实现了两个版本的基数排序(该版本允许对值达到 n² 的整数进行排序,其中 n 是要排序的列表的大小),以针对标准排序 (Timsort) 进行基准测试。我使用 PyPy 进行更公平的比较。
令人惊讶的是,即使不使用哈希图(使用直接访问数组),我的基数排序实现也比标准排序慢,即使对于较大的输入大小也是如此。我缺少微观优化,因为 O(n) 最终应该击败 O(nlogn)。我正在寻求建议以实现更好的表现。我这样做是为了学习目的,因此我并不是在寻找内置函数、库或 C 编译的代码来从 Python 调用。
我可以进行微观优化吗?我的代码真的是 O(n) 吗?
我的代码在 AMD Ryzen 9 7950X CPU 上运行最多需要 10 秒:
import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict
def radix_sort(arr, size):
least_sig_digit = defaultdict(list)
for num in arr:
q, r = divmod(num, size)
least_sig_digit[r].append(q)
highest_sig_digit = defaultdict(list)
for k in range(size): # k goes in order of lowest significant digit
for q in least_sig_digit[k]:
highest_sig_digit[q].append(q*size+k)
i: int = 0
for k in range(size):
for num in highest_sig_digit[k]:
arr[i] = num
i += 1
return arr
def radix_sort_no_hashmap(arr, size):
least_sig_digit = [[] for _ in range(size)]
for num in arr:
q, r = divmod(num, size)
least_sig_digit[r].append(q)
highest_sig_digit = [[] for _ in range(size)]
for k in range(size): # k goes in order of lowest significant digit
for q in least_sig_digit[k]:
highest_sig_digit[q].append(q*size+k)
i: int = 0
for k in range(size):
for num in highest_sig_digit[k]:
arr[i] = num
i += 1
return arr
def benchmark_sorting_algorithms():
sizes = [1000, 10000, 100000, 200000, 1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 10000000]
radix_times = []
radix_sort_no_hashmap_times = []
std_sort_times = []
for size in sizes:
array = random.sample(range(1, size**2), size)
new_arr = array.copy()
start_time = time.time()
a = radix_sort(new_arr, size)
radix_times.append(time.time() - start_time)
new_arr = array.copy()
start_time = time.time()
b = radix_sort_no_hashmap(new_arr, size)
radix_sort_no_hashmap_times.append(time.time() - start_time)
new_arr = array.copy()
start_time = time.time()
c = sorted(new_arr)
std_sort_times.append(time.time() - start_time)
for k in range(len(array)):
assert a[k] == b[k] == c[k]
return sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times
sizes, radix_times, std_sort_times, radix_sort_no_hashmap_times = benchmark_sorting_algorithms()
plt.figure(figsize=(12, 6))
plt.plot(sizes, radix_times, label='Radix Sort (O(n))')
plt.plot(sizes, std_sort_times, label='Standard Sort (O(nlogn))')
plt.plot(sizes, radix_sort_no_hashmap_times, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()
按照@user24714692的建议,在CPP中发布了同样的问题。
它避免创建这么多
list
对象。相反,它使用
arr
形式的链表 - 索引存储在三个长列表中。每个余数/商一个链表:
first[r]
和
last[r]
告诉余数
r
链表的结尾,而
next[i]
告诉列表中的下一个元素。
First
/
Last
/
Next
相同,但对于商 q。
def radix_sorted(arr):
n = len(arr)
# Sort by x%n
first = [-1] * n
last = [-1] * n
next = [-1] * n
for i, x in enumerate(arr):
r = x % n
if first[r] == -1:
first[r] = i
else:
next[last[r]] = i
last[r] = i
# Sort by x//n
First = [-1] * n
Last = [-1] * n
Next = [-1] * n
for r in range(n):
i = first[r]
while i != -1:
x = arr[i]
q = x // n
if First[q] == -1:
First[q] = i
else:
Next[Last[q]] = i
Last[q] = i
i = next[i]
# Output
out = []
for q in range(n):
i = First[q]
while i != -1:
out.append(arr[i])
i = Next[i]
return out
# Demo / testing
import random
n = 10**6
arr = random.sample(range(n**2), n)
print(sorted(arr) == radix_sorted(arr))
O(N)
。有时也称为
O(K * N))
,K 是位数。然而,
K
是恒定的并且
O(N)
是正确的。我会添加一个
seed()
以获得更准确的结果:
import matplotlib.pyplot as plt
import random
import time
from collections import defaultdict
def radix_sort(arr, base):
least_sig_digit = defaultdict(list)
for num in arr:
q, r = divmod(num, base)
least_sig_digit[r].append(q)
highest_sig_digit = defaultdict(list)
for r in sorted(least_sig_digit.keys()):
for q in least_sig_digit[r]:
highest_sig_digit[q].append(q * base + r)
arr.clear()
for k in sorted(highest_sig_digit.keys()):
arr.extend(highest_sig_digit[k])
return arr
def radix_sort_no_hashmap(arr, size):
least_sig_digit = [[] for _ in range(size)]
highest_sig_digit = [[] for _ in range(size)]
for num in arr:
q, r = divmod(num, size)
least_sig_digit[r].append(q)
for k in range(size):
for q in least_sig_digit[k]:
highest_sig_digit[q % size].append(q * size + k)
return [num for sublist in highest_sig_digit for num in sublist]
def benchmark_sorting_algorithms():
sizes = [1_000, 10_000, 200_000, 400_000, 800_000, 1_200_000]
std_sorted, std_sort, radix, radix_no_hashmap = [], [], [], []
random.seed(40)
for size in sizes:
array = random.sample(range(1, size ** 2), size)
A, B, C, D = array.copy(), array.copy(), array.copy(), array.copy()
start_time = time.time()
_ = radix_sort(A, size)
radix.append(time.time() - start_time)
start_time = time.time()
_ = radix_sort_no_hashmap(B, size)
radix_no_hashmap.append(time.time() - start_time)
start_time = time.time()
_ = sorted(C)
std_sorted.append(time.time() - start_time)
start_time = time.time()
D.sort()
std_sort.append(time.time() - start_time)
return sizes, radix, std_sorted, std_sort, radix_no_hashmap
sizes, radix, std_sorted, std_sort, radix_no_hashmap = benchmark_sorting_algorithms()
plt.figure(figsize=(12, 6))
plt.plot(sizes, radix, label='Radix Sort (O(n))')
plt.plot(sizes, std_sorted, label='Standard Sort using sorted() (O(n log n))')
plt.plot(sizes, std_sort, label='Standard Sort using .sort() (O(n log n))')
plt.plot(sizes, radix_no_hashmap, label='Radix Sort (O(n)) - No Hashmap')
plt.xlabel('Input size (n)')
plt.xscale('log')
plt.ylabel('Time (seconds)')
plt.yscale('log')
plt.title('Radix Sort vs Standard Sort')
plt.legend()
plt.grid(True)
plt.show()
打印array = list(range(size ** 20, size ** 20 - size, -1))
进行排序和基准测试:
def benchmark_sorting_algorithms():
sizes = [1_000, 10_000, 20_000, 80_000, 200_000, 800_000]
std_sorted, std_sort, radix, radix_no_hashmap = [], [], [], []
for size in sizes:
array = list(range(size ** 20, size ** 20 - size, -1))
A, B, C, D = array.copy(), array.copy(), array.copy(), array.copy()
start_time = time.time()
_ = radix_sort(A, size)
radix.append(time.time() - start_time)
start_time = time.time()
_ = radix_sort_no_hashmap(B, size)
radix_no_hashmap.append(time.time() - start_time)
start_time = time.time()
_ = sorted(C)
std_sorted.append(time.time() - start_time)
start_time = time.time()
D.sort()
std_sort.append(time.time() - start_time)
return sizes, radix, std_sorted, std_sort, radix_no_hashmap
评论