JAX documentation中,涵盖了具有单个输出的函数的自定义导数.我想知道如何实现自定义衍生物的函数与多个输出,如这一个?

# want to define custom derivative of out_2 with respect to *args
def test_func(*args, **kwargs):
    ...
    return out_1, out_2

推荐答案

您可以为具有任意数量的输入和输出的函数定义定制的导数:只需将适当数量的元素添加到custom_jvp规则中的primalstangents个元组.例如:

import jax
import jax.numpy as jnp

@jax.custom_jvp
def f(x, y):
  return x * y, x / y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primals_out = f(x, y)
  tangents_out = (x_dot * y + y_dot * x, 
                  x_dot / y - y_dot * x / y ** 2)
  return primals_out, tangents_out

x = jnp.float32(0.5)
y = jnp.float32(2.0)

jax.jacobian(f, argnums=(0, 1))(x, y)
# ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
#  (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))

将其与使用相同函数的标准非定制导数规则计算的结果进行比较,可以看出结果是等价的:

def f2(x, y):
  return x * y, x / y

jax.jacobian(f2, argnums=(0, 1))(x, y)
# ((Array(2., dtype=float32), Array(0.5, dtype=float32)),
#  (Array(0.5, dtype=float32), Array(-0.125, dtype=float32)))

Python相关问答推荐

Godot:需要碰撞的对象的AdditionerBody2D或Area2D以及queue_free?

OR—Tools中CP—SAT求解器的IntVar设置值

当点击tkinter菜单而不是菜单选项时,如何执行命令?

如何在Pyplot表中舍入值

重置PD帧中的值

在Docker容器(Alpine)上运行的Python应用程序中读取. accdb数据库

使用字典或列表的值组合

以异步方式填充Pandas 数据帧

使用SeleniumBase保存和加载Cookie时出现问题

如何在Python中将超链接添加到PDF中每个页面的顶部?

504未连接IB API TWS错误—即使API连接显示已接受''

使用np.fft.fft2和cv2.dft重现相位谱.为什么结果并不相似呢?

将相应的值从第2列合并到第1列(Pandas )

解析CSV文件以将详细信息添加到XML文件

生产者/消费者-Queue.get by list

这是什么排序算法?(将迭代器与自身合并&&Q;)

torch 二维张量与三维张量欧氏距离的计算

如何强制SqlalChemy指向与连接字符串的默认架构不同的架构

sklearn ridgeCV与ElasticNetCV

优化将索引分配给极轴上的拆分数据组