TL;DR:我为我的equinox.Module型号创建了一个新实例,并使用OpTax对其进行了调整.一切都运行得很好.当我创建同一模型的new个实例,并try 使用相同的代码、相同的初始值、相同的一切从头开始适应它时,我得到:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

...在税法深处的某个地方.我的代码不比较任何array.错误消息不会显示比较的确切位置.怎么了?

代码

# 1. Import dependencies.
import jax; jax.config.update("jax_enable_x64", True)
import jax.numpy as np, jax.random as rnd, equinox as eqx
import optax

# 2. Define loss function. I'm fairly confident this is correct.
def npdf(x, var):
    return np.exp(-0.5 * x**2 / var) / np.sqrt(2 * np.pi * var)

def mixpdf(x, ps, vars):
    return ps.dot(npdf(x, vars))

def loss(model, series):
    weights, condvars = model(series)
    return -jax.vmap(
        lambda x, vars: np.log(mixpdf(x, weights, vars))
    )(series[1:], condvars[:-1]).mean()

# 3. Define recurrent neural network.
class RNNCell(eqx.Module):
    bias: np.ndarray
    Wx: np.ndarray
    Wh: np.ndarray
    def __init__(self, ncomp: int, n_in: int=1, *, key: np.ndarray):
        k1, k2, k3 = rnd.split(key, 3)
        self.bias = rnd.uniform(k1, (ncomp, ))
        self.Wx = rnd.uniform(k2, (ncomp, n_in))
        self.Wh = 0.9 * rnd.uniform(k3, (ncomp, ))

    def __call__(self, vars_prev, obs):
        vars_new = self.bias + self.Wx @ obs + self.Wh * vars_prev
        return vars_new, vars_new

class RNN(eqx.Module):
    cell: RNNCell
    logits: np.ndarray
    vars0: np.ndarray = eqx.field(static=True)

    def __init__(self, vars0: np.ndarray, n_in=1, *, key: np.ndarray):
        self.vars0 = np.array(vars0)
        K = len(self.vars0)
        self.cell = RNNCell(K, n_in, key=key)
        self.logits = np.zeros(K)

    def __call__(self, series: np.ndarray):
        _, hist = jax.lax.scan(self.cell.__call__, self.vars0, series**2)
        return jax.nn.softmax(self.logits), abs(hist)

    def condvar(self, series):
        weights, variances = self(series)
        return variances @ weights

    def predict(self, series: np.ndarray):
        return self.condvar(series).flatten()[-1]

# 4. Training/fitting code.
def fit(model, logret, nepochs: int, optimizer, loss):
    loss_and_grad = eqx.filter_value_and_grad(loss)
    
    @eqx.filter_jit
    def make_step(model, opt_state):
        loss_val, grads = loss_and_grad(model, logret)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss_val, model, opt_state

    opt_state = optimizer.init(model)
    for epoch in range(nepochs):
        loss_val, model, opt_state = make_step(model, opt_state)
    print("Works!")
    return model

def experiment():
    series = rnd.normal(rnd.PRNGKey(8), (100, 1))
    model = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
    return fit(model, series, 100, optax.adam(0.01), loss)

# 5. Run the exact same code twice.
experiment() # 1st call, works
experiment() # 2nd call, error

错误讯息

> python my_RNN.py
Works!
Traceback (most recent call last):
  File "/Users/forcebru/test/my_RNN.py", line 75, in <module>
    experiment() # 2nd call, error
    ^^^^^^^^^^^^
  File "/Users/forcebru/test/my_RNN.py", line 72, in experiment
    return fit(model, series, 100, optax.adam(0.01), loss)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/test/my_RNN.py", line 65, in fit
    loss_val, model, opt_state = make_step(model, opt_state)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_module.py", line 935, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 200, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
    infer_params_fn(*args, **kwargs)
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 495, in common_infer_params
    jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
                                                                    ^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1150, in _pjit_jaxpr
    jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
                                                   ^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 350, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1089, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2314, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2336, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 49, in fun_wrapped
    out = fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/test/my_RNN.py", line 59, in make_step
    updates, opt_state = optimizer.update(grads, opt_state)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/combine.py", line 59, in update_fn
    updates, new_s = fn(updates, s, params, **extra_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/base.py", line 337, in update
    return tx.update(updates, state, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/transform.py", line 369, in update_fn
    mu_hat = bias_correction(mu, b1, count_inc)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
    infer_params_fn(*args, **kwargs)
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 491, in common_infer_params
    canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 4, in __eq__
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 745, in __bool__
    check_bool_conversion(self)
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 662, in check_bool_conversion
    raise ValueError("The truth value of an array with more than one element is "
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

问题

  • 错误消息显示为File "<string>", line 4, in __eq__,这无济于事.
  • 它指的是操作税代码中的第mu_hat = bias_correction(mu, b1, count_inc)行,但据我所知,它是doesn't compare any arrays行.
  • 它还引用了应该负责JIT编译的JAX代码,但这似乎超出了我的控制范围.

我的模型定义(RNNCellRNN)中有错误吗?我是不是执行了错误的训练循环?我基本上是直接从Equinox docs开始复制的,所以应该没问题.为什么当我第一次拨打experiment()时,它能起作用,而第二次却不行?

推荐答案

这似乎是equinox中的一个错误.函数_process_in_axis_resourcesfunctools.lru_cache中被修饰,这意味着判断所有输入是否与来自前一调用的参数相等.在第二次运行时,这会触发对equinox.Module.__eq__的调用,从而引发错误.您可以通过直接执行相等性判断来查看此问题:

model = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
model2 = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
model == model2
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我建议将此错误报告为https://github.com/patrick-kidger/equinox/issues

您可以通过不将NumPy数组(vars0)存储为静态属性来解决此问题.我怀疑Equinox假设所有静态属性都是Hasable的,而Numpy数组不是.

Edit:我刚刚判断过,并更改了以下内容:

vars0: np.ndarray = eqx.field(static=True)

对此:

vars0: np.ndarray

解决问题.

Edit 2:实际上,Equinox中的静态字段看起来必须是Hasable的,所以这不是Equinox错误,而是一个使用错误(参见https://github.com/patrick-kidger/equinox/issues/154#issuecomment-1561735995处的讨论).您可以try 将vars0存储为元组(这是可哈希的),而不是数组(这不是).

Python相关问答推荐

Python 3.12中的通用[T]类方法隐式类型检索

滚动和,句号来自Pandas列

PywinAuto在Windows 11上引发了Memory错误,但在Windows 10上未引发

Pandas 都是(),但有一个门槛

优化pytorch函数以消除for循环

使用@ guardlasses. guardlass和注释的Python继承

Python—从np.array中 Select 复杂的列子集

组/群集按字符串中的子字符串或子字符串中的字符串轮询数据框

Python中的变量每次增加超过1

处理具有多个独立头的CSV文件

具有相同图例 colored颜色 和标签的堆叠子图

循环浏览每个客户记录,以获取他们来自的第一个/最后一个渠道

根据Pandas中带条件的两个列的值创建新列

在极点中读取、扫描和接收有什么不同?

为什么Visual Studio Code说我的代码在使用Pandas concat函数后无法访问?

修改.pdb文件中的值并另存为新的

VSCode Pylance假阳性(?)对ImportError的react

Groupby并在组内比较单独行上的两个时间戳

对当前的鼹鼠进行编码,并且我的按键获得了注册

用LAKEF划分实木地板AWS Wrangler