我有一个用于我的任务的Python代码,它做得很好,但是时间复杂度很差,我不知道如何更改它。任务:
给定两个整数数组“a”和“b”,每个数组的长度为“n”,您需要找到索引对“i”和“j”的数量,以满足不等式: aᵢ ⊕ aʲ ≥ bᵢ ⊕ bʲ。 请注意,在此问题中,如果“i”不等于“j”,则 (i, j) 和 (j, i) 对被视为不同。符号 ⊕ 表示按位加法模 2,或 XOR 运算。
n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
count = 0
for i in range(n):
for j in range(n):
if a[i] ^ a[j] >= b[i] ^ b[j]:
count += 1
print(count)
有什么改进运行时间的建议吗?
使用 Numpy 提高了代码效率。
比较:
import datetime
import numpy as np
n = 10000
a = [1,2,3,4,5,16,17,18,19,20] * int(n/10) # List with 10 K items created for testing
b = [10,9,8,7,6,11,12,13,14,15] * int(n/10) # List with 10 K items created for testing
print("Len a : ", len(a))
print("Len b : ", len(b))
######### Your Code Test ######################
count = 0
print("Starting your code test")
start = datetime.datetime.now() # Start time
for i in range(n):
for j in range(n):
if a[i] ^ a[j] >= b[i] ^ b[j]:
count += 1
end = datetime.datetime.now() # end time
delta = end-start
print("Your code time delta : ", delta)
print("Result : ", count)
############## New Code test #######################
count = 0
print("Starting new code test")
start = datetime.datetime.now() # Start time
a_arr = np.array(a).reshape(-1,1)
a_arr_biwise = np.bitwise_xor(a_arr, a_arr.transpose()) #bitwise XOR of array on itself
b_arr = np.array(b).reshape(-1,1)
b_arr_biwise = np.bitwise_xor(b_arr, b_arr.transpose()) #bitwise XOR of array on itself
comparison_matrix = np.greater_equal(a_arr_biwise,b_arr_biwise) # axor >= bxor comparison on each position.
count = comparison_matrix.sum() # find True count , i.e. where >= is True
end = datetime.datetime.now() # end time
delta = end-start
print("New code time delta : ", delta)
print("Result : ", count)
输出
Len a : 10000
Len b : 10000
Starting your code test
Your code time delta : 0:00:46.925000
Result : 80000000
Starting new code test
New code time delta : 0:00:00.453003
Result : 80000000
可以看到两个代码的结果是相同的。
由于 xor 是对称的:a[i] ^ a[j] == a[j] ^ a[i] (对于 b 也是如此),因此您可以将内部循环减少到 j>i。
由于 a[i] ^ a[i] >= b[i] ^ b[i] 始终为 true,因此您也可以跳过这种情况。
这不会提高复杂度 O(n²),但速度至少会提高一倍。
n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
count = n // the cases i=j alsways fulfill the condition a[i] ^ a[j] >= b[i] ^ b[j]:
for i in range(n):
for j =0; j<i; j++ :
if a[i] ^ a[j] >= b[i] ^ b[j]:
count += 2
print(count)