我想计算一下两个二阶张量在Python(NumPy)中的次对称和大对称交叉并矢积.

C[i,j,k,l] = (
    A[i, k] * B[j, l] + A[i, l] * B[k, j] + B[i, k] * A[j, l] + B[i, l] * A[k, j]
) / 4

我有两个形状(3, 3, n)的二阶张量,其中最后一个尾轴表示批次维度.我已经在我最近发布的Python包hyperelastic中实现了该函数,希望得到您对我的代码效率的反馈.我想避免任何像Numba这样的JIT东西.

给定两个对称的二阶张量作为缩减向量存储中的NumPy数组,

import numpy as np

# any random 3x3 tensor
F = np.eye(3).reshape(3, 3, 1) + np.random.rand(3, 3, 100000) / 10

# a symmetric 3x3 tensor
C = F.T @ F

# Voigt-notation vector storage 
# (upper triangle items but sorted on the diagonals; starting from the main diag.)
i = [0, 1, 2, 0, 1, 0]
j = [0, 1, 2, 1, 2, 2]
A = C[i, j]

我当前的实现如下所示:

import numpy as np

def cdya(A, B):
    i, j = [a.ravel() for a in np.indices((6, 6))]

    a = np.array([(0, 0), (1, 1), (2, 2), (0, 1), (1, 2), (0, 2)])
    b = np.array([0, 3, 5, 3, 1, 4, 5, 4, 2]).reshape(3, 3)

    i, j, k, l = np.hstack([a[i], a[j]]).T

    ik = b[i, k].reshape(6, 6)
    jl = b[j, l].reshape(6, 6)

    il = b[i, l].reshape(6, 6)
    kj = b[k, j].reshape(6, 6)

    C = (A[ik] * B[jl] + A[il] * B[kj]) / 2

    if A is not B:
        C += (B[ik] * A[jl] + B[il] * A[kj]) / 2
        C /= 2

    return C


A = np.random.rand(6, 100000)
B = np.random.rand(6, 100000)

C = cdya(A, B)

在我的机器上,这个函数大约需要花费几分钟.np.einsumms.在全张量存储中使用np.einsum个相同的实现大约需要花费.160毫秒.

A = np.random.rand(3, 3, 100000)
B = np.random.rand(3, 3, 100000)

C = (
    np.einsum("ij...,kl...->ikjl...", A, B) + np.einsum("ij...,kl...->ilkj...", A, B) +
    np.einsum("ij...,kl...->ikjl...", B, A) + np.einsum("ij...,kl...->ilkj...", B, A)
) / 4

你认为还有改进的余地吗?还是我已经找到了一个很好的解决方案?

首先要感谢大家!

推荐答案

Numpy的问题在于,默认情况下会创建许多相对较大的临时数组(为了方便起见).这太贵了.你可以做in-place operations个,尽管这做起来很麻烦.此外,您还可以对最终的乘法进行因式分解,以确定条件是否成立:

下面是一个例子:

def cdya_opt(A, B):
    i, j = [a.ravel() for a in np.indices((6, 6))]

    a = np.array([(0, 0), (1, 1), (2, 2), (0, 1), (1, 2), (0, 2)])
    b = np.array([0, 3, 5, 3, 1, 4, 5, 4, 2]).reshape(3, 3)

    i, j, k, l = np.hstack([a[i], a[j]]).T

    ik = b[i, k].reshape(6, 6)
    jl = b[j, l].reshape(6, 6)

    il = b[i, l].reshape(6, 6)
    kj = b[k, j].reshape(6, 6)

    x1, x2, x3, x4 = A[ik], B[jl], A[il], B[kj]
    np.multiply(x1, x2, out=x1)
    np.multiply(x3, x4, out=x3)
    np.add(x1, x3, out=x1)

    if A is not B:
        x5, x6, x7, x8 = B[ik], A[jl], B[il], A[kj]
        np.multiply(x5, x6, out=x5)
        np.multiply(x7, x8, out=x7)
        np.add(x5, x7, out=x5)
        np.add(x5, x1, out=x1)
        np.multiply(x1, 0.25, out=x1)
    else:
        np.multiply(x1, 0.5, out=x1)

    return x1

如果您可以使用100之类的其他包,那么您可以编写更快的实现:

import numexpr as ne

def cdya_opt2(A, B):
    i, j = [a.ravel() for a in np.indices((6, 6))]

    a = np.array([(0, 0), (1, 1), (2, 2), (0, 1), (1, 2), (0, 2)])
    b = np.array([0, 3, 5, 3, 1, 4, 5, 4, 2]).reshape(3, 3)

    i, j, k, l = np.hstack([a[i], a[j]]).T

    ik = b[i, k].reshape(6, 6)
    jl = b[j, l].reshape(6, 6)

    il = b[i, l].reshape(6, 6)
    kj = b[k, j].reshape(6, 6)

    x1, x2, x3, x4 = A[ik], B[jl], A[il], B[kj]

    if A is not B:
        x5, x6, x7, x8 = B[ik], A[jl], B[il], A[kj]
        ne.evaluate('(x1 * x2 + x3 * x4 + x5 * x6 + x7 * x8) * 0.25', out=x1)
    else:
        ne.evaluate('(x1 * x2 + x3 * x4) * 0.5', out=x1)

    return x1

性能结果

以下是在我的机器(i5-9600KF CPU,运行在Windows上,安装了Numpy 1.24.3)上的结果:

cdya:       112.0 ms
cdya_opt:    70.7 ms
cdya_opt2:   59.1 ms  <-----

上一个版本的速度几乎快了一倍.大部分时间都花在像A[ik]这样的操作上,而这些操作在Numpy中并没有得到很好的优化.使用Cython(或Numba)是使这段代码更快的好方法.请注意,与Numba相反,Cython不是JIT.

Python相关问答推荐

将大小为n*512的数组绘制到另一个大小为n*256的数组的PC组件

定义同侪组并计算同侪组分析

如何观察cv2.erode()的中间过程?

机器人与Pyton Minecraft服务器状态不和

运行回文查找器代码时发生错误:[类型错误:builtin_index_or_system对象不可订阅]

如何避免Chained when/then分配中的Mypy不兼容类型警告?

不理解Value错误:在Python中使用迭代对象设置时必须具有相等的len键和值

在Python Attrs包中,如何在field_Transformer函数中添加字段?

如何使用LangChain和AzureOpenAI在Python中解决AttribeHelp和BadPressMessage错误?

在Python中管理打开对话框

如何将Docker内部运行的mariadb与主机上Docker外部运行的Python脚本连接起来

如何在Python数据框架中加速序列的符号化

有没有一种方法可以从python的pussompy比较结果中提取文本?

如果条件不满足,我如何获得掩码的第一个索引并获得None?

cv2.matchTemplate函数匹配失败

为一个组的每个子组绘制,

未知依赖项pin—1阻止conda安装""

需要帮助重新调整python fill_between与数据点

解决调用嵌入式函数的XSLT中表达式的语法移位/归约冲突

从Windows Python脚本在WSL上运行Linux应用程序