我有以下一段代码,它计算一组批处理要素的马氏距离,在我的设备上大约需要100ms,其中大部分是由于Delta和inv_coariance之间的矩阵乘法

Delta是维度为874x32x100的矩阵,inv_covariance是维度为874x100x100的矩阵

def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
    batch, channel, height, width = embedding.shape
    embedding = embedding.reshape(batch, channel, height * width)

    # calculate mahalanobis distances
    delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))

    distances = ((delta @ inv_covariance) * delta).sum(2).transpose(1, 0)
    distances = distances.reshape(batch, 1, height, width)
    distances = np.sqrt(distances.clip(0))

    return distances

我已经try 将代码转换为使用Numba和@NJIT,我已经预先分配了中间矩阵,并且我正在try 使用for循环执行较小的矩阵乘法,因为3维矩阵不支持matmul.

def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
    batch, channel, height, width = embedding.shape
    embedding = embedding.reshape(batch, channel, height * width)

    # calculate mahalanobis distances
    delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
    inv_covariance = np.ascontiguousarray(inv_covariance)
    
    intermediate_matrix = np.zeros_like(delta)
    for i in range(intermediate_matrix.shape[0]):
        intermediate_matrix[i] = delta[i] @ inv_covariance[i]

    distances = (intermediate_matrix * delta).sum(2).transpose(1, 0)
    distances = np.ascontiguousarray(distances)
    distances = distances.reshape(batch, 1, height, width)
    distances = np.sqrt(distances.clip(0))

    return distances

我添加了几个ascontiguousarray,最后一个是重要的或代码不工作,其他的是为了 suppress 警告说@将执行得更快(它似乎不太多).

有没有一种方法可以让代码更快,要么通过改进它,要么用不同的数学方式重新思考它?

Edit - Final implementation

基于Jérôme Richard的回答,我最终得到了这个代码

@nb.njit()
def matmul(delta: np.ndarray, inv_covariance: np.ndarray):
    """Computes distances = ((delta[i] @ inv_covariance[i]) * delta[i]).sum(2) using numba.

    Args:
        delta: Matrix of dimension BxD
        inv_covariance: Matrix of dimension DxD

    Returns:
        Matrix of dimension BxD
    """
    si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
    assert sk == inv_covariance.shape[0]
    line = np.zeros(sj, dtype=delta.dtype)
    res = np.zeros(si, dtype=delta.dtype)
    for i in range(si):
        line.fill(0.0)
        for k in range(sk):
            factor = delta[i, k]
            for j in range(sj):
                line[j] += factor * inv_covariance[k, j]
        for j in range(sj):
            res[i] += line[j] * delta[i, j]
    return res


@nb.njit
def mean_subtraction(embeddings: np.ndarray, mean: np.ndarray):
    """Computes embeddings - mean using numba, this is required as I have errors with the default numpy
    implementation.

    Args:
        embeddings: Embedding matrix of dimension FxBxD
        mean: Mean matrix of dimension BxD

    Returns:
        Delta matrix of dimension FxBxD
    """
    output_matrix = np.zeros_like(embeddings)
    for i in range(embeddings.shape[0]):
        output_matrix[i] = embeddings[i] - mean

    return output_matrix


@nb.njit(parallel=True)
def compute_distance_numba(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
    """Compute distance score using numba.

    Args:
        embedding: Embedding Vector
        mean: Mean of the multivariate Gaussian distribution
        inv_covariance: Inverse Covariance matrix of the multivariate Gaussian distribution.
    """
    batch, channel, height, width = embedding.shape
    embedding = embedding.reshape(batch, channel, height * width)

    delta = np.ascontiguousarray(mean_subtraction(embedding, mean).transpose(2, 0, 1))
    inv_covariance = np.ascontiguousarray(inv_covariance)

    intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]), dtype=delta.dtype)
    for i in nb.prange(intermediate_matrix.shape[0]):
        intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])

    distances = intermediate_matrix.transpose(1, 0)
    distances = np.ascontiguousarray(distances)
    distances = distances.reshape(batch, 1, height, width)
    distances = np.sqrt(distances.clip(0))

    return distances

与可接受的答案相比,更改的是用于减法的自定义函数,并为中间矩阵添加了dtype以避免默认的np.flat64.

推荐答案

首先,矩阵乘法是由称为BLAS的库来完成的,并且大多数实现都是高效的并行实现.也就是说,对于批量的小矩阵,并行实现不可能如此高效.事实上,粒度太小,因此使用多线程的开销变得很大.这是better to parallelize the outer-loop and use a sequential matrix-multiplication code美元.

由于矩阵乘法所涉及的矩阵非常小,因此最好是reimplement the matrix-multiplication manually.实际上,这消除了调用矩阵乘法库(BLAS)函数的开销,并确保在矩阵乘法过程中不使用线程.但是,人们需要关心连续地读/写值,所以操作是SIMD-friendly.

最重要的是,the matrix-multiplication can be merged with the next line (intermediate_matrix * delta).sum(2)这样写一个较小的输出数组,并避免读回大的临时array.自the RAM is slow年以来,这一点至关重要.此策略还减少了内存占用,同时速度更快,可伸缩性更好.当然,将操作与(embedding - mean).transpose(2, 0, 1)行合并也是一个好主意,尽管我没有测试它.


实施

以下是一个考虑了除最后一点之外的所有要点的实现:

@nb.njit()
def matmul(delta, inv_covariance):
    si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
    assert sk == inv_covariance.shape[0]
    line = np.zeros(sj, dtype=delta.dtype)
    res = np.zeros(si, dtype=delta.dtype)
    for i in range(si):
        line.fill(0.0)
        for k in range(sk):
            factor = delta[i, k]
            for j in range(sj):
                line[j] += factor * inv_covariance[k, j]
        for j in range(sj):
            res[i] += line[j] * delta[i, j]
    return res

@nb.njit(parallel=True)
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
    batch, channel, height, width = embedding.shape
    embedding = embedding.reshape(batch, channel, height * width)

    # calculate mahalanobis distances
    delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
    inv_covariance = np.ascontiguousarray(inv_covariance)
    
    intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]))
    for i in nb.prange(intermediate_matrix.shape[0]):
        intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])

    distances = intermediate_matrix.transpose(1, 0)
    distances = np.ascontiguousarray(distances)
    distances = distances.reshape(batch, 1, height, width)
    distances = np.sqrt(distances.clip(0))

    return distances

结果

在我的i5-9600KF CPU(6核)上,大约是3 times faster.大部分时间似乎都花在第一行上,也可以合并第一行以获得更好的性能(假设数组跨度合理).注编译时间不包括在计时中,结果是相等的(以np.allclose为基准).

Python相关问答推荐

如何修复fpdf中的线路出血

将numpy数组与空数组相加

Tkinter滑动条标签.我不确定如何删除滑动块标签或更改其文本

使用scipy. optimate.least_squares()用可变数量的参数匹配两条曲线

配置Sweetviz以分析对象类型列,而无需转换

使用mySQL的SQlalchemy过滤重叠时间段

Polars比较了两个预设-有没有方法在第一次不匹配时立即失败

比较两个数据帧并并排附加结果(获取性能警告)

由于NEP 50,向uint 8添加-256的代码是否会在numpy 2中失败?

如何比较numPy数组中的两个图像以获取它们不同的像素

如何让剧作家等待Python中出现特定cookie(然后返回它)?

Python 约束无法解决n皇后之谜

如何在虚拟Python环境中运行Python程序?

用NumPy优化a[i] = a[i-1]*b[i] + c[i]的迭代计算

Streamlit应用程序中的Plotly条形图中未正确显示Y轴刻度

python中字符串的条件替换

如何从列表框中 Select 而不出错?

如何排除prefecture_related中查询集为空的实例?

在不同的帧B中判断帧A中的子字符串,每个帧的大小不同

搜索按钮不工作,Python tkinter