所以我有一个似乎完全可以接受的循环来进行并行.但当我把它传递给Numba parallel时,它总是给出错误的结果.循环中发生的所有事情是,输入矩阵的一个元素被设置为0,矩阵乘法发生并填充一个新矩阵,然后被设置为0的元素被设置回其原始值.似乎数组a
在每次发送Numba时都会被修改,所以我try 将a
复制到循环中的另一个变量,只修改副本,但却得到了相同的错误结果(未显示).下面是一个简单的例子.我只是不知道问题是什么,也不知道如何解决:
import numpy as np
from scipy.stats import random_correlation
import numba as nb
def myfunc(a, corr):
b = np.zeros(a.shape[0])
for i in range(b.shape[0]):
temp = a[i]
a[i] = 0
b[i] = a@corr@a.T
a[i] = temp
return b
@nb.njit(parallel=True)
def numbafunc(a, corr):
b = np.zeros(a.shape[0])
for i in nb.prange(b.shape[0]):
temp = a[i]
a[i] = 0
b[i] = a@corr@a.T
a[i] = temp
return b
if __name__ == '__main__':
a = np.random.rand(10)
corr = random_correlation.rvs(eigs=[2,2,1,1,1,1,0.5,0.5,0.5,0.5])
b_1 = myfunc(a, corr)
b_2 = numbafunc(a, corr)
# check if serial and Numba results match off the same inputs
print(np.isclose(b_1,b_2))
# double check the original function returns the same results again..
b_1_check = myfunc(a, corr)
print(np.isclose(b_1, b_1_check))
返回所有假值,或至少9/10为假...有人能指出代码的哪一部分对并行化有问题吗?我觉得不错.非常感谢!