我有以下代码来计算函数的导数:
import jax
import jax.numpy as jnp
def f(x):
return jnp.prod(x)
df1 = jax.grad(f)
df2 = jax.jacobian(df1)
df3 = jax.jacobian(df2)
有了这个,所有的偏导数都是可用的,例如(另外还有vmap
):
x = jnp.array([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.],
[16., 17., 18., 19., 20.],
[21., 22., 23., 24., 25.],
[26., 27., 28., 29., 30.]])
df3_x0_x2_x4 = jax.vmap(df3)(x)[:, 0, 2, 4]
print(df3_x0_x2_x4)
# [ 8. 63. 168. 323. 528. 783.]
问题是,我如何才能只计算df3_x0_x2_x4
,避免所有不必要的导数计算(并让f
只有一个向量参数)?