我在这里看到了同样的问题,它帮助我走到了这一步,但我没有得到正确的结果.

我用数据点x和y以及模型ypred=a*x+b进行了线性回归.我需要设置a=10并计算MSE,这很好用.但是,通过将a递减0.1到0,并判断可能的最低MSE,我在遍历代码时遇到了麻烦.我也必须对b重复同样的事情,这是我有点迷茫的.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

data = pd.read_csv('dataset.csv')

#x = [0., 0.05263158, 0.10526316, 0.15789474, 0.21052632,
      #0.26315789, 0.31578947, 0.36842105, 0.42105263, 0.47368421,
      #0.52631579, 0.57894737, 0.63157895, 0.68421053, 0.73684211,
      #0.78947368, 0.84210526, 0.89473684, 0.94736842, 1.]
#y = [0.49671415, 0.01963044, 0.96347801, 1.99671407, 0.39742557,
      #0.55533673, 2.52658124, 1.87269789, 0.79368351, 1.96361268,
      #1.11552968, 1.27111235, 2.13669911, 0.13935133, 0.48560848,
      #1.80613352, 1.51348467, 2.99845786, 1.93408119, 1.5876963]

x = data.x
y = data.y


plt.scatter(data.x, data.y)
plt.show()


a = 10 
b = 0

for y in x:
   ypred = a*x+b

#print(ypred)

ytrue = data.y

MSE = np.square(np.subtract(ytrue,ypred)).mean()

print (MSE)
#21.3
a = 10
ytrue = data.y           
tmp_MSE = np.infty 
tmp_a = a            
for i in range(100):
   ytrue = a-0.1*(i+1)
   MSE = np.square(np.subtract(ypred,ytrue)).mean()
   if MSE < tmp_MSE: 
       tmp_MSE = MSE 
       tmp_a = ytrue

print(tmp_a,tmp_MSE)

没有错误,但我没有得到正确的结果,我哪里错了?

推荐答案

I see you're iterating through all possible combination of a and b to get the minimum MSE.
Here's a possible solution:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

x = [0., 0.05263158, 0.10526316, 0.15789474, 0.21052632,
      0.26315789, 0.31578947, 0.36842105, 0.42105263, 0.47368421,
      0.52631579, 0.57894737, 0.63157895, 0.68421053, 0.73684211,
      0.78947368, 0.84210526, 0.89473684, 0.94736842, 1.]
y = [0.49671415, 0.01963044, 0.96347801, 1.99671407, 0.39742557,
      0.55533673, 2.52658124, 1.87269789, 0.79368351, 1.96361268,
      1.11552968, 1.27111235, 2.13669911, 0.13935133, 0.48560848,
      1.80613352, 1.51348467, 2.99845786, 1.93408119, 1.5876963]

# data = pd.read_csv('dataset.csv')
data = pd.DataFrame({'x': x, 'y': y})
x = data.x
y = data.y

plt.scatter(data.x, data.y)
plt.show()

a = 10 
b = 0
ypred = a*x + b    #this is a series
ytrue = data.y
MSE = np.square(np.subtract(ytrue,ypred)).mean()
print (MSE)
#21.3

ytrue = data.y
min_MSE = np.infty

for a in np.arange(10, 0, -0.1):
    for b in np.arange(10, 0, -0.1):
        ypred = a*x + b    #this is a series
        MSE = np.square(np.subtract(ypred,ytrue)).mean()
        if MSE < min_MSE: 
            min_MSE = MSE 
            min_a = a
            min_b = b

print('min_a =', round(min_a, 3))
print('min_b =', round(min_b, 3))
print('min_MSE =', round(min_MSE, 3))

输出:

enter image description here

21.306499412264095
min_a = 1.1
min_b = 0.8
min_MSE = 0.546

Edit:如果您想要更高级别的精确度,可以运行以下代码:

def find_min(a_range, b_range):
    min_MSE = np.infty
    for a in a_range:
        for b in b_range:
            ypred = a*x + b    #this is a series
            MSE = np.square(np.subtract(ypred,ytrue)).mean()
            if MSE < min_MSE: 
                min_MSE = MSE 
                min_a = a
                min_b = b
    return min_a, min_b, min_MSE
            
min_a, min_b, min_MSE = find_min(np.arange(10, 0, -0.1), np.arange(10, 0, -0.1))
min_a, min_b, min_MSE = find_min(np.arange(min_a+0.1, min_a-0.1, -0.001), np.arange(min_b+0.1, min_b-0.1, -0.001))   
print('min_a =', round(min_a, 3))
print('min_b =', round(min_b, 3))
print('min_MSE =', round(min_MSE, 3))

输出:

min_a = 1.109
min_b = 0.774
min_MSE = 0.546

Python相关问答推荐

Python tkinter关闭第一个窗口,同时打开第二个窗口

使用Python和PRNG(不是梅森龙卷风)有效地生成伪随机浮点数在[0,1)中均匀?

如何使用Selenium访问svg对象内部的元素

当值是一个integer时,在Python中使用JMESPath来验证字典中的值(例如:1)

从今天起的future 12个月内使用Python迭代

从包含数字和单词的文件中读取和获取数据集

Python中MongoDB的BSON时间戳

根据条件将新值添加到下面的行或下面新创建的行中

在Python Attrs包中,如何在field_Transformer函数中添加字段?

用Python解密Java加密文件

Julia CSV for Python中的等效性Pandas index_col参数

无法定位元素错误404

如何在Python脚本中附加一个Google tab(已经打开)

Pandas—合并数据帧,在公共列上保留非空值,在另一列上保留平均值

numpy卷积与有效

在Python中动态计算范围

当递归函数的返回值未绑定到变量时,非局部变量不更新:

提取相关行的最快方法—pandas

如何从需要点击/切换的网页中提取表格?

python—telegraph—bot send_voice发送空文件