当我在replace_max_row_wise_add_first_delete_last()
中使用并行化时,我得到了约70%的加速:
@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel(d):
"""
Run max_bar_one on each row but the last, prepend an all -inf row
"""
m, n = d.shape
result = np.empty((m, n))
result[0] = -np.inf
for i in nb.prange(0, m - 1): # <-- using prange here
result[i + 1, :] = max_bar_one(d[i, :])
return result
@nb.njit
def main_function_parallel(d, subcusum, j):
temp = replace_max_row_wise_add_first_delete_last_parallel(d) # <-- using parallel version of the function here
for i1 in range(temp.shape[0]):
for i2 in range(temp.shape[1]):
temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
return temp
编辑:额外的加速是删除missing_maxes = np.empty_like(arr)
个临时分配.在这种情况下,加速比为300%:
@nb.njit
def max_bar_one2(arr, result, to_compare, to_add):
largest, second_largest = find_two_largest(arr)
missing_maxes = result
for i in range(arr.shape[0]):
if arr[i] == largest:
if largest != second_largest:
missing_maxes[i] = second_largest
else:
missing_maxes[i] = largest # largest == second_largest
else:
missing_maxes[i] = largest
missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
return missing_maxes
@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
"""
Run max_bar_one on each row but the last, prepend an all -inf row
"""
m, n = d.shape
result = np.empty((m, n))
result[0] = d[0] + to_add
for i in nb.prange(0, m - 1):
max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
return result
@nb.njit
def main_function_parallel2(d, subcusum, j):
return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])
基准:
from timeit import timeit
import numba as nb
import numpy as np
@nb.njit(cache=True)
def find_two_largest(arr):
# Initialize the first and second largest elements
if arr[0] >= arr[1]:
largest = arr[0]
second_largest = arr[1]
else:
largest = arr[1]
second_largest = arr[0]
# Iterate through the array starting from the third element
for num in arr[2:]:
if num > largest:
second_largest = largest
largest = num
elif num > second_largest:
second_largest = num
return largest, second_largest
@nb.njit(cache=True)
def max_bar_one(arr):
largest, second_largest = find_two_largest(arr)
missing_maxes = np.empty_like(arr)
for i in range(arr.shape[0]):
if arr[i] == largest:
if largest != second_largest:
missing_maxes[i] = second_largest
else:
missing_maxes[i] = largest # largest == second_largest
else:
missing_maxes[i] = largest
return missing_maxes
@nb.njit(cache=True)
def replace_max_row_wise_add_first_delete_last(d):
"""
Run max_bar_one on each row but the last, prepend an all -inf row
"""
m, n = d.shape
result = np.empty((m, n))
result[0] = -np.inf
for i in range(0, m - 1):
result[i + 1, :] = max_bar_one(d[i])
return result
@nb.njit(cache=True)
def main_function(d, subcusum, j):
temp = replace_max_row_wise_add_first_delete_last(d)
for i1 in range(temp.shape[0]):
for i2 in range(temp.shape[1]):
temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
return temp
@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel(d):
"""
Run max_bar_one on each row but the last, prepend an all -inf row
"""
m, n = d.shape
result = np.empty((m, n))
result[0] = -np.inf
for i in nb.prange(0, m - 1):
result[i + 1, :] = max_bar_one(d[i, :])
return result
@nb.njit
def main_function_parallel(d, subcusum, j):
temp = replace_max_row_wise_add_first_delete_last_parallel(d)
for i1 in range(temp.shape[0]):
for i2 in range(temp.shape[1]):
temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
return temp
@nb.njit
def max_bar_one2(arr, result, to_compare, to_add):
largest, second_largest = find_two_largest(arr)
missing_maxes = result
for i in range(arr.shape[0]):
if arr[i] == largest:
if largest != second_largest:
missing_maxes[i] = second_largest
else:
missing_maxes[i] = largest # largest == second_largest
else:
missing_maxes[i] = largest
missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
return missing_maxes
@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
"""
Run max_bar_one on each row but the last, prepend an all -inf row
"""
m, n = d.shape
result = np.empty((m, n))
result[0] = d[0] + to_add
for i in nb.prange(0, m - 1):
max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
return result
@nb.njit
def main_function_parallel2(d, subcusum, j):
return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])
def get_d_cumsum_rows(n):
A = np.random.randint(-300, 400, (n, n)).astype(float)
cusum_rows = np.cumsum(A, axis=1)
d = np.random.randint(-300, 400, (n, n))
return d, cusum_rows
n = 10
np.random.seed(42)
out1 = main_function(*get_d_cumsum_rows(n), 0)
np.random.seed(42)
out2 = main_function_parallel(*get_d_cumsum_rows(n), 0)
np.random.seed(42)
out3 = main_function_parallel2(*get_d_cumsum_rows(n), 0)
assert np.allclose(out1, out2)
assert np.allclose(out1, out3)
t1 = timeit(
"main_function(a, b, 0)",
setup="n=5000;a,b=get_d_cumsum_rows(n)",
globals=globals(),
number=100,
)
t2 = timeit(
"main_function_parallel(a, b, 0)",
setup="n=5000;a,b=get_d_cumsum_rows(n)",
globals=globals(),
number=100,
)
t3 = timeit(
"main_function_parallel2(a, b, 0)",
setup="n=5000;a,b=get_d_cumsum_rows(n)",
globals=globals(),
number=100,
)
print(t1)
print(t2)
print(t3)
我电脑上的打印(AMD 5700x):
7.003944834927097
4.12014868715778
2.2788363839499652