我的数据集变得非常大,因此我无法使用典型的OLS方法来计算线性回归估计量,因此我想使用典型的优化器(Adam似乎很适合)

我知道我可以用Keras相当简单地做到这一点,请参阅下面的例子

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.optimizers import Adam

    # Define the model
    def build_model(input_dim):
        model = Sequential()
        # Using a smaller standard deviation for the normal initializer
        model.add(Dense(1, input_dim=input_dim, kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05), activation='linear'))
        # Increased learning rate
        optimizer = Adam(learning_rate=0.1)
        model.compile(loss='mse', optimizer=optimizer, metrics=['mse'])
        return model

    # Example usage:
    X = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]], dtype=float)
    y = np.array([3, 5, 7, 9, 11], dtype=float)

    # Build and train the model
    model = build_model(input_dim=2)
    model.fit(X, y, epochs=1000, verbose=0, batch_size=5)  # Reduced number of epochs and batch size

    # Make predictions
    predictions = model.predict(X)
    print("Predictions:", predictions.flatten())

    # Output the model summary to check the structure
    model.summary()
    model.get_weights()

然而,我的问题是,即使在1000个纪元之后,它仍然没有收敛到明显的1,1权重,大约是1.15 / 0.85

Adam对于这个例子来说不是一个很好的优化者,或者我做错了什么-我记得前段时间玩过Singapore,我记得当时在lineg问题上交谈得非常快.这对我来说有点令人担忧,因为我需要在一个超过1,000,000 x 100的矩阵上运行它,并且在那里运行1000个纪元将永远需要.

推荐答案

问题在于您对训练数据的 Select .您的数据的形式为(x1, x2),但在所有训练示例中为x2 == x1 + 1.所以你实际上只有一个输入,但有两个权重加上一个偏差,从而导致无限多个解决方案.你要学习的功能基本上是2 * x1 + 1.但由于您有两个权重,因此有不同的方法来分割它,例如

  • w=(0.9, 0.1), b=0.1
  • w=(0.7, 0.3), b=0.3
  • 等等,也许你已经可以看到模式了.

由于一种解决方案并不比另一种解决方案"更好",因此它没有理由收敛到"显而易见"的解决方案.可能的修复:

  • 使用更好的 Select 不存在此问题的训练数据.
  • 在密集层中设置use_bias=False,强制偏差为0,在这种情况下,w只有一个解决方案.

如果您想阅读更多信息--您的数据显示Multicolinearity.

Python相关问答推荐

在Arrow上迭代的快速方法.Julia中包含3000万行和25列的表

模型序列化器中未调用现场验证器

跟踪我已从数组中 Select 的样本的最有效方法

使用polars .滤镜进行切片速度比pandas .loc慢

Pandas 填充条件是另一列

如何使用matplotlib在Python中使用规范化数据和原始t测试值创建组合热图?

追溯(最近最后一次调用):文件C:\Users\Diplom/PycharmProject\Yolo01\Roboflow-4.py,第4行,在模块导入roboflow中

Pandas - groupby字符串字段并按时间范围 Select

计算组中唯一值的数量

无法定位元素错误404

如何在Python数据框架中加速序列的符号化

梯度下降:简化要素集的运行时间比原始要素集长

如何从需要点击/切换的网页中提取表格?

在嵌套span下的span中擦除信息

并行编程:同步进程

Pandas—堆栈多索引头,但不包括第一列

从嵌套极轴列的列表中删除元素

使用np.fft.fft2和cv2.dft重现相位谱.为什么结果并不相似呢?

替换包含Python DataFrame中的值的<;

如何在Pandas中用迭代器求一个序列的平均值?