我有两个长度相同的一维NumPy数组A和B.我想要找到两个数组的交集,这意味着我想要找到A中也存在于B中的所有元素.
当数组A中索引处的元素也是数组B的成员时,结果应该是一个布尔数组,其值为True
,保留顺序,以便我可以使用结果来索引另一个array.
如果没有布尔掩码约束,我会将两个数组都转换为集合,并使用集合交集操作符(&
).然而,我try 使用np.isin
和np.in1d
,发现使用普通的Python列表理解要快得多.
在给定的设置下:
import numba
import numpy as np
primes = np.array([
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41,
43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101,
103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167,
173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239,
241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313,
317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397,
401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467,
479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569,
571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643,
647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733,
739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823,
827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997],
dtype=np.int64)
@numba.vectorize(nopython=True, cache=True, fastmath=True, forceobj=False)
def reverse_digits(n, base):
out = 0
while n:
n, rem = divmod(n, base)
out = out * base + rem
return out
flipped = reverse_digits(primes, 10)
def set_isin(a, b):
return a in b
vec_isin = np.vectorize(set_isin)
primes
包含primes
0以下的所有质数,总数为168.我之所以 Select 它,是因为它的大小合适,而且是预先确定的.我做过各种测试:
In [2]: %timeit np.isin(flipped, primes)
51.3 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit np.in1d(flipped, primes)
46.2 µs ± 386 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [4]: %timeit setp = set(primes)
12.9 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [5]: %timeit setp = set(primes.tolist())
6.84 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [6]: %timeit setp = set(primes.flat)
11.5 µs ± 54.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [7]: setp = set(primes.tolist())
In [8]: %timeit [x in setp for x in flipped]
23.3 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [9]: %timeit [x in setp for x in flipped.tolist()]
12.1 µs ± 76.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [10]: %timeit [x in setp for x in flipped.flat]
19.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [11]: %timeit vec_isin(flipped, setp)
40 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [12]: %timeit np.frompyfunc(lambda x: x in setp, 1, 1)(flipped)
25.7 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [13]: %timeit setf = set(flipped.tolist())
6.51 µs ± 44 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [14]: setf = set(flipped.tolist())
In [15]: %timeit np.array(sorted(setf & setp))
9.42 µs ± 78.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
setp = set(primes.tolist()); [x in setp for x in flipped.tolist()]
大约需要19微秒,这比NumPy方法更快.我想知道为什么会这样,是否有办法让它更快.
(我写了所有的代码,我使用了AI建议的编辑功能来编辑问题)