我有以下代码来计算函数的导数:

import jax
import jax.numpy as jnp


def f(x):
    return jnp.prod(x)


df1 = jax.grad(f)
df2 = jax.jacobian(df1)
df3 = jax.jacobian(df2)

有了这个,所有的偏导数都是可用的,例如(另外还有vmap):

x = jnp.array([[ 1.,  2.,  3.,  4.,  5.],
               [ 6.,  7.,  8.,  9., 10.],
               [11., 12., 13., 14., 15.],
               [16., 17., 18., 19., 20.],
               [21., 22., 23., 24., 25.],
               [26., 27., 28., 29., 30.]])
df3_x0_x2_x4 = jax.vmap(df3)(x)[:, 0, 2, 4]
print(df3_x0_x2_x4)
# [  8.  63. 168. 323. 528. 783.]

问题是,我如何才能只计算df3_x0_x2_x4,避免所有不必要的导数计算(并让f只有一个向量参数)?

推荐答案

问题是,我如何才能只计算df3_x0_x2_x4,避免所有不必要的导数计算(并让f只有一个向量参数)?

从本质上讲,您是在请求一种计算稀疏Hessian和Jacobian的方法;JAX对此并不普遍支持(请参阅上一期线程;例如https://github.com/google/jax/issues/1032).

Edit

但是,在这种特殊情况下,由于您有效地计算了每个导数通道中单个元素的梯度/jaacobian,因此您可以通过在每个变换中仅将JVP应用于单个one-hot向量来做得更好.例如:

def deriv(f, x, v):
  return jax.jvp(f, [x], [v])[1]

def one_hot(i):
  return jnp.zeros(x.shape[1]).at[i].set(1)

df_x0 = lambda x: deriv(f, x, one_hot(0))
df2_x0_x2 = lambda x: deriv(df_x0, x, one_hot(2))
df3_x0_x2_x4 = lambda x: deriv(df2_x0_x2, x, one_hot(4))
print(jax.vmap(df3_x0_x2_x4)(x))
# [  8.  63. 168. 323. 528. 783.]

Previous answer

如果你愿意放松你的"让f只有一个参数"的标准,你可以这样做:

def f(*x):
  return jnp.prod(jnp.asarray(x))

df1 = jax.grad(f, argnums=4)
df2 = jax.jacobian(df1, argnums=2)
df3 = jax.jacobian(df2, argnums=0)

df3_x0_x2_x4 = jax.vmap(df3)(*(x.T))
print(df3_x0_x2_x4)
# [  8.  63. 168. 323. 528. 783.]

在这里,您不需要计算所有的渐变并切出结果,而是只计算与您感兴趣的特定三个元素相关的渐变.

Python相关问答推荐

Locust请求中的Python和参数

我必须将Sigmoid函数与r2值的两种类型的数据集(每种6个数据集)进行匹配,然后绘制匹配函数的求导.我会犯错

如何在箱形图中添加绘制线的传奇?

在Pandas DataFrame操作中用链接替换'方法的更有效方法

根据二元组列表在pandas中创建新列

如何使用根据其他值相似的列从列表中获取的中间值填充空NaN数据

如何使用数组的最小条目拆分数组

当从Docker的--env-file参数读取Python中的环境变量时,每个\n都会添加一个\'.如何没有额外的?

在Python argparse包中添加formatter_class MetavarTypeHelpFormatter时, - help不再工作""""

使用NeuralProphet绘制置信区间时出错

如何使用scipy的curve_fit与约束,其中拟合的曲线总是在观测值之下?

isinstance()在使用dill.dump和dill.load后,对列表中包含的对象失败

以逻辑方式获取自己的pyproject.toml依赖项

Flask Jinja2如果语句总是计算为false&

计算机找不到已安装的库'

提取最内层嵌套链接

合并相似列表

Python如何导入类的实例

read_csv分隔符正在创建无关的空列

删除Dataframe中的第一个空白行并重新索引列