我有两个长度相同的一维NumPy数组A和B.我想要找到两个数组的交集,这意味着我想要找到A中也存在于B中的所有元素.

当数组A中索引处的元素也是数组B的成员时,结果应该是一个布尔数组,其值为True,保留顺序,以便我可以使用结果来索引另一个array.

如果没有布尔掩码约束,我会将两个数组都转换为集合,并使用集合交集操作符(&).然而,我try 使用np.isinnp.in1d,发现使用普通的Python列表理解要快得多.

在给定的设置下:

import numba
import numpy as np

primes = np.array([
    2,   3,   5,   7,  11,  13,  17,  19,  23,  29,  31,  37,  41,
    43,  47,  53,  59,  61,  67,  71,  73,  79,  83,  89,  97, 101,
    103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167,
    173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239,
    241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313,
    317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397,
    401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467,
    479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569,
    571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643,
    647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733,
    739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823,
    827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997],
    dtype=np.int64)

@numba.vectorize(nopython=True, cache=True, fastmath=True, forceobj=False)
def reverse_digits(n, base):
    out = 0
    while n:
        n, rem = divmod(n, base)
        out = out * base + rem
    return out

flipped = reverse_digits(primes, 10)

def set_isin(a, b):
    return a in b

vec_isin = np.vectorize(set_isin)

primes包含primes0以下的所有质数,总数为168.我之所以 Select 它,是因为它的大小合适,而且是预先确定的.我做过各种测试:

In [2]: %timeit np.isin(flipped, primes)
51.3 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [3]: %timeit np.in1d(flipped, primes)
46.2 µs ± 386 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [4]: %timeit setp = set(primes)
12.9 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [5]: %timeit setp = set(primes.tolist())
6.84 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [6]: %timeit setp = set(primes.flat)
11.5 µs ± 54.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [7]: setp = set(primes.tolist())

In [8]: %timeit [x in setp for x in flipped]
23.3 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [9]: %timeit [x in setp for x in flipped.tolist()]
12.1 µs ± 76.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [10]: %timeit [x in setp for x in flipped.flat]
19.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [11]: %timeit vec_isin(flipped, setp)
40 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [12]: %timeit np.frompyfunc(lambda x: x in setp, 1, 1)(flipped)
25.7 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [13]: %timeit setf = set(flipped.tolist())
6.51 µs ± 44 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [14]: setf = set(flipped.tolist())

In [15]: %timeit np.array(sorted(setf & setp))
9.42 µs ± 78.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

setp = set(primes.tolist()); [x in setp for x in flipped.tolist()]大约需要19微秒,这比NumPy方法更快.我想知道为什么会这样,是否有办法让它更快.

(我写了所有的代码,我使用了AI建议的编辑功能来编辑问题)

推荐答案

为什么提供的解决方案效率不高

np.isin有两个实现.第一种方法是对两个数组进行排序(使用合并排序),然后将它们合并.此解决方案在O(n log n + m log m + n+m)O(n log n + m log m)中运行.另一种实现是基于查找表的.该第二实现基于第二数组创建布尔值的数组,然后判断是否为第一数组的每一项设置了lookupTable[item].对于包含小整数的数组,第二种实现可能会更快(这稍微复杂一些,但却是explained in the documentation).第二个解决方案在O(n + m + max(arr2))中运行(理论上甚至在某些具有大的隐藏常量的平台上运行O(n + m)).然而,它可以使用much more memory.默认情况下,Numpy会try 挑选最好的.在您的例子中,两个数组很小,其中的整数也相对较小,因此这两个解决方案相对较快.对于具有小整数的较大数组,查找表应该更快.

问题是,Numpy在这里效率不高,因为与实际计算相比,调用这样的Numpy函数的开销相对较大.此外,第二个数组已经排序,因此再次排序效率不高.


更快地实施

例如,您可以只使用二进制搜索来查找第二个数组中第一个数组的值,而无需分配任何额外的临时array.您可以使用Numba so来减少在小数组上调用Numpy多个函数的开销,甚至可以使用jited循环更快地填充结果.以下是最终实现:

# Assume primes is sorted
@numba.njit('bool_[:](int64[:],int64[:])')
def compute(flipped, primes):
    assert primes.size > 0 and primes.size == flipped.size
    res = np.empty(flipped.size, dtype=np.bool_)
    idx = np.searchsorted(primes, flipped)
    for i in range(res.size):
        if idx[i] < len(primes) and primes[idx[i]] == flipped[i]:
            res[i] = True
        else:
            res[i] = False
    return res

在我的机器上,这个解决方案是15 times faster than 100,比所有其他替代方案都要快(明显快得多).提供的输入仅需约2微米S.它的规模也相对较好.


大型数组的最快解决方案

对于大型数组,使用查找表应该会更快,因为上面的解决方案运行时间是O(n log m)倍,而这里的查找表实现可以在线性时间运行.也就是说,查找表使用的内存也要多得多.最好的方法是使用Bloom filter来使查找表更加紧凑(这要归功于散列).然而,这个解决方案的实施要复杂得多.setdif1dan exemple here.最快的解决方案往往以更复杂的代码为代价(没有免费的午餐).

Python相关问答推荐

将每个关键字值对转换为pyspark中的Intramame列

aiohTTP与pytest的奇怪行为

OdooElectron 商务产品详情页面中add_qty参数动态更新

在Python中是否可以输入使用任意大小参数列表的第一个元素的函数

如何判断. text文件中的某个字符,然后读取该行

想要使用Polars groupby_Dynamic来缩减时间序列收件箱(包括空垃圾箱)

提取两行之间的标题的常规表达

如何让剧作家等待Python中出现特定cookie(然后返回它)?

发生异常:TclMessage命令名称无效.!listbox"

加速Python循环

NP.round解算数据后NP.unique

用NumPy优化a[i] = a[i-1]*b[i] + c[i]的迭代计算

当从Docker的--env-file参数读取Python中的环境变量时,每个\n都会添加一个\'.如何没有额外的?

将输入聚合到统一词典中

Pandas GroupBy可以分成两个盒子吗?

isinstance()在使用dill.dump和dill.load后,对列表中包含的对象失败

在嵌套span下的span中擦除信息

基于形状而非距离的两个numpy数组相似性

基于行条件计算(pandas)

跳过嵌套JSON中的级别并转换为Pandas Rame