我正在try 实现快速 routine 来计算能量数组,并找到最小计算值及其指数.以下是我运行良好的代码:

@jit
def findMinEnergy(x):
  def calcEnergy(a):
    return a*a  # very simplified body, it is actually 15 lines of code
  energies = vmap(calcEnergy, in_axes=(0))(x)
  idx = energies.argmin(axis=0)
  minenrgy = energies[idx]
  return idx, minenrgy

我想知道是否有可能不使用(单独的)argmin调用,而是从VMAP返回最小计算的能量值及其索引(类似于其他聚合函数的工作原理,例如jax.sum)?我希望它能更有效率.

推荐答案

如果您对当前的方法进行JIT编译,您会发现它和做一些更复杂的事情一样高效.

查看argmin的实现,您将看到它在仅返回索引之前计算值和索引:https://github.com/google/jax/blob/jax-v0.4.18/jax/_src/lax/lax.py#L3892-L3914

如果需要,可以遵循此实现并使用lax.reduce定义一个函数,该函数在一次传递中同时返回这两个值:

import jax
import jax.numpy as jnp

@jax.jit
def min_and_argmin_onepass(x):
  # This only works for 1D float arrays, but you could generalize it.
  assert x.ndim == 1
  assert jnp.issubdtype(x.dtype, jnp.floating)
  def reducer(op_val_index, acc_val_index):
    op_val, op_index = op_val_index
    acc_val, acc_index = acc_val_index
    pick_op_val = (op_val < acc_val) | jnp.isnan(op_val)
    pick_op_index = pick_op_val | ((op_val == acc_val) & (op_index < acc_index))
    return (jnp.where(pick_op_val, op_val, acc_val),
            jnp.where(pick_op_index, op_index, acc_index))
  indices = jnp.arange(len(x))
  return jax.lax.reduce((x, indices), (jnp.inf, 0), reducer, (0,))

对此进行测试,我们发现它与不太复杂的方法的输出相匹配:

@jax.jit
def min_and_argmin(x):
  i = jnp.argmin(x)
  return x[i], i

x = jax.random.uniform(jax.random.key(0), (1000000,))
print(min_and_argmin_onepass(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))
print(min_and_argmin(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))

如果将两者的运行时进行比较,您将看到类似的运行时:

%timeit jax.block_until_ready(min_and_argmin_onepass(x))
# 2.17 ms ± 68.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.block_until_ready(min_and_argmin(x))
# 2.07 ms ± 66.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这里的jax.jit修饰符意味着编译器以不太复杂的方法优化操作序列,其结果是您不会从试图更巧妙地表达事物中获得太多好处.鉴于此,我认为您最好的 Select 是坚持使用原始代码,而不是试图超越XLA编译器进行优化.

Python相关问答推荐

如何在Pandas 中存储二进制数?

在Python中添加期货之间的延迟

如何将 map 数组组合到pyspark中每列的单个 map 中

如何观察cv2.erode()的中间过程?

Numpy索引argsorted使用integer数组,同时保留排序顺序

KNN分类器中的GridSearchCV

如何使用Tkinter创建两个高度相同的框架(顶部和底部)?

阅读Polars Python中管道的函数定义

如果索引不存在,pandas系列将通过索引获取值,并填充值

试图找到Python方法来部分填充numpy数组

Pandas 滚动最接近的价值

PMMLPipeline._ fit()需要2到3个位置参数,但给出了4个位置参数

数据抓取失败:寻求帮助

如何在Python脚本中附加一个Google tab(已经打开)

numpy卷积与有效

实现神经网络代码时的TypeError

* 动态地 * 修饰Python中的递归函数

python中csv. Dictreader. fieldname的类型是什么?'

python—telegraph—bot send_voice发送空文件

PYTHON、VLC、RTSP.屏幕截图不起作用