如果您对当前的方法进行JIT编译,您会发现它和做一些更复杂的事情一样高效.
查看argmin
的实现,您将看到它在仅返回索引之前计算值和索引:https://github.com/google/jax/blob/jax-v0.4.18/jax/_src/lax/lax.py#L3892-L3914
如果需要,可以遵循此实现并使用lax.reduce
定义一个函数,该函数在一次传递中同时返回这两个值:
import jax
import jax.numpy as jnp
@jax.jit
def min_and_argmin_onepass(x):
# This only works for 1D float arrays, but you could generalize it.
assert x.ndim == 1
assert jnp.issubdtype(x.dtype, jnp.floating)
def reducer(op_val_index, acc_val_index):
op_val, op_index = op_val_index
acc_val, acc_index = acc_val_index
pick_op_val = (op_val < acc_val) | jnp.isnan(op_val)
pick_op_index = pick_op_val | ((op_val == acc_val) & (op_index < acc_index))
return (jnp.where(pick_op_val, op_val, acc_val),
jnp.where(pick_op_index, op_index, acc_index))
indices = jnp.arange(len(x))
return jax.lax.reduce((x, indices), (jnp.inf, 0), reducer, (0,))
对此进行测试,我们发现它与不太复杂的方法的输出相匹配:
@jax.jit
def min_and_argmin(x):
i = jnp.argmin(x)
return x[i], i
x = jax.random.uniform(jax.random.key(0), (1000000,))
print(min_and_argmin_onepass(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))
print(min_and_argmin(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))
如果将两者的运行时进行比较,您将看到类似的运行时:
%timeit jax.block_until_ready(min_and_argmin_onepass(x))
# 2.17 ms ± 68.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.block_until_ready(min_and_argmin(x))
# 2.07 ms ± 66.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
这里的jax.jit
修饰符意味着编译器以不太复杂的方法优化操作序列,其结果是您不会从试图更巧妙地表达事物中获得太多好处.鉴于此,我认为您最好的 Select 是坚持使用原始代码,而不是试图超越XLA编译器进行优化.