我有4个通道的时间序列数据,并试图使用我的模型生成长度为N的序列.我从提供给我的序列生成函数的输入数据中确定N:

def generate_sequence(self, input_data):
    predicted_sequence = tf.convert_to_tensor(input_data, dtype=tf.float32)
    data_shape = predicted_sequence.shape
    for i in range(len(predicted_sequence)):
        model_input = tf.reshape(predicted_sequence, shape=data_shape)
        result = self.model(model_input)
        predicted_sequence = tf.concat([predicted_sequence[:, 1:, :], result], 0)
    return predicted_sequence

这会导致以下错误:

ConcatOp : Dimension 1 in both shapes must be equal: shape[0] = [1,1439,4] vs. shape[1] = [1,1,4] [Op:ConcatV2] name: concat

这似乎表明我使用了错误的方法来生成我的序列(我天真地编写了这个函数,假设tensorflow张量的行为类似numpy数组).在我的循环中,我从输入数据开始:

[[[a1, b1, c1, d1]
  [a2, b2, c2, d2]
  ...
  [aN, bN, cN, dN]]

然后我用我的模型生成一个预测

[[aP1, bP1, cP1, dP1]]

这里我的意图是删除输入数据中的第一个条目,因为它是最老的一行数据,并将预测数据添加到最后:

[[[a2, b2, c2, b2]
  [a3, b3, c3, d3]
  ...
  [aN, bN, cN, dN]
  [aP1, bP1, cP1, dP1]]]

从这里运行循环,直到整个序列包含下N行数据的预测.

有没有另一种tensorflow方法更适合这个问题,或者我在tf.concat方法中遗漏了什么?

任何帮助将不胜感激.

推荐答案

一切都是好的,除了连接轴.应该是axis=1.

def generate_sequence(self, input_data):
predicted_sequence = tf.convert_to_tensor(input_data, dtype=tf.float32)
data_shape = predicted_sequence.shape
for i in range(len(predicted_sequence)):
    model_input = tf.reshape(predicted_sequence, shape=data_shape)
    result = self.model(model_input)
    predicted_sequence = tf.concat([predicted_sequence[:, 1:, :], result], axis=1)
return predicted_sequence

经验法则是"all the tensors should possess the same shape in all the axes except the concatenating axis"

Python相关问答推荐

大Pandas 胚胎中产生组合

当多个值具有相同模式时返回空

如何避免Chained when/then分配中的Mypy不兼容类型警告?

运行终端命令时出现问题:pip start anonymous"

为什么这个带有List输入的简单numba函数这么慢

如何在Django基于类的视图中有效地使用UTE和RST HTIP方法?

将tdqm与cx.Oracle查询集成

海上重叠直方图

未知依赖项pin—1阻止conda安装""

如何并行化/加速并行numba代码?

什么是合并两个embrame的最佳方法,其中一个有日期范围,另一个有日期没有任何共享列?

如何从列表框中 Select 而不出错?

如何按row_id/row_number过滤数据帧

如何在Python中使用Iscolc迭代器实现观察者模式?

在Python中控制列表中的数据步长

为用户输入的整数查找根/幂整数对的Python练习

504未连接IB API TWS错误—即使API连接显示已接受''

如何写一个polars birame到DuckDB

比较两个有条件的数据帧并删除所有不合格的数据帧

Polars定制函数返回多列