我有一个数字数组,我想计算所有可能的对组合,对该对的XOR运算大于AND运算.

Example:

4,3,5,2

possible pairs are:

(4,3) -> xor=7, and = 0
(4,5) -> xor=1, and = 4
(4,2) -> xor=6, and = 0
(3,5) -> xor=6, and = 1
(3,2) -> xor=1, and = 2
(5,2) -> xor=7, and = 0

Valid pairs for which xor > and are (4,3), (4,2), (3,5), (5,2) so result is 4.

这是我的程序:

def solve(array):
    n = len(array)
    ans = 0
    for i in range(0, n):
        p1 = array[i]
        for j in range(i, n):
            p2 = array[j]
            if p1 ^ p2 > p1 & p2:
                ans +=1
    return ans

时间复杂度为O(n^2),但我的数组大小是1到10^5,数组中的每个元素都是1到2^30.那么如何才能降低这个程序的时间复杂度呢?

推荐答案

这(有效地)使用与您相同的算法,因此它仍然是O(n^2),但是您可以使用NumPy来加速操作:

  • np.bitwise_xor对两个数组执行按位异或运算
  • np.bitwise_and对两个数组执行按位与运算
  • 给这些函数一个行向量和一个列向量允许NumPy将结果广播到一个方阵.
  • 比较得到的矩阵,我们得到一个布尔array.我们只需要这个矩阵的下三角形.因为我们知道a ^ a == 0,所以我们只需将整个数组相加,然后将其结果除以2即可得到答案.
import numpy as np

def solve(nums):
    xor_arr = np.bitwise_xor(nums, nums[:, None])
    and_arr = np.bitwise_and(nums, nums[:, None])

    return (xor_arr > and_arr).sum() // 2

您也可以完全跳过Numpy,使用numba在代码运行之前对其进行JIT编译.

import numba

@numba.njit
def solve(array):
    n = len(array)
    ans = 0
    for i in range(0, n):
        p1 = array[i]
        for j in range(i, n):
            p2 = array[j]
            if p1 ^ p2 > p1 & p2:
                ans +=1
    return ans

最后,下面是我对Dave's algorithm的实现:

from collections import defaultdict
def new_alg(array):
    msb_num_count = defaultdict(int)
    for num in array:
        msb = len(bin(num)) - 2 # This was faster than right-shifting until zero
        msb_num_count[msb] += 1 # Increment the count of numbers that have this MSB
    
    # Now, for each number, the count will be the sum of the numbers in all other groups
    cnt = 0
    len_all_groups = sum(msb_num_count.values())
    for group_len in msb_num_count.values():
        cnt += group_len * (len_all_groups - group_len)

    return cnt // 2

作为与Numba兼容的函数.我需要定义一个get_msb,因为numba.njit不会处理内置的python函数

@numba.njit
def get_msb(num):
    msb = 0
    while num:
        msb += 1
        num = num >> 1
    return msb

@numba.njit
def new_alg_numba(array):
    msb_num_count = {}
    for num in array:
        msb = get_msb(num)
        if msb not in msb_num_count:
            msb_num_count[msb] = 0
        msb_num_count[msb] += 1

    # Now, for each number, the count will be the sum of the numbers in all other groups
    cnt = 0
    len_all_groups = 0

    for grp_len in msb_num_count.values():
        len_all_groups += grp_len
        
    for grp_len in msb_num_count.values():
        cnt += grp_len * (len_all_groups - grp_len)

    return cnt // 2

比较运行时,我们看到Numba方法比Numpy方法要快得多,Numpy方法本身也比Python中的循环快.

Dave给出的线性时间算法从一开始就比NumPy方法更快,而且开始比针对Inputs&>loopy0个元素的Numba编译代码更快.这种方法的Numba编译版本甚至更快--它在~loopy个元素上超过了Numba编译的loopy个元素.

对于更大的输入,Dave算法的Kelly's excellent implementation%与我的实现的Numba版本相当

enter image description here

Python相关问答推荐

使用mySQL的SQlalchemy过滤重叠时间段

海运图:调整行和列标签

如何从.cgi网站刮一张表到rame?

Python解析整数格式说明符的规则?

如何获得每个组的时间戳差异?

如何使用两个关键函数来排序一个多索引框架?

幂集,其中每个元素可以是正或负""""

numpy.unique如何消除重复列?

为什么常规操作不以其就地对应操作为基础?

ruamel.yaml dump:如何阻止map标量值被移动到一个新的缩进行?

Pandas:填充行并删除重复项,但保留不同的值

Numpyro AR(1)均值切换模型抽样不一致性

PYTHON、VLC、RTSP.屏幕截图不起作用

如何从pandas DataFrame中获取. groupby()和. agg()之后的子列?

Flask运行时无法在Python中打印到控制台

按条件添加小计列

我可以不带视频系统的pygame,只用于游戏手柄输入吗?''

如何根据一定条件生成段id

查找数据帧的给定列中是否存在特定值

对包含JSON列的DataFrame进行分组