我正在try 训练一个具有两个带有梯度下降的输出的模型.因此,我的成本函数返回两个错误.处理这个问题的典型方法是什么?

我看到这里和那里提到这个问题,但我还没有想出一个令人满意的解决方案.

这是一个重现我的问题的玩具例子:

from jax import jit, random, grad
import optax


@jit
def my_model(forz, params):
    a, b = params

    a_vect = a + forz**b
    b_vect = b + forz**a

    return a_vect, b_vect*50.


@jit
def rmse(predictions, targets):

    rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
    return rmse


@jit
def my_loss(forz, params, true_a, true_b):

    sim_a, sim_b = my_model(forz, params)

    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)

    return loss_a, loss_b


grad_myloss = jit(grad(my_loss, argnums=1))

# synthetic true data
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))

true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)

# Train
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)

for i in range(1000):

    grads = grad_myloss(forz, model_params, true_a, true_b)  # this fails
    updates, opt_state = optimizer.update(grads, opt_state)
    model_params = optax.apply_updates(model_params, updates)

我明白,要么这两个错误必须以某种方式聚合到一个单一的实现某种归一化的损失(我的输出向量有不可比较的单位),

@jit
def normalized_rmse(predictions, targets):
   std_dev_targets = jnp.std(targets)
   rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
   return rmse/std_dev_targets


@jit
def my_loss_single(forz, params, true_a, true_b):

   sim_a, sim_b = my_model(forz, params)

   loss_a = normalized_rmse(sim_a, true_a)
   loss_b = normalized_rmse(sim_b, true_b)

   return jnp.sqrt((loss_a ** 2) + (loss_b * 2)) 

或者我应该以某种方式使用雅可比矩阵(jacrev)?

推荐答案

像大多数优化框架一样,optax只能优化单值损失函数.你应该决定什么样的单一价值损失对你的特定问题是有意义的.给出个人损失的均方根形式,一个好的 Select 可能是平方和:

@jit
def my_loss(forz, params, true_a, true_b):

    sim_a, sim_b = my_model(forz, params)

    loss_a = rmse(sim_a, true_a)
    loss_b = rmse(sim_b, true_b)

    return loss_a ** 2 + loss_b ** 2

进行此更改后,您的代码执行时不会出现错误.

Python相关问答推荐

大Pandas 胚胎中产生组合

替换字符串中的多个重叠子字符串

标题:如何在Python中使用嵌套饼图可视化分层数据?

2D空间中的反旋算法

Python虚拟环境的轻量级使用

ODE集成中如何终止solve_ivp的无限运行

如果条件不满足,我如何获得掩码的第一个索引并获得None?

如何从数据库上传数据到html?

多处理队列在与Forking http.server一起使用时随机跳过项目

如何在Pyplot表中舍入值

Python Tkinter为特定样式调整所有ttkbootstrap或ttk Button填充的大小,适用于所有主题

OpenGL仅渲染第二个三角形,第一个三角形不可见

为什么在FastAPI中创建与数据库的连接时需要使用生成器?

ModuleNotFoundError:没有模块名为x时try 运行我的代码''

如何在Python请求中组合多个适配器?

语法错误:文档. evaluate:表达式不是合法表达式

在极点中读取、扫描和接收有什么不同?

为什么后跟inplace方法的`.rename(Columns={';b';:';b';},Copy=False)`没有更新原始数据帧?

对数据帧进行分组,并按组间等概率抽样n行

将像素信息写入文件并读取该文件