我需要根据B从A中提取一个值,但是使用for循环对于大型数据集来说是不切实际的.虽然下面的代码避免使用for循环,但它仍然需要相同的时间.

import numpy as np

# Given matrices A and B
A = np.array([[254, 0, 0],
              [109, 0, 1],
              [126, 0, 2],
              [66, 0, 3],
              [220, 1, 0],
              [98, 1, 1],
              [230, 1, 2],
              [17, 1, 3],
              [83, 2, 0],
              [106, 2, 1],
              [123, 2, 2],
              [57, 2, 3]])

B = np.array([[1, 2],
              [0, 1],
              [1, 0],
              [1, 1],
              [-1, 2],
              [1, 3],
              [1, 1],
              [0, 0],
              [2, 2],
              [1, 0],
              [2, 3],
              [0, 1]])

## these two functions give the same results.
def get_pixel2d(A, B): 
    corresponding_rows = np.all(A[:, 1:3] == B[:, None], axis=-1)
    get_pixel_final = (corresponding_rows * A[:, 0]).sum(axis=1)
    return get_pixel_final

# this is more faster
def get_pixel2d(A, B):
    corresponding_rows = np.all(A[:, 1:3] == B[:, None], axis=-1)
    get_pixel_final = np.sum(A[:, 0] * corresponding_rows, axis=1)
    return get_pixel_final

result = get_pixel2d(A, B)
print(result)

[230 109 220  98   0  17  98 254 123 220  57 109]

推荐答案

您可以转换查找表,然后使用普通索引+np.where(我建议在此之前转换查找表(A),以进一步提高速度):

def get_pixel2d_3(A, B):
    a, b = np.max(A[:, 1]), np.max(A[:, 2])

    lookup_table = np.zeros((a + 1, b + 1), dtype=np.uint8)
    lookup_table[A[:, 1], A[:, 2]] = A[:, 0]

    m1 = (B[:, 0] >= 0) & (B[:, 0] <= a)
    m2 = (B[:, 1] >= 0) & (B[:, 1] <= b)
    m = m1 & m2

    return np.where(m, lookup_table[B[:, 0], B[:, 1]], 0)

基准:

import perfplot

perfplot.show(
    setup=lambda n: (A, np.tile(B, (n, 1))),
    kernels=[
        get_pixel2d,
        get_pixel2d_2,
        get_pixel2d_3,
    ],
    labels=["get_pixel2d", "get_pixel2d_2", "get_pixel2d_3"],
    n_range=[1, 10, 20, 50, 100, 200, 500, 1000],
    xlabel="N",
    equality_check=lambda a, b: np.all(a == b),
    logx=True,
    logy=True,
)

创建此图表:

enter image description here

Python相关问答推荐

替换为Pandas

我们可以在apps.py?中使用Post_Save信号吗

使用Python计算cmyk,在PDF上发现覆盖范围

从收件箱获取特定列中的重复行

如何以实现以下所述的预期行为的方式添加两只Pandas pyramme

手动为pandas中的列上色

拆分pandas列并创建包含这些拆分值计数的新列

Python中是否有方法从公共域检索搜索结果

计算相同形状的两个张量的SSE损失

无法使用equals_html从网址获取全文

如果条件为真,则Groupby.mean()

acme错误-Veritas错误:模块收件箱没有属性linear_util'

Python中绕y轴曲线的旋转

Streamlit应用程序中的Plotly条形图中未正确显示Y轴刻度

NumPy中条件嵌套for循环的向量化

创建可序列化数据模型的最佳方法

实现神经网络代码时的TypeError

启动带有参数的Python NTFS会导致文件路径混乱

在Python中计算连续天数

剪切间隔以添加特定日期