目前,jax.lax.cond适用于一个布尔条件.有没有办法将它扩展到多个布尔条件?

例如,下面是一个不可跟踪的函数:

def func(x):
    if x < 0: return x
    elif (x >= 0) & (x < 1): return 2*x
    else: return 3*x

如何在JAX中以可跟踪的方式编写此函数?

推荐答案

一种简洁的方式来写这样的东西是使用jnp.select:

import jax
import jax.numpy as jnp

@jax.jit
def func(x):
  return jnp.select([x < 0, x < 1], [x, 2 * x], default=3 * x)

x = jnp.array([-0.5, 0.5, 1.5])
print(func(x))
# [-0.5  1.   4.5]

Python相关问答推荐

如何标记Spacy中不包含特定符号的单词?

未删除映射表的行

输出中带有南的亚麻神经网络

使可滚动框架在tkinter环境中看起来自然

C#使用程序从Python中执行Exec文件

迭代嵌套字典的值

UNIQUE约束失败:customuser. username

合并帧,但不按合并键排序

无论输入分辨率如何,稳定扩散管道始终输出512 * 512张图像

Polars Group by描述扩展

如何在GEKKO中使用复共轭物

如何在一组行中找到循环?

如何在Airflow执行日期中保留日期并将时间转换为00:00

在第一次调用时使用不同行为的re. sub的最佳方式

Python如何导入类的实例

我怎么才能用拉夫分拣呢?

TypeError:';Locator';对象无法在PlayWriter中使用.first()调用

Matplotlib中的曲线箭头样式

如何将验证器应用于PYDANC2中的EACHY_ITEM?

Fake pathlib.使用pyfakefs的类变量中的路径'