我正在try 训练一个具有两个带有梯度下降的输出的模型.因此,我的成本函数返回两个错误.处理这个问题的典型方法是什么?
我看到这里和那里提到这个问题,但我还没有想出一个令人满意的解决方案.
这是一个重现我的问题的玩具例子:
from jax import jit, random, grad
import optax
@jit
def my_model(forz, params):
a, b = params
a_vect = a + forz**b
b_vect = b + forz**a
return a_vect, b_vect*50.
@jit
def rmse(predictions, targets):
rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
return rmse
@jit
def my_loss(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = rmse(sim_a, true_a)
loss_b = rmse(sim_b, true_b)
return loss_a, loss_b
grad_myloss = jit(grad(my_loss, argnums=1))
# synthetic true data
key = random.PRNGKey(758493)
forz = random.uniform(key, shape=(1000,))
true_params = [8.9, 6.6]
true_a, true_b = my_model(forz, true_params)
# Train
model_params = random.uniform(key, shape=(2,))
optimizer = optax.adabelief(1e-1)
opt_state = optimizer.init(model_params)
for i in range(1000):
grads = grad_myloss(forz, model_params, true_a, true_b) # this fails
updates, opt_state = optimizer.update(grads, opt_state)
model_params = optax.apply_updates(model_params, updates)
我明白,要么这两个错误必须以某种方式聚合到一个单一的实现某种归一化的损失(我的输出向量有不可比较的单位),
@jit
def normalized_rmse(predictions, targets):
std_dev_targets = jnp.std(targets)
rmse = jnp.sqrt(jnp.mean((predictions - targets) ** 2))
return rmse/std_dev_targets
@jit
def my_loss_single(forz, params, true_a, true_b):
sim_a, sim_b = my_model(forz, params)
loss_a = normalized_rmse(sim_a, true_a)
loss_b = normalized_rmse(sim_b, true_b)
return jnp.sqrt((loss_a ** 2) + (loss_b * 2))
或者我应该以某种方式使用雅可比矩阵(jacrev
)?