我有以下代码:

@nb.njit(cache=True)
def find_two_largest(arr):
    # Initialize the first and second largest elements
    if arr[0] >= arr[1]:
        largest = arr[0]
        second_largest = arr[1]
    else:
        largest = arr[1]
        second_largest = arr[0]

    # Iterate through the array starting from the third element
    for num in arr[2:]:
        if num > largest:
            second_largest = largest
            largest = num
        elif num > second_largest:
            second_largest = num
    return largest, second_largest


@nb.njit(cache=True)
def max_bar_one(arr):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = np.empty_like(arr)
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest
    return missing_maxes


@nb.njit(cache=True)
def replace_max_row_wise_add_first_delete_last(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in range(0, m - 1):
        result[i + 1, :] = max_bar_one(d[i, :])
    return result


@nb.njit(cache=True)
def main_function(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last(d)
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp

然后,我设置数据:

n = 5000
A = np.random.randint(-3, 4, (n, n)).astype(float)
cusum_rows = np.cumsum(A, axis=1)
rowseq = np.arange(n)
d = np.random.randint(-3, 4, (5000, 5000))

然后,我们可以使用以下时间:

%timeit main_function(d, cusum_rows, 0)
166 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

有没有可能并行化for循环或一般的代码来加快这个速度?我try 在replace_max_row_wise_add_first_edit_last中使用parallel = True 但它并没有加速代码,只报告:

Instruction hoisting:
loop #1:
  Failed to hoist the following:
    dependency: $value_var.73 = getitem(value=_72call__function_11, index=$parfor__index_72.90, fn=<built-in function getitem>)

这是令人惊讶的,因为for循环中的所有调用都是独立的.

这段代码可以加速和/或并行化吗?

推荐答案

当我在replace_max_row_wise_add_first_delete_last()中使用并行化时,我得到了约70%的加速:

@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in nb.prange(0, m - 1):                 # <-- using prange here
        result[i + 1, :] = max_bar_one(d[i, :])
    return result


@nb.njit
def main_function_parallel(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last_parallel(d)  # <-- using parallel version of the function here
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp

编辑:额外的加速是删除missing_maxes = np.empty_like(arr)个临时分配.在这种情况下,加速比为300%:

@nb.njit
def max_bar_one2(arr, result, to_compare, to_add):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = result
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest

        missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
    return missing_maxes


@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = d[0] + to_add
    for i in nb.prange(0, m - 1):
        max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
    return result


@nb.njit
def main_function_parallel2(d, subcusum, j):
    return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])

基准:

from timeit import timeit

import numba as nb
import numpy as np


@nb.njit(cache=True)
def find_two_largest(arr):
    # Initialize the first and second largest elements
    if arr[0] >= arr[1]:
        largest = arr[0]
        second_largest = arr[1]
    else:
        largest = arr[1]
        second_largest = arr[0]

    # Iterate through the array starting from the third element
    for num in arr[2:]:
        if num > largest:
            second_largest = largest
            largest = num
        elif num > second_largest:
            second_largest = num
    return largest, second_largest


@nb.njit(cache=True)
def max_bar_one(arr):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = np.empty_like(arr)
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest
    return missing_maxes


@nb.njit(cache=True)
def replace_max_row_wise_add_first_delete_last(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in range(0, m - 1):
        result[i + 1, :] = max_bar_one(d[i])
    return result


@nb.njit(cache=True)
def main_function(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last(d)
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp


@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in nb.prange(0, m - 1):
        result[i + 1, :] = max_bar_one(d[i, :])
    return result


@nb.njit
def main_function_parallel(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last_parallel(d)
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp


@nb.njit
def max_bar_one2(arr, result, to_compare, to_add):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = result
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest

        missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
    return missing_maxes


@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = d[0] + to_add
    for i in nb.prange(0, m - 1):
        max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
    return result


@nb.njit
def main_function_parallel2(d, subcusum, j):
    return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])


def get_d_cumsum_rows(n):
    A = np.random.randint(-300, 400, (n, n)).astype(float)
    cusum_rows = np.cumsum(A, axis=1)
    d = np.random.randint(-300, 400, (n, n))

    return d, cusum_rows


n = 10
np.random.seed(42)
out1 = main_function(*get_d_cumsum_rows(n), 0)

np.random.seed(42)
out2 = main_function_parallel(*get_d_cumsum_rows(n), 0)

np.random.seed(42)
out3 = main_function_parallel2(*get_d_cumsum_rows(n), 0)

assert np.allclose(out1, out2)
assert np.allclose(out1, out3)

t1 = timeit(
    "main_function(a, b, 0)",
    setup="n=5000;a,b=get_d_cumsum_rows(n)",
    globals=globals(),
    number=100,
)

t2 = timeit(
    "main_function_parallel(a, b, 0)",
    setup="n=5000;a,b=get_d_cumsum_rows(n)",
    globals=globals(),
    number=100,
)

t3 = timeit(
    "main_function_parallel2(a, b, 0)",
    setup="n=5000;a,b=get_d_cumsum_rows(n)",
    globals=globals(),
    number=100,
)

print(t1)
print(t2)
print(t3)

我电脑上的打印(AMD 5700x):

7.003944834927097
4.12014868715778
2.2788363839499652

Python相关问答推荐

分组数据并删除重复数据

如何让 turtle 通过点击和拖动来绘制?

如果条件为真,则Groupby.mean()

点到面的Y距离

从numpy数组和参数创建收件箱

无法使用requests或Selenium抓取一个href链接

从一个系列创建一个Dataframe,特别是如何重命名其中的列(例如:使用NAs/NaN)

递归访问嵌套字典中的元素值

Pandas Loc Select 到NaN和值列表

转换为浮点,pandas字符串列,混合千和十进制分隔符

什么是最好的方法来切割一个相框到一个面具的第一个实例?

在matplotlib中删除子图之间的间隙_mosaic

如何在海上配对图中使某些标记周围的黑色边框

从源代码显示不同的输出(机器学习)(Python)

Python类型提示:对于一个可以迭代的变量,我应该使用什么?

为用户输入的整数查找根/幂整数对的Python练习

504未连接IB API TWS错误—即使API连接显示已接受''

Matplotlib中的曲线箭头样式

Python:在cmd中添加参数时的语法

组颠倒大Pandas 数据帧