我有以下一段代码,它计算一组批处理要素的马氏距离,在我的设备上大约需要100ms,其中大部分是由于Delta和inv_coariance之间的矩阵乘法
Delta是维度为874x32x100的矩阵,inv_covariance是维度为874x100x100的矩阵
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
distances = ((delta @ inv_covariance) * delta).sum(2).transpose(1, 0)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
我已经try 将代码转换为使用Numba和@NJIT,我已经预先分配了中间矩阵,并且我正在try 使用for循环执行较小的矩阵乘法,因为3维矩阵不支持matmul.
def compute_distance(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
# calculate mahalanobis distances
delta = np.ascontiguousarray((embedding - mean).transpose(2, 0, 1))
inv_covariance = np.ascontiguousarray(inv_covariance)
intermediate_matrix = np.zeros_like(delta)
for i in range(intermediate_matrix.shape[0]):
intermediate_matrix[i] = delta[i] @ inv_covariance[i]
distances = (intermediate_matrix * delta).sum(2).transpose(1, 0)
distances = np.ascontiguousarray(distances)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
我添加了几个ascontiguousarray,最后一个是重要的或代码不工作,其他的是为了 suppress 警告说@将执行得更快(它似乎不太多).
有没有一种方法可以让代码更快,要么通过改进它,要么用不同的数学方式重新思考它?
Edit - Final implementation
基于Jérôme Richard的回答,我最终得到了这个代码
@nb.njit()
def matmul(delta: np.ndarray, inv_covariance: np.ndarray):
"""Computes distances = ((delta[i] @ inv_covariance[i]) * delta[i]).sum(2) using numba.
Args:
delta: Matrix of dimension BxD
inv_covariance: Matrix of dimension DxD
Returns:
Matrix of dimension BxD
"""
si, sj, sk = delta.shape[0], inv_covariance.shape[1], delta.shape[1]
assert sk == inv_covariance.shape[0]
line = np.zeros(sj, dtype=delta.dtype)
res = np.zeros(si, dtype=delta.dtype)
for i in range(si):
line.fill(0.0)
for k in range(sk):
factor = delta[i, k]
for j in range(sj):
line[j] += factor * inv_covariance[k, j]
for j in range(sj):
res[i] += line[j] * delta[i, j]
return res
@nb.njit
def mean_subtraction(embeddings: np.ndarray, mean: np.ndarray):
"""Computes embeddings - mean using numba, this is required as I have errors with the default numpy
implementation.
Args:
embeddings: Embedding matrix of dimension FxBxD
mean: Mean matrix of dimension BxD
Returns:
Delta matrix of dimension FxBxD
"""
output_matrix = np.zeros_like(embeddings)
for i in range(embeddings.shape[0]):
output_matrix[i] = embeddings[i] - mean
return output_matrix
@nb.njit(parallel=True)
def compute_distance_numba(embedding: np.ndarray, mean: np.ndarray, inv_covariance: np.ndarray) -> np.ndarray:
"""Compute distance score using numba.
Args:
embedding: Embedding Vector
mean: Mean of the multivariate Gaussian distribution
inv_covariance: Inverse Covariance matrix of the multivariate Gaussian distribution.
"""
batch, channel, height, width = embedding.shape
embedding = embedding.reshape(batch, channel, height * width)
delta = np.ascontiguousarray(mean_subtraction(embedding, mean).transpose(2, 0, 1))
inv_covariance = np.ascontiguousarray(inv_covariance)
intermediate_matrix = np.zeros((delta.shape[0], delta.shape[1]), dtype=delta.dtype)
for i in nb.prange(intermediate_matrix.shape[0]):
intermediate_matrix[i] = matmul(delta[i], inv_covariance[i])
distances = intermediate_matrix.transpose(1, 0)
distances = np.ascontiguousarray(distances)
distances = distances.reshape(batch, 1, height, width)
distances = np.sqrt(distances.clip(0))
return distances
与可接受的答案相比,更改的是用于减法的自定义函数,并为中间矩阵添加了dtype以避免默认的np.flat64.