目前,jax.lax.cond
适用于一个布尔条件.有没有办法将它扩展到多个布尔条件?
例如,下面是一个不可跟踪的函数:
def func(x):
if x < 0: return x
elif (x >= 0) & (x < 1): return 2*x
else: return 3*x
如何在JAX中以可跟踪的方式编写此函数?
目前,jax.lax.cond
适用于一个布尔条件.有没有办法将它扩展到多个布尔条件?
例如,下面是一个不可跟踪的函数:
def func(x):
if x < 0: return x
elif (x >= 0) & (x < 1): return 2*x
else: return 3*x
如何在JAX中以可跟踪的方式编写此函数?
一种简洁的方式来写这样的东西是使用jnp.select
:
import jax
import jax.numpy as jnp
@jax.jit
def func(x):
return jnp.select([x < 0, x < 1], [x, 2 * x], default=3 * x)
x = jnp.array([-0.5, 0.5, 1.5])
print(func(x))
# [-0.5 1. 4.5]