在一段Python代码中,我需要在某个时刻分别将两个2x2矩阵的大列表相乘.在代码中,这两个列表都是形状为(n,2,2)的数字array.另一(n,2,2)数组中的预期结果,其中矩阵1是第一列表的矩阵1与第二列表的矩阵1之间的乘法的结果,依此类推.
经过一些分析后,我发现矩阵乘法是性能瓶颈.出于好奇,我试着"显式"地编写矩阵乘法.下面是一个带有测量运行时的代码示例.
import timeit
import numpy as np
def explicit_2x2_matrices_multiplication(
mats_a: np.ndarray, mats_b: np.ndarray
) -> np.ndarray:
matrices_multiplied = np.empty_like(mats_b)
for i in range(2):
for j in range(2):
matrices_multiplied[:, i, j] = (
mats_a[:, i, 0] * mats_b[:, 0, j] + mats_a[:, i, 1] * mats_b[:, 1, j]
)
return matrices_multiplied
matrices_a = np.random.random((1000, 2, 2))
matrices_b = np.random.random((1000, 2, 2))
assert np.allclose( # Checking that the explicit version is correct
matrices_a @ matrices_b,
explicit_2x2_matrices_multiplication(matrices_a, matrices_b),
)
print( # 1.1814142999992328 seconds
timeit.timeit(lambda: matrices_a @ matrices_b, number=10000)
)
print( # 1.1954495010013488 seconds
timeit.timeit(lambda: np.matmul(matrices_a, matrices_b), number=10000)
)
print( # 2.2304022700009227 seconds
timeit.timeit(lambda: np.einsum('lij,ljk->lik', matrices_a, matrices_b), number=10000)
)
print( # 0.19581600800120214 seconds
timeit.timeit(
lambda: explicit_2x2_matrices_multiplication(matrices_a, matrices_b),
number=10000,
)
)
如在代码中测试的,该函数产生与常规矩阵__matmul__
结果相同的结果.然而,不同的是速度:在我的机器上,显式表达式的速度最多快10倍.
这对我来说是一个相当令人惊讶的结果.我原本预计NumPy表达式会更快,或者至少与更长的Python版本相当,而不是像我们在这里看到的那样慢一个数量级.我很想知道为什么业绩差异如此之大.
我运行的是NumPy版本1.25和Python版本3.10.6.