当样本大小设置为10时,直到收敛的平均迭代次数应该在15左右.然而,当在我的代码中实现该算法时,大约需要225次(或更多)迭代才能达到收敛.这使我怀疑代码中的While循环可能有问题,但我无法识别它.

def gen_data(N=10):
    size = (N, 2)
    data = np.random.uniform(-1, 1, size)
    point1, point2 = data[np.random.choice(data.shape[0], 2, replace=False), :]
    m = (point2[1] - point1[1]) / (point2[0] - point1[0])
    c = point1[1] - m * point1[0]
    labels = np.array([+1 if y >= m * x + c else -1 for x, y in data])
    data = np.column_stack((data, labels))
    return data, point1, point2


class PLA:
    def __init__(self, data):
        m, n = data.shape
        self.X = np.hstack((np.ones((m, 1)), data[:, :2]))
        self.w = np.zeros(n)
        self.y = data[:, -1]
        self.count = 0

    def fit(self):
        while True:
            self.count += 1
            y_pred = self.predict(self.X)
            misclassified = np.where(y_pred != self.y)[0]
            if len(misclassified) == 0:
                break

            idx = np.random.choice(misclassified)
            self.update_weight(idx)

    def update_weight(self, idx):
        self.w +=  self.y[idx] * self.X[idx]

    def sign(self, z):
        return np.where(z > 0, 1, np.where(z < 0, -1, 0))

    def predict(self, x):
        z = np.dot(x, self.w)
        return self.sign(z)

推荐答案

问题不在于您的右循环,而在于您的数据生成函数.

你从N个随机点中 Select 两个点来定义你的决策线:

point1, point2 = data[np.random.choice(data.shape[0], 2, replace=False), :]

然而,它们会留在您的数据集中,因此它们被标记为1,并且恰好在您的决策线上.

如果您 Select 了两个不在您的数据集中的随机点,那么该算法应该会在与我测试的结果大致相同的10个步骤中收敛(只需采样N + 2个点,然后 Select 前两个来定义您的决策线,其他两个来定义您的数据集).

So why is this small difference slowing that much the number of steps needed to converge ?

我要说的是,由于数据集中的两个点在决策线上,学习零误差模型可能是最困难的,特别是如果其他点离它很近,因为一次模型更新可能会导致仍然不完美的模型.

Easy case

Hard case

Is it relevent to define your decision line such that no point in your dataset is on it ?

我会说是的,因为域空间是连续的.

Python相关问答推荐

DuckDB将蜂巢分区插入拼花文件

如何防止Plotly在输出到PDF时减少行中的点数?

Python 3.12中的通用[T]类方法隐式类型检索

我在使用fill_between()将最大和最小带应用到我的图表中时遇到问题

Pandas实际上如何对基于自定义的索引(integer和非integer)执行索引

Pystata:从Python并行运行stata实例

Deliveryter Notebook -无法在for循环中更新matplotlib情节(保留之前的情节),也无法使用动画子功能对情节进行动画

数据抓取失败:寻求帮助

无法定位元素错误404

如何请求使用Python将文件下载到带有登录名的门户网站?

如何在Python数据框架中加速序列的符号化

连接一个rabrame和另一个1d rabrame不是问题,但当使用[...]'运算符会产生不同的结果

与命令行相比,相同的Python代码在Companyter Notebook中运行速度慢20倍

dask无groupby(ddf. agg([min,max])?''''

递归函数修饰器

如何将相同组的值添加到嵌套的Pandas Maprame的倒数第二个索引级别

来自Airflow Connection的额外参数

如何在SQLAlchemy + Alembic中定义一个"Index()",在基表中的列上

随机森林n_估计器的计算

BeatuifulSoup从欧洲志愿者服务中获取数据和解析:一个从EU-Site收集机会的小铲子