我有一个场景是我有三个数组:

  • 上界的数组
  • 一个下界的数组
  • 的目标数组

我想要得到的是一个数组列表,其中包含目标数组的索引,这些索引位于每个索引的边界数组内,例如:

ub_arr = [4,5,6]

lb_arr = [1,2,3]

target_arr = [12, 3, 5, 10, 6, 1, 3, 4]

结果将是:

[[1, 5, 6, 7], [1, 2, 6, 7], [1, 2, 4, 6, 7]]

目前我使用的代码看起来像:

matching_indices = []
for idx in range(len(ub_arr)):
    indices = np.array(np.where((target_arr >= lb_arr[idx]) & (target_arr <= ub_arr)))
    matching_indices.append(indices)

如果我们必须判断两个不同的目标数组,每个数组都有它们的上限,那么我会对target matching_indices列表中的每个索引进行约简,以获得索引的交集.

我想知道是否有更有效的方法来做这件事?

[编辑]

可以生成潜在的测试数组:

ub_arr = np.random.uniform(0,10,10_000)
lb_arr = np.random.uniform(0,10,10_000)
target_arr = np.random.uniform(0,10,20_000)

推荐答案

我建议使用二进制搜索(例如bisect内置模块):

from bisect import bisect_left, bisect_right

ub_arr = [4, 5, 6]
lb_arr = [1, 2, 3]
target_arr = [12, 3, 5, 10, 6, 1, 3, 4]

idxs, vals = zip(*sorted(enumerate(target_arr), key=lambda k: k[1]))

out = []
for l, u in zip(lb_arr, ub_arr):
    a, b = bisect_left(vals, l), bisect_right(vals, u)
    out.append(idxs[a:b])

print(out)

打印:

[(5, 1, 6, 7), (1, 6, 7, 2), (1, 6, 7, 2, 4)]

或:使用numpy:

idxs = np.argsort(target_arr)
vals = target_arr[idxs]

a = np.searchsorted(vals, lb_arr, side="left")
b = np.searchsorted(vals, ub_arr, side="right")

out = []
for i, j in zip(a, b):
    out.append(idxs[i:j])

print(out)

基准:

from bisect import bisect_left, bisect_right
from timeit import timeit

import numpy as np


def get_arrays():
    lb_arr = np.random.uniform(0, 10, 10_000)
    ub_arr = np.random.uniform(0, 10, 10_000)
    target_arr = np.random.uniform(0, 10, 20_000)
    return lb_arr, ub_arr, target_arr


def get_indices(lb_arr, ub_arr, target_arr):
    idxs, vals = zip(*sorted(enumerate(target_arr), key=lambda k: k[1]))

    out = []
    for l, u in zip(lb_arr, ub_arr):
        a, b = bisect_left(vals, l), bisect_right(vals, u)
        out.append(idxs[a:b])

    return out


def get_indices_numpy(lb_arr, ub_arr, target_arr):
    idxs = np.argsort(target_arr)
    vals = target_arr[idxs]

    a = np.searchsorted(vals, lb_arr, side="left")
    b = np.searchsorted(vals, ub_arr, side="right")

    out = []
    for i, j in zip(a, b):
        out.append(idxs[i:j])
    return out


def get_indices_original(lb_arr, ub_arr, target_arr):
    matching_indices = []
    for idx in range(len(ub_arr)):
        indices = np.array(
            np.where((target_arr >= lb_arr[idx]) & (target_arr <= ub_arr[idx]))
        )
        matching_indices.append(indices)
    return matching_indices


t_python = timeit(
    "get_indices(l, u, t)", setup="l,u,t=get_arrays()", globals=globals(), number=1
)
t_numpy = timeit(
    "get_indices_numpy(l, u, t)",
    setup="l,u,t=get_arrays()",
    globals=globals(),
    number=1,
)
t_original = timeit(
    "get_indices_original(l, u, t)",
    setup="l,u,t=get_arrays()",
    globals=globals(),
    number=1,
)

print(f"{t_python=}")
print(f"{t_numpy=}")
print(f"{t_original=}")

这在我的电脑上打印(AMD 5700x/Python 3.11):

t_python=0.1390752261504531
t_numpy=0.005280003882944584
t_original=0.2402120700571686

Python相关问答推荐

拆分pandas列并创建包含这些拆分值计数的新列

Locust请求中的Python和参数

Odoo 14 hr. emergency.public内的二进制字段

Django mysql图标不适用于小 case

运行总计基于多列pandas的分组和总和

为什么符号没有按顺序添加?

优化pytorch函数以消除for循环

' osmnx.shortest_track '返回有效源 node 和目标 node 的'无'

如何使用它?

将9个3x3矩阵按特定顺序排列成9x9矩阵

Pandas计数符合某些条件的特定列的数量

提取相关行的最快方法—pandas

字符串合并语法在哪里记录

为什么\b在这个正则表达式中不解释为反斜杠

如何在海上配对图中使某些标记周围的黑色边框

使用Python异步地持久跟踪用户输入

Django Table—如果项目是唯一的,则单行

Pandas数据框上的滚动平均值,其中平均值的中心基于另一数据框的时间

遍历列表列表,然后创建数据帧

在Pandas 中以十六进制显示/打印列?