我有一个m n x n矩阵的列表和一个m实值(alphas)的列表.n和m的值可能相当大.我试图用alphas计算矩阵的加权和.

我想知道是否有一个numpy函数(或任何其他库)可以比manual for loop方法更快地完成这项工作.

我在下面列出了我当前的功能.

def calculate_matrix_sums(mats, alphas):
    """
        Calculate the weighted sum of matrices in mats with weights alpha
    """
    k_mults = [np.multiply(mats[i], alphas[i]) for i in range(len(alphas))]
    k_sums1 = np.matrix(k_mults[0]) + np.matrix(k_mults[1])
    for i in range(2, len(k_mults)):
        k_sums1 = k_sums1 + np.asmatrix(k_mults[i])
    k_sums2 = np.asarray(k_sums1).astype(float)
    k_sums2 = k_sums2.reshape(len(mats[0]), len(mats[0]))
    return k_sums2

和示例代码:

matrices = np.asarray([np.array([[1., 0.77841638, 0.53239253, 0.9444068, 0.93024477],
                                 [0.77841638, 1., 0.7221497, 0.5805838, 0.68501944],
                                 [0.53239253, 0.7221497, 1., 0.36986265, 0.62792847],
                                 [0.9444068, 0.5805838, 0.36986265, 1., 0.88303226],
                                 [0.93024477, 0.68501944, 0.62792847, 0.88303226, 1.]]),
                       np.array([[1., 0.45650032, 0.13898701, 0.83605729, 0.79743304],
                                 [0.45650032, 1., 0.36094014, 0.18229867, 0.30596445],
                                 [0.13898701, 0.36094014, 1., 0.04443844, 0.23300302],
                                 [0.83605729, 0.18229867, 0.04443844, 1., 0.67745532],
                                 [0.79743304, 0.30596445, 0.23300302, 0.67745532, 1.]])])
alpha_vals = [0.47547796, 0.52452204]

print(calculate_matrix_sums(matrices, alpha_vals))

欢迎提出任何建议.

推荐答案

您可以reshape alpha_vals的形状,使其在matrices的第一个轴上正确广播:

(np.array(alpha_vals)[:, None, None] * matrices).sum(axis=0)

或者,您可以调整matrices的步长,使最后一个维度对应于alpha_vals:

(np.moveaxis(matrices, 0, -1) * alpha_vals).sum(axis=-1)

您也可以将np.einsum用于这种情况(可能是最优雅的解决方案):

np.einsum('ijk,i->jk', matrices, alpha_vals)

Python相关问答推荐

如何在BeautifulSoup中链接Find()方法并处理无?

我在使用fill_between()将最大和最小带应用到我的图表中时遇到问题

如何使用symy打印方程?

在Pandas DataFrame操作中用链接替换'方法的更有效方法

如何获取TFIDF Transformer中的值?

如何在虚拟Python环境中运行Python程序?

为什么默认情况下所有Python类都是可调用的?

如果值发生变化,则列上的极性累积和

如何在Polars中从列表中的所有 struct 中 Select 字段?

Plotly Dash Creating Interactive Graph下拉列表

在matplotlib中删除子图之间的间隙_mosaic

如何在两列上groupBy,并使用pyspark计算每个分组列的平均总价值

Python—压缩叶 map html作为邮箱附件并通过sendgrid发送

如何在Python请求中组合多个适配器?

如何在Airflow执行日期中保留日期并将时间转换为00:00

如何获取包含`try`外部堆栈的`__traceback__`属性的异常

如何提高Pandas DataFrame中随机列 Select 和分配的效率?

Django抛出重复的键值违反唯一约束错误

是否需要依赖反转来确保呼叫方和被呼叫方之间的分离?

解析CSV文件以将详细信息添加到XML文件