TL;DR:我为我的equinox.Module
型号创建了一个新实例,并使用OpTax对其进行了调整.一切都运行得很好.当我创建同一模型的new个实例,并try 使用相同的代码、相同的初始值、相同的一切从头开始适应它时,我得到:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
...在税法深处的某个地方.我的代码不比较任何array.错误消息不会显示比较的确切位置.怎么了?
代码
# 1. Import dependencies.
import jax; jax.config.update("jax_enable_x64", True)
import jax.numpy as np, jax.random as rnd, equinox as eqx
import optax
# 2. Define loss function. I'm fairly confident this is correct.
def npdf(x, var):
return np.exp(-0.5 * x**2 / var) / np.sqrt(2 * np.pi * var)
def mixpdf(x, ps, vars):
return ps.dot(npdf(x, vars))
def loss(model, series):
weights, condvars = model(series)
return -jax.vmap(
lambda x, vars: np.log(mixpdf(x, weights, vars))
)(series[1:], condvars[:-1]).mean()
# 3. Define recurrent neural network.
class RNNCell(eqx.Module):
bias: np.ndarray
Wx: np.ndarray
Wh: np.ndarray
def __init__(self, ncomp: int, n_in: int=1, *, key: np.ndarray):
k1, k2, k3 = rnd.split(key, 3)
self.bias = rnd.uniform(k1, (ncomp, ))
self.Wx = rnd.uniform(k2, (ncomp, n_in))
self.Wh = 0.9 * rnd.uniform(k3, (ncomp, ))
def __call__(self, vars_prev, obs):
vars_new = self.bias + self.Wx @ obs + self.Wh * vars_prev
return vars_new, vars_new
class RNN(eqx.Module):
cell: RNNCell
logits: np.ndarray
vars0: np.ndarray = eqx.field(static=True)
def __init__(self, vars0: np.ndarray, n_in=1, *, key: np.ndarray):
self.vars0 = np.array(vars0)
K = len(self.vars0)
self.cell = RNNCell(K, n_in, key=key)
self.logits = np.zeros(K)
def __call__(self, series: np.ndarray):
_, hist = jax.lax.scan(self.cell.__call__, self.vars0, series**2)
return jax.nn.softmax(self.logits), abs(hist)
def condvar(self, series):
weights, variances = self(series)
return variances @ weights
def predict(self, series: np.ndarray):
return self.condvar(series).flatten()[-1]
# 4. Training/fitting code.
def fit(model, logret, nepochs: int, optimizer, loss):
loss_and_grad = eqx.filter_value_and_grad(loss)
@eqx.filter_jit
def make_step(model, opt_state):
loss_val, grads = loss_and_grad(model, logret)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss_val, model, opt_state
opt_state = optimizer.init(model)
for epoch in range(nepochs):
loss_val, model, opt_state = make_step(model, opt_state)
print("Works!")
return model
def experiment():
series = rnd.normal(rnd.PRNGKey(8), (100, 1))
model = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
return fit(model, series, 100, optax.adam(0.01), loss)
# 5. Run the exact same code twice.
experiment() # 1st call, works
experiment() # 2nd call, error
错误讯息
> python my_RNN.py
Works!
Traceback (most recent call last):
File "/Users/forcebru/test/my_RNN.py", line 75, in <module>
experiment() # 2nd call, error
^^^^^^^^^^^^
File "/Users/forcebru/test/my_RNN.py", line 72, in experiment
return fit(model, series, 100, optax.adam(0.01), loss)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/test/my_RNN.py", line 65, in fit
loss_val, model, opt_state = make_step(model, opt_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 206, in __call__
return self._call(False, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_module.py", line 935, in __call__
return self.__func__(self.__self__, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 200, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
infer_params_fn(*args, **kwargs)
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 495, in common_infer_params
jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1150, in _pjit_jaxpr
jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 350, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1089, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2314, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2336, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 49, in fun_wrapped
out = fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/test/my_RNN.py", line 59, in make_step
updates, opt_state = optimizer.update(grads, opt_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/combine.py", line 59, in update_fn
updates, new_s = fn(updates, s, params, **extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/base.py", line 337, in update
return tx.update(updates, state, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/transform.py", line 369, in update_fn
mu_hat = bias_correction(mu, b1, count_inc)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
infer_params_fn(*args, **kwargs)
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 491, in common_infer_params
canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<string>", line 4, in __eq__
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 745, in __bool__
check_bool_conversion(self)
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 662, in check_bool_conversion
raise ValueError("The truth value of an array with more than one element is "
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
问题
- 错误消息显示为
File "<string>", line 4, in __eq__
,这无济于事. - 它指的是操作税代码中的第
mu_hat = bias_correction(mu, b1, count_inc)
行,但据我所知,它是doesn't compare any arrays行. - 它还引用了应该负责JIT编译的JAX代码,但这似乎超出了我的控制范围.
我的模型定义(RNNCell
或RNN
)中有错误吗?我是不是执行了错误的训练循环?我基本上是直接从Equinox docs开始复制的,所以应该没问题.为什么当我第一次拨打experiment()
时,它能起作用,而第二次却不行?