这(有效地)使用与您相同的算法,因此它仍然是O(n^2),但是您可以使用NumPy来加速操作:
np.bitwise_xor
对两个数组执行按位异或运算
np.bitwise_and
对两个数组执行按位与运算
- 给这些函数一个行向量和一个列向量允许NumPy将结果广播到一个方阵.
- 比较得到的矩阵,我们得到一个布尔array.我们只需要这个矩阵的下三角形.因为我们知道
a ^ a == 0
,所以我们只需将整个数组相加,然后将其结果除以2即可得到答案.
import numpy as np
def solve(nums):
xor_arr = np.bitwise_xor(nums, nums[:, None])
and_arr = np.bitwise_and(nums, nums[:, None])
return (xor_arr > and_arr).sum() // 2
您也可以完全跳过Numpy,使用numba
在代码运行之前对其进行JIT编译.
import numba
@numba.njit
def solve(array):
n = len(array)
ans = 0
for i in range(0, n):
p1 = array[i]
for j in range(i, n):
p2 = array[j]
if p1 ^ p2 > p1 & p2:
ans +=1
return ans
最后,下面是我对Dave's algorithm的实现:
from collections import defaultdict
def new_alg(array):
msb_num_count = defaultdict(int)
for num in array:
msb = len(bin(num)) - 2 # This was faster than right-shifting until zero
msb_num_count[msb] += 1 # Increment the count of numbers that have this MSB
# Now, for each number, the count will be the sum of the numbers in all other groups
cnt = 0
len_all_groups = sum(msb_num_count.values())
for group_len in msb_num_count.values():
cnt += group_len * (len_all_groups - group_len)
return cnt // 2
作为与Numba兼容的函数.我需要定义一个get_msb
,因为numba.njit
不会处理内置的python函数
@numba.njit
def get_msb(num):
msb = 0
while num:
msb += 1
num = num >> 1
return msb
@numba.njit
def new_alg_numba(array):
msb_num_count = {}
for num in array:
msb = get_msb(num)
if msb not in msb_num_count:
msb_num_count[msb] = 0
msb_num_count[msb] += 1
# Now, for each number, the count will be the sum of the numbers in all other groups
cnt = 0
len_all_groups = 0
for grp_len in msb_num_count.values():
len_all_groups += grp_len
for grp_len in msb_num_count.values():
cnt += grp_len * (len_all_groups - grp_len)
return cnt // 2
比较运行时,我们看到Numba方法比Numpy方法要快得多,Numpy方法本身也比Python中的循环快.
Dave给出的线性时间算法从一开始就比NumPy方法更快,而且开始比针对Inputs&>loopy
0个元素的Numba编译代码更快.这种方法的Numba编译版本甚至更快--它在~loopy
个元素上超过了Numba编译的loopy
个元素.
对于更大的输入,Dave算法的Kelly's excellent implementation%与我的实现的Numba版本相当