Problem:我有S个序列,每个序列包含T个时间步,每个时间步包含F个特征,所以总的来说,一个

Goal:使用LSTMs对架构进行建模/训练,以学习/实现函数逼近器模型M,并给出序列s,以预测Target_1Target_2

比如:

M(s) ~ (Target_1, Target_2)

我真的很难找到一种方法,下面是Keras实现的一个例子,可能不起作用.我制作了两个模型,一个用于第一个目标值,一个用于第二个目标值.

model1 = Sequential()
model1.add(Masking(mask_value=-10.0))
model1.add(LSTM(1, input_shape=(batch, timesteps, features), return_sequences = True))  
model1.add(Flatten())
model1.add(Dense(hidden_units, activation = "relu"))
model1.add(Dense(1, activation = "linear"))
model1.compile(loss='mse', optimizer=Adam(learning_rate=0.0001))
model1.fit(x_train, y_train[:,0], validation_data=(x_test, y_test[:,0]), epochs=epochs, batch_size=batch, shuffle=False)

model2 = Sequential()
model2.add(Masking(mask_value=-10.0))
model2.add(LSTM(1, input_shape=(batch, timesteps, features), return_sequences=True))
model2.add(Flatten())
model2.add(Dense(hidden_units, activation = "relu"))
model2.add(Dense(1, activation = "linear"))
model2.compile(loss='mse', optimizer=Adam(learning_rate=0.0001))
model2.fit(x_train, y_train[:,1], validation_data=(x_test, y_test[:,1]), epochs=epochs, batch_size=batch, shuffle=False)

我想以某种方式充分利用LSTMs时间相关记忆,以实现良好的回归.

推荐答案

IIUC,您可以使用两个输出层,从一个简单(天真)的方法开始:

import tensorflow as tf

timesteps, features = 20, 5
inputs = tf.keras.layers.Input((timesteps, features))
x = tf.keras.layers.Masking(mask_value=-10.0)(inputs)
x = tf.keras.layers.LSTM(32, return_sequences=False)(x)
x = tf.keras.layers.Dense(32, activation = "relu")(x)
output1 = Dense(1, activation = "linear", name='output1')(x)
output2 = Dense(1, activation = "linear", name='output2')(x)

model = tf.keras.Model(inputs, [output1, output2])
model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001))

x_train = tf.random.normal((500, timesteps, features))
y_train = tf.random.normal((500, 2))
model.fit(x_train, [y_train[:,0],y_train[:,1]] , epochs=5, batch_size=32, shuffle=False)

Python相关问答推荐

跳过包含某些键的字典

用ctype构建指针链

使用子字符串动态更新Python DataFrame中的列

在Python中,如何才能/应该使用decorator 来实现函数多态性?

如何观察cv2.erode()的中间过程?

从 struct 类型创建MultiPolygon对象,并使用Polars列出[list[f64]列

"Discord机器人中缺少所需的位置参数ctx

如何在Python中使用ijson解析SON期间检索文件位置?

将轨迹优化问题描述为NLP.如何用Gekko解决这个问题?当前面临异常:@错误:最大方程长度错误

从webhook中的短代码(而不是电话号码)接收Twilio消息

运行Python脚本时,用作命令行参数的SON文本

有没有一种方法可以从python的pussompy比较结果中提取文本?

cv2.matchTemplate函数匹配失败

为一个组的每个子组绘制,

joblib:无法从父目录的另一个子文件夹加载转储模型

不能使用Gekko方程'

Polars asof在下一个可用日期加入

python中csv. Dictreader. fieldname的类型是什么?'

剪切间隔以添加特定日期

如何获取Python synsets列表的第一个内容?