我正在使用Flax训练神经网络.我的训练数据在输出中有大量的nan.我想忽略这些,只使用非nan值进行训练.为了实现这一点,我try 使用jnp.nanmean
来计算损失,即:
def nanloss(params, inputs, targets):
pred = model.apply(params, inputs)
return jnp.nanmean((pred - targets) ** 2)
def train_step(state, inputs, targets):
loss, grads = jax.value_and_grad(nanloss)(state.params, inputs, targets)
state = state.apply_gradients(grads=grads)
return state, loss
然而,经过一个训练步骤,损失就没有了.
我正在努力实现的目标是否可能?如果是这样,我该如何解决这个问题?