[我知道这个问题经常出现,但我找不到任何与我的用例匹配的答案]

[编辑:如果你知道 torch ,我知道矩阵乘法,请回答这个问题,这不是这里的问题]

I将由8个浮点组成的N=10个输入馈送到8个输入层

但我明白这个错误

Mat1 and mat2 shapes cannot be multiplied (8x10 and 8x8) 

我错过了什么?

我在数据集中生成随机数据

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        self.input_size = 0
        for i in range(0,10):
            label = random.randint(0, 1)
            data = [random.uniform(0.0, 1.0) for _ in range(8)]
            self.data.append((data , label))
            self.input_size = len(data ) if self.input_size < len(encoded_text) else self.input_size
  
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text, label = self.data[idx]
        return text, label

这是我的模型

class MyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyModel, self).__init__()
        self.input = nn.Linear(input_size,hidden_size)
        self.hidden = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # padding is there as original dataset does not have full 8 floats inputs
        x_padded = pad_sequence(x, batch_first=True, padding_value=0).float()  
        output = self.input(x_padded) >>>>>>>> ERROR
        return torch.sigmoid(output)

def train_model(model, train_loader, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        for inputs, labels in train_loader:
            outputs = model(inputs)

初始化部分

if __name__ == "__main__":
    hidden_size = 8  # hidden size 
    output_size = 1  # binary classification 
    learning_rate = 0.001
    num_epochs = 10
    
    dataset = MyDataset()
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    input_size = dataset.input_size
    
    model = MyModel(input_size, hidden_size, output_size)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    train_model(model, train_loader, criterion, optimizer, num_epochs)`

这是一个非常简单的初学者用例,没什么特别的

谢谢你的帮忙

[编辑:x输入由10个值的8个张量组成,但它应该是8个值的10个张量]

enter image description here

enter image description here

推荐答案

问题是这句话:

x_flattened = x_padded.view(x_padded.size(0), -1) 

我不得不用这个:

x_flattened = x_padded.view(x_padded.size(1), -1) 

Python相关问答推荐

替换字符串中的多个重叠子字符串

如何使用Python将工作表从一个Excel工作簿复制粘贴到另一个工作簿?

在Pandas DataFrame操作中用链接替换'方法的更有效方法

当从Docker的--env-file参数读取Python中的环境变量时,每个\n都会添加一个\'.如何没有额外的?

Python虚拟环境的轻量级使用

Pandas:将多级列名改为一级

无论输入分辨率如何,稳定扩散管道始终输出512 * 512张图像

旋转多边形而不改变内部空间关系

循环浏览每个客户记录,以获取他们来自的第一个/最后一个渠道

在用于Python的Bokeh包中设置按钮的样式

根据Pandas中带条件的两个列的值创建新列

仅使用预先计算的排序获取排序元素

使用SQLAlchemy从多线程Python应用程序在postgr中插入多行的最佳方法是什么?'

每次查询的流通股数量

合并相似列表

Django在一个不是ForeignKey的字段上加入'

Scipy差分进化:如何传递矩阵作为参数进行优化?

替换包含Python DataFrame中的值的<;

为什么在更新Pandas 2.x中的列时,数据类型不会更改,而在Pandas 1.x中会更改?

了解如何让库认识到我具有所需的依赖项