我正在使用Flax训练神经网络.我的训练数据在输出中有大量的nan.我想忽略这些,只使用非nan值进行训练.为了实现这一点,我try 使用jnp.nanmean来计算损失,即:

def nanloss(params, inputs, targets):
    pred = model.apply(params, inputs)
    return jnp.nanmean((pred - targets) ** 2)

def train_step(state, inputs, targets):
    loss, grads = jax.value_and_grad(nanloss)(state.params, inputs, targets)
    state = state.apply_gradients(grads=grads)
    return state, loss

然而,经过一个训练步骤,损失就没有了.

我正在努力实现的目标是否可能?如果是这样,我该如何解决这个问题?

推荐答案

我怀疑您正在解决这里讨论的问题:JAX FAQ: gradients contain NaN where using where.您已经在计算本身中处理了NaN,但由于autodiff的实现方式,它们仍然潜入梯度中.

如果这确实是问题所在,您可以通过在计算损失之前过滤值来解决这个问题;例如:

def nanloss(params, inputs, targets):
    pred = model.apply(params, inputs)
    mask = jnp.isnan(pred) | jnp.isnan(targets)
    pred = jnp.where(mask, 0, pred)
    targets = jnp.where(mask, 0, targets)
    return jnp.mean((pred - targets) ** 2, where=~mask)

Python相关问答推荐

如何在Python中使用io.BytesIO写入现有缓冲区?

使用Keras的线性回归参数估计

更改matplotlib彩色条的字体并勾选标签?

try 与gemini-pro进行多轮聊天时出错

如何让Flask 中的请求标签发挥作用

为什么sys.exit()不能与subproccess.run()或subprocess.call()一起使用

在Mac上安装ipython

所有列的滚动标准差,忽略NaN

Python全局变量递归得到不同的结果

幂集,其中每个元素可以是正或负""""

Python—为什么我的代码返回一个TypeError

如何在Great Table中处理inf和nans

使用__json__的 pyramid 在客户端返回意外格式

解决Geopandas和Altair中的正图和投影问题

Pandas在rame中在组内洗牌行,保持相对组的顺序不变,

Python:从目录内的文件导入目录

Pandas:计数器的滚动和,复位

Regex用于匹配Python中逗号分隔的AWS区域

.awk文件可以使用子进程执行吗?

无法使用请求模块从网页上抓取一些产品的名称