我从d2 l网站学习pytorch并编写一个简单的线性回归模型.这是我的优化器:

class SGD():
    def __init__(self, params, lr):
        self.params = params
        self.lr = lr
    def step(self):
        for param in self.params:
            param -= self.lr * param.grad
    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.zero_()

然而,它无法在训练期间更新模型参数.我不知道出了什么问题.

值得注意的是,模型参数可以通过内置优化器更新:optim.SGD(model.parameters(), lr=self.learning_rate).因此,我怀疑问题出在我天真的Singapore实现上.

以下是一个可复制的示例:

import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset

import warnings
warnings.filterwarnings("ignore")


class SyntheticRegressionData():
    """synthetic tensor dataset for linear regression from S02"""
    def __init__(self, w, b, noise=0.01, num_trains=1000, num_vals=1000, batch_size=32):
        self.w = w
        self.b = b
        self.noise = noise
        self.num_trains = num_trains
        self.num_vals = num_vals
        self.batch_size = batch_size
        n = num_trains + num_vals
        self.X = torch.randn(n, len(w))
        self.y = torch.matmul(self.X, w.reshape(-1, 1)) + b + noise * torch.randn(n, 1)
    def get_tensorloader(self, tensors, train, indices=slice(0, None)):
        tensors = tuple(a[indices] for a in tensors)
        dataset = TensorDataset(*tensors)
        return DataLoader(dataset, self.batch_size, shuffle=train)
    def get_dataloader(self, train=True):
        indices = slice(0, self.num_trains) if train else slice(self.num_trains, None)
        return self.get_tensorloader((self.X, self.y), train, indices)
    def train_dataloader(self):
        return self.get_dataloader(train=True)
    def val_dataloader(self):
        return self.get_dataloader(train=False)
    

class LinearNetwork(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_features, out_features))
        self.bias = nn.Parameter(torch.randn(out_features))
    def forward(self, x):
        return torch.matmul(x, self.weight) + self.bias


class SGD():
    def __init__(self, params, lr):
        self.params = params
        self.lr = lr
    def step(self):
        for param in self.params:
            param -= self.lr * param.grad
    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.zero_()


class MyTrainer():
    """
    custom trainer for linear regression
    """
    def __init__(self, max_epochs=10, learning_rate=1e-3):
        self.max_epochs = max_epochs
        self.learning_rate = learning_rate
    def fit(self, model, train_dataloader, val_dataloader=None):
        self.model = model
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        # self.optim = MySGD([self.model.weight, self.model.bias], lr=self.learning_rate)
        self.optim = SGD(self.model.parameters(), lr=self.learning_rate)
        self.loss = nn.MSELoss()
        self.num_train_batches = len(train_dataloader)
        self.num_val_batches = len(val_dataloader) if val_dataloader is not None else 0

        self.epoch = 0
        for epoch in range(self.max_epochs):
            self.fit_epoch()
    def fit_epoch(self):
        # train
        self.model.train()
        avg_loss = 0
        for x, y in self.train_dataloader:
            self.optim.zero_grad()
            y_hat = self.model(x)
            loss = self.loss(y_hat, y)
            loss.backward()
            self.optim.step()
            avg_loss += loss.item()
        avg_loss /= self.num_train_batches
        print(f'epoch {self.epoch}: train_loss={avg_loss:>8f}')
        # test
        if self.val_dataloader is not None:
            self.model.eval()
            val_loss = 0
            with torch.no_grad():
                for x, y in self.val_dataloader:
                    y_hat = self.model(x)
                    loss = self.loss(y_hat, y)
                    val_loss += loss.item()
            val_loss /= self.num_val_batches
            print(f'epoch {self.epoch}: val_loss={val_loss:>8f}')
        self.epoch += 1


torch.manual_seed(2024) 

trainer = MyTrainer(max_epochs=10, learning_rate=0.01)
model = LinearNetwork(2, 1)

torch.manual_seed(2024)
w = torch.tensor([2., -3.])
b = torch.Tensor([1.])
noise = 0.01
num_trains = 1000
num_vals = 1000
batch_size = 64
data = SyntheticRegressionData(w, b, noise, num_trains, num_vals, batch_size)
train_data = data.train_dataloader()
val_data = data.val_dataloader()

trainer.fit(model, train_data, val_data)

以下是输出:

epoch 0: train_loss=29.762345
epoch 0: val_loss=29.574341
epoch 1: train_loss=29.547140
epoch 1: val_loss=29.574341
epoch 2: train_loss=29.559777
epoch 2: val_loss=29.574341
epoch 3: train_loss=29.340937
epoch 3: val_loss=29.574341
epoch 4: train_loss=29.371171
epoch 4: val_loss=29.574341
epoch 5: train_loss=29.649407
epoch 5: val_loss=29.574341
epoch 6: train_loss=29.717251
epoch 6: val_loss=29.574341
epoch 7: train_loss=29.545675
epoch 7: val_loss=29.574341
epoch 8: train_loss=29.456314
epoch 8: val_loss=29.574341
epoch 9: train_loss=29.537769
epoch 9: val_loss=29.574341

推荐答案

问题解决了.model.parameters()的返回是一个迭代器/生成器,当我第一次调用我的自定义新元的zero_grad方法时,它在迭代中耗尽了.因此,后来拨打stepzero_grad并没有任何效果.更正版本如下:

class SGD():
    def __init__(self, params, lr):
        self.params = list(params)  # convert to list here
        self.lr = lr
    def step(self):
        for param in self.params:
            param.data -= self.lr * param.grad   # update .data
    def zero_grad(self):
        for param in self.params:
            if param.grad is not None:
                param.grad.zero_()

Python相关问答推荐

tempfile.mkstemp(text=.)参数实际上是什么?

从收件箱获取特定列中的重复行

在Python中添加期货之间的延迟

Django序列化器没有验证或保存数据

使用Python进行网页抓取,没有页面

如何修复使用turtle和tkinter制作的绘画应用程序的撤销功能

如何让我的Tkinter应用程序适合整个窗口,无论大小如何?

优化在numpy数组中非零值周围创建缓冲区的函数的性能

分组数据并删除重复数据

如何使用scipy从频谱图中回归多个高斯峰?

如何使用symy打印方程?

Odoo 14 hr. emergency.public内的二进制字段

ModuleNotFound错误:没有名为Crypto Windows 11、Python 3.11.6的模块

Python虚拟环境的轻量级使用

pyscript中的压痕问题

所有列的滚动标准差,忽略NaN

当点击tkinter菜单而不是菜单选项时,如何执行命令?

为什么\b在这个正则表达式中不解释为反斜杠

lityter不让我输入左边的方括号,'

matplotlib + python foor loop