我建议使用二进制搜索(例如bisect
内置模块):
from bisect import bisect_left, bisect_right
ub_arr = [4, 5, 6]
lb_arr = [1, 2, 3]
target_arr = [12, 3, 5, 10, 6, 1, 3, 4]
idxs, vals = zip(*sorted(enumerate(target_arr), key=lambda k: k[1]))
out = []
for l, u in zip(lb_arr, ub_arr):
a, b = bisect_left(vals, l), bisect_right(vals, u)
out.append(idxs[a:b])
print(out)
打印:
[(5, 1, 6, 7), (1, 6, 7, 2), (1, 6, 7, 2, 4)]
或:使用numpy
:
idxs = np.argsort(target_arr)
vals = target_arr[idxs]
a = np.searchsorted(vals, lb_arr, side="left")
b = np.searchsorted(vals, ub_arr, side="right")
out = []
for i, j in zip(a, b):
out.append(idxs[i:j])
print(out)
基准:
from bisect import bisect_left, bisect_right
from timeit import timeit
import numpy as np
def get_arrays():
lb_arr = np.random.uniform(0, 10, 10_000)
ub_arr = np.random.uniform(0, 10, 10_000)
target_arr = np.random.uniform(0, 10, 20_000)
return lb_arr, ub_arr, target_arr
def get_indices(lb_arr, ub_arr, target_arr):
idxs, vals = zip(*sorted(enumerate(target_arr), key=lambda k: k[1]))
out = []
for l, u in zip(lb_arr, ub_arr):
a, b = bisect_left(vals, l), bisect_right(vals, u)
out.append(idxs[a:b])
return out
def get_indices_numpy(lb_arr, ub_arr, target_arr):
idxs = np.argsort(target_arr)
vals = target_arr[idxs]
a = np.searchsorted(vals, lb_arr, side="left")
b = np.searchsorted(vals, ub_arr, side="right")
out = []
for i, j in zip(a, b):
out.append(idxs[i:j])
return out
def get_indices_original(lb_arr, ub_arr, target_arr):
matching_indices = []
for idx in range(len(ub_arr)):
indices = np.array(
np.where((target_arr >= lb_arr[idx]) & (target_arr <= ub_arr[idx]))
)
matching_indices.append(indices)
return matching_indices
t_python = timeit(
"get_indices(l, u, t)", setup="l,u,t=get_arrays()", globals=globals(), number=1
)
t_numpy = timeit(
"get_indices_numpy(l, u, t)",
setup="l,u,t=get_arrays()",
globals=globals(),
number=1,
)
t_original = timeit(
"get_indices_original(l, u, t)",
setup="l,u,t=get_arrays()",
globals=globals(),
number=1,
)
print(f"{t_python=}")
print(f"{t_numpy=}")
print(f"{t_original=}")
这在我的电脑上打印(AMD 5700x/Python 3.11):
t_python=0.1390752261504531
t_numpy=0.005280003882944584
t_original=0.2402120700571686