I'm trying to optimize a particular piece of code to calculate the mahalanobis distance in a vectorized manner. I have a standard implementation which used traditional python multiplication, and another implementation which uses einsum. However, I'm surprised that the einsum implementation is slower than the standard python implementation. Is there anything I'm doing inefficiently in einsum, or are there potentially other methods such as tensordot that I should be looking into?

BATCH_SZ = 128

xvals = np.random.random((BATCH_SZ, 1, 4))
means = np.random.random((GAUSSIANS, 1, 4))
inv_covs = np.random.random((GAUSSIANS, 4, 4))
xvals_newdim = xvals[:, np.newaxis, ...]
means_newdim = means[np.newaxis, ...]
diff_newdim = xvals_newdim - means_newdim
regular = diff_newdim @ inv_covs @ (diff_newdim).transpose(0, 1, 3, 2)

>> 731 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
diff = xvals - means.squeeze(1)
einsum = np.einsum("ijk,jkl,ijl->ij", diff, inv_covs, diff)

>> 949 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


First thing first. One need to understand what is going on to optimize such a code, then profile, then estimate the time, and only then find a better solution.

TL;DR: both versions are inefficient and runs serially. Neither BLAS libraries nor Numpy are designed for optimizing this use-case. In fact, even basic Numpy operations are not efficient when the last axis is very small (ie. 4). This can be optimized using Numba by writing an implementation specifically designed for your size of matrices.


@ is a Python operator but it calls Numpy function internally like + or * for example. It performs a loop iterating over all matrices and call a highly optimized BLAS implementation on each matrix. A BLAS is a numerical algebra library. There are many existing BLAS but the default one for Numpy is generally OpenBLAS which is pretty optimized, especially for large matrices. Please note also that np.einsum can call BLAS implementations in specific pattern (if the optimize flag is set properly though) but this is not the case here. It is also worth mentioning that np.einsum is well optimized for 2 matrices in input but less-well for 3 matrices and not optimized for more matrices in parameter. This is because the number of possibility grows exponentially and that the code do the optimization manually. For more information about how np.einsum works, please read How is numpy.einsum implemented?.


The thing is you are multiplying many very-small matrices and most BLAS implementations are not optimized for that. In fact, Numpy either: the cost of the generic loop iteration can become big compared to the computation, not to mention the function call to the BLAS. A profiling of Numpy shows that the slowest function of the np.einsum implementation is PyArray_TransferNDimToStrided. This function is not the main computing function but a helper one. In fact, the main computing function takes only 20% of the overall time which leaves a lot of room for improvement! The same is true for the BLAS implementation: cblas_dgemv only takes about 20% as well as dgemv_n_HASWELL (the main computing kernel of the BLAS cblas_dgemv function). The rest is nearly pure overhead of the BLAS library or Numpy (roughly half the time for both). Moreover, both version runs serially. Indeed, np.einsum is not optimized to run with multiple threads and the BLAS cannot use multiple threads since the matrices are too small so multiple threads can be useful (since multi-threading has a significant overhead). This means both versions are pretty inefficient.

Performance metric

To know how inefficient the versions are, one need to know the amount of computation to do and the speed of the processor. The number of Flop (floating-point operation)is provided by np.einsum_path and is 5.120e+05 (for an optimized implementation, otherwise it is 6.144e+05). Mainstream CPUs usually performs >=100 GFlops/s with multiple threads and dozens of GFlops/s serially. For example my i5-9600KF processor can achieve 300-400 GFlops/s in parallel and 50-60 GFlops/s serially. Since the computation last for 0.52 ms for the BLAS version (best), this means the code runs at 1 GFlops/s which is a poor result compared to the optimal.


On solution to speed up the computation is to design a Numba (JIT compiler) or Cython (Python to C compiler) implementation that is optimized for your specific sizes of matrices. Indeed, the last dimension is too small for generic codes to be fast. Even a basic compiled code would not be very fast in this case: even the overhead of a C loop can be quite big compared to the actual computation. We can tell to the compiler that the size some matrix axis is small and fixed at compilation time so the compiler can generate a much faster code (thanks to loop unrolling, tiling and SIMD instructions). This can be done with a basic assert in Numba. In addition, we can use the fastmath=True flag so to speed the computation even more if there is no special floating-point (FP) values like NaN or subnormal numbers used. This flag can also impact the accuracy of the result since is assume FP math is associative (which is not true). Put it shortly, it breaks the IEEE-754 standard for sake of performance. Here is the resulting code:

import numba as nb

# use `fastmath=True` for better performance if there is no 
# special value used and the accuracy is not critical.
@nb.njit('(float64[:,:,::1], float64[:,:,::1])', fastmath=True)
def compute_fast_einsum(diff, inv_covs):
    ni, nj, nk = diff.shape
    nl = inv_covs.shape[2]
    assert inv_covs.shape == (nj, nk, nl)
    assert nk == 4 and nl == 4
    res = np.empty((ni, nj), dtype=np.float64)
    for i in range(ni):
        for j in range(nj):
            s = 0.0
            for k in range(nk):
                for l in range(nl):
                    s += diff[i, j, k] * inv_covs[j, k, l] * diff[i, j, l]
            res[i, j] = s
    return res

diff = xvals - means.squeeze(1)
compute_fast_einsum(diff, inv_covs)


Here are performance results on my machine (mean ± std. dev. of 7 runs, 1000 loops each):

@ operator:           602 µs ± 3.33 µs per loop 
einsum:               698 µs ± 4.62 µs per loop

Numba code:           193 µs ± 544 ns per loop
Numba + fastmath:     177 µs ± 624 ns per loop
Best Numba:         < 100 µs                     <------ 6x-7x faster !

Note that 100 µs is spent in the computation of diff which is not efficient. This one can be also optimized with Numba. In fact, the value of diff can be compute on the fly in the i-based loop from other arrays. This make the computation more cache friendly. This version is called "best Numba" in the results. Note that the Numba versions are not even using multiple threads. That being said, the overhead of multi-threading is generally about 5-500 µs so it may be slower on some machine to use multiple threads (on mainstream PCs, ie. not computing server, the overhead is generally 5-100 µs and it is about 10 µs on my machine).


如何在带有 GUI 的 python 游戏中设置回答时间限制?


从 yahoo Finance python 一次下载多只股票


Pandas 将列格式化为货币



所有 Python dunder 方法的列表 - 您需要实现哪些方法才能正确代理对象?

无法在 Windows 8 中使用 Python 3.3 找到 vcvarsall.bat

如何在 python 3.x 中禁用 ssl 判断?

有效地判断一个元素是否在列表中至少出现 n 次

为什么 Python 生成器中的异常没有被捕获?


有没有办法在 Python 中为最小值和最大值返回自定义值?

无法更新到 Spyder 4.0.0



Python:try 从导入的包中导入模块时出现“ModuleNotFoundError”

Python SyntaxError:无效的语法 end=''

如何在 Python 中将元组的元组转换为 pandas.DataFrame?