我想计算一下:

from jax.scipy.linalg import expm
import jax.numpy as jnp
from functools import lru_cache, reduce

num_qubits=2
theta = jnp.asarray(np.pi*np.random.random((15,2,2,2,2,2,2,2,2)))

def pauli_matrix(num_qubits):
    _pauli_matrices = jnp.array(
    [[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]]
    )
    return reduce(jnp.kron, (_pauli_matrices for _ in range(num_qubits)))[1:]

def SpecialUnitary(num_qubits,theta):
    assert theta.shape[0] == 15
    A = jnp.tensordot(theta, pauli_matrix(num_qubits), axes=[[0], [0]])
    print(f'{A.shape= }{pauli_matrix(num_qubits).shape=}{theta.shape=}')
    return expm(1j*A/2)

SpecialUnitary(num_qubits,theta)

形状:A.shape= (2, 2, 2, 2, 2, 2, 2, 2, 4, 4)pauli_matrix(num_qubits).shape=(15, 4, 4)theta.shape=(15, 2, 2, 2, 2, 2, 2, 2, 2) 错误:ValueError: expected A to be a square matrix

我被卡住了,因为文档说exPM是在最后两个轴上计算的,这两个轴必须是正方形的,这是完成的.

推荐答案

最近的JAX版本支持Batched expm,您会发现这在JAX v0.4.7或更高版本中工作得很好:

import jax.numpy as jnp
import jax.scipy.linalg

X = jnp.arange(128.0).reshape(2, 2, 2, 4, 4)

result = jax.scipy.linalg.expm(X)
print(result.shape)
# (2, 2, 2, 4, 4)

如果出于某种原因,您必须使用较旧的JAX版本,您可以使用jax.numpy.vectorize来解决这个问题.例如:

expm = jnp.vectorize(jax.scipy.linalg.expm, signature='(n,n)->(n,n)')

result = expm(X)
print(result.shape)
# (2, 2, 2, 4, 4)

Python相关问答推荐

调查TensorFlow和PyTorch性能的差异

Pandas或pyspark跨越列创建

使用unmanagedexports从Python调用的c#DLC

从收件箱获取特定列中的重复行

FastAPI:使用APIRouter路由子模块功能

将numpy数组与空数组相加

如何判断LazyFrame是否为空?

在两极中实施频率编码

Python中的函数中是否有充分的理由接受float而不接受int?

如何使用Google Gemini API为单个提示生成多个响应?

在函数内部使用eval(),将函数的输入作为字符串的一部分

如何在BeautifulSoup中链接Find()方法并处理无?

在内部列表上滚动窗口

我从带有langchain的mongoDB中的vector serch获得一个空数组

Excel图表-使用openpyxl更改水平轴与Y轴相交的位置(Python)

为什么这个带有List输入的简单numba函数这么慢

如何从在虚拟Python环境中运行的脚本中运行需要宿主Python环境的Shell脚本?

Julia CSV for Python中的等效性Pandas index_col参数

删除marplotlib条形图上的底边

Flash只从html表单中获取一个值